| | |
| | |
| | """ |
| | FIR windowed sinc lowpass filters. |
| | """ |
| |
|
| | import math |
| | from typing import Sequence, Optional |
| |
|
| | import torch |
| | from torch.nn import functional as F |
| |
|
| | from .core import sinc |
| | from .fftconv import fft_conv1d |
| | from .utils import simple_repr |
| |
|
| |
|
| | class LowPassFilters(torch.nn.Module): |
| | """ |
| | Bank of low pass filters. Note that a high pass or band pass filter can easily |
| | be implemented by substracting a same signal processed with low pass filters with different |
| | frequencies (see `julius.bands.SplitBands` for instance). |
| | This uses a windowed sinc filter, very similar to the one used in |
| | `julius.resample`. However, because we do not change the sample rate here, |
| | this filter can be much more efficiently implemented using the FFT convolution from |
| | `julius.fftconv`. |
| | |
| | Args: |
| | cutoffs (list[float]): list of cutoff frequencies, in [0, 0.5] expressed as `f/f_s` where |
| | f_s is the samplerate and `f` is the cutoff frequency. |
| | The upper limit is 0.5, because a signal sampled at `f_s` contains only |
| | frequencies under `f_s / 2`. |
| | stride (int): how much to decimate the output. Keep in mind that decimation |
| | of the output is only acceptable if the cutoff frequency is under `1/ (2 * stride)` |
| | of the original sampling rate. |
| | pad (bool): if True, appropriately pad the input with zero over the edge. If `stride=1`, |
| | the output will have the same length as the input. |
| | zeros (float): Number of zero crossings to keep. |
| | Controls the receptive field of the Finite Impulse Response filter. |
| | For lowpass filters with low cutoff frequency, e.g. 40Hz at 44.1kHz, |
| | it is a bad idea to set this to a high value. |
| | This is likely appropriate for most use. Lower values |
| | will result in a faster filter, but with a slower attenuation around the |
| | cutoff frequency. |
| | fft (bool or None): if True, uses `julius.fftconv` rather than PyTorch convolutions. |
| | If False, uses PyTorch convolutions. If None, either one will be chosen automatically |
| | depending on the effective filter size. |
| | |
| | |
| | ..warning:: |
| | All the filters will use the same filter size, aligned on the lowest |
| | frequency provided. If you combine a lot of filters with very diverse frequencies, it might |
| | be more efficient to split them over multiple modules with similar frequencies. |
| | |
| | ..note:: |
| | A lowpass with a cutoff frequency of 0 is defined as the null function |
| | by convention here. This allows for a highpass with a cutoff of 0 to |
| | be equal to identity, as defined in `julius.filters.HighPassFilters`. |
| | |
| | Shape: |
| | |
| | - Input: `[*, T]` |
| | - Output: `[F, *, T']`, with `T'=T` if `pad` is True and `stride` is 1, and |
| | `F` is the numer of cutoff frequencies. |
| | |
| | >>> lowpass = LowPassFilters([1/4]) |
| | >>> x = torch.randn(4, 12, 21, 1024) |
| | >>> list(lowpass(x).shape) |
| | [1, 4, 12, 21, 1024] |
| | """ |
| |
|
| | def __init__(self, cutoffs: Sequence[float], stride: int = 1, pad: bool = True, |
| | zeros: float = 8, fft: Optional[bool] = None): |
| | super().__init__() |
| | self.cutoffs = list(cutoffs) |
| | if min(self.cutoffs) < 0: |
| | raise ValueError("Minimum cutoff must be larger than zero.") |
| | if max(self.cutoffs) > 0.5: |
| | raise ValueError("A cutoff above 0.5 does not make sense.") |
| | self.stride = stride |
| | self.pad = pad |
| | self.zeros = zeros |
| | self.half_size = int(zeros / min([c for c in self.cutoffs if c > 0]) / 2) |
| | if fft is None: |
| | fft = self.half_size > 32 |
| | self.fft = fft |
| | window = torch.hann_window(2 * self.half_size + 1, periodic=False) |
| | time = torch.arange(-self.half_size, self.half_size + 1) |
| | filters = [] |
| | for cutoff in cutoffs: |
| | if cutoff == 0: |
| | filter_ = torch.zeros_like(time) |
| | else: |
| | filter_ = 2 * cutoff * window * sinc(2 * cutoff * math.pi * time) |
| | |
| | |
| | filter_ /= filter_.sum() |
| | filters.append(filter_) |
| | self.register_buffer("filters", torch.stack(filters)[:, None]) |
| |
|
| | def forward(self, input): |
| | shape = list(input.shape) |
| | input = input.view(-1, 1, shape[-1]) |
| | if self.pad: |
| | input = F.pad(input, (self.half_size, self.half_size), mode='replicate') |
| | if self.fft: |
| | out = fft_conv1d(input, self.filters, stride=self.stride) |
| | else: |
| | out = F.conv1d(input, self.filters, stride=self.stride) |
| | shape.insert(0, len(self.cutoffs)) |
| | shape[-1] = out.shape[-1] |
| | return out.permute(1, 0, 2).reshape(shape) |
| |
|
| | def __repr__(self): |
| | return simple_repr(self) |
| |
|
| |
|
| | class LowPassFilter(torch.nn.Module): |
| | """ |
| | Same as `LowPassFilters` but applies a single low pass filter. |
| | |
| | Shape: |
| | |
| | - Input: `[*, T]` |
| | - Output: `[*, T']`, with `T'=T` if `pad` is True and `stride` is 1. |
| | |
| | >>> lowpass = LowPassFilter(1/4, stride=2) |
| | >>> x = torch.randn(4, 124) |
| | >>> list(lowpass(x).shape) |
| | [4, 62] |
| | """ |
| |
|
| | def __init__(self, cutoff: float, stride: int = 1, pad: bool = True, |
| | zeros: float = 8, fft: Optional[bool] = None): |
| | super().__init__() |
| | self._lowpasses = LowPassFilters([cutoff], stride, pad, zeros, fft) |
| |
|
| | @property |
| | def cutoff(self): |
| | return self._lowpasses.cutoffs[0] |
| |
|
| | @property |
| | def stride(self): |
| | return self._lowpasses.stride |
| |
|
| | @property |
| | def pad(self): |
| | return self._lowpasses.pad |
| |
|
| | @property |
| | def zeros(self): |
| | return self._lowpasses.zeros |
| |
|
| | @property |
| | def fft(self): |
| | return self._lowpasses.fft |
| |
|
| | def forward(self, input): |
| | return self._lowpasses(input)[0] |
| |
|
| | def __repr__(self): |
| | return simple_repr(self) |
| |
|
| |
|
| | def lowpass_filters(input: torch.Tensor, cutoffs: Sequence[float], |
| | stride: int = 1, pad: bool = True, |
| | zeros: float = 8, fft: Optional[bool] = None): |
| | """ |
| | Functional version of `LowPassFilters`, refer to this class for more information. |
| | """ |
| | return LowPassFilters(cutoffs, stride, pad, zeros, fft).to(input)(input) |
| |
|
| |
|
| | def lowpass_filter(input: torch.Tensor, cutoff: float, |
| | stride: int = 1, pad: bool = True, |
| | zeros: float = 8, fft: Optional[bool] = None): |
| | """ |
| | Same as `lowpass_filters` but with a single cutoff frequency. |
| | Output will not have a dimension inserted in the front. |
| | """ |
| | return lowpass_filters(input, [cutoff], stride, pad, zeros, fft)[0] |
| |
|