| |
| |
| |
| |
| |
| """ |
| Code to apply a model to a mix. It will handle chunking with overlaps and |
| inteprolation between chunks, as well as the "shift trick". |
| """ |
| from concurrent.futures import ThreadPoolExecutor |
| import random |
| import typing as tp |
| from multiprocessing import Process,Queue,Pipe |
|
|
| import torch as th |
| from torch import nn |
| from torch.nn import functional as F |
| import tqdm |
| import tkinter as tk |
|
|
| from .demucs import Demucs |
| from .hdemucs import HDemucs |
| from .utils import center_trim, DummyPoolExecutor |
|
|
| Model = tp.Union[Demucs, HDemucs] |
|
|
| progress_bar_num = 0 |
|
|
| class BagOfModels(nn.Module): |
| def __init__(self, models: tp.List[Model], |
| weights: tp.Optional[tp.List[tp.List[float]]] = None, |
| segment: tp.Optional[float] = None): |
| """ |
| Represents a bag of models with specific weights. |
| You should call `apply_model` rather than calling directly the forward here for |
| optimal performance. |
| |
| Args: |
| models (list[nn.Module]): list of Demucs/HDemucs models. |
| weights (list[list[float]]): list of weights. If None, assumed to |
| be all ones, otherwise it should be a list of N list (N number of models), |
| each containing S floats (S number of sources). |
| segment (None or float): overrides the `segment` attribute of each model |
| (this is performed inplace, be careful if you reuse the models passed). |
| """ |
| |
| super().__init__() |
| assert len(models) > 0 |
| first = models[0] |
| for other in models: |
| assert other.sources == first.sources |
| assert other.samplerate == first.samplerate |
| assert other.audio_channels == first.audio_channels |
| if segment is not None: |
| other.segment = segment |
|
|
| self.audio_channels = first.audio_channels |
| self.samplerate = first.samplerate |
| self.sources = first.sources |
| self.models = nn.ModuleList(models) |
|
|
| if weights is None: |
| weights = [[1. for _ in first.sources] for _ in models] |
| else: |
| assert len(weights) == len(models) |
| for weight in weights: |
| assert len(weight) == len(first.sources) |
| self.weights = weights |
|
|
| def forward(self, x): |
| raise NotImplementedError("Call `apply_model` on this.") |
|
|
| class TensorChunk: |
| def __init__(self, tensor, offset=0, length=None): |
| total_length = tensor.shape[-1] |
| assert offset >= 0 |
| assert offset < total_length |
|
|
| if length is None: |
| length = total_length - offset |
| else: |
| length = min(total_length - offset, length) |
|
|
| if isinstance(tensor, TensorChunk): |
| self.tensor = tensor.tensor |
| self.offset = offset + tensor.offset |
| else: |
| self.tensor = tensor |
| self.offset = offset |
| self.length = length |
| self.device = tensor.device |
|
|
| @property |
| def shape(self): |
| shape = list(self.tensor.shape) |
| shape[-1] = self.length |
| return shape |
|
|
| def padded(self, target_length): |
| delta = target_length - self.length |
| total_length = self.tensor.shape[-1] |
| assert delta >= 0 |
|
|
| start = self.offset - delta // 2 |
| end = start + target_length |
|
|
| correct_start = max(0, start) |
| correct_end = min(total_length, end) |
|
|
| pad_left = correct_start - start |
| pad_right = end - correct_end |
|
|
| out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right)) |
| assert out.shape[-1] == target_length |
| return out |
|
|
| def tensor_chunk(tensor_or_chunk): |
| if isinstance(tensor_or_chunk, TensorChunk): |
| return tensor_or_chunk |
| else: |
| assert isinstance(tensor_or_chunk, th.Tensor) |
| return TensorChunk(tensor_or_chunk) |
|
|
| def apply_model(model, mix, shifts=1, split=True, overlap=0.25, transition_power=1., static_shifts=1, set_progress_bar=None, device=None, progress=False, num_workers=0, pool=None): |
| """ |
| Apply model to a given mixture. |
| |
| Args: |
| shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec |
| and apply the oppositve shift to the output. This is repeated `shifts` time and |
| all predictions are averaged. This effectively makes the model time equivariant |
| and improves SDR by up to 0.2 points. |
| split (bool): if True, the input will be broken down in 8 seconds extracts |
| and predictions will be performed individually on each and concatenated. |
| Useful for model with large memory footprint like Tasnet. |
| progress (bool): if True, show a progress bar (requires split=True) |
| device (torch.device, str, or None): if provided, device on which to |
| execute the computation, otherwise `mix.device` is assumed. |
| When `device` is different from `mix.device`, only local computations will |
| be on `device`, while the entire tracks will be stored on `mix.device`. |
| """ |
| |
| global fut_length |
| global bag_num |
| global prog_bar |
| |
| if device is None: |
| device = mix.device |
| else: |
| device = th.device(device) |
| if pool is None: |
| if num_workers > 0 and device.type == 'cpu': |
| pool = ThreadPoolExecutor(num_workers) |
| else: |
| pool = DummyPoolExecutor() |
| |
| kwargs = { |
| 'shifts': shifts, |
| 'split': split, |
| 'overlap': overlap, |
| 'transition_power': transition_power, |
| 'progress': progress, |
| 'device': device, |
| 'pool': pool, |
| 'set_progress_bar': set_progress_bar, |
| 'static_shifts': static_shifts, |
| } |
| |
| if isinstance(model, BagOfModels): |
| |
| |
| |
|
|
| estimates = 0 |
| totals = [0] * len(model.sources) |
| bag_num = len(model.models) |
| fut_length = 0 |
| prog_bar = 0 |
| current_model = 0 |
| for sub_model, weight in zip(model.models, model.weights): |
| original_model_device = next(iter(sub_model.parameters())).device |
| sub_model.to(device) |
| fut_length += fut_length |
| current_model += 1 |
| out = apply_model(sub_model, mix, **kwargs) |
| sub_model.to(original_model_device) |
| for k, inst_weight in enumerate(weight): |
| out[:, k, :, :] *= inst_weight |
| totals[k] += inst_weight |
| estimates += out |
| del out |
|
|
| for k in range(estimates.shape[1]): |
| estimates[:, k, :, :] /= totals[k] |
| return estimates |
|
|
| model.to(device) |
| model.eval() |
| assert transition_power >= 1, "transition_power < 1 leads to weird behavior." |
| batch, channels, length = mix.shape |
| |
| if shifts: |
| kwargs['shifts'] = 0 |
| max_shift = int(0.5 * model.samplerate) |
| mix = tensor_chunk(mix) |
| padded_mix = mix.padded(length + 2 * max_shift) |
| out = 0 |
| for _ in range(shifts): |
| offset = random.randint(0, max_shift) |
| shifted = TensorChunk(padded_mix, offset, length + max_shift - offset) |
| shifted_out = apply_model(model, shifted, **kwargs) |
| out += shifted_out[..., max_shift - offset:] |
| out /= shifts |
| return out |
| elif split: |
| kwargs['split'] = False |
| out = th.zeros(batch, len(model.sources), channels, length, device=mix.device) |
| sum_weight = th.zeros(length, device=mix.device) |
| segment = int(model.samplerate * model.segment) |
| stride = int((1 - overlap) * segment) |
| offsets = range(0, length, stride) |
| scale = float(format(stride / model.samplerate, ".2f")) |
| |
| |
| |
| weight = th.cat([th.arange(1, segment // 2 + 1, device=device), |
| th.arange(segment - segment // 2, 0, -1, device=device)]) |
| assert len(weight) == segment |
| |
| |
| weight = (weight / weight.max())**transition_power |
| futures = [] |
| for offset in offsets: |
| chunk = TensorChunk(mix, offset, segment) |
| future = pool.submit(apply_model, model, chunk, **kwargs) |
| futures.append((future, offset)) |
| offset += segment |
| if progress: |
| futures = tqdm.tqdm(futures, unit_scale=scale, ncols=120, unit='seconds') |
| for future, offset in futures: |
| if set_progress_bar: |
| fut_length = (len(futures) * bag_num * static_shifts) |
| prog_bar += 1 |
| set_progress_bar(0.1, (0.8/fut_length*prog_bar)) |
| chunk_out = future.result() |
| chunk_length = chunk_out.shape[-1] |
| out[..., offset:offset + segment] += (weight[:chunk_length] * chunk_out).to(mix.device) |
| sum_weight[offset:offset + segment] += weight[:chunk_length].to(mix.device) |
| assert sum_weight.min() > 0 |
| out /= sum_weight |
| return out |
| else: |
| if hasattr(model, 'valid_length'): |
| valid_length = model.valid_length(length) |
| else: |
| valid_length = length |
| mix = tensor_chunk(mix) |
| padded_mix = mix.padded(valid_length).to(device) |
| with th.no_grad(): |
| out = model(padded_mix) |
| return center_trim(out, length) |
| |
| def demucs_segments(demucs_segment, demucs_model): |
| |
| if demucs_segment == 'Default': |
| segment = None |
| if isinstance(demucs_model, BagOfModels): |
| if segment is not None: |
| for sub in demucs_model.models: |
| sub.segment = segment |
| else: |
| if segment is not None: |
| sub.segment = segment |
| else: |
| try: |
| segment = int(demucs_segment) |
| if isinstance(demucs_model, BagOfModels): |
| if segment is not None: |
| for sub in demucs_model.models: |
| sub.segment = segment |
| else: |
| if segment is not None: |
| sub.segment = segment |
| except: |
| segment = None |
| if isinstance(demucs_model, BagOfModels): |
| if segment is not None: |
| for sub in demucs_model.models: |
| sub.segment = segment |
| else: |
| if segment is not None: |
| sub.segment = segment |
| |
| return demucs_model |