| | import functools |
| | import math |
| | import warnings |
| | from distutils.version import LooseVersion |
| | from fractions import Fraction |
| | from typing import Optional |
| |
|
| | import numpy |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch import Tensor |
| | from torch.nn.modules.utils import _single |
| |
|
| |
|
| | def _get_sinc_resample_kernel( |
| | orig_freq: int, |
| | new_freq: int, |
| | gcd: int, |
| | lowpass_filter_width: int = 6, |
| | rolloff: float = 0.99, |
| | resampling_kernel: str = "sinc", |
| | sinc_window: str = "sinc_interpolation", |
| | beta: Optional[float] = None, |
| | device: torch.device = torch.device("cpu"), |
| | dtype: Optional[torch.dtype] = None, |
| | ): |
| | if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq): |
| | raise Exception( |
| | "Frequencies must be of integer type to ensure quality resampling computation. " |
| | "To work around this, manually convert both frequencies to integer values " |
| | "that maintain their resampling rate ratio before passing them into the function. " |
| | "Example: To downsample a 44100 hz waveform by a factor of 8, use " |
| | "`orig_freq=8` and `new_freq=1` instead of `orig_freq=44100` and `new_freq=5512.5`. " |
| | "For more information, please refer to https://github.com/pytorch/audio/issues/1487." |
| | ) |
| |
|
| | orig_freq = int(orig_freq) // gcd |
| | new_freq = int(new_freq) // gcd |
| |
|
| | if lowpass_filter_width <= 0: |
| | raise ValueError("Low pass filter width should be positive.") |
| | base_freq = max(orig_freq, new_freq) |
| |
|
| | if resampling_kernel == "sinc": |
| | if sinc_window not in ["sinc_interpolation", "kaiser_window"]: |
| | raise ValueError("Invalid resampling method: {}".format(sinc_window)) |
| | base_freq *= rolloff |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | width = math.ceil(lowpass_filter_width * orig_freq / base_freq) |
| | |
| | |
| | |
| | |
| | idx_dtype = dtype if dtype is not None else torch.float64 |
| |
|
| | idx = ( |
| | torch.arange(-width, width + orig_freq, dtype=idx_dtype, device=device)[ |
| | None, None |
| | ] |
| | / orig_freq |
| | ) |
| |
|
| | t = ( |
| | torch.arange(0, -new_freq, -1, dtype=dtype, device=device)[:, None, None] |
| | / new_freq |
| | + idx |
| | ) |
| | t *= base_freq |
| | t = t.clamp_(-lowpass_filter_width, lowpass_filter_width) |
| |
|
| | |
| | |
| | if sinc_window == "sinc_interpolation": |
| | window = torch.cos(t * math.pi / lowpass_filter_width / 2) ** 2 |
| | else: |
| | |
| | if beta is None: |
| | beta = 14.769656459379492 |
| | beta_tensor = torch.tensor(float(beta)) |
| | window = torch.i0( |
| | beta_tensor * torch.sqrt(1 - (t / lowpass_filter_width) ** 2) |
| | ) / torch.i0(beta_tensor) |
| |
|
| | t *= math.pi |
| |
|
| | kernels = torch.where(t == 0, torch.tensor(1.0).to(t), t.sin() / t) |
| | kernels *= window |
| |
|
| | if dtype is None: |
| | kernels = kernels.to(dtype=torch.float32) |
| | else: |
| | raise NotImplementedError |
| |
|
| | return kernels, width |
| |
|
| |
|
| | def _apply_sinc_resample_kernel( |
| | waveform: Tensor, |
| | orig_freq: int, |
| | new_freq: int, |
| | gcd: int, |
| | kernel: Tensor, |
| | width: int, |
| | ): |
| | if not waveform.is_floating_point(): |
| | raise TypeError( |
| | f"Expected floating point type for waveform tensor, but received {waveform.dtype}." |
| | ) |
| |
|
| | orig_freq = int(orig_freq) // gcd |
| | new_freq = int(new_freq) // gcd |
| |
|
| | |
| | shape = waveform.size() |
| | waveform = waveform.view(-1, shape[-1]) |
| |
|
| | num_wavs, length = waveform.shape |
| | waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq)) |
| | resampled = torch.nn.functional.conv1d(waveform[:, None], kernel, stride=orig_freq) |
| | resampled = resampled.transpose(1, 2).reshape(num_wavs, -1) |
| | target_length = int(math.ceil(new_freq * length / orig_freq)) |
| | resampled = resampled[..., :target_length] |
| |
|
| | |
| | resampled = resampled.view(shape[:-1] + resampled.shape[-1:]) |
| | return resampled |
| |
|
| |
|
| | def resample( |
| | waveform: Tensor, |
| | orig_freq: int, |
| | new_freq: int, |
| | lowpass_filter_width: int = 16, |
| | rolloff: float = 0.99, |
| | resampling_kernel: str = "sinc", |
| | sinc_window: str = "kaiser_window", |
| | beta: Optional[float] = None, |
| | kernel=None, |
| | width=None, |
| | ) -> Tensor: |
| | r"""Resamples the waveform at the new frequency using bandlimited interpolation. :cite:`RESAMPLE`. |
| | .. devices:: CPU CUDA |
| | .. properties:: Autograd TorchScript |
| | Note: |
| | ``transforms.Resample`` precomputes and reuses the resampling kernel, so using it will result in |
| | more efficient computation if resampling multiple waveforms with the same resampling parameters. |
| | Args: |
| | waveform (Tensor): The input signal of dimension `(..., time)` |
| | orig_freq (int): The original frequency of the signal |
| | new_freq (int): The desired frequency |
| | lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper |
| | but less efficient. (Default: ``6``) |
| | rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist. |
| | Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``) |
| | resampling_kernel (str, optional): The resampling kernel to use. |
| | Options: [``"sinc"``, ``"linear"``, ``"nearest_neighbor"``] (Default: ``"sinc"``) |
| | sinc_window (str, optional): The resampling method to use. |
| | Options: [``"sinc_interpolation"``, ``"kaiser_window"``] (Default: ``"sinc_interpolation"``) |
| | beta (float or None, optional): The shape parameter used for kaiser window. |
| | Returns: |
| | Tensor: The waveform at the new frequency of dimension `(..., time).` |
| | """ |
| |
|
| | if orig_freq <= 0.0 or new_freq <= 0.0: |
| | raise ValueError("Original frequency and desired frequecy should be positive") |
| |
|
| | if orig_freq == new_freq: |
| | return waveform |
| |
|
| | gcd = math.gcd(int(orig_freq), int(new_freq)) |
| |
|
| | if (kernel is None) or (width is None): |
| | kernel, width = _get_sinc_resample_kernel( |
| | orig_freq=orig_freq, |
| | new_freq=new_freq, |
| | gcd=gcd, |
| | lowpass_filter_width=lowpass_filter_width, |
| | rolloff=rolloff, |
| | resampling_kernel=resampling_kernel, |
| | sinc_window=sinc_window, |
| | beta=beta, |
| | device=waveform.device, |
| | dtype=waveform.dtype, |
| | ) |
| |
|
| | resampled = _apply_sinc_resample_kernel( |
| | waveform, orig_freq, new_freq, gcd, kernel, width |
| | ) |
| |
|
| | return resampled |
| |
|
| |
|
| | def compute_Hilbert_transforms_of_filters(filters): |
| | """Compute the Hilber transforms of the input filters |
| | |
| | Args: |
| | weight (torch.Tensor): weight, n_filters x kernel_size |
| | Return |
| | torch.Tensor: Hilbert transforms of the weight, out_channels x in_channels x kernel_size |
| | """ |
| | if LooseVersion(torch.__version__) < LooseVersion("1.7.0"): |
| | ft_f = torch.rfft( |
| | filters.reshape(filters.shape[0], 1, filters.shape[1]), 1, normalized=True |
| | ) |
| | hft_f = torch.stack([ft_f[:, :, :, 1], -ft_f[:, :, :, 0]], dim=-1) |
| | hft_f = torch.irfft(hft_f, 1, normalized=True, signal_sizes=(filters.shape[1],)) |
| | else: |
| | ft_f = torch.fft.rfft(filters, n=filters.shape[1], dim=1, norm="ortho") |
| | hft_f = torch.view_as_complex(torch.stack((ft_f.imag, -ft_f.real), axis=-1)) |
| | hft_f = torch.fft.irfft(hft_f, n=filters.shape[1], dim=1, norm="ortho") |
| | return hft_f.reshape(*(filters.shape)) |
| |
|
| |
|
| | class _FIRDesignBase(nn.Module): |
| | def __init__( |
| | self, |
| | in_channels, |
| | out_channels, |
| | ContFilterType, |
| | filter_params, |
| | use_Hilbert_transforms=False, |
| | transposed=False, |
| | ): |
| | super().__init__() |
| | self.in_channels = in_channels |
| | self.out_channels = out_channels |
| | n_filters = in_channels * out_channels |
| | self.use_Hilbert_transforms = use_Hilbert_transforms |
| | if self.use_Hilbert_transforms: |
| | if n_filters % 2 == 1: |
| | raise ValueError( |
| | f"n_filters must be even when using Hilbert transforms of filters [n_filters={n_filters}]" |
| | ) |
| | n_filters //= 2 |
| | self.continuous_filters = ContFilterType( |
| | n_filters=n_filters, |
| | **filter_params, |
| | ) |
| | self._transposed = transposed |
| |
|
| | def weight(self, weight): |
| | if self.use_Hilbert_transforms: |
| | weight = torch.cat( |
| | (weight, compute_Hilbert_transforms_of_filters(weight)), dim=0 |
| | ) |
| | if self._transposed: |
| | return weight.reshape(self.in_channels, self.out_channels, -1) |
| | else: |
| | return weight.reshape(self.out_channels, self.in_channels, -1) |
| |
|
| | def prepare( |
| | self, |
| | sample_rate: int, |
| | kernel_size, |
| | stride, |
| | padding: int = None, |
| | output_padding: int = 0, |
| | ): |
| | self.sample_rate = sample_rate |
| | self.kernel_size = _single(kernel_size) |
| | self.stride = _single(stride) |
| | self.trained_stride = stride |
| | if padding is None: |
| | self.padding = _single(int((self.kernel_size[0] - self.stride[0]) // 2)) |
| | else: |
| | self.padding = _single(padding) |
| | self.output_padding = (int(output_padding),) |
| |
|
| | def forward(self, input): |
| | raise NotImplementedError |
| |
|
| | def extra_repr(self): |
| | s = "{in_channels}, {out_channels}, sample_rate={sample_rate}" |
| | if hasattr(self, "kernel_size"): |
| | s += ", kernel_size={kernel_size}" |
| | if hasattr(self, "stride"): |
| | s += ", stride={stride}" |
| | if hasattr(self, "kernel_size"): |
| | s += ", padding={padding}" |
| | if hasattr(self, "output_padding"): |
| | s += ", output_padding={output_padding}" |
| | return s.format(**self.__dict__) |
| |
|
| | def precompute_weight(self): |
| | raise NotImplementedError |
| |
|
| | def convert(self): |
| | warnings.warn( |
| | f"Converting SFI to Non-SFI convolutional layers [sample_rate={self.sample_rate}, kernel_size={self.kernel_size}, stride={self.stride}, padding={self.padding}, output_padding={self.output_padding}]" |
| | ) |
| | if self.is_transposed: |
| | return functools.partial( |
| | F.conv_transpose1d, |
| | weight=self.weight(), |
| | bias=None, |
| | stride=self.stride, |
| | padding=self.padding, |
| | output_padding=self.output_padding, |
| | dilation=_single(1), |
| | groups=1, |
| | ) |
| | else: |
| | return functools.partial( |
| | F.conv1d, |
| | weight=self.weight(), |
| | bias=None, |
| | stride=self.stride, |
| | padding=self.padding, |
| | dilation=_single(1), |
| | groups=1, |
| | ) |
| |
|
| | def prepare_for_time_interpolation(self, trained_stride): |
| | self.trained_stride = trained_stride |
| |
|
| | def convolution_time_interpolation(self, input, torchaudio_options={}): |
| | conv = F.conv1d( |
| | input, self.weight(), None, 1, int(self.padding[0]), _single(1), 1 |
| | ) |
| | |
| | |
| | orig_sr2new_sr = Fraction(self.trained_stride / 1) |
| | org_sr, new_sr = orig_sr2new_sr.as_integer_ratio() |
| | output = resample( |
| | conv, org_sr, new_sr, sinc_window="kaiser_window", **torchaudio_options |
| | ) |
| | return output |
| |
|
| | def new_convolution_time_interpolation_reverse(self, input, torchaudio_options={}): |
| | |
| | org_sr = numpy.round(self.sample_rate * self.trained_stride / self.stride[0]) |
| | new_sr = numpy.round(self.sample_rate * self.trained_stride) |
| |
|
| | output_list = [] |
| | for i in range(input.shape[0]): |
| | middle = resample( |
| | input[i : i + 1], |
| | org_sr, |
| | new_sr, |
| | sinc_window="kaiser_window", |
| | **torchaudio_options, |
| | ) |
| | output_tmp = F.conv_transpose1d( |
| | middle, |
| | self.weight(), |
| | None, |
| | 1, |
| | int(self.padding[0]), |
| | self.output_padding, |
| | 1, |
| | _single(1), |
| | ) |
| | output_list.append(output_tmp) |
| | output = torch.cat(output_list, dim=0) |
| |
|
| | return output |
| |
|
| |
|
| | class _FreqRespSampFIRs(_FIRDesignBase): |
| | SAMPLING_STRATEGY = [ |
| | "fixed", |
| | "randomized", |
| | "completely_randomized", |
| | "fixed_for_noninteger_kernel_size", |
| | ] |
| |
|
| | def __init__( |
| | self, |
| | in_channels, |
| | out_channels, |
| | n_samples, |
| | ContFilterType, |
| | filter_params, |
| | use_Hilbert_transforms=False, |
| | transposed=False, |
| | frequency_sampling_strategy=["fixed", "fixed"], |
| | ): |
| | """ |
| | Args: |
| | in_channels (int): Number of channels of 1D sequence |
| | out_channels (int): Number of channels produced by the convolution |
| | n_samples (int): Number of sampled points for frequency sampling |
| | ContFilterType (Class): Continuous filter class |
| | filter_params (dict): Parameters of continuous filter class |
| | use_Hilbert_transforms (bool): If True, the latter half of the filters are the Hilbert pairs of the former half. |
| | transposed (bool): Whether this convolution is a transposed convolution or not. |
| | frequency_sampling_strategy (list[str]): Strategy of frequency sampling at training and inference stages. "fixed" means equally-spaced sampling and "randomized" means adding small noise to the equally-spaced sampled positions. |
| | """ |
| | super().__init__( |
| | in_channels=in_channels, |
| | out_channels=out_channels, |
| | ContFilterType=ContFilterType, |
| | filter_params=filter_params, |
| | use_Hilbert_transforms=use_Hilbert_transforms, |
| | transposed=transposed, |
| | ) |
| | self.n_samples = n_samples |
| | for i in range(len(frequency_sampling_strategy)): |
| | if frequency_sampling_strategy[i] not in self.SAMPLING_STRATEGY: |
| | raise NotImplementedError( |
| | f"Undefined frequency sampling strategy [{frequency_sampling_strategy[i]}]" |
| | ) |
| | self.frequency_sampling_strategy = frequency_sampling_strategy |
| | self._cache = dict() |
| |
|
| | def weight(self): |
| | weight = self.approximate_by_FIR( |
| | self.continuous_filters.device |
| | ) |
| | return super().weight(weight) |
| |
|
| | def precompute_weight(self): |
| | assert self.training, "This function should not be called during training." |
| | self._precomputed_weight = self.weight() |
| |
|
| | def _compute_pinvW(self, device): |
| | kernel_size = self.kernel_size[0] |
| | sample_rate = self.sample_rate |
| | P = (kernel_size - 1) // 2 if kernel_size % 2 == 1 else kernel_size // 2 |
| | M = self.n_samples |
| | nyquist_rate = sample_rate / 2 |
| | |
| |
|
| | if not self.training and hasattr(self, "noninteger_kernel_size_adaptation"): |
| | self.frequency_sampling_strategy[1] = "fixed_for_noninteger_kernel_size" |
| |
|
| | strategy = self.frequency_sampling_strategy[0 if self.training else 1] |
| | if strategy == "fixed": |
| | ang_freqs = ( |
| | torch.linspace(0, nyquist_rate * 2.0 * numpy.pi, M).float().to(device) |
| | ) |
| | elif strategy == "randomized": |
| | ang_freqs = torch.linspace(0, nyquist_rate * 2.0 * numpy.pi, M).float() |
| | ang_freqs.requires_grad_(False) |
| | delta_val = ang_freqs[1] - ang_freqs[0] |
| | delta = torch.zeros_like(ang_freqs).uniform_(-delta_val / 2, delta_val / 2) |
| | delta.requires_grad_(False) |
| | if delta[0] < 0: |
| | delta[0] = -delta[0] |
| | if delta[-1] > 0: |
| | delta[-1] = -delta[-1] |
| | ang_freqs = ang_freqs + delta |
| | ang_freqs = ang_freqs.to(device) |
| | elif strategy == "completely_randomized": |
| | ang_freqs = ( |
| | torch.zeros((M,), device=device) |
| | .float() |
| | .uniform_(0.0, nyquist_rate * 2.0 * numpy.pi) |
| | ) |
| | ang_freqs.requires_grad_(False) |
| | ang_freqs, _ = torch.sort(ang_freqs, descending=False) |
| | elif strategy == "fixed_for_noninteger_kernel_size": |
| | max_freq = nyquist_rate / kernel_size * int(kernel_size) |
| | ang_freqs = ( |
| | torch.linspace(0, max_freq * 2.0 * numpy.pi, M).float().to(device) |
| | ) |
| | else: |
| | raise NotImplementedError( |
| | f"Undefined frequency sampling strategy [{strategy}]" |
| | ) |
| | normalized_ang_freqs = ang_freqs / float(sample_rate) |
| | if kernel_size % 2 == 1: |
| | seq_P = torch.arange(-P, P + 1).float()[None, :].to(device) |
| | ln_W = -normalized_ang_freqs[:, None] * seq_P |
| | else: |
| | seq_P = torch.arange(-(P - 1), P + 1).float()[None, :].to(device) |
| | ln_W = -normalized_ang_freqs[:, None] * seq_P |
| | ln_W = ln_W.to(device) |
| | W = torch.cat((torch.cos(ln_W), torch.sin(ln_W)), dim=0) |
| | |
| | pinvW = torch.pinverse(W) |
| | pinvW.requires_grad_(False) |
| | ang_freqs.requires_grad_(False) |
| | return ang_freqs, pinvW |
| |
|
| | def approximate_by_FIR(self, device): |
| | """Approximate frequency responses of analog filters with those of digital filters |
| | |
| | Args: |
| | device (torch.Device): Computation device |
| | |
| | Return: |
| | torch.Tensor: Time-reversed impulse responses of digital filters (n_filters x filter_degree (-P to P)) |
| | """ |
| | strategy = self.frequency_sampling_strategy[0 if self.training else 1] |
| | if strategy == "fixed" or strategy == "fixed_for_noninteger_kernel_size": |
| | cache_tag = (self.sample_rate, self.kernel_size, self.stride) |
| | if cache_tag in self._cache: |
| | ang_freqs, pinvW = self._cache[cache_tag] |
| | ang_freqs = ang_freqs.detach().to(device) |
| | pinvW = pinvW.detach().to(device) |
| | else: |
| | ang_freqs, pinvW = self._compute_pinvW(device) |
| | self._cache[cache_tag] = ( |
| | ang_freqs.detach().cpu(), |
| | pinvW.detach().cpu(), |
| | ) |
| | elif strategy == "randomized" or strategy == "completely_randomized": |
| | ang_freqs, pinvW = self._compute_pinvW(device) |
| | else: |
| | raise NotImplementedError( |
| | f"Undefined frequency sampling strategy [{strategy}]" |
| | ) |
| | |
| | resp_r, resp_i = self.continuous_filters.get_frequency_responses( |
| | ang_freqs |
| | ) |
| | resp = torch.cat((resp_r, resp_i), dim=1) |
| | |
| | fir_coeffs = (pinvW[None, :, :] @ resp[:, :, None])[:, :, 0] |
| | kernel_size = int(self.kernel_size[0] / 2) * 2 |
| | return fir_coeffs[ |
| | :, torch.arange(kernel_size - 1, -1, -1) |
| | ] |
| |
|
| | def extra_repr(self): |
| | s = "{in_channels}, {out_channels}, sample_rate={sample_rate}, n_samples={n_samples}" |
| | if hasattr(self, "kernel_size"): |
| | s += ", kernel_size={kernel_size}" |
| | if hasattr(self, "stride"): |
| | s += ", stride={stride}" |
| | if hasattr(self, "kernel_size"): |
| | s += ", padding={padding}" |
| | return s.format(**self.__dict__) |
| |
|
| | def get_analog_freqresp_for_visualization(self, ang_freqs): |
| | """Get frequency responses of analog filters for visualization |
| | |
| | Args: |
| | ang_freqs (torch.Tensor): Unnormalized angular frequency [rad] (n_angfreqs) |
| | |
| | Return: |
| | torch.Tensor[cfloat]: Frequency responses of analog filters (n_filters x n_angfreqs) |
| | """ |
| | resp_r, resp_i = self.continuous_filters.get_frequency_responses(ang_freqs) |
| | resp = torch.stack((resp_r, resp_i), axis=-1) |
| | return torch.view_as_complex(resp) |
| |
|
| |
|
| | class FreqRespSampConv1d(_FreqRespSampFIRs): |
| | def __init__( |
| | self, |
| | in_channels, |
| | out_channels, |
| | n_samples, |
| | ContFilterType, |
| | filter_params, |
| | use_Hilbert_transforms=False, |
| | frequency_sampling_strategy=["fixed", "fixed"], |
| | ): |
| | super().__init__( |
| | in_channels=in_channels, |
| | out_channels=out_channels, |
| | n_samples=n_samples, |
| | ContFilterType=ContFilterType, |
| | filter_params=filter_params, |
| | use_Hilbert_transforms=use_Hilbert_transforms, |
| | transposed=False, |
| | frequency_sampling_strategy=frequency_sampling_strategy, |
| | ) |
| |
|
| | def forward( |
| | self, |
| | input, |
| | torchaudio_options={}, |
| | ): |
| | return self.convolution_time_interpolation( |
| | input, |
| | torchaudio_options=torchaudio_options, |
| | ) |
| |
|
| | @property |
| | def is_transposed(self): |
| | return False |
| |
|
| |
|
| | class FreqRespSampConvTranspose1d(_FreqRespSampFIRs): |
| | def __init__( |
| | self, |
| | in_channels, |
| | out_channels, |
| | n_samples, |
| | ContFilterType, |
| | filter_params, |
| | use_Hilbert_transforms=False, |
| | frequency_sampling_strategy=["fixed", "fixed"], |
| | ): |
| | super().__init__( |
| | in_channels=in_channels, |
| | out_channels=out_channels, |
| | n_samples=n_samples, |
| | ContFilterType=ContFilterType, |
| | filter_params=filter_params, |
| | use_Hilbert_transforms=use_Hilbert_transforms, |
| | transposed=True, |
| | frequency_sampling_strategy=frequency_sampling_strategy, |
| | ) |
| |
|
| | def forward( |
| | self, |
| | input, |
| | time_interpolation=False, |
| | interp_type="cubic", |
| | torchaudio_options={}, |
| | other_options={}, |
| | ): |
| | if time_interpolation: |
| | return self.new_convolution_time_interpolation_reverse( |
| | input, |
| | interp_type=interp_type, |
| | torchaudio_options=torchaudio_options, |
| | other_options=other_options, |
| | ) |
| | else: |
| | return F.conv_transpose1d( |
| | input, |
| | self.weight(), |
| | None, |
| | self.stride, |
| | self.padding, |
| | self.output_padding, |
| | 1, |
| | _single(1), |
| | ) |
| |
|
| | @property |
| | def is_transposed(self): |
| | return True |
| |
|
| |
|
| | |
| | class _InvImpRespFIRs(_FIRDesignBase): |
| | def weight(self): |
| | weight = self.continuous_filters.get_impulse_responses( |
| | self.sample_rate, self.kernel_size[0] |
| | ) |
| | return super().weight(weight) |
| |
|
| | def precompute_weight(self): |
| | assert self.training, "This function should not be called during training." |
| | self._precomputed_weight = self.weight() |
| |
|
| | def get_analog_impulse_resp_for_visualization(self, sample_rate, kernel_size): |
| | return self.continuous_filters.get_impulse_responses(sample_rate, kernel_size) |
| |
|
| | def get_analog_impulse_resp_for_visualization_oversampling( |
| | self, sample_rate, kernel_size |
| | ): |
| | return self.continuous_filters.get_impulse_responses_oversampling( |
| | sample_rate, kernel_size |
| | ) |
| |
|
| |
|
| | class InvImpRespConv1d(_InvImpRespFIRs): |
| | def __init__( |
| | self, |
| | in_channels, |
| | out_channels, |
| | ContFilterType, |
| | filter_params, |
| | use_Hilbert_transforms=False, |
| | ): |
| | super().__init__( |
| | in_channels=in_channels, |
| | out_channels=out_channels, |
| | ContFilterType=ContFilterType, |
| | filter_params=filter_params, |
| | use_Hilbert_transforms=use_Hilbert_transforms, |
| | transposed=False, |
| | ) |
| |
|
| | def forward( |
| | self, |
| | input, |
| | time_interpolation=False, |
| | interp_type="cubic", |
| | torchaudio_options={}, |
| | other_options={}, |
| | ): |
| | if time_interpolation: |
| | return self.convolution_time_interpolation( |
| | input, |
| | interp_type=interp_type, |
| | torchaudio_options=torchaudio_options, |
| | other_options=other_options, |
| | ) |
| | else: |
| | return F.conv1d( |
| | input, self.weight(), None, self.stride, self.padding, _single(1), 1 |
| | ) |
| |
|
| | @property |
| | def is_transposed(self): |
| | return False |
| |
|
| |
|
| | class InvImpRespConvTranspose1d(_InvImpRespFIRs): |
| | def __init__( |
| | self, |
| | in_channels, |
| | out_channels, |
| | ContFilterType, |
| | filter_params, |
| | use_Hilbert_transforms=False, |
| | ): |
| | super().__init__( |
| | in_channels=in_channels, |
| | out_channels=out_channels, |
| | ContFilterType=ContFilterType, |
| | filter_params=filter_params, |
| | use_Hilbert_transforms=use_Hilbert_transforms, |
| | transposed=True, |
| | ) |
| |
|
| | def forward( |
| | self, |
| | input, |
| | time_interpolation=False, |
| | interp_type="cubic", |
| | torchaudio_options={}, |
| | other_options={}, |
| | ): |
| | if time_interpolation: |
| | return self.new_convolution_time_interpolation_reverse( |
| | input, |
| | interp_type=interp_type, |
| | torchaudio_options=torchaudio_options, |
| | other_options=other_options, |
| | ) |
| | else: |
| | return F.conv_transpose1d( |
| | input, |
| | self.weight(), |
| | None, |
| | self.stride, |
| | self.padding, |
| | self.output_padding, |
| | 1, |
| | _single(1), |
| | ) |
| |
|
| | @property |
| | def is_transposed(self): |
| | return True |
| |
|
| |
|
| | |
| | class PinvConvTranspose1d(nn.Module): |
| | def __init__(self, encoder): |
| | super().__init__() |
| | self.encoder = encoder |
| |
|
| | def prepare(self, sample_rate: int, kernel_size: int, stride: int): |
| | self.sample_rate = sample_rate |
| | self.kernel_size = kernel_size |
| | self.stride = stride |
| |
|
| | def weight(self): |
| | """Computes pseudo inverse filterbank of given filters.""" |
| | scale = self.stride / self.kernel_size |
| | filters = self.encoder.weight() |
| | shape = filters.shape |
| | ifilt = torch.pinverse(filters.squeeze()).transpose(-1, -2).view(shape) |
| | |
| | return ifilt * scale |
| |
|
| | def forward(self, input): |
| | return F.conv_transpose1d( |
| | input, |
| | self.weight(), |
| | None, |
| | self.stride, |
| | (self.kernel_size - self.stride) // 2, |
| | _single(0), |
| | 1, |
| | _single(1), |
| | ) |
| |
|
| |
|
| | class Conv1dWithHilbertTransforms(nn.Module): |
| | def __init__( |
| | self, |
| | in_channels, |
| | out_channels, |
| | kernel_size, |
| | stride, |
| | use_Hilbert_transforms=True, |
| | ): |
| | super().__init__() |
| | self.conv = nn.Conv1d( |
| | in_channels, |
| | out_channels // 2 if use_Hilbert_transforms else out_channels, |
| | kernel_size=kernel_size, |
| | stride=stride, |
| | bias=False, |
| | padding=(kernel_size - stride) // 2, |
| | ) |
| | self.use_Hilbert_transforms = use_Hilbert_transforms |
| |
|
| | def forward(self, input): |
| | if self.use_Hilbert_transforms: |
| | weight = self.conv.weight |
| | weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) |
| | weight = torch.cat( |
| | (weight, compute_Hilbert_transforms_of_filters(weight)), dim=0 |
| | ) |
| | return self._forward( |
| | input, |
| | weight.reshape( |
| | self.conv.weight.shape[0], |
| | self.conv.weight.shape[1] * 2, |
| | self.conv.weight.shape[2], |
| | ), |
| | ) |
| | else: |
| | return self.conv(input) |
| |
|
| | def _forward(self, input, weight): |
| | if self.conv.padding_mode != "zeros": |
| | return F.conv1d( |
| | F.pad( |
| | input, |
| | self.conv._padding_repeated_twice, |
| | mode=self.conv.padding_mode, |
| | ), |
| | weight, |
| | None, |
| | self.conv.stride, |
| | _single(0), |
| | self.conv.dilation, |
| | self.conv.groups, |
| | ) |
| | return F.conv1d( |
| | input, |
| | weight, |
| | None, |
| | self.conv.stride, |
| | self.conv.padding, |
| | self.conv.dilation, |
| | self.conv.groups, |
| | ) |
| |
|
| |
|
| | class ConvTranspose1dWithHilbertTransforms(nn.Module): |
| | def __init__( |
| | self, |
| | in_channels, |
| | out_channels, |
| | kernel_size, |
| | stride, |
| | use_Hilbert_transforms=True, |
| | ): |
| | super().__init__() |
| | self.conv = nn.ConvTranspose1d( |
| | in_channels, |
| | out_channels // 2 if use_Hilbert_transforms else out_channels, |
| | kernel_size=kernel_size, |
| | stride=stride, |
| | bias=False, |
| | padding=(kernel_size - stride) // 2, |
| | ) |
| | self.use_Hilbert_transforms = use_Hilbert_transforms |
| |
|
| | def forward(self, input): |
| | if self.use_Hilbert_transforms: |
| | weight = self.conv.weight |
| | weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) |
| | weight = torch.cat( |
| | (weight, compute_Hilbert_transforms_of_filters(weight)), dim=0 |
| | ) |
| | return self._forward( |
| | input, |
| | weight.reshape( |
| | self.conv.weight.shape[0], |
| | self.conv.weight.shape[1] * 2, |
| | self.conv.weight.shape[2], |
| | ), |
| | ) |
| | else: |
| | return self.conv(input) |
| |
|
| | def _forward(self, input, weight, output_size=None): |
| | if self.conv.padding_mode != "zeros": |
| | raise ValueError( |
| | "Only `zeros` padding mode is supported for ConvTranspose1d" |
| | ) |
| |
|
| | output_padding = self.conv._output_padding( |
| | input, |
| | output_size, |
| | self.conv.stride, |
| | self.conv.padding, |
| | self.conv.kernel_size, |
| | ) |
| | return F.conv_transpose1d( |
| | input, |
| | weight, |
| | None, |
| | self.conv.stride, |
| | self.conv.padding, |
| | output_padding, |
| | self.conv.groups, |
| | self.conv.dilation, |
| | ) |
| |
|