Spaces:
Runtime error
Runtime error
| import math | |
| from typing import Dict, List, Tuple | |
| import torch | |
| from torch import Tensor, nn | |
| from torch.nn import functional as F | |
| from .unet import UNet | |
| def batchify(tensor: Tensor, T: int) -> Tensor: | |
| """ | |
| partition tensor into segments of length T, zero pad any ragged samples | |
| Args: | |
| tensor(Tensor): BxCxFxL | |
| Returns: | |
| tensor of size (B*[L/T] x C x F x T) | |
| """ | |
| # Zero pad the original tensor to an even multiple of T | |
| orig_size = tensor.size(-1) | |
| new_size = math.ceil(orig_size / T) * T | |
| tensor = F.pad(tensor, [0, new_size - orig_size]) | |
| # Partition the tensor into multiple samples of length T and stack them into a batch | |
| return torch.cat(torch.split(tensor, T, dim=-1), dim=0) | |
| class Splitter(nn.Module): | |
| def __init__(self, stem_names: List[str] = None): | |
| super(Splitter, self).__init__() | |
| # stft config | |
| self.F = 1024 | |
| self.T = 512 | |
| self.win_length = 4096 | |
| self.hop_length = 1024 | |
| self.win = nn.Parameter(torch.hann_window(self.win_length), requires_grad=False) | |
| self.stems = nn.ModuleDict({'vocals': UNet(in_channels=2), | |
| 'accompaniment': UNet(in_channels=2)}) | |
| def compute_stft(self, wav: Tensor) -> Tuple[Tensor, Tensor]: | |
| """ | |
| Computes stft feature from wav | |
| Args: | |
| wav (Tensor): B x L | |
| """ | |
| stft = torch.stft( | |
| wav, | |
| n_fft=self.win_length, | |
| hop_length=self.hop_length, | |
| window=self.win, | |
| center=True, | |
| return_complex=False, | |
| pad_mode="constant", | |
| ) | |
| # only keep freqs smaller than self.F | |
| stft = stft[:, : self.F, :, :] | |
| real = stft[:, :, :, 0] | |
| im = stft[:, :, :, 1] | |
| mag = torch.sqrt(real ** 2 + im ** 2) | |
| return stft, mag | |
| def inverse_stft(self, stft: Tensor) -> Tensor: | |
| """Inverses stft to wave form""" | |
| pad = self.win_length // 2 + 1 - stft.size(1) | |
| stft = F.pad(stft, (0, 0, 0, 0, 0, pad)) | |
| wav = torch.istft( | |
| stft, | |
| self.win_length, | |
| hop_length=self.hop_length, | |
| center=True, | |
| window=self.win, | |
| ) | |
| return wav.detach() | |
| def forward(self, wav: Tensor) -> Dict[str, Tensor]: | |
| """ | |
| Separates stereo wav into different tracks (1 predicted track per stem) | |
| Args: | |
| wav (tensor): 2 x L | |
| Returns: | |
| masked stfts by track name | |
| """ | |
| # stft - 2 X F x L x 2 | |
| # stft_mag - 2 X F x L | |
| stft, stft_mag = self.compute_stft(wav.squeeze()) | |
| L = stft.size(2) | |
| # 1 x 2 x F x T | |
| stft_mag = stft_mag.unsqueeze(-1).permute([3, 0, 1, 2]) | |
| stft_mag = batchify(stft_mag, self.T) # B x 2 x F x T | |
| stft_mag = stft_mag.transpose(2, 3) # B x 2 x T x F | |
| # compute stems' mask | |
| masks = {name: net(stft_mag) for name, net in self.stems.items()} | |
| # compute denominator | |
| mask_sum = sum([m ** 2 for m in masks.values()]) | |
| mask_sum += 1e-10 | |
| def apply_mask(mask): | |
| mask = (mask ** 2 + 1e-10 / 2) / (mask_sum) | |
| mask = mask.transpose(2, 3) # B x 2 X F x T | |
| mask = torch.cat(torch.split(mask, 1, dim=0), dim=3) | |
| mask = mask.squeeze(0)[:, :, :L].unsqueeze(-1) # 2 x F x L x 1 | |
| stft_masked = stft * mask | |
| return stft_masked | |
| return {name: apply_mask(m) for name, m in masks.items()} | |
| def separate(self, wav: Tensor) -> Dict[str, Tensor]: | |
| """ | |
| Separates stereo wav into different tracks (1 predicted track per stem) | |
| Args: | |
| wav (tensor): 2 x L | |
| Returns: | |
| wavs by track name | |
| """ | |
| stft_masks = self.forward(wav) | |
| return { | |
| name: self.inverse_stft(stft_masked) | |
| for name, stft_masked in stft_masks.items() | |
| } | |
| def from_pretrained(cls, model_path: str): | |
| checkpoint = torch.load(model_path) | |
| model = cls() | |
| model.load_state_dict(checkpoint) | |
| return model | |