| """! |
| @author Yi Luo (oulyluo) |
| @copyright Tencent AI Lab |
| """ |
|
|
| from __future__ import print_function |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| from torch.utils.checkpoint import checkpoint_sequential |
| from thop import profile, clever_format |
|
|
| class RMVN(nn.Module): |
| """ |
| Rescaled MVN. |
| """ |
| def __init__(self, dimension, groups=1): |
| super(RMVN, self).__init__() |
| |
| self.mean = nn.Parameter(torch.zeros(dimension)) |
| self.std = nn.Parameter(torch.ones(dimension)) |
| self.groups = groups |
| self.eps = torch.finfo(torch.float32).eps |
|
|
| def forward(self, input): |
| |
| B, N, T = input.shape |
| assert N % self.groups == 0 |
|
|
| input = input.view(B, self.groups, -1, T) |
| input_norm = (input - input.mean(2).unsqueeze(2)) / (input.var(2).unsqueeze(2) + self.eps).sqrt() |
| input_norm = input_norm.view(B, N, T) * self.std.view(1, -1, 1) + self.mean.view(1, -1, 1) |
|
|
| return input_norm |
|
|
| class ConvActNorm1d(nn.Module): |
| def __init__(self, in_channel, hidden_channel, kernel=7, causal=False): |
| super(ConvActNorm1d, self).__init__() |
| |
| self.in_channel = in_channel |
| self.kernel = kernel |
| self.causal = causal |
| if not causal: |
| self.conv = nn.Sequential(nn.Conv1d(in_channel, in_channel, kernel, padding=(kernel-1)//2), |
| RMVN(in_channel), |
| nn.Conv1d(in_channel, hidden_channel*2, 1), |
| nn.GLU(dim=1), |
| nn.Conv1d(hidden_channel, in_channel, 1) |
| ) |
| else: |
| self.conv = nn.Sequential(nn.Conv1d(in_channel, in_channel, kernel, padding=kernel-1), |
| RMVN(in_channel), |
| nn.Conv1d(in_channel, hidden_channel*2, 1), |
| nn.GLU(dim=1), |
| nn.Conv1d(hidden_channel, in_channel, 1) |
| ) |
| |
| def forward(self, input): |
| |
| output = self.conv(input) |
| if self.causal: |
| output = output[...,:-self.kernel+1].contiguous() |
| return input + output |
|
|
| class ICB(nn.Module): |
| def __init__(self, in_channel, kernel=7, causal=False): |
| super(ICB, self).__init__() |
| |
| self.blocks = nn.Sequential(ConvActNorm1d(in_channel, in_channel*4, kernel, causal=causal), |
| ConvActNorm1d(in_channel, in_channel*4, kernel, causal=causal), |
| ConvActNorm1d(in_channel, in_channel*4, kernel, causal=causal) |
| ) |
| |
| def forward(self, input): |
| |
| return self.blocks(input) |
| |
| class ResRNN(nn.Module): |
| def __init__(self, input_size, hidden_size, bidirectional=False): |
| super(ResRNN, self).__init__() |
| |
| self.input_size = input_size |
| self.hidden_size = hidden_size |
| self.eps = torch.finfo(torch.float32).eps |
| |
| self.norm = RMVN(input_size) |
| self.rnn = nn.LSTM(input_size, hidden_size, 1, batch_first=True, bidirectional=bidirectional) |
|
|
| self.proj = nn.Linear(hidden_size*(int(bidirectional)+1), input_size) |
|
|
| def forward(self, input, use_head=1): |
| |
|
|
| B, N, T = input.shape |
|
|
| rnn_output, _ = self.rnn(self.norm(input).transpose(1,2).contiguous()) |
|
|
| output = self.proj(rnn_output.contiguous().view(-1, rnn_output.shape[2])) |
| output = output.view(B, T, -1).transpose(1,2).contiguous() |
| |
| return input + output |
|
|
| class BSNet(nn.Module): |
| def __init__(self, feature_dim, kernel=7, causal=False): |
| super(BSNet, self).__init__() |
|
|
| self.feature_dim = feature_dim |
|
|
| self.seq_net = ICB(self.feature_dim, kernel=kernel, causal=causal) |
| self.band_net = ResRNN(self.feature_dim, self.feature_dim*2, bidirectional=True) |
|
|
| def forward(self, input): |
| |
|
|
| B, nband, N, T = input.shape |
|
|
| band_output = self.seq_net(input.view(B*nband, N, T)).view(B, nband, -1, T) |
|
|
| |
| band_output = band_output.permute(0,3,2,1).contiguous().view(B*T, -1, nband) |
| output = self.band_net(band_output).view(B, T, -1, nband).permute(0,3,2,1).contiguous() |
|
|
| return output.view(B, nband, N, T) |
| |
| |
| class VQEmbeddingEMA(nn.Module): |
| def __init__(self, num_code, code_dim, decay=0.99, layer=0): |
| super(VQEmbeddingEMA, self).__init__() |
|
|
| self.num_code = num_code |
| self.code_dim = code_dim |
| self.decay = decay |
| self.layer = layer |
| self.stale_tolerance = 100 |
| self.eps = torch.finfo(torch.float32).eps |
|
|
| embedding = torch.empty(num_code, code_dim).normal_() / ((layer+1) * code_dim) |
| self.register_buffer("embedding", embedding) |
| self.register_buffer("ema_weight", self.embedding.clone()) |
| self.register_buffer("ema_count", torch.zeros(self.num_code)) |
| self.register_buffer("stale_counter", torch.zeros(self.num_code)) |
|
|
| def forward(self, input): |
|
|
| B, N, T = input.shape |
| assert N == self.code_dim |
|
|
| input_detach = input.detach().mT.contiguous().view(B*T, N) |
|
|
| |
| eu_dis = input_detach.pow(2).sum(-1).unsqueeze(-1) + self.embedding.pow(2).sum(-1).unsqueeze(0) |
| eu_dis = eu_dis - 2 * input_detach.mm(self.embedding.T) |
|
|
| |
| indices = torch.argmin(eu_dis, dim=-1) |
| quantized = torch.gather(self.embedding, 0, indices.unsqueeze(-1).expand(-1, self.code_dim)) |
| quantized = quantized.view(B, T, N).mT.contiguous() |
|
|
| |
| encodings = F.one_hot(indices, self.num_code).float() |
| avg_probs = encodings.mean(0) |
| perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + self.eps), -1)).mean() |
| indices = indices.view(B, T) |
|
|
| if self.training: |
| |
| |
| self.ema_count = self.decay * self.ema_count + (1 - self.decay) * torch.sum(encodings, dim=0) |
|
|
| update_direction = encodings.T.mm(input_detach) |
| self.ema_weight = self.decay * self.ema_weight + (1 - self.decay) * update_direction |
|
|
| |
| |
| n = torch.sum(self.ema_count, dim=-1, keepdim=True) |
| self.ema_count = (self.ema_count + self.eps) / (n + self.num_code * self.eps) * n |
|
|
| self.embedding = self.ema_weight / self.ema_count.unsqueeze(-1) |
|
|
| |
| stale_codes = (encodings.sum(0) == 0).float() |
| self.stale_counter = self.stale_counter * stale_codes + stale_codes |
|
|
| |
| replace_code = (self.stale_counter == self.stale_tolerance).float() |
| if replace_code.sum(-1).max() > 0: |
| random_input_idx = torch.randperm(input_detach.shape[0]) |
| random_input = input_detach[random_input_idx].view(input_detach.shape) |
| if random_input.shape[0] < self.num_code: |
| random_input = torch.cat([random_input]*(self.num_code // random_input.shape[0] + 1), 0) |
| random_input = random_input[:self.num_code].contiguous() |
|
|
| self.embedding = self.embedding * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1) |
| self.ema_weight = self.ema_weight * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1) |
| self.ema_count = self.ema_count * (1 - replace_code) |
| self.stale_counter = self.stale_counter * (1 - replace_code) |
|
|
| return quantized, indices, perplexity |
|
|
| class RVQEmbedding(nn.Module): |
| def __init__(self, code_dim, decay=0.99, bit=[10]): |
| super(RVQEmbedding, self).__init__() |
|
|
| self.code_dim = code_dim |
| self.decay = decay |
| self.eps = torch.finfo(torch.float32).eps |
|
|
| self.VQEmbedding = nn.ModuleList([]) |
| for i in range(len(bit)): |
| self.VQEmbedding.append(VQEmbeddingEMA(2**bit[i], code_dim, decay, layer=i)) |
|
|
| def forward(self, input): |
| quantized = [] |
| indices = [] |
| ppl = [] |
|
|
| residual_input = input |
| for i in range(len(self.VQEmbedding)): |
| this_quantized, this_indices, this_perplexity = self.VQEmbedding[i](residual_input) |
| indices.append(this_indices) |
| ppl.append(this_perplexity) |
| residual_input = residual_input - this_quantized |
| if i == 0: |
| quantized.append(this_quantized) |
| else: |
| quantized.append(quantized[-1] + this_quantized) |
|
|
| quantized = torch.stack(quantized, -1) |
| indices = torch.stack(indices, -1) |
| ppl = torch.stack(ppl, -1) |
| latent_loss = 0 |
| for i in range(quantized.shape[-1]): |
| latent_loss = latent_loss + F.mse_loss(input, quantized.detach()[...,i]) |
|
|
| return quantized, indices, ppl, latent_loss |
|
|
| class Codec(nn.Module): |
| def __init__(self, nch=1, sr=44100, win=100, feature_dim=128, vae_dim=2, enc_layer=12, dec_layer=12, bit=[8]*5, causal=True): |
| super(Codec, self).__init__() |
| |
| self.nch = nch |
| self.sr = sr |
| self.win = int(sr / 1000 * win) |
| self.stride = self.win // 2 |
| self.enc_dim = self.win // 2 + 1 |
| self.feature_dim = feature_dim |
| self.vae_dim = vae_dim |
| self.bit = bit |
| self.eps = torch.finfo(torch.float32).eps |
|
|
| |
| |
| bandwidth_50 = int(np.floor(50 / (sr / 2.) * self.enc_dim)) |
| bandwidth_100 = int(np.floor(100 / (sr / 2.) * self.enc_dim)) |
| bandwidth_200 = int(np.floor(200 / (sr / 2.) * self.enc_dim)) |
| bandwidth_400 = int(np.floor(400 / (sr / 2.) * self.enc_dim)) |
| bandwidth_500 = int(np.floor(500 / (sr / 2.) * self.enc_dim)) |
| self.band_width = [bandwidth_50]*20 |
| self.band_width += [bandwidth_100]*30 |
| self.band_width += [bandwidth_200]*20 |
| self.band_width += [bandwidth_400]*10 |
| self.band_width += [bandwidth_500]*19 |
| self.band_width.append(self.enc_dim - np.sum(self.band_width)) |
| self.nband = len(self.band_width) |
| print(self.band_width, self.nband) |
|
|
| self.VAE_BN = nn.ModuleList([]) |
| for i in range(self.nband): |
| self.VAE_BN.append(nn.Sequential(RMVN((self.band_width[i]*2+1)*self.nch), |
| nn.Conv1d(((self.band_width[i]*2+1)*self.nch), self.feature_dim, 1)) |
| ) |
|
|
| self.VAE_encoder = [] |
| for _ in range(enc_layer): |
| self.VAE_encoder.append(BSNet(self.feature_dim, kernel=7, causal=causal)) |
| self.VAE_encoder = nn.Sequential(*self.VAE_encoder) |
|
|
| self.vae_FC = nn.Sequential(RMVN(self.nband*self.feature_dim, groups=self.nband), |
| nn.Conv1d(self.nband*self.feature_dim, self.nband*self.vae_dim*2, 1, groups=self.nband) |
| ) |
| self.codebook = RVQEmbedding(self.nband*self.vae_dim*2, bit=bit) |
| self.vae_reshape = nn.Conv1d(self.nband*self.vae_dim, self.nband*self.feature_dim, 1, groups=self.nband) |
|
|
| self.VAE_decoder = [] |
| for _ in range(dec_layer): |
| self.VAE_decoder.append(BSNet(self.feature_dim, kernel=7, causal=causal)) |
| self.VAE_decoder = nn.Sequential(*self.VAE_decoder) |
| |
| self.VAE_output = nn.ModuleList([]) |
| for i in range(self.nband): |
| self.VAE_output.append(nn.Sequential(RMVN(self.feature_dim), |
| nn.Conv1d(self.feature_dim, self.band_width[i]*4*self.nch, 1), |
| nn.GLU(dim=1)) |
| ) |
| |
| def spec_band_split(self, input): |
|
|
| B, nch, nsample = input.shape |
|
|
| spec = torch.stft(input.view(B*nch, nsample).float(), n_fft=self.win, hop_length=self.stride, |
| window=torch.hann_window(self.win).to(input.device), return_complex=True) |
|
|
| subband_spec = [] |
| subband_spec_norm = [] |
| subband_power = [] |
| band_idx = 0 |
| for i in range(self.nband): |
| this_spec = spec[:,band_idx:band_idx+self.band_width[i]] |
| subband_spec.append(this_spec) |
| subband_power.append((this_spec.abs().pow(2).sum(1) + self.eps).sqrt().unsqueeze(1)) |
| subband_spec_norm.append([this_spec.real / subband_power[-1], this_spec.imag / subband_power[-1]]) |
| band_idx += self.band_width[i] |
| subband_power = torch.cat(subband_power, 1) |
|
|
| return subband_spec, subband_spec_norm, subband_power |
|
|
| def feature_extractor(self, input): |
| |
| _, subband_spec_norm, subband_power = self.spec_band_split(input) |
| |
| |
| subband_feature = [] |
| for i in range(self.nband): |
| concat_spec = torch.cat([subband_spec_norm[i][0], subband_spec_norm[i][1], torch.log(subband_power[:,i].unsqueeze(1))], 1) |
| concat_spec = concat_spec.view(-1, (self.band_width[i]*2+1)*self.nch, concat_spec.shape[-1]) |
| subband_feature.append(self.VAE_BN[i](concat_spec.type(input.type()))) |
| subband_feature = torch.stack(subband_feature, 1) |
|
|
| return subband_feature |
| |
| def vae_sample(self, input): |
|
|
| B, nch, _ = input.shape |
|
|
| subband_feature = self.feature_extractor(input) |
|
|
| |
| enc_output = checkpoint_sequential(self.VAE_encoder, len(self.VAE_encoder), subband_feature) |
| enc_output = self.vae_FC(enc_output.view(B, self.nband*self.feature_dim, -1)).view(B, self.nband, 2, self.vae_dim, -1) |
| mu = enc_output[:,:,0].contiguous() |
| logvar = enc_output[:,:,1].contiguous() |
|
|
| |
| reparam_feature = mu + torch.randn_like(logvar) * torch.exp(0.5 * logvar) |
| vae_loss = (-0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(2)).mean() |
|
|
| |
| mu_var = torch.stack([mu, logvar], 1).view(B, self.nband*self.vae_dim*2, -1) |
| quantized_emb, indices, ppl, latent_loss = self.codebook(mu_var.detach()) |
|
|
| return reparam_feature, quantized_emb, mu_var, indices, ppl, latent_loss, vae_loss |
| |
| def vae_decode(self, vae_feature, nsample=None): |
| B = vae_feature.shape[0] |
| dec_input = self.vae_reshape(vae_feature.contiguous().view(B, self.nband*self.vae_dim, -1)) |
| output = checkpoint_sequential(self.VAE_decoder, len(self.VAE_decoder), dec_input.view(B, self.nband, self.feature_dim, -1)) |
| |
| est_spec = [] |
| for i in range(self.nband): |
| this_RI = self.VAE_output[i](output[:,i]).view(B*self.nch, 2, self.band_width[i], -1) |
| est_spec.append(torch.complex(this_RI[:,0].float(), this_RI[:,1].float())) |
| est_spec = torch.cat(est_spec, 1) |
| if nsample is not None: |
| output = torch.istft(est_spec, n_fft=self.win, hop_length=self.stride, |
| window=torch.hann_window(self.win).to(vae_feature.device), length=nsample).view(B, self.nch, -1) |
| else: |
| output = torch.istft(est_spec, n_fft=self.win, hop_length=self.stride, |
| window=torch.hann_window(self.win).to(vae_feature.device)).view(B, self.nch, -1) |
| |
| return output.type(vae_feature.type()) |
| |
| def forward(self, input): |
|
|
| B, nch, nsample = input.shape |
| assert nch == self.nch |
|
|
| vae_feature, quantized_emb, mu_var, indices, ppl, latent_loss, vae_loss = self.vae_sample(input) |
| output = self.vae_decode(vae_feature, nsample=nsample).view(input.shape) |
| |
|
|
| return output |
|
|
| def get_bsrnnvae(ckpt): |
| nch = 1 |
| model = Codec(nch = nch, \ |
| win = 100, \ |
| feature_dim = 128, \ |
| vae_dim = 8, \ |
| bit = [14]*5, \ |
| causal = True) |
| weight = torch.load(ckpt, map_location='cpu') |
| model.load_state_dict(weight) |
| return model.eval() |
|
|