| | |
| | |
| | """ |
| | FIR windowed sinc highpass and bandpass filters. |
| | Those are convenience wrappers around the filters defined in `julius.lowpass`. |
| | """ |
| |
|
| | from typing import Sequence, Optional |
| |
|
| | import torch |
| |
|
| | |
| | from .lowpass import lowpass_filter, lowpass_filters, LowPassFilter, LowPassFilters |
| | from .utils import simple_repr |
| |
|
| |
|
| | class HighPassFilters(torch.nn.Module): |
| | """ |
| | Bank of high pass filters. See `julius.lowpass.LowPassFilters` for more |
| | details on the implementation. |
| | |
| | Args: |
| | cutoffs (list[float]): list of cutoff frequencies, in [0, 0.5] expressed as `f/f_s` where |
| | f_s is the samplerate and `f` is the cutoff frequency. |
| | The upper limit is 0.5, because a signal sampled at `f_s` contains only |
| | frequencies under `f_s / 2`. |
| | stride (int): how much to decimate the output. Probably not a good idea |
| | to do so with a high pass filters though... |
| | pad (bool): if True, appropriately pad the input with zero over the edge. If `stride=1`, |
| | the output will have the same length as the input. |
| | zeros (float): Number of zero crossings to keep. |
| | Controls the receptive field of the Finite Impulse Response filter. |
| | For filters with low cutoff frequency, e.g. 40Hz at 44.1kHz, |
| | it is a bad idea to set this to a high value. |
| | This is likely appropriate for most use. Lower values |
| | will result in a faster filter, but with a slower attenuation around the |
| | cutoff frequency. |
| | fft (bool or None): if True, uses `julius.fftconv` rather than PyTorch convolutions. |
| | If False, uses PyTorch convolutions. If None, either one will be chosen automatically |
| | depending on the effective filter size. |
| | |
| | |
| | ..warning:: |
| | All the filters will use the same filter size, aligned on the lowest |
| | frequency provided. If you combine a lot of filters with very diverse frequencies, it might |
| | be more efficient to split them over multiple modules with similar frequencies. |
| | |
| | Shape: |
| | |
| | - Input: `[*, T]` |
| | - Output: `[F, *, T']`, with `T'=T` if `pad` is True and `stride` is 1, and |
| | `F` is the numer of cutoff frequencies. |
| | |
| | >>> highpass = HighPassFilters([1/4]) |
| | >>> x = torch.randn(4, 12, 21, 1024) |
| | >>> list(highpass(x).shape) |
| | [1, 4, 12, 21, 1024] |
| | """ |
| |
|
| | def __init__(self, cutoffs: Sequence[float], stride: int = 1, pad: bool = True, |
| | zeros: float = 8, fft: Optional[bool] = None): |
| | super().__init__() |
| | self._lowpasses = LowPassFilters(cutoffs, stride, pad, zeros, fft) |
| |
|
| | @property |
| | def cutoffs(self): |
| | return self._lowpasses.cutoffs |
| |
|
| | @property |
| | def stride(self): |
| | return self._lowpasses.stride |
| |
|
| | @property |
| | def pad(self): |
| | return self._lowpasses.pad |
| |
|
| | @property |
| | def zeros(self): |
| | return self._lowpasses.zeros |
| |
|
| | @property |
| | def fft(self): |
| | return self._lowpasses.fft |
| |
|
| | def forward(self, input): |
| | lows = self._lowpasses(input) |
| |
|
| | |
| | |
| | if self.pad: |
| | start, end = 0, input.shape[-1] |
| | else: |
| | start = self._lowpasses.half_size |
| | end = -start |
| | input = input[..., start:end:self.stride] |
| | highs = input - lows |
| | return highs |
| |
|
| | def __repr__(self): |
| | return simple_repr(self) |
| |
|
| |
|
| | class HighPassFilter(torch.nn.Module): |
| | """ |
| | Same as `HighPassFilters` but applies a single high pass filter. |
| | |
| | Shape: |
| | |
| | - Input: `[*, T]` |
| | - Output: `[*, T']`, with `T'=T` if `pad` is True and `stride` is 1. |
| | |
| | >>> highpass = HighPassFilter(1/4, stride=1) |
| | >>> x = torch.randn(4, 124) |
| | >>> list(highpass(x).shape) |
| | [4, 124] |
| | """ |
| |
|
| | def __init__(self, cutoff: float, stride: int = 1, pad: bool = True, |
| | zeros: float = 8, fft: Optional[bool] = None): |
| | super().__init__() |
| | self._highpasses = HighPassFilters([cutoff], stride, pad, zeros, fft) |
| |
|
| | @property |
| | def cutoff(self): |
| | return self._highpasses.cutoffs[0] |
| |
|
| | @property |
| | def stride(self): |
| | return self._highpasses.stride |
| |
|
| | @property |
| | def pad(self): |
| | return self._highpasses.pad |
| |
|
| | @property |
| | def zeros(self): |
| | return self._highpasses.zeros |
| |
|
| | @property |
| | def fft(self): |
| | return self._highpasses.fft |
| |
|
| | def forward(self, input): |
| | return self._highpasses(input)[0] |
| |
|
| | def __repr__(self): |
| | return simple_repr(self) |
| |
|
| |
|
| | def highpass_filters(input: torch.Tensor, cutoffs: Sequence[float], |
| | stride: int = 1, pad: bool = True, |
| | zeros: float = 8, fft: Optional[bool] = None): |
| | """ |
| | Functional version of `HighPassFilters`, refer to this class for more information. |
| | """ |
| | return HighPassFilters(cutoffs, stride, pad, zeros, fft).to(input)(input) |
| |
|
| |
|
| | def highpass_filter(input: torch.Tensor, cutoff: float, |
| | stride: int = 1, pad: bool = True, |
| | zeros: float = 8, fft: Optional[bool] = None): |
| | """ |
| | Functional version of `HighPassFilter`, refer to this class for more information. |
| | Output will not have a dimension inserted in the front. |
| | """ |
| | return highpass_filters(input, [cutoff], stride, pad, zeros, fft)[0] |
| |
|
| |
|
| | class BandPassFilter(torch.nn.Module): |
| | """ |
| | Single band pass filter, implemented as a the difference of two lowpass filters. |
| | |
| | Args: |
| | cutoff_low (float): lower cutoff frequency, in [0, 0.5] expressed as `f/f_s` where |
| | f_s is the samplerate and `f` is the cutoff frequency. |
| | The upper limit is 0.5, because a signal sampled at `f_s` contains only |
| | frequencies under `f_s / 2`. |
| | cutoff_high (float): higher cutoff frequency, in [0, 0.5] expressed as `f/f_s`. |
| | This must be higher than cutoff_high. Note that due to the fact |
| | that filter are not perfect, the output will be non zero even if |
| | cutoff_high == cutoff_low. |
| | stride (int): how much to decimate the output. |
| | pad (bool): if True, appropriately pad the input with zero over the edge. If `stride=1`, |
| | the output will have the same length as the input. |
| | zeros (float): Number of zero crossings to keep. |
| | Controls the receptive field of the Finite Impulse Response filter. |
| | For filters with low cutoff frequency, e.g. 40Hz at 44.1kHz, |
| | it is a bad idea to set this to a high value. |
| | This is likely appropriate for most use. Lower values |
| | will result in a faster filter, but with a slower attenuation around the |
| | cutoff frequency. |
| | fft (bool or None): if True, uses `julius.fftconv` rather than PyTorch convolutions. |
| | If False, uses PyTorch convolutions. If None, either one will be chosen automatically |
| | depending on the effective filter size. |
| | |
| | |
| | Shape: |
| | |
| | - Input: `[*, T]` |
| | - Output: `[*, T']`, with `T'=T` if `pad` is True and `stride` is 1. |
| | |
| | ..Note:: There is no BandPassFilters (bank of bandpasses) because its |
| | signification would be the same as `julius.bands.SplitBands`. |
| | |
| | >>> bandpass = BandPassFilter(1/4, 1/3) |
| | >>> x = torch.randn(4, 12, 21, 1024) |
| | >>> list(bandpass(x).shape) |
| | [4, 12, 21, 1024] |
| | """ |
| |
|
| | def __init__(self, cutoff_low: float, cutoff_high: float, stride: int = 1, pad: bool = True, |
| | zeros: float = 8, fft: Optional[bool] = None): |
| | super().__init__() |
| | if cutoff_low > cutoff_high: |
| | raise ValueError(f"Lower cutoff {cutoff_low} should be less than " |
| | f"higher cutoff {cutoff_high}.") |
| | self._lowpasses = LowPassFilters([cutoff_low, cutoff_high], stride, pad, zeros, fft) |
| |
|
| | @property |
| | def cutoff_low(self): |
| | return self._lowpasses.cutoffs[0] |
| |
|
| | @property |
| | def cutoff_high(self): |
| | return self._lowpasses.cutoffs[1] |
| |
|
| | @property |
| | def stride(self): |
| | return self._lowpasses.stride |
| |
|
| | @property |
| | def pad(self): |
| | return self._lowpasses.pad |
| |
|
| | @property |
| | def zeros(self): |
| | return self._lowpasses.zeros |
| |
|
| | @property |
| | def fft(self): |
| | return self._lowpasses.fft |
| |
|
| | def forward(self, input): |
| | lows = self._lowpasses(input) |
| | return lows[1] - lows[0] |
| |
|
| | def __repr__(self): |
| | return simple_repr(self) |
| |
|
| |
|
| | def bandpass_filter(input: torch.Tensor, cutoff_low: float, cutoff_high: float, |
| | stride: int = 1, pad: bool = True, |
| | zeros: float = 8, fft: Optional[bool] = None): |
| | """ |
| | Functional version of `BandPassfilter`, refer to this class for more information. |
| | Output will not have a dimension inserted in the front. |
| | """ |
| | return BandPassFilter(cutoff_low, cutoff_high, stride, pad, zeros, fft).to(input)(input) |
| |
|