| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| |
| |
| |
|
|
| class RMSNorm(nn.Module): |
| """RMSNorm with learnable weight. Drop-in for LayerNorm when using Pre-Norm.""" |
| def __init__(self, d_model, eps=1e-8): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(d_model)) |
|
|
| def forward(self, x): |
| |
| norm = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt() |
| return self.weight * x * norm |
|
|
|
|
| def get_activation(name: str): |
| name = (name or "relu").lower() |
| if name == "relu": |
| return nn.ReLU() |
| if name in ("gelu", "geglu"): |
| return nn.GELU() |
| if name in ("silu", "swish"): |
| return nn.SiLU() |
| return nn.ReLU() |
|
|
|
|
| class SwiGLU(nn.Module): |
| """SwiGLU FFN: proj( SiLU(a) * b ), a,b from linear split.""" |
| def __init__(self, d_model, d_ff): |
| super().__init__() |
| self.w12 = nn.Linear(d_model, 2 * d_ff, bias=True) |
| self.proj = nn.Linear(d_ff, d_model, bias=True) |
|
|
| def forward(self, x): |
| a, b = self.w12(x).chunk(2, dim=-1) |
| return self.proj(F.silu(a) * b) |
|
|
|
|
| |
| |
| |
|
|
| class ConvLayer(nn.Module): |
| def __init__(self, c_in): |
| super(ConvLayer, self).__init__() |
| padding = 1 if torch.__version__ >= "1.5.0" else 2 |
| self.downConv = nn.Conv1d( |
| in_channels=c_in, |
| out_channels=c_in, |
| kernel_size=3, |
| padding=padding, |
| padding_mode="circular", |
| ) |
| self.norm = nn.BatchNorm1d(c_in) |
| self.activation = nn.ELU() |
| self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) |
|
|
| def forward(self, x): |
| |
| x = x.permute(0, 2, 1) |
| y = self.downConv(x) |
| y = self.norm(y) |
| y = self.activation(y) |
| y = self.maxPool(y) |
| y = y.transpose(1, 2).contiguous() |
| return y |
|
|
|
|
| |
| |
| |
|
|
| class EncoderLayer(nn.Module): |
| """ |
| Keep the same signature: |
| __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu", ln_mode="pre", conv_layer=False) |
| forward(self, x, attn_mask=None) |
| Internals: |
| - Pre-Norm by default (ln_mode="pre") |
| - RMSNorm (instead of LN) but callable doesn’t change |
| - Residual scaling 1/sqrt(2) |
| - FFN uses SwiGLU, dropout after branch |
| - Attention module is expected to have signature (q, k, v, attn_mask=None) -> (new_x, attn) |
| and internally do q *= 1/sqrt(d_head) |
| """ |
| def __init__( |
| self, |
| attention, |
| d_model, |
| d_ff=None, |
| dropout=0.1, |
| activation="relu", |
| ln_mode="pre", |
| conv_layer=False, |
| ): |
| super(EncoderLayer, self).__init__() |
| self.attention = attention |
| self.conv_layer = ConvLayer(d_model) if conv_layer else None |
| self.dropout = nn.Dropout(dropout) |
| self.activation = get_activation(activation) |
| self.ln_mode = ln_mode |
|
|
| |
| self.d_model = d_model |
| self.d_ff = d_ff or 4 * d_model |
| self.res_scale = 1.0 / math.sqrt(2.0) |
|
|
| |
| self.norm1 = RMSNorm(d_model) |
| self.norm2 = RMSNorm(d_model) |
|
|
| |
| self.ff = SwiGLU(d_model, self.d_ff) |
|
|
| |
| |
| if self.ln_mode == "post": |
| self.post_ln1 = nn.LayerNorm(d_model) |
| self.post_ln2 = nn.LayerNorm(d_model) |
|
|
| def forward(self, x, attn_mask=None): |
| |
|
|
| if self.conv_layer is not None: |
| x = x + self.dropout(self.conv_layer(x)) * self.res_scale |
|
|
| if self.ln_mode == "post": |
| |
| new_x, attn = self.attention(x, x, x, attn_mask=attn_mask) |
| x = x + self.dropout(new_x) * self.res_scale |
| x = self.post_ln1(x) |
|
|
| y = self.ff(x) |
| x = x + self.dropout(y) * self.res_scale |
| x = self.post_ln2(x) |
| return x, attn |
|
|
| |
| |
| h, attn = self.attention(self.norm1(x), self.norm1(x), self.norm1(x), attn_mask=attn_mask) |
| x = x + self.dropout(h) * self.res_scale |
|
|
| |
| y = self.ff(self.norm2(x)) |
| x = x + self.dropout(y) * self.res_scale |
|
|
| return x, attn |
|
|
|
|
| |
| |
| |
|
|
| class Encoder(nn.Module): |
| """ |
| Keep the same signature: |
| __init__(self, attn_layers, conv_layers=None, norm_layer=None) |
| forward(self, x, attn_mask=None) |
| """ |
| def __init__(self, attn_layers, conv_layers=None, norm_layer=None): |
| super(Encoder, self).__init__() |
| self.attn_layers = nn.ModuleList(attn_layers) |
| self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None |
| self.norm = norm_layer |
|
|
| def forward(self, x, attn_mask=None): |
| |
| attns = [] |
| for i, attn_layer in enumerate(self.attn_layers): |
| x, attn = attn_layer(x, attn_mask=attn_mask) |
| attns.append(attn) |
| if self.conv_layers is not None and i < len(self.conv_layers): |
| x = self.conv_layers[i](x) |
|
|
| if self.norm is not None: |
| x = self.norm(x) |
| return x, attns |
|
|
|
|
| |
| |
| |
|
|
| class EncoderStack(nn.Module): |
| """ |
| Keep the same signature: |
| __init__(self, encoders, inp_lens, d_model) |
| forward(self, x, attn_mask=None) |
| """ |
| def __init__(self, encoders, inp_lens, d_model): |
| super(EncoderStack, self).__init__() |
| self.encoders = nn.ModuleList(encoders) |
| self.inp_lens = inp_lens |
| self.d_model = d_model |
|
|
| def forward(self, x, attn_mask=None): |
| |
| x_stack = [] |
| attns = [] |
| |
| for i_len, encoder in zip(self.inp_lens, self.encoders): |
| inp_len = x.shape[1] // (2 ** i_len) |
| x_s, attn = encoder(x[:, -inp_len:, :], attn_mask=attn_mask) |
| x_stack.append(x_s) |
| attns.append(attn) |
| x_stack = torch.cat(x_stack, dim=-2) |
| return x_stack, attns |
|
|
|
|