| |
| |
| """ |
| Differentiable, Pytorch based resampling. |
| Implementation of Julius O. Smith algorithm for resampling. |
| See https://ccrma.stanford.edu/~jos/resample/ for details. |
| This implementation is specially optimized for when new_sr / old_sr is a fraction |
| with a small numerator and denominator when removing the gcd (e.g. new_sr = 700, old_sr = 500). |
| |
| Very similar to [bmcfee/resampy](https://github.com/bmcfee/resampy) except this implementation |
| is optimized for the case mentioned before, while resampy is slower but more general. |
| |
| """ |
|
|
| import math |
| from typing import Optional |
|
|
| import torch |
| from torch.nn import functional as F |
|
|
| from .core import sinc |
| from .utils import simple_repr |
|
|
|
|
| class ResampleFrac(torch.nn.Module): |
| """ |
| Resampling from the sample rate `old_sr` to `new_sr`. |
| """ |
| def __init__(self, old_sr: int, new_sr: int, zeros: int = 24, rolloff: float = 0.945): |
| """ |
| Args: |
| old_sr (int): sample rate of the input signal x. |
| new_sr (int): sample rate of the output. |
| zeros (int): number of zero crossing to keep in the sinc filter. |
| rolloff (float): use a lowpass filter that is `rolloff * new_sr / 2`, |
| to ensure sufficient margin due to the imperfection of the FIR filter used. |
| Lowering this value will reduce anti-aliasing, but will reduce some of the |
| highest frequencies. |
| |
| Shape: |
| |
| - Input: `[*, T]` |
| - Output: `[*, T']` with `T' = int(new_sr * T / old_sr) |
| |
| |
| .. caution:: |
| After dividing `old_sr` and `new_sr` by their GCD, both should be small |
| for this implementation to be fast. |
| |
| >>> import torch |
| >>> resample = ResampleFrac(4, 5) |
| >>> x = torch.randn(1000) |
| >>> print(len(resample(x))) |
| 1250 |
| """ |
| super().__init__() |
| if not isinstance(old_sr, int) or not isinstance(new_sr, int): |
| raise ValueError("old_sr and new_sr should be integers") |
| gcd = math.gcd(old_sr, new_sr) |
| self.old_sr = old_sr // gcd |
| self.new_sr = new_sr // gcd |
| self.zeros = zeros |
| self.rolloff = rolloff |
|
|
| self._init_kernels() |
|
|
| def _init_kernels(self): |
| if self.old_sr == self.new_sr: |
| return |
|
|
| kernels = [] |
| sr = min(self.new_sr, self.old_sr) |
| |
| |
| |
| |
| sr *= self.rolloff |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| self._width = math.ceil(self.zeros * self.old_sr / sr) |
| |
| |
| |
| |
| idx = torch.arange(-self._width, self._width + self.old_sr).float() |
| for i in range(self.new_sr): |
| t = (-i/self.new_sr + idx/self.old_sr) * sr |
| t = t.clamp_(-self.zeros, self.zeros) |
| t *= math.pi |
| window = torch.cos(t/self.zeros/2)**2 |
| kernel = sinc(t) * window |
| |
| kernel.div_(kernel.sum()) |
| kernels.append(kernel) |
|
|
| self.register_buffer("kernel", torch.stack(kernels).view(self.new_sr, 1, -1)) |
|
|
| def forward(self, x: torch.Tensor, output_length: Optional[int] = None, full: bool = False): |
| """ |
| Resample x. |
| Args: |
| x (Tensor): signal to resample, time should be the last dimension |
| output_length (None or int): This can be set to the desired output length |
| (last dimension). Allowed values are between 0 and |
| ceil(length * new_sr / old_sr). When None (default) is specified, the |
| floored output length will be used. In order to select the largest possible |
| size, use the `full` argument. |
| full (bool): return the longest possible output from the input. This can be useful |
| if you chain resampling operations, and want to give the `output_length` only |
| for the last one, while passing `full=True` to all the other ones. |
| """ |
| if self.old_sr == self.new_sr: |
| return x |
| shape = x.shape |
| length = x.shape[-1] |
| x = x.reshape(-1, length) |
| x = F.pad(x[:, None], (self._width, self._width + self.old_sr), mode='replicate') |
| ys = F.conv1d(x, self.kernel, stride=self.old_sr) |
| y = ys.transpose(1, 2).reshape(list(shape[:-1]) + [-1]) |
|
|
| float_output_length = self.new_sr * length / self.old_sr |
| max_output_length = int(math.ceil(float_output_length)) |
| default_output_length = int(float_output_length) |
| if output_length is None: |
| output_length = max_output_length if full else default_output_length |
| elif output_length < 0 or output_length > max_output_length: |
| raise ValueError(f"output_length must be between 0 and {max_output_length}") |
| else: |
| if full: |
| raise ValueError("You cannot pass both full=True and output_length") |
| return y[..., :output_length] |
|
|
| def __repr__(self): |
| return simple_repr(self) |
|
|
|
|
| def resample_frac(x: torch.Tensor, old_sr: int, new_sr: int, |
| zeros: int = 24, rolloff: float = 0.945, |
| output_length: Optional[int] = None, full: bool = False): |
| """ |
| Functional version of `ResampleFrac`, refer to its documentation for more information. |
| |
| ..warning:: |
| If you call repeatidly this functions with the same sample rates, then the |
| resampling kernel will be recomputed everytime. For best performance, you should use |
| and cache an instance of `ResampleFrac`. |
| """ |
| return ResampleFrac(old_sr, new_sr, zeros, rolloff).to(x)(x, output_length, full) |
|
|
|
|
| |
| |
|
|
| def _kernel_upsample2_downsample2(zeros): |
| |
| |
| win = torch.hann_window(4 * zeros + 1, periodic=False) |
| winodd = win[1::2] |
| t = torch.linspace(-zeros + 0.5, zeros - 0.5, 2 * zeros) |
| t *= math.pi |
| kernel = (sinc(t) * winodd).view(1, 1, -1) |
| return kernel |
|
|
|
|
| def _upsample2(x, zeros=24): |
| """ |
| Upsample x by a factor of two. The output will be exactly twice as long as the input. |
| Args: |
| x (Tensor): signal to upsample, time should be the last dimension |
| zeros (int): number of zero crossing to keep in the sinc filter. |
| |
| This function is kept only for reference, you should use the more generic `resample_frac` |
| one. This function does not perform anti-aliasing filtering. |
| """ |
| *other, time = x.shape |
| kernel = _kernel_upsample2_downsample2(zeros).to(x) |
| out = F.conv1d(x.view(-1, 1, time), kernel, padding=zeros)[..., 1:].view(*other, time) |
| y = torch.stack([x, out], dim=-1) |
| return y.view(*other, -1) |
|
|
|
|
| def _downsample2(x, zeros=24): |
| """ |
| Downsample x by a factor of two. The output length is half of the input, ceiled. |
| Args: |
| x (Tensor): signal to downsample, time should be the last dimension |
| zeros (int): number of zero crossing to keep in the sinc filter. |
| |
| This function is kept only for reference, you should use the more generic `resample_frac` |
| one. This function does not perform anti-aliasing filtering. |
| """ |
| if x.shape[-1] % 2 != 0: |
| x = F.pad(x, (0, 1)) |
| xeven = x[..., ::2] |
| xodd = x[..., 1::2] |
| *other, time = xodd.shape |
| kernel = _kernel_upsample2_downsample2(zeros).to(x) |
| out = xeven + F.conv1d(xodd.view(-1, 1, time), kernel, padding=zeros)[..., :-1].view( |
| *other, time) |
| return out.view(*other, -1).mul(0.5) |
|
|