Spaces:
Runtime error
Runtime error
| from torch import nn as nn | |
| from torch.nn import functional as F | |
| class ConvLayer(nn.Module): | |
| def __init__(self, n_inputs, n_outputs, kernel_size, stride, conv_type, transpose=False): | |
| super(ConvLayer, self).__init__() | |
| self.transpose = transpose | |
| self.stride = stride | |
| self.kernel_size = kernel_size | |
| self.conv_type = conv_type | |
| # How many channels should be normalised as one group if GroupNorm is activated | |
| # WARNING: Number of channels has to be divisible by this number! | |
| NORM_CHANNELS = 8 | |
| if self.transpose: | |
| self.filter = nn.ConvTranspose1d(n_inputs, n_outputs, self.kernel_size, stride, padding=kernel_size-1) | |
| else: | |
| self.filter = nn.Conv1d(n_inputs, n_outputs, self.kernel_size, stride) | |
| if conv_type == "gn": | |
| assert(n_outputs % NORM_CHANNELS == 0) | |
| self.norm = nn.GroupNorm(n_outputs // NORM_CHANNELS, n_outputs) | |
| elif conv_type == "bn": | |
| self.norm = nn.BatchNorm1d(n_outputs, momentum=0.01) | |
| # Add you own types of variations here! | |
| def forward(self, x): | |
| # Apply the convolution | |
| if self.conv_type == "gn" or self.conv_type == "bn": | |
| out = F.relu(self.norm((self.filter(x)))) | |
| else: # Add your own variations here with elifs conditioned on "conv_type" parameter! | |
| assert(self.conv_type == "normal") | |
| out = F.leaky_relu(self.filter(x)) | |
| return out | |
| def get_input_size(self, output_size): | |
| # Strided conv/decimation | |
| if not self.transpose: | |
| curr_size = (output_size - 1)*self.stride + 1 # o = (i-1)//s + 1 => i = (o - 1)*s + 1 | |
| else: | |
| curr_size = output_size | |
| # Conv | |
| curr_size = curr_size + self.kernel_size - 1 # o = i + p - k + 1 | |
| # Transposed | |
| if self.transpose: | |
| assert ((curr_size - 1) % self.stride == 0)# We need to have a value at the beginning and end | |
| curr_size = ((curr_size - 1) // self.stride) + 1 | |
| assert(curr_size > 0) | |
| return curr_size | |
| def get_output_size(self, input_size): | |
| # Transposed | |
| if self.transpose: | |
| assert(input_size > 1) | |
| curr_size = (input_size - 1)*self.stride + 1 # o = (i-1)//s + 1 => i = (o - 1)*s + 1 | |
| else: | |
| curr_size = input_size | |
| # Conv | |
| curr_size = curr_size - self.kernel_size + 1 # o = i + p - k + 1 | |
| assert (curr_size > 0) | |
| # Strided conv/decimation | |
| if not self.transpose: | |
| assert ((curr_size - 1) % self.stride == 0) # We need to have a value at the beginning and end | |
| curr_size = ((curr_size - 1) // self.stride) + 1 | |
| return curr_size |