import math import torch import torch.nn as nn import torch.nn.functional as F # -------------------------- # Utilities / Norm / Activations # -------------------------- class RMSNorm(nn.Module): """RMSNorm with learnable weight. Drop-in for LayerNorm when using Pre-Norm.""" def __init__(self, d_model, eps=1e-8): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(d_model)) def forward(self, x): # x: (..., D) norm = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt() return self.weight * x * norm def get_activation(name: str): name = (name or "relu").lower() if name == "relu": return nn.ReLU() if name in ("gelu", "geglu"): return nn.GELU() if name in ("silu", "swish"): return nn.SiLU() return nn.ReLU() class SwiGLU(nn.Module): """SwiGLU FFN: proj( SiLU(a) * b ), a,b from linear split.""" def __init__(self, d_model, d_ff): super().__init__() self.w12 = nn.Linear(d_model, 2 * d_ff, bias=True) self.proj = nn.Linear(d_ff, d_model, bias=True) def forward(self, x): a, b = self.w12(x).chunk(2, dim=-1) return self.proj(F.silu(a) * b) # -------------------------- # Conv Layer (kept signature) # -------------------------- class ConvLayer(nn.Module): def __init__(self, c_in): super(ConvLayer, self).__init__() padding = 1 if torch.__version__ >= "1.5.0" else 2 self.downConv = nn.Conv1d( in_channels=c_in, out_channels=c_in, kernel_size=3, padding=padding, padding_mode="circular", ) self.norm = nn.BatchNorm1d(c_in) self.activation = nn.ELU() self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) def forward(self, x): # x: [B, L, D] x = x.permute(0, 2, 1) # B, D, L y = self.downConv(x) y = self.norm(y) y = self.activation(y) y = self.maxPool(y) y = y.transpose(1, 2).contiguous() # B, L', D return y # -------------------------- # Encoder Layer (kept signature) # -------------------------- class EncoderLayer(nn.Module): """ Keep the same signature: __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu", ln_mode="pre", conv_layer=False) forward(self, x, attn_mask=None) Internals: - Pre-Norm by default (ln_mode="pre") - RMSNorm (instead of LN) but callable doesn’t change - Residual scaling 1/sqrt(2) - FFN uses SwiGLU, dropout after branch - Attention module is expected to have signature (q, k, v, attn_mask=None) -> (new_x, attn) and internally do q *= 1/sqrt(d_head) """ def __init__( self, attention, d_model, d_ff=None, dropout=0.1, activation="relu", ln_mode="pre", conv_layer=False, ): super(EncoderLayer, self).__init__() self.attention = attention self.conv_layer = ConvLayer(d_model) if conv_layer else None self.dropout = nn.Dropout(dropout) self.activation = get_activation(activation) self.ln_mode = ln_mode # will honor "pre" / "post" without changing signature # Core hyperparams self.d_model = d_model self.d_ff = d_ff or 4 * d_model self.res_scale = 1.0 / math.sqrt(2.0) # Use RMSNorm but keep object names norm1/norm2 to avoid API change elsewhere self.norm1 = RMSNorm(d_model) self.norm2 = RMSNorm(d_model) # FFN: use SwiGLU for better stability/accuracy self.ff = SwiGLU(d_model, self.d_ff) # In case some pipeline expects LayerNorm instance, we also keep a post-norm if ln_mode="post" # (But the actual normalization used is RMSNorm above; this is just to respect the mode) if self.ln_mode == "post": self.post_ln1 = nn.LayerNorm(d_model) self.post_ln2 = nn.LayerNorm(d_model) def forward(self, x, attn_mask=None): # x: [B, L, D] if self.conv_layer is not None: x = x + self.dropout(self.conv_layer(x)) * self.res_scale if self.ln_mode == "post": # -------- Post-LN path (kept behavior but more explicit/clean) -------- new_x, attn = self.attention(x, x, x, attn_mask=attn_mask) x = x + self.dropout(new_x) * self.res_scale x = self.post_ln1(x) y = self.ff(x) x = x + self.dropout(y) * self.res_scale x = self.post_ln2(x) return x, attn # -------- Default: Pre-LN (recommended) -------- # Attention branch (Pre-Norm) h, attn = self.attention(self.norm1(x), self.norm1(x), self.norm1(x), attn_mask=attn_mask) x = x + self.dropout(h) * self.res_scale # FFN branch (Pre-Norm) y = self.ff(self.norm2(x)) x = x + self.dropout(y) * self.res_scale return x, attn # -------------------------- # Encoder (kept signature) # -------------------------- class Encoder(nn.Module): """ Keep the same signature: __init__(self, attn_layers, conv_layers=None, norm_layer=None) forward(self, x, attn_mask=None) """ def __init__(self, attn_layers, conv_layers=None, norm_layer=None): super(Encoder, self).__init__() self.attn_layers = nn.ModuleList(attn_layers) self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None self.norm = norm_layer # can be None or nn.Module def forward(self, x, attn_mask=None): # x: [B, L, D] attns = [] for i, attn_layer in enumerate(self.attn_layers): x, attn = attn_layer(x, attn_mask=attn_mask) attns.append(attn) if self.conv_layers is not None and i < len(self.conv_layers): x = self.conv_layers[i](x) if self.norm is not None: x = self.norm(x) return x, attns # -------------------------- # Encoder Stack (kept signature) # -------------------------- class EncoderStack(nn.Module): """ Keep the same signature: __init__(self, encoders, inp_lens, d_model) forward(self, x, attn_mask=None) """ def __init__(self, encoders, inp_lens, d_model): super(EncoderStack, self).__init__() self.encoders = nn.ModuleList(encoders) self.inp_lens = inp_lens self.d_model = d_model def forward(self, x, attn_mask=None): # x: [B, L, D] x_stack = [] attns = [] # For each pyramid level, take the tail part of the sequence for i_len, encoder in zip(self.inp_lens, self.encoders): inp_len = x.shape[1] // (2 ** i_len) x_s, attn = encoder(x[:, -inp_len:, :], attn_mask=attn_mask) x_stack.append(x_s) attns.append(attn) x_stack = torch.cat(x_stack, dim=-2) # concat on sequence length axis return x_stack, attns