import torch import torch.nn as nn import torch.nn.functional as F class LayerNorm(nn.Module): def __init__(self, channels, eps=1e-5): super().__init__() self.ln = nn.LayerNorm(channels, eps=eps) def forward(self, x): x = x.transpose(1, 2) x = self.ln(x) x = x.transpose(1, 2) return x class ConvReluNorm(nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, bias): super().__init__() self.in_channels = in_channels self.hidden_channels = hidden_channels self.out_channels = out_channels self.kernel_size = kernel_size self.n_layers = n_layers self.bias = bias convs = [] convs.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2, bias=bias)) convs.append(LayerNorm(hidden_channels)) convs.append(nn.ReLU()) for _ in range(n_layers - 2): convs.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2, bias=bias)) convs.append(LayerNorm(hidden_channels)) convs.append(nn.ReLU()) convs.append(nn.Conv1d(hidden_channels, out_channels, kernel_size, padding=kernel_size//2, bias=bias)) self.main = nn.Sequential(*convs) def forward(self, x): return self.main(x)