| | |
| | |
| | |
| | |
| | |
| |
|
| | import random |
| | import torch as th |
| | from torch import nn |
| |
|
| |
|
| | class Shift(nn.Module): |
| | """ |
| | Randomly shift audio in time by up to `shift` samples. |
| | """ |
| | def __init__(self, shift=8192): |
| | super().__init__() |
| | self.shift = shift |
| |
|
| | def forward(self, wav): |
| | batch, sources, channels, time = wav.size() |
| | length = time - self.shift |
| | if self.shift > 0: |
| | if not self.training: |
| | wav = wav[..., :length] |
| | else: |
| | offsets = th.randint(self.shift, [batch, sources, 1, 1], device=wav.device) |
| | offsets = offsets.expand(-1, -1, channels, -1) |
| | indexes = th.arange(length, device=wav.device) |
| | wav = wav.gather(3, indexes + offsets) |
| | return wav |
| |
|
| |
|
| | class FlipChannels(nn.Module): |
| | """ |
| | Flip left-right channels. |
| | """ |
| | def forward(self, wav): |
| | batch, sources, channels, time = wav.size() |
| | if self.training and wav.size(2) == 2: |
| | left = th.randint(2, (batch, sources, 1, 1), device=wav.device) |
| | left = left.expand(-1, -1, -1, time) |
| | right = 1 - left |
| | wav = th.cat([wav.gather(2, left), wav.gather(2, right)], dim=2) |
| | return wav |
| |
|
| |
|
| | class FlipSign(nn.Module): |
| | """ |
| | Random sign flip. |
| | """ |
| | def forward(self, wav): |
| | batch, sources, channels, time = wav.size() |
| | if self.training: |
| | signs = th.randint(2, (batch, sources, 1, 1), device=wav.device, dtype=th.float32) |
| | wav = wav * (2 * signs - 1) |
| | return wav |
| |
|
| |
|
| | class Remix(nn.Module): |
| | """ |
| | Shuffle sources to make new mixes. |
| | """ |
| | def __init__(self, group_size=4): |
| | """ |
| | Shuffle sources within one batch. |
| | Each batch is divided into groups of size `group_size` and shuffling is done within |
| | each group separatly. This allow to keep the same probability distribution no matter |
| | the number of GPUs. Without this grouping, using more GPUs would lead to a higher |
| | probability of keeping two sources from the same track together which can impact |
| | performance. |
| | """ |
| | super().__init__() |
| | self.group_size = group_size |
| |
|
| | def forward(self, wav): |
| | batch, streams, channels, time = wav.size() |
| | device = wav.device |
| |
|
| | if self.training: |
| | group_size = self.group_size or batch |
| | if batch % group_size != 0: |
| | raise ValueError(f"Batch size {batch} must be divisible by group size {group_size}") |
| | groups = batch // group_size |
| | wav = wav.view(groups, group_size, streams, channels, time) |
| | permutations = th.argsort(th.rand(groups, group_size, streams, 1, 1, device=device), |
| | dim=1) |
| | wav = wav.gather(1, permutations.expand(-1, -1, -1, channels, time)) |
| | wav = wav.view(batch, streams, channels, time) |
| | return wav |
| |
|
| |
|
| | class Scale(nn.Module): |
| | def __init__(self, proba=1., min=0.25, max=1.25): |
| | super().__init__() |
| | self.proba = proba |
| | self.min = min |
| | self.max = max |
| |
|
| | def forward(self, wav): |
| | batch, streams, channels, time = wav.size() |
| | device = wav.device |
| | if self.training and random.random() < self.proba: |
| | scales = th.empty(batch, streams, 1, 1, device=device).uniform_(self.min, self.max) |
| | wav *= scales |
| | return wav |
| |
|