import functools import math import warnings from distutils.version import LooseVersion from fractions import Fraction from typing import Optional import numpy import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from torch.nn.modules.utils import _single def _get_sinc_resample_kernel( orig_freq: int, new_freq: int, gcd: int, lowpass_filter_width: int = 6, rolloff: float = 0.99, resampling_kernel: str = "sinc", sinc_window: str = "sinc_interpolation", beta: Optional[float] = None, device: torch.device = torch.device("cpu"), dtype: Optional[torch.dtype] = None, ): if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq): raise Exception( "Frequencies must be of integer type to ensure quality resampling computation. " "To work around this, manually convert both frequencies to integer values " "that maintain their resampling rate ratio before passing them into the function. " "Example: To downsample a 44100 hz waveform by a factor of 8, use " "`orig_freq=8` and `new_freq=1` instead of `orig_freq=44100` and `new_freq=5512.5`. " "For more information, please refer to https://github.com/pytorch/audio/issues/1487." ) orig_freq = int(orig_freq) // gcd new_freq = int(new_freq) // gcd if lowpass_filter_width <= 0: raise ValueError("Low pass filter width should be positive.") base_freq = max(orig_freq, new_freq) if resampling_kernel == "sinc": if sinc_window not in ["sinc_interpolation", "kaiser_window"]: raise ValueError("Invalid resampling method: {}".format(sinc_window)) base_freq *= rolloff # The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor) # using the sinc interpolation formula: # x(t) = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - t)) # We can then sample the function x(t) with a different sample rate: # y[j] = x(j / new_freq) # or, # y[j] = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - j / new_freq)) # We see here that y[j] is the convolution of x[i] with a specific filter, for which # we take an FIR approximation, stopping when we see at least `lowpass_filter_width` zeros crossing. # But y[j+1] is going to have a different set of weights and so on, until y[j + new_freq]. # Indeed: # y[j + new_freq] = sum_i x[i] sinc(pi * orig_freq * ((i / orig_freq - (j + new_freq) / new_freq)) # = sum_i x[i] sinc(pi * orig_freq * ((i - orig_freq) / orig_freq - j / new_freq)) # = sum_i x[i + orig_freq] sinc(pi * orig_freq * (i / orig_freq - j / new_freq)) # so y[j+new_freq] uses the same filter as y[j], but on a shifted version of x by `orig_freq`. # This will explain the F.conv1d after, with a stride of orig_freq. width = math.ceil(lowpass_filter_width * orig_freq / base_freq) # If orig_freq is still big after GCD reduction, most filters will be very unbalanced, i.e., # they will have a lot of almost zero values to the left or to the right... # There is probably a way to evaluate those filters more efficiently, but this is kept for # future work. idx_dtype = dtype if dtype is not None else torch.float64 idx = ( torch.arange(-width, width + orig_freq, dtype=idx_dtype, device=device)[ None, None ] / orig_freq ) t = ( torch.arange(0, -new_freq, -1, dtype=dtype, device=device)[:, None, None] / new_freq + idx ) t *= base_freq t = t.clamp_(-lowpass_filter_width, lowpass_filter_width) # we do not use built in torch windows here as we need to evaluate the window # at specific positions, not over a regular grid. if sinc_window == "sinc_interpolation": window = torch.cos(t * math.pi / lowpass_filter_width / 2) ** 2 else: # kaiser_window if beta is None: beta = 14.769656459379492 beta_tensor = torch.tensor(float(beta)) window = torch.i0( beta_tensor * torch.sqrt(1 - (t / lowpass_filter_width) ** 2) ) / torch.i0(beta_tensor) t *= math.pi kernels = torch.where(t == 0, torch.tensor(1.0).to(t), t.sin() / t) kernels *= window if dtype is None: kernels = kernels.to(dtype=torch.float32) else: raise NotImplementedError return kernels, width def _apply_sinc_resample_kernel( waveform: Tensor, orig_freq: int, new_freq: int, gcd: int, kernel: Tensor, width: int, ): if not waveform.is_floating_point(): raise TypeError( f"Expected floating point type for waveform tensor, but received {waveform.dtype}." ) orig_freq = int(orig_freq) // gcd new_freq = int(new_freq) // gcd # pack batch shape = waveform.size() waveform = waveform.view(-1, shape[-1]) num_wavs, length = waveform.shape waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq)) resampled = torch.nn.functional.conv1d(waveform[:, None], kernel, stride=orig_freq) resampled = resampled.transpose(1, 2).reshape(num_wavs, -1) target_length = int(math.ceil(new_freq * length / orig_freq)) resampled = resampled[..., :target_length] # unpack batch resampled = resampled.view(shape[:-1] + resampled.shape[-1:]) return resampled def resample( waveform: Tensor, orig_freq: int, new_freq: int, lowpass_filter_width: int = 16, rolloff: float = 0.99, resampling_kernel: str = "sinc", sinc_window: str = "kaiser_window", beta: Optional[float] = None, kernel=None, width=None, ) -> Tensor: r"""Resamples the waveform at the new frequency using bandlimited interpolation. :cite:`RESAMPLE`. .. devices:: CPU CUDA .. properties:: Autograd TorchScript Note: ``transforms.Resample`` precomputes and reuses the resampling kernel, so using it will result in more efficient computation if resampling multiple waveforms with the same resampling parameters. Args: waveform (Tensor): The input signal of dimension `(..., time)` orig_freq (int): The original frequency of the signal new_freq (int): The desired frequency lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper but less efficient. (Default: ``6``) rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist. Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``) resampling_kernel (str, optional): The resampling kernel to use. Options: [``"sinc"``, ``"linear"``, ``"nearest_neighbor"``] (Default: ``"sinc"``) sinc_window (str, optional): The resampling method to use. Options: [``"sinc_interpolation"``, ``"kaiser_window"``] (Default: ``"sinc_interpolation"``) beta (float or None, optional): The shape parameter used for kaiser window. Returns: Tensor: The waveform at the new frequency of dimension `(..., time).` """ if orig_freq <= 0.0 or new_freq <= 0.0: raise ValueError("Original frequency and desired frequecy should be positive") if orig_freq == new_freq: return waveform gcd = math.gcd(int(orig_freq), int(new_freq)) if (kernel is None) or (width is None): kernel, width = _get_sinc_resample_kernel( orig_freq=orig_freq, new_freq=new_freq, gcd=gcd, lowpass_filter_width=lowpass_filter_width, rolloff=rolloff, resampling_kernel=resampling_kernel, sinc_window=sinc_window, beta=beta, device=waveform.device, dtype=waveform.dtype, ) resampled = _apply_sinc_resample_kernel( waveform, orig_freq, new_freq, gcd, kernel, width ) return resampled def compute_Hilbert_transforms_of_filters(filters): """Compute the Hilber transforms of the input filters Args: weight (torch.Tensor): weight, n_filters x kernel_size Return torch.Tensor: Hilbert transforms of the weight, out_channels x in_channels x kernel_size """ if LooseVersion(torch.__version__) < LooseVersion("1.7.0"): ft_f = torch.rfft( filters.reshape(filters.shape[0], 1, filters.shape[1]), 1, normalized=True ) hft_f = torch.stack([ft_f[:, :, :, 1], -ft_f[:, :, :, 0]], dim=-1) hft_f = torch.irfft(hft_f, 1, normalized=True, signal_sizes=(filters.shape[1],)) else: # New PyTorch version has fft module ft_f = torch.fft.rfft(filters, n=filters.shape[1], dim=1, norm="ortho") hft_f = torch.view_as_complex(torch.stack((ft_f.imag, -ft_f.real), axis=-1)) hft_f = torch.fft.irfft(hft_f, n=filters.shape[1], dim=1, norm="ortho") return hft_f.reshape(*(filters.shape)) class _FIRDesignBase(nn.Module): def __init__( self, in_channels, out_channels, ContFilterType, filter_params, use_Hilbert_transforms=False, transposed=False, ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels n_filters = in_channels * out_channels self.use_Hilbert_transforms = use_Hilbert_transforms if self.use_Hilbert_transforms: if n_filters % 2 == 1: raise ValueError( f"n_filters must be even when using Hilbert transforms of filters [n_filters={n_filters}]" ) n_filters //= 2 self.continuous_filters = ContFilterType( n_filters=n_filters, **filter_params, ) self._transposed = transposed def weight(self, weight): if self.use_Hilbert_transforms: weight = torch.cat( (weight, compute_Hilbert_transforms_of_filters(weight)), dim=0 ) if self._transposed: return weight.reshape(self.in_channels, self.out_channels, -1) else: return weight.reshape(self.out_channels, self.in_channels, -1) def prepare( self, sample_rate: int, kernel_size, stride, padding: int = None, output_padding: int = 0, ): self.sample_rate = sample_rate self.kernel_size = _single(kernel_size) self.stride = _single(stride) self.trained_stride = stride if padding is None: self.padding = _single(int((self.kernel_size[0] - self.stride[0]) // 2)) else: self.padding = _single(padding) self.output_padding = (int(output_padding),) def forward(self, input): raise NotImplementedError def extra_repr(self): s = "{in_channels}, {out_channels}, sample_rate={sample_rate}" if hasattr(self, "kernel_size"): s += ", kernel_size={kernel_size}" if hasattr(self, "stride"): s += ", stride={stride}" if hasattr(self, "kernel_size"): s += ", padding={padding}" if hasattr(self, "output_padding"): s += ", output_padding={output_padding}" return s.format(**self.__dict__) def precompute_weight(self): raise NotImplementedError def convert(self): warnings.warn( f"Converting SFI to Non-SFI convolutional layers [sample_rate={self.sample_rate}, kernel_size={self.kernel_size}, stride={self.stride}, padding={self.padding}, output_padding={self.output_padding}]" ) if self.is_transposed: return functools.partial( F.conv_transpose1d, weight=self.weight(), bias=None, stride=self.stride, padding=self.padding, output_padding=self.output_padding, dilation=_single(1), groups=1, ) else: return functools.partial( F.conv1d, weight=self.weight(), bias=None, stride=self.stride, padding=self.padding, dilation=_single(1), groups=1, ) def prepare_for_time_interpolation(self, trained_stride): self.trained_stride = trained_stride def convolution_time_interpolation(self, input, torchaudio_options={}): conv = F.conv1d( input, self.weight(), None, 1, int(self.padding[0]), _single(1), 1 ) # org_sr: new_sr = 1 : 1/stride = stride : 1, org_sr and new_sr must be integers # org_sr : new_sr = 1 : 1 / S^{(test)} = S^{(test)} : 1 = S^{(train)} * F_s^{(test)} / F_s^{(train)} : 1 = S^{(train)} * F_s^{(test)} : F_s^{(train)} orig_sr2new_sr = Fraction(self.trained_stride / 1) org_sr, new_sr = orig_sr2new_sr.as_integer_ratio() output = resample( conv, org_sr, new_sr, sinc_window="kaiser_window", **torchaudio_options ) return output def new_convolution_time_interpolation_reverse(self, input, torchaudio_options={}): # org_sr : new_sr = 1 : 1 / S^{(test)} = S^{(test)} : 1 = S^{(train)} * F_s^{(test)} / F_s^{(train)} : 1 = S^{(train)} * F_s^{(test)} : F_s^{(train)} org_sr = numpy.round(self.sample_rate * self.trained_stride / self.stride[0]) new_sr = numpy.round(self.sample_rate * self.trained_stride) output_list = [] for i in range(input.shape[0]): middle = resample( input[i : i + 1], org_sr, new_sr, sinc_window="kaiser_window", **torchaudio_options, ) output_tmp = F.conv_transpose1d( middle, self.weight(), None, 1, int(self.padding[0]), self.output_padding, 1, _single(1), ) output_list.append(output_tmp) output = torch.cat(output_list, dim=0) return output class _FreqRespSampFIRs(_FIRDesignBase): SAMPLING_STRATEGY = [ "fixed", "randomized", "completely_randomized", "fixed_for_noninteger_kernel_size", ] def __init__( self, in_channels, out_channels, n_samples, ContFilterType, filter_params, use_Hilbert_transforms=False, transposed=False, frequency_sampling_strategy=["fixed", "fixed"], ): """ Args: in_channels (int): Number of channels of 1D sequence out_channels (int): Number of channels produced by the convolution n_samples (int): Number of sampled points for frequency sampling ContFilterType (Class): Continuous filter class filter_params (dict): Parameters of continuous filter class use_Hilbert_transforms (bool): If True, the latter half of the filters are the Hilbert pairs of the former half. transposed (bool): Whether this convolution is a transposed convolution or not. frequency_sampling_strategy (list[str]): Strategy of frequency sampling at training and inference stages. "fixed" means equally-spaced sampling and "randomized" means adding small noise to the equally-spaced sampled positions. """ super().__init__( in_channels=in_channels, out_channels=out_channels, ContFilterType=ContFilterType, filter_params=filter_params, use_Hilbert_transforms=use_Hilbert_transforms, transposed=transposed, ) self.n_samples = n_samples for i in range(len(frequency_sampling_strategy)): if frequency_sampling_strategy[i] not in self.SAMPLING_STRATEGY: raise NotImplementedError( f"Undefined frequency sampling strategy [{frequency_sampling_strategy[i]}]" ) self.frequency_sampling_strategy = frequency_sampling_strategy self._cache = dict() def weight(self): weight = self.approximate_by_FIR( self.continuous_filters.device ) # n_filters (or n_filters//2) x kernel_size return super().weight(weight) def precompute_weight(self): assert self.training, "This function should not be called during training." self._precomputed_weight = self.weight() def _compute_pinvW(self, device): kernel_size = self.kernel_size[0] sample_rate = self.sample_rate P = (kernel_size - 1) // 2 if kernel_size % 2 == 1 else kernel_size // 2 M = self.n_samples nyquist_rate = sample_rate / 2 # if not self.training and hasattr(self, "noninteger_kernel_size_adaptation"): self.frequency_sampling_strategy[1] = "fixed_for_noninteger_kernel_size" strategy = self.frequency_sampling_strategy[0 if self.training else 1] if strategy == "fixed": ang_freqs = ( torch.linspace(0, nyquist_rate * 2.0 * numpy.pi, M).float().to(device) ) elif strategy == "randomized": ang_freqs = torch.linspace(0, nyquist_rate * 2.0 * numpy.pi, M).float() ang_freqs.requires_grad_(False) delta_val = ang_freqs[1] - ang_freqs[0] delta = torch.zeros_like(ang_freqs).uniform_(-delta_val / 2, delta_val / 2) delta.requires_grad_(False) if delta[0] < 0: delta[0] = -delta[0] if delta[-1] > 0: delta[-1] = -delta[-1] ang_freqs = ang_freqs + delta ang_freqs = ang_freqs.to(device) elif strategy == "completely_randomized": ang_freqs = ( torch.zeros((M,), device=device) .float() .uniform_(0.0, nyquist_rate * 2.0 * numpy.pi) ) ang_freqs.requires_grad_(False) ang_freqs, _ = torch.sort(ang_freqs, descending=False) elif strategy == "fixed_for_noninteger_kernel_size": max_freq = nyquist_rate / kernel_size * int(kernel_size) ang_freqs = ( torch.linspace(0, max_freq * 2.0 * numpy.pi, M).float().to(device) ) else: raise NotImplementedError( f"Undefined frequency sampling strategy [{strategy}]" ) normalized_ang_freqs = ang_freqs / float(sample_rate) if kernel_size % 2 == 1: seq_P = torch.arange(-P, P + 1).float()[None, :].to(device) ln_W = -normalized_ang_freqs[:, None] * seq_P # M x 2P+1 else: seq_P = torch.arange(-(P - 1), P + 1).float()[None, :].to(device) ln_W = -normalized_ang_freqs[:, None] * seq_P # M x 2P ln_W = ln_W.to(device) W = torch.cat((torch.cos(ln_W), torch.sin(ln_W)), dim=0) # 2*M x 2P ### pinvW = torch.pinverse(W) # 2P x 2M pinvW.requires_grad_(False) ang_freqs.requires_grad_(False) return ang_freqs, pinvW def approximate_by_FIR(self, device): """Approximate frequency responses of analog filters with those of digital filters Args: device (torch.Device): Computation device Return: torch.Tensor: Time-reversed impulse responses of digital filters (n_filters x filter_degree (-P to P)) """ strategy = self.frequency_sampling_strategy[0 if self.training else 1] if strategy == "fixed" or strategy == "fixed_for_noninteger_kernel_size": cache_tag = (self.sample_rate, self.kernel_size, self.stride) if cache_tag in self._cache: ang_freqs, pinvW = self._cache[cache_tag] ang_freqs = ang_freqs.detach().to(device) pinvW = pinvW.detach().to(device) else: ang_freqs, pinvW = self._compute_pinvW(device) self._cache[cache_tag] = ( ang_freqs.detach().cpu(), pinvW.detach().cpu(), ) elif strategy == "randomized" or strategy == "completely_randomized": ang_freqs, pinvW = self._compute_pinvW(device) else: raise NotImplementedError( f"Undefined frequency sampling strategy [{strategy}]" ) ### resp_r, resp_i = self.continuous_filters.get_frequency_responses( ang_freqs ) # n_filters x M resp = torch.cat((resp_r, resp_i), dim=1) # n_filters x 2M ### fir_coeffs = (pinvW[None, :, :] @ resp[:, :, None])[:, :, 0] # n_filters x 2P kernel_size = int(self.kernel_size[0] / 2) * 2 return fir_coeffs[ :, torch.arange(kernel_size - 1, -1, -1) ] # time-reversed impulse response def extra_repr(self): s = "{in_channels}, {out_channels}, sample_rate={sample_rate}, n_samples={n_samples}" if hasattr(self, "kernel_size"): s += ", kernel_size={kernel_size}" if hasattr(self, "stride"): s += ", stride={stride}" if hasattr(self, "kernel_size"): s += ", padding={padding}" return s.format(**self.__dict__) def get_analog_freqresp_for_visualization(self, ang_freqs): """Get frequency responses of analog filters for visualization Args: ang_freqs (torch.Tensor): Unnormalized angular frequency [rad] (n_angfreqs) Return: torch.Tensor[cfloat]: Frequency responses of analog filters (n_filters x n_angfreqs) """ resp_r, resp_i = self.continuous_filters.get_frequency_responses(ang_freqs) resp = torch.stack((resp_r, resp_i), axis=-1) return torch.view_as_complex(resp) class FreqRespSampConv1d(_FreqRespSampFIRs): def __init__( self, in_channels, out_channels, n_samples, ContFilterType, filter_params, use_Hilbert_transforms=False, frequency_sampling_strategy=["fixed", "fixed"], ): super().__init__( in_channels=in_channels, out_channels=out_channels, n_samples=n_samples, ContFilterType=ContFilterType, filter_params=filter_params, use_Hilbert_transforms=use_Hilbert_transforms, transposed=False, frequency_sampling_strategy=frequency_sampling_strategy, ) def forward( self, input, torchaudio_options={}, ): return self.convolution_time_interpolation( input, torchaudio_options=torchaudio_options, ) @property def is_transposed(self): return False class FreqRespSampConvTranspose1d(_FreqRespSampFIRs): def __init__( self, in_channels, out_channels, n_samples, ContFilterType, filter_params, use_Hilbert_transforms=False, frequency_sampling_strategy=["fixed", "fixed"], ): super().__init__( in_channels=in_channels, out_channels=out_channels, n_samples=n_samples, ContFilterType=ContFilterType, filter_params=filter_params, use_Hilbert_transforms=use_Hilbert_transforms, transposed=True, frequency_sampling_strategy=frequency_sampling_strategy, ) def forward( self, input, time_interpolation=False, interp_type="cubic", torchaudio_options={}, other_options={}, ): if time_interpolation: return self.new_convolution_time_interpolation_reverse( input, interp_type=interp_type, torchaudio_options=torchaudio_options, other_options=other_options, ) else: return F.conv_transpose1d( input, self.weight(), None, self.stride, self.padding, self.output_padding, 1, _single(1), ) @property def is_transposed(self): return True ## invariant impulse response method class _InvImpRespFIRs(_FIRDesignBase): def weight(self): weight = self.continuous_filters.get_impulse_responses( self.sample_rate, self.kernel_size[0] ) return super().weight(weight) def precompute_weight(self): assert self.training, "This function should not be called during training." self._precomputed_weight = self.weight() def get_analog_impulse_resp_for_visualization(self, sample_rate, kernel_size): return self.continuous_filters.get_impulse_responses(sample_rate, kernel_size) def get_analog_impulse_resp_for_visualization_oversampling( self, sample_rate, kernel_size ): return self.continuous_filters.get_impulse_responses_oversampling( sample_rate, kernel_size ) class InvImpRespConv1d(_InvImpRespFIRs): def __init__( self, in_channels, out_channels, ContFilterType, filter_params, use_Hilbert_transforms=False, ): super().__init__( in_channels=in_channels, out_channels=out_channels, ContFilterType=ContFilterType, filter_params=filter_params, use_Hilbert_transforms=use_Hilbert_transforms, transposed=False, ) def forward( self, input, time_interpolation=False, interp_type="cubic", torchaudio_options={}, other_options={}, ): if time_interpolation: return self.convolution_time_interpolation( input, interp_type=interp_type, torchaudio_options=torchaudio_options, other_options=other_options, ) else: return F.conv1d( input, self.weight(), None, self.stride, self.padding, _single(1), 1 ) @property def is_transposed(self): return False class InvImpRespConvTranspose1d(_InvImpRespFIRs): def __init__( self, in_channels, out_channels, ContFilterType, filter_params, use_Hilbert_transforms=False, ): super().__init__( in_channels=in_channels, out_channels=out_channels, ContFilterType=ContFilterType, filter_params=filter_params, use_Hilbert_transforms=use_Hilbert_transforms, transposed=True, ) def forward( self, input, time_interpolation=False, interp_type="cubic", torchaudio_options={}, other_options={}, ): if time_interpolation: return self.new_convolution_time_interpolation_reverse( input, interp_type=interp_type, torchaudio_options=torchaudio_options, other_options=other_options, ) else: return F.conv_transpose1d( input, self.weight(), None, self.stride, self.padding, self.output_padding, 1, _single(1), ) @property def is_transposed(self): return True ############################## class PinvConvTranspose1d(nn.Module): def __init__(self, encoder): super().__init__() self.encoder = encoder def prepare(self, sample_rate: int, kernel_size: int, stride: int): self.sample_rate = sample_rate self.kernel_size = kernel_size self.stride = stride def weight(self): """Computes pseudo inverse filterbank of given filters.""" scale = self.stride / self.kernel_size filters = self.encoder.weight() shape = filters.shape ifilt = torch.pinverse(filters.squeeze()).transpose(-1, -2).view(shape) # Compensate for the overlap-add. return ifilt * scale def forward(self, input): return F.conv_transpose1d( input, self.weight(), None, self.stride, (self.kernel_size - self.stride) // 2, _single(0), 1, _single(1), ) class Conv1dWithHilbertTransforms(nn.Module): def __init__( self, in_channels, out_channels, kernel_size, stride, use_Hilbert_transforms=True, ): super().__init__() self.conv = nn.Conv1d( in_channels, out_channels // 2 if use_Hilbert_transforms else out_channels, kernel_size=kernel_size, stride=stride, bias=False, padding=(kernel_size - stride) // 2, ) self.use_Hilbert_transforms = use_Hilbert_transforms def forward(self, input): if self.use_Hilbert_transforms: weight = self.conv.weight # out_ch x in_ch x weight weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) weight = torch.cat( (weight, compute_Hilbert_transforms_of_filters(weight)), dim=0 ) return self._forward( input, weight.reshape( self.conv.weight.shape[0], self.conv.weight.shape[1] * 2, self.conv.weight.shape[2], ), ) else: return self.conv(input) def _forward(self, input, weight): if self.conv.padding_mode != "zeros": return F.conv1d( F.pad( input, self.conv._padding_repeated_twice, mode=self.conv.padding_mode, ), weight, None, self.conv.stride, _single(0), self.conv.dilation, self.conv.groups, ) return F.conv1d( input, weight, None, self.conv.stride, self.conv.padding, self.conv.dilation, self.conv.groups, ) class ConvTranspose1dWithHilbertTransforms(nn.Module): def __init__( self, in_channels, out_channels, kernel_size, stride, use_Hilbert_transforms=True, ): super().__init__() self.conv = nn.ConvTranspose1d( in_channels, out_channels // 2 if use_Hilbert_transforms else out_channels, kernel_size=kernel_size, stride=stride, bias=False, padding=(kernel_size - stride) // 2, ) self.use_Hilbert_transforms = use_Hilbert_transforms def forward(self, input): if self.use_Hilbert_transforms: weight = self.conv.weight # out_ch x in_ch x weight weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) weight = torch.cat( (weight, compute_Hilbert_transforms_of_filters(weight)), dim=0 ) return self._forward( input, weight.reshape( self.conv.weight.shape[0], self.conv.weight.shape[1] * 2, self.conv.weight.shape[2], ), ) else: return self.conv(input) def _forward(self, input, weight, output_size=None): if self.conv.padding_mode != "zeros": raise ValueError( "Only `zeros` padding mode is supported for ConvTranspose1d" ) output_padding = self.conv._output_padding( input, output_size, self.conv.stride, self.conv.padding, self.conv.kernel_size, ) return F.conv_transpose1d( input, weight, None, self.conv.stride, self.conv.padding, output_padding, self.conv.groups, self.conv.dilation, )