|
|
import os |
|
|
import warnings |
|
|
import torch |
|
|
import numpy as np |
|
|
import soundfile as sf |
|
|
|
|
|
|
|
|
def get_device(tensor_or_module, default=None): |
|
|
if hasattr(tensor_or_module, "device"): |
|
|
return tensor_or_module.device |
|
|
elif hasattr(tensor_or_module, "parameters"): |
|
|
return next(tensor_or_module.parameters()).device |
|
|
elif default is None: |
|
|
raise TypeError( |
|
|
f"Don't know how to get device of {type(tensor_or_module)} object" |
|
|
) |
|
|
else: |
|
|
return torch.device(default) |
|
|
|
|
|
|
|
|
class Separator: |
|
|
def forward_wav(self, wav, **kwargs): |
|
|
raise NotImplementedError |
|
|
|
|
|
def sample_rate(self): |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
def separate(model, wav, **kwargs): |
|
|
if isinstance(wav, np.ndarray): |
|
|
return numpy_separate(model, wav, **kwargs) |
|
|
elif isinstance(wav, torch.Tensor): |
|
|
return torch_separate(model, wav, **kwargs) |
|
|
else: |
|
|
raise ValueError( |
|
|
f"Only support filenames, numpy arrays and torch tensors, received {type(wav)}" |
|
|
) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def torch_separate(model: Separator, wav: torch.Tensor, **kwargs) -> torch.Tensor: |
|
|
"""Core logic of `separate`.""" |
|
|
if model.in_channels is not None and wav.shape[-2] != model.in_channels: |
|
|
raise RuntimeError( |
|
|
f"Model supports {model.in_channels}-channel inputs but found audio with {wav.shape[-2]} channels." |
|
|
f"Please match the number of channels." |
|
|
) |
|
|
|
|
|
input_device = get_device(wav, default="cpu") |
|
|
model_device = get_device(model, default="cpu") |
|
|
wav = wav.to(model_device) |
|
|
|
|
|
separate_func = getattr(model, "forward_wav", model) |
|
|
out_wavs = separate_func(wav, **kwargs) |
|
|
|
|
|
|
|
|
out_wavs *= wav.abs().sum() / (out_wavs.abs().sum()) |
|
|
|
|
|
|
|
|
out_wavs = out_wavs.to(input_device) |
|
|
return out_wavs |
|
|
|
|
|
|
|
|
def numpy_separate(model: Separator, wav: np.ndarray, **kwargs) -> np.ndarray: |
|
|
"""Numpy interface to `separate`.""" |
|
|
wav = torch.from_numpy(wav) |
|
|
out_wavs = torch_separate(model, wav, **kwargs) |
|
|
out_wavs = out_wavs.data.numpy() |
|
|
return out_wavs |
|
|
|