File size: 7,088 Bytes
093b0a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


# --------------------------
# Utilities / Norm / Activations
# --------------------------

class RMSNorm(nn.Module):
    """RMSNorm with learnable weight. Drop-in for LayerNorm when using Pre-Norm."""
    def __init__(self, d_model, eps=1e-8):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))

    def forward(self, x):
        # x: (..., D)
        norm = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt()
        return self.weight * x * norm


def get_activation(name: str):
    name = (name or "relu").lower()
    if name == "relu":
        return nn.ReLU()
    if name in ("gelu", "geglu"):
        return nn.GELU()
    if name in ("silu", "swish"):
        return nn.SiLU()
    return nn.ReLU()


class SwiGLU(nn.Module):
    """SwiGLU FFN: proj( SiLU(a) * b ), a,b from linear split."""
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.w12 = nn.Linear(d_model, 2 * d_ff, bias=True)
        self.proj = nn.Linear(d_ff, d_model, bias=True)

    def forward(self, x):
        a, b = self.w12(x).chunk(2, dim=-1)
        return self.proj(F.silu(a) * b)


# --------------------------
# Conv Layer (kept signature)
# --------------------------

class ConvLayer(nn.Module):
    def __init__(self, c_in):
        super(ConvLayer, self).__init__()
        padding = 1 if torch.__version__ >= "1.5.0" else 2
        self.downConv = nn.Conv1d(
            in_channels=c_in,
            out_channels=c_in,
            kernel_size=3,
            padding=padding,
            padding_mode="circular",
        )
        self.norm = nn.BatchNorm1d(c_in)
        self.activation = nn.ELU()
        self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)

    def forward(self, x):
        # x: [B, L, D]
        x = x.permute(0, 2, 1)                       # B, D, L
        y = self.downConv(x)
        y = self.norm(y)
        y = self.activation(y)
        y = self.maxPool(y)
        y = y.transpose(1, 2).contiguous()           # B, L', D
        return y


# --------------------------
# Encoder Layer (kept signature)
# --------------------------

class EncoderLayer(nn.Module):
    """
    Keep the same signature:
      __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu", ln_mode="pre", conv_layer=False)
      forward(self, x, attn_mask=None)
    Internals:
      - Pre-Norm by default (ln_mode="pre")
      - RMSNorm (instead of LN) but callable doesn’t change
      - Residual scaling 1/sqrt(2)
      - FFN uses SwiGLU, dropout after branch
      - Attention module is expected to have signature (q, k, v, attn_mask=None) -> (new_x, attn)
        and internally do q *= 1/sqrt(d_head)
    """
    def __init__(
        self,
        attention,
        d_model,
        d_ff=None,
        dropout=0.1,
        activation="relu",
        ln_mode="pre",
        conv_layer=False,
    ):
        super(EncoderLayer, self).__init__()
        self.attention = attention
        self.conv_layer = ConvLayer(d_model) if conv_layer else None
        self.dropout = nn.Dropout(dropout)
        self.activation = get_activation(activation)
        self.ln_mode = ln_mode  # will honor "pre" / "post" without changing signature

        # Core hyperparams
        self.d_model = d_model
        self.d_ff = d_ff or 4 * d_model
        self.res_scale = 1.0 / math.sqrt(2.0)

        # Use RMSNorm but keep object names norm1/norm2 to avoid API change elsewhere
        self.norm1 = RMSNorm(d_model)
        self.norm2 = RMSNorm(d_model)

        # FFN: use SwiGLU for better stability/accuracy
        self.ff = SwiGLU(d_model, self.d_ff)

        # In case some pipeline expects LayerNorm instance, we also keep a post-norm if ln_mode="post"
        # (But the actual normalization used is RMSNorm above; this is just to respect the mode)
        if self.ln_mode == "post":
            self.post_ln1 = nn.LayerNorm(d_model)
            self.post_ln2 = nn.LayerNorm(d_model)

    def forward(self, x, attn_mask=None):
        # x: [B, L, D]

        if self.conv_layer is not None:
            x = x + self.dropout(self.conv_layer(x)) * self.res_scale

        if self.ln_mode == "post":
            # -------- Post-LN path (kept behavior but more explicit/clean) --------
            new_x, attn = self.attention(x, x, x, attn_mask=attn_mask)
            x = x + self.dropout(new_x) * self.res_scale
            x = self.post_ln1(x)

            y = self.ff(x)
            x = x + self.dropout(y) * self.res_scale
            x = self.post_ln2(x)
            return x, attn

        # -------- Default: Pre-LN (recommended) --------
        # Attention branch (Pre-Norm)
        h, attn = self.attention(self.norm1(x), self.norm1(x), self.norm1(x), attn_mask=attn_mask)
        x = x + self.dropout(h) * self.res_scale

        # FFN branch (Pre-Norm)
        y = self.ff(self.norm2(x))
        x = x + self.dropout(y) * self.res_scale

        return x, attn


# --------------------------
# Encoder (kept signature)
# --------------------------

class Encoder(nn.Module):
    """
    Keep the same signature:
      __init__(self, attn_layers, conv_layers=None, norm_layer=None)
      forward(self, x, attn_mask=None)
    """
    def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
        super(Encoder, self).__init__()
        self.attn_layers = nn.ModuleList(attn_layers)
        self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
        self.norm = norm_layer  # can be None or nn.Module

    def forward(self, x, attn_mask=None):
        # x: [B, L, D]
        attns = []
        for i, attn_layer in enumerate(self.attn_layers):
            x, attn = attn_layer(x, attn_mask=attn_mask)
            attns.append(attn)
            if self.conv_layers is not None and i < len(self.conv_layers):
                x = self.conv_layers[i](x)

        if self.norm is not None:
            x = self.norm(x)
        return x, attns


# --------------------------
# Encoder Stack (kept signature)
# --------------------------

class EncoderStack(nn.Module):
    """
    Keep the same signature:
      __init__(self, encoders, inp_lens, d_model)
      forward(self, x, attn_mask=None)
    """
    def __init__(self, encoders, inp_lens, d_model):
        super(EncoderStack, self).__init__()
        self.encoders = nn.ModuleList(encoders)
        self.inp_lens = inp_lens
        self.d_model = d_model

    def forward(self, x, attn_mask=None):
        # x: [B, L, D]
        x_stack = []
        attns = []
        # For each pyramid level, take the tail part of the sequence
        for i_len, encoder in zip(self.inp_lens, self.encoders):
            inp_len = x.shape[1] // (2 ** i_len)
            x_s, attn = encoder(x[:, -inp_len:, :], attn_mask=attn_mask)
            x_stack.append(x_s)
            attns.append(attn)
        x_stack = torch.cat(x_stack, dim=-2)  # concat on sequence length axis
        return x_stack, attns