Spaces:
Paused
Paused
| import torch.nn as nn | |
| import torch | |
| import numpy as np | |
| import torch.nn.functional as F | |
| import math | |
| from torchlibrosa.stft import magphase | |
| def init_layer(layer): | |
| """Initialize a Linear or Convolutional layer. """ | |
| nn.init.xavier_uniform_(layer.weight) | |
| if hasattr(layer, "bias"): | |
| if layer.bias is not None: | |
| layer.bias.data.fill_(0.0) | |
| def init_bn(bn): | |
| """Initialize a Batchnorm layer. """ | |
| bn.bias.data.fill_(0.0) | |
| bn.weight.data.fill_(1.0) | |
| def init_embedding(layer): | |
| """Initialize a Linear or Convolutional layer. """ | |
| nn.init.uniform_(layer.weight, -1., 1.) | |
| if hasattr(layer, 'bias'): | |
| if layer.bias is not None: | |
| layer.bias.data.fill_(0.) | |
| def init_gru(rnn): | |
| """Initialize a GRU layer. """ | |
| def _concat_init(tensor, init_funcs): | |
| (length, fan_out) = tensor.shape | |
| fan_in = length // len(init_funcs) | |
| for (i, init_func) in enumerate(init_funcs): | |
| init_func(tensor[i * fan_in : (i + 1) * fan_in, :]) | |
| def _inner_uniform(tensor): | |
| fan_in = nn.init._calculate_correct_fan(tensor, "fan_in") | |
| nn.init.uniform_(tensor, -math.sqrt(3 / fan_in), math.sqrt(3 / fan_in)) | |
| for i in range(rnn.num_layers): | |
| _concat_init( | |
| getattr(rnn, "weight_ih_l{}".format(i)), | |
| [_inner_uniform, _inner_uniform, _inner_uniform], | |
| ) | |
| torch.nn.init.constant_(getattr(rnn, "bias_ih_l{}".format(i)), 0) | |
| _concat_init( | |
| getattr(rnn, "weight_hh_l{}".format(i)), | |
| [_inner_uniform, _inner_uniform, nn.init.orthogonal_], | |
| ) | |
| torch.nn.init.constant_(getattr(rnn, "bias_hh_l{}".format(i)), 0) | |
| def act(x, activation): | |
| if activation == "relu": | |
| return F.relu_(x) | |
| elif activation == "leaky_relu": | |
| return F.leaky_relu_(x, negative_slope=0.01) | |
| elif activation == "swish": | |
| return x * torch.sigmoid(x) | |
| else: | |
| raise Exception("Incorrect activation!") | |
| class Base: | |
| def __init__(self): | |
| pass | |
| def spectrogram(self, input, eps=0.): | |
| (real, imag) = self.stft(input) | |
| return torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5 | |
| def spectrogram_phase(self, input, eps=0.): | |
| (real, imag) = self.stft(input) | |
| mag = torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5 | |
| cos = real / mag | |
| sin = imag / mag | |
| return mag, cos, sin | |
| def wav_to_spectrogram_phase(self, input, eps=1e-10): | |
| """Waveform to spectrogram. | |
| Args: | |
| input: (batch_size, segment_samples, channels_num) | |
| Outputs: | |
| output: (batch_size, channels_num, time_steps, freq_bins) | |
| """ | |
| sp_list = [] | |
| cos_list = [] | |
| sin_list = [] | |
| channels_num = input.shape[1] | |
| for channel in range(channels_num): | |
| mag, cos, sin = self.spectrogram_phase(input[:, channel, :], eps=eps) | |
| sp_list.append(mag) | |
| cos_list.append(cos) | |
| sin_list.append(sin) | |
| sps = torch.cat(sp_list, dim=1) | |
| coss = torch.cat(cos_list, dim=1) | |
| sins = torch.cat(sin_list, dim=1) | |
| return sps, coss, sins | |
| def wav_to_spectrogram(self, input, eps=0.): | |
| """Waveform to spectrogram. | |
| Args: | |
| input: (batch_size, segment_samples, channels_num) | |
| Outputs: | |
| output: (batch_size, channels_num, time_steps, freq_bins) | |
| """ | |
| sp_list = [] | |
| channels_num = input.shape[1] | |
| for channel in range(channels_num): | |
| sp_list.append(self.spectrogram(input[:, channel, :], eps=eps)) | |
| output = torch.cat(sp_list, dim=1) | |
| return output | |
| def spectrogram_to_wav(self, input, spectrogram, length=None): | |
| """Spectrogram to waveform. | |
| Args: | |
| input: (batch_size, segment_samples, channels_num) | |
| spectrogram: (batch_size, channels_num, time_steps, freq_bins) | |
| Outputs: | |
| output: (batch_size, segment_samples, channels_num) | |
| """ | |
| channels_num = input.shape[1] | |
| wav_list = [] | |
| for channel in range(channels_num): | |
| (real, imag) = self.stft(input[:, channel, :]) | |
| (_, cos, sin) = magphase(real, imag) | |
| wav_list.append(self.istft(spectrogram[:, channel : channel + 1, :, :] * cos, | |
| spectrogram[:, channel : channel + 1, :, :] * sin, length)) | |
| output = torch.stack(wav_list, dim=1) | |
| return output | |