| from abc import ABCMeta |
|
|
| import torch |
| import torch.nn as nn |
| from pytorch_lightning import LightningModule |
| from .modules import TFC_TDF |
|
|
| dim_s = 4 |
|
|
| class AbstractMDXNet(LightningModule): |
| __metaclass__ = ABCMeta |
|
|
| def __init__(self, target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length, overlap): |
| super().__init__() |
| self.target_name = target_name |
| self.lr = lr |
| self.optimizer = optimizer |
| self.dim_c = dim_c |
| self.dim_f = dim_f |
| self.dim_t = dim_t |
| self.n_fft = n_fft |
| self.n_bins = n_fft // 2 + 1 |
| self.hop_length = hop_length |
| self.window = nn.Parameter(torch.hann_window(window_length=self.n_fft, periodic=True), requires_grad=False) |
| self.freq_pad = nn.Parameter(torch.zeros([1, dim_c, self.n_bins - self.dim_f, self.dim_t]), requires_grad=False) |
|
|
| def configure_optimizers(self): |
| if self.optimizer == 'rmsprop': |
| return torch.optim.RMSprop(self.parameters(), self.lr) |
| |
| if self.optimizer == 'adamw': |
| return torch.optim.AdamW(self.parameters(), self.lr) |
|
|
| class ConvTDFNet(AbstractMDXNet): |
| def __init__(self, target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length, |
| num_blocks, l, g, k, bn, bias, overlap): |
|
|
| super(ConvTDFNet, self).__init__( |
| target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length, overlap) |
| self.save_hyperparameters() |
|
|
| self.num_blocks = num_blocks |
| self.l = l |
| self.g = g |
| self.k = k |
| self.bn = bn |
| self.bias = bias |
|
|
| if optimizer == 'rmsprop': |
| norm = nn.BatchNorm2d |
| |
| if optimizer == 'adamw': |
| norm = lambda input:nn.GroupNorm(2, input) |
| |
| self.n = num_blocks // 2 |
| scale = (2, 2) |
|
|
| self.first_conv = nn.Sequential( |
| nn.Conv2d(in_channels=self.dim_c, out_channels=g, kernel_size=(1, 1)), |
| norm(g), |
| nn.ReLU(), |
| ) |
|
|
| f = self.dim_f |
| c = g |
| self.encoding_blocks = nn.ModuleList() |
| self.ds = nn.ModuleList() |
| for i in range(self.n): |
| self.encoding_blocks.append(TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm)) |
| self.ds.append( |
| nn.Sequential( |
| nn.Conv2d(in_channels=c, out_channels=c + g, kernel_size=scale, stride=scale), |
| norm(c + g), |
| nn.ReLU() |
| ) |
| ) |
| f = f // 2 |
| c += g |
|
|
| self.bottleneck_block = TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm) |
|
|
| self.decoding_blocks = nn.ModuleList() |
| self.us = nn.ModuleList() |
| for i in range(self.n): |
| self.us.append( |
| nn.Sequential( |
| nn.ConvTranspose2d(in_channels=c, out_channels=c - g, kernel_size=scale, stride=scale), |
| norm(c - g), |
| nn.ReLU() |
| ) |
| ) |
| f = f * 2 |
| c -= g |
|
|
| self.decoding_blocks.append(TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm)) |
|
|
| self.final_conv = nn.Sequential( |
| nn.Conv2d(in_channels=c, out_channels=self.dim_c, kernel_size=(1, 1)), |
| ) |
|
|
| def forward(self, x): |
|
|
| x = self.first_conv(x) |
|
|
| x = x.transpose(-1, -2) |
|
|
| ds_outputs = [] |
| for i in range(self.n): |
| x = self.encoding_blocks[i](x) |
| ds_outputs.append(x) |
| x = self.ds[i](x) |
|
|
| x = self.bottleneck_block(x) |
|
|
| for i in range(self.n): |
| x = self.us[i](x) |
| x *= ds_outputs[-i - 1] |
| x = self.decoding_blocks[i](x) |
|
|
| x = x.transpose(-1, -2) |
|
|
| x = self.final_conv(x) |
|
|
| return x |
| |
| class Mixer(nn.Module): |
| def __init__(self, device, mixer_path): |
| |
| super(Mixer, self).__init__() |
| |
| self.linear = nn.Linear((dim_s+1)*2, dim_s*2, bias=False) |
| |
| self.load_state_dict( |
| torch.load(mixer_path, map_location=device) |
| ) |
|
|
| def forward(self, x): |
| x = x.reshape(1,(dim_s+1)*2,-1).transpose(-1,-2) |
| x = self.linear(x) |
| return x.transpose(-1,-2).reshape(dim_s,2,-1) |