buchi-stdesign's picture
Upload 18 files
1ee91f8 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-5):
super().__init__()
self.ln = nn.LayerNorm(channels, eps=eps)
def forward(self, x):
x = x.transpose(1, 2)
x = self.ln(x)
x = x.transpose(1, 2)
return x
class ConvReluNorm(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, bias):
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.bias = bias
convs = []
convs.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2, bias=bias))
convs.append(LayerNorm(hidden_channels))
convs.append(nn.ReLU())
for _ in range(n_layers - 2):
convs.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2, bias=bias))
convs.append(LayerNorm(hidden_channels))
convs.append(nn.ReLU())
convs.append(nn.Conv1d(hidden_channels, out_channels, kernel_size, padding=kernel_size//2, bias=bias))
self.main = nn.Sequential(*convs)
def forward(self, x):
return self.main(x)