sfi_hubert / conv_any_stride.py
Wataru's picture
add non_integer stride
f634d4b verified
import functools
import math
import warnings
from distutils.version import LooseVersion
from fractions import Fraction
from typing import Optional
import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn.modules.utils import _single
def _get_sinc_resample_kernel(
orig_freq: int,
new_freq: int,
gcd: int,
lowpass_filter_width: int = 6,
rolloff: float = 0.99,
resampling_kernel: str = "sinc",
sinc_window: str = "sinc_interpolation",
beta: Optional[float] = None,
device: torch.device = torch.device("cpu"),
dtype: Optional[torch.dtype] = None,
):
if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq):
raise Exception(
"Frequencies must be of integer type to ensure quality resampling computation. "
"To work around this, manually convert both frequencies to integer values "
"that maintain their resampling rate ratio before passing them into the function. "
"Example: To downsample a 44100 hz waveform by a factor of 8, use "
"`orig_freq=8` and `new_freq=1` instead of `orig_freq=44100` and `new_freq=5512.5`. "
"For more information, please refer to https://github.com/pytorch/audio/issues/1487."
)
orig_freq = int(orig_freq) // gcd
new_freq = int(new_freq) // gcd
if lowpass_filter_width <= 0:
raise ValueError("Low pass filter width should be positive.")
base_freq = max(orig_freq, new_freq)
if resampling_kernel == "sinc":
if sinc_window not in ["sinc_interpolation", "kaiser_window"]:
raise ValueError("Invalid resampling method: {}".format(sinc_window))
base_freq *= rolloff
# The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor)
# using the sinc interpolation formula:
# x(t) = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - t))
# We can then sample the function x(t) with a different sample rate:
# y[j] = x(j / new_freq)
# or,
# y[j] = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))
# We see here that y[j] is the convolution of x[i] with a specific filter, for which
# we take an FIR approximation, stopping when we see at least `lowpass_filter_width` zeros crossing.
# But y[j+1] is going to have a different set of weights and so on, until y[j + new_freq].
# Indeed:
# y[j + new_freq] = sum_i x[i] sinc(pi * orig_freq * ((i / orig_freq - (j + new_freq) / new_freq))
# = sum_i x[i] sinc(pi * orig_freq * ((i - orig_freq) / orig_freq - j / new_freq))
# = sum_i x[i + orig_freq] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))
# so y[j+new_freq] uses the same filter as y[j], but on a shifted version of x by `orig_freq`.
# This will explain the F.conv1d after, with a stride of orig_freq.
width = math.ceil(lowpass_filter_width * orig_freq / base_freq)
# If orig_freq is still big after GCD reduction, most filters will be very unbalanced, i.e.,
# they will have a lot of almost zero values to the left or to the right...
# There is probably a way to evaluate those filters more efficiently, but this is kept for
# future work.
idx_dtype = dtype if dtype is not None else torch.float64
idx = (
torch.arange(-width, width + orig_freq, dtype=idx_dtype, device=device)[
None, None
]
/ orig_freq
)
t = (
torch.arange(0, -new_freq, -1, dtype=dtype, device=device)[:, None, None]
/ new_freq
+ idx
)
t *= base_freq
t = t.clamp_(-lowpass_filter_width, lowpass_filter_width)
# we do not use built in torch windows here as we need to evaluate the window
# at specific positions, not over a regular grid.
if sinc_window == "sinc_interpolation":
window = torch.cos(t * math.pi / lowpass_filter_width / 2) ** 2
else:
# kaiser_window
if beta is None:
beta = 14.769656459379492
beta_tensor = torch.tensor(float(beta))
window = torch.i0(
beta_tensor * torch.sqrt(1 - (t / lowpass_filter_width) ** 2)
) / torch.i0(beta_tensor)
t *= math.pi
kernels = torch.where(t == 0, torch.tensor(1.0).to(t), t.sin() / t)
kernels *= window
if dtype is None:
kernels = kernels.to(dtype=torch.float32)
else:
raise NotImplementedError
return kernels, width
def _apply_sinc_resample_kernel(
waveform: Tensor,
orig_freq: int,
new_freq: int,
gcd: int,
kernel: Tensor,
width: int,
):
if not waveform.is_floating_point():
raise TypeError(
f"Expected floating point type for waveform tensor, but received {waveform.dtype}."
)
orig_freq = int(orig_freq) // gcd
new_freq = int(new_freq) // gcd
# pack batch
shape = waveform.size()
waveform = waveform.view(-1, shape[-1])
num_wavs, length = waveform.shape
waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq))
resampled = torch.nn.functional.conv1d(waveform[:, None], kernel, stride=orig_freq)
resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
target_length = int(math.ceil(new_freq * length / orig_freq))
resampled = resampled[..., :target_length]
# unpack batch
resampled = resampled.view(shape[:-1] + resampled.shape[-1:])
return resampled
def resample(
waveform: Tensor,
orig_freq: int,
new_freq: int,
lowpass_filter_width: int = 16,
rolloff: float = 0.99,
resampling_kernel: str = "sinc",
sinc_window: str = "kaiser_window",
beta: Optional[float] = None,
kernel=None,
width=None,
) -> Tensor:
r"""Resamples the waveform at the new frequency using bandlimited interpolation. :cite:`RESAMPLE`.
.. devices:: CPU CUDA
.. properties:: Autograd TorchScript
Note:
``transforms.Resample`` precomputes and reuses the resampling kernel, so using it will result in
more efficient computation if resampling multiple waveforms with the same resampling parameters.
Args:
waveform (Tensor): The input signal of dimension `(..., time)`
orig_freq (int): The original frequency of the signal
new_freq (int): The desired frequency
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
but less efficient. (Default: ``6``)
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
resampling_kernel (str, optional): The resampling kernel to use.
Options: [``"sinc"``, ``"linear"``, ``"nearest_neighbor"``] (Default: ``"sinc"``)
sinc_window (str, optional): The resampling method to use.
Options: [``"sinc_interpolation"``, ``"kaiser_window"``] (Default: ``"sinc_interpolation"``)
beta (float or None, optional): The shape parameter used for kaiser window.
Returns:
Tensor: The waveform at the new frequency of dimension `(..., time).`
"""
if orig_freq <= 0.0 or new_freq <= 0.0:
raise ValueError("Original frequency and desired frequecy should be positive")
if orig_freq == new_freq:
return waveform
gcd = math.gcd(int(orig_freq), int(new_freq))
if (kernel is None) or (width is None):
kernel, width = _get_sinc_resample_kernel(
orig_freq=orig_freq,
new_freq=new_freq,
gcd=gcd,
lowpass_filter_width=lowpass_filter_width,
rolloff=rolloff,
resampling_kernel=resampling_kernel,
sinc_window=sinc_window,
beta=beta,
device=waveform.device,
dtype=waveform.dtype,
)
resampled = _apply_sinc_resample_kernel(
waveform, orig_freq, new_freq, gcd, kernel, width
)
return resampled
def compute_Hilbert_transforms_of_filters(filters):
"""Compute the Hilber transforms of the input filters
Args:
weight (torch.Tensor): weight, n_filters x kernel_size
Return
torch.Tensor: Hilbert transforms of the weight, out_channels x in_channels x kernel_size
"""
if LooseVersion(torch.__version__) < LooseVersion("1.7.0"):
ft_f = torch.rfft(
filters.reshape(filters.shape[0], 1, filters.shape[1]), 1, normalized=True
)
hft_f = torch.stack([ft_f[:, :, :, 1], -ft_f[:, :, :, 0]], dim=-1)
hft_f = torch.irfft(hft_f, 1, normalized=True, signal_sizes=(filters.shape[1],))
else: # New PyTorch version has fft module
ft_f = torch.fft.rfft(filters, n=filters.shape[1], dim=1, norm="ortho")
hft_f = torch.view_as_complex(torch.stack((ft_f.imag, -ft_f.real), axis=-1))
hft_f = torch.fft.irfft(hft_f, n=filters.shape[1], dim=1, norm="ortho")
return hft_f.reshape(*(filters.shape))
class _FIRDesignBase(nn.Module):
def __init__(
self,
in_channels,
out_channels,
ContFilterType,
filter_params,
use_Hilbert_transforms=False,
transposed=False,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
n_filters = in_channels * out_channels
self.use_Hilbert_transforms = use_Hilbert_transforms
if self.use_Hilbert_transforms:
if n_filters % 2 == 1:
raise ValueError(
f"n_filters must be even when using Hilbert transforms of filters [n_filters={n_filters}]"
)
n_filters //= 2
self.continuous_filters = ContFilterType(
n_filters=n_filters,
**filter_params,
)
self._transposed = transposed
def weight(self, weight):
if self.use_Hilbert_transforms:
weight = torch.cat(
(weight, compute_Hilbert_transforms_of_filters(weight)), dim=0
)
if self._transposed:
return weight.reshape(self.in_channels, self.out_channels, -1)
else:
return weight.reshape(self.out_channels, self.in_channels, -1)
def prepare(
self,
sample_rate: int,
kernel_size,
stride,
padding: int = None,
output_padding: int = 0,
):
self.sample_rate = sample_rate
self.kernel_size = _single(kernel_size)
self.stride = _single(stride)
self.trained_stride = stride
if padding is None:
self.padding = _single(int((self.kernel_size[0] - self.stride[0]) // 2))
else:
self.padding = _single(padding)
self.output_padding = (int(output_padding),)
def forward(self, input):
raise NotImplementedError
def extra_repr(self):
s = "{in_channels}, {out_channels}, sample_rate={sample_rate}"
if hasattr(self, "kernel_size"):
s += ", kernel_size={kernel_size}"
if hasattr(self, "stride"):
s += ", stride={stride}"
if hasattr(self, "kernel_size"):
s += ", padding={padding}"
if hasattr(self, "output_padding"):
s += ", output_padding={output_padding}"
return s.format(**self.__dict__)
def precompute_weight(self):
raise NotImplementedError
def convert(self):
warnings.warn(
f"Converting SFI to Non-SFI convolutional layers [sample_rate={self.sample_rate}, kernel_size={self.kernel_size}, stride={self.stride}, padding={self.padding}, output_padding={self.output_padding}]"
)
if self.is_transposed:
return functools.partial(
F.conv_transpose1d,
weight=self.weight(),
bias=None,
stride=self.stride,
padding=self.padding,
output_padding=self.output_padding,
dilation=_single(1),
groups=1,
)
else:
return functools.partial(
F.conv1d,
weight=self.weight(),
bias=None,
stride=self.stride,
padding=self.padding,
dilation=_single(1),
groups=1,
)
def prepare_for_time_interpolation(self, trained_stride):
self.trained_stride = trained_stride
def convolution_time_interpolation(self, input, torchaudio_options={}):
conv = F.conv1d(
input, self.weight(), None, 1, int(self.padding[0]), _single(1), 1
)
# org_sr: new_sr = 1 : 1/stride = stride : 1, org_sr and new_sr must be integers
# org_sr : new_sr = 1 : 1 / S^{(test)} = S^{(test)} : 1 = S^{(train)} * F_s^{(test)} / F_s^{(train)} : 1 = S^{(train)} * F_s^{(test)} : F_s^{(train)}
orig_sr2new_sr = Fraction(self.trained_stride / 1)
org_sr, new_sr = orig_sr2new_sr.as_integer_ratio()
output = resample(
conv, org_sr, new_sr, sinc_window="kaiser_window", **torchaudio_options
)
return output
def new_convolution_time_interpolation_reverse(self, input, torchaudio_options={}):
# org_sr : new_sr = 1 : 1 / S^{(test)} = S^{(test)} : 1 = S^{(train)} * F_s^{(test)} / F_s^{(train)} : 1 = S^{(train)} * F_s^{(test)} : F_s^{(train)}
org_sr = numpy.round(self.sample_rate * self.trained_stride / self.stride[0])
new_sr = numpy.round(self.sample_rate * self.trained_stride)
output_list = []
for i in range(input.shape[0]):
middle = resample(
input[i : i + 1],
org_sr,
new_sr,
sinc_window="kaiser_window",
**torchaudio_options,
)
output_tmp = F.conv_transpose1d(
middle,
self.weight(),
None,
1,
int(self.padding[0]),
self.output_padding,
1,
_single(1),
)
output_list.append(output_tmp)
output = torch.cat(output_list, dim=0)
return output
class _FreqRespSampFIRs(_FIRDesignBase):
SAMPLING_STRATEGY = [
"fixed",
"randomized",
"completely_randomized",
"fixed_for_noninteger_kernel_size",
]
def __init__(
self,
in_channels,
out_channels,
n_samples,
ContFilterType,
filter_params,
use_Hilbert_transforms=False,
transposed=False,
frequency_sampling_strategy=["fixed", "fixed"],
):
"""
Args:
in_channels (int): Number of channels of 1D sequence
out_channels (int): Number of channels produced by the convolution
n_samples (int): Number of sampled points for frequency sampling
ContFilterType (Class): Continuous filter class
filter_params (dict): Parameters of continuous filter class
use_Hilbert_transforms (bool): If True, the latter half of the filters are the Hilbert pairs of the former half.
transposed (bool): Whether this convolution is a transposed convolution or not.
frequency_sampling_strategy (list[str]): Strategy of frequency sampling at training and inference stages. "fixed" means equally-spaced sampling and "randomized" means adding small noise to the equally-spaced sampled positions.
"""
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
ContFilterType=ContFilterType,
filter_params=filter_params,
use_Hilbert_transforms=use_Hilbert_transforms,
transposed=transposed,
)
self.n_samples = n_samples
for i in range(len(frequency_sampling_strategy)):
if frequency_sampling_strategy[i] not in self.SAMPLING_STRATEGY:
raise NotImplementedError(
f"Undefined frequency sampling strategy [{frequency_sampling_strategy[i]}]"
)
self.frequency_sampling_strategy = frequency_sampling_strategy
self._cache = dict()
def weight(self):
weight = self.approximate_by_FIR(
self.continuous_filters.device
) # n_filters (or n_filters//2) x kernel_size
return super().weight(weight)
def precompute_weight(self):
assert self.training, "This function should not be called during training."
self._precomputed_weight = self.weight()
def _compute_pinvW(self, device):
kernel_size = self.kernel_size[0]
sample_rate = self.sample_rate
P = (kernel_size - 1) // 2 if kernel_size % 2 == 1 else kernel_size // 2
M = self.n_samples
nyquist_rate = sample_rate / 2
#
if not self.training and hasattr(self, "noninteger_kernel_size_adaptation"):
self.frequency_sampling_strategy[1] = "fixed_for_noninteger_kernel_size"
strategy = self.frequency_sampling_strategy[0 if self.training else 1]
if strategy == "fixed":
ang_freqs = (
torch.linspace(0, nyquist_rate * 2.0 * numpy.pi, M).float().to(device)
)
elif strategy == "randomized":
ang_freqs = torch.linspace(0, nyquist_rate * 2.0 * numpy.pi, M).float()
ang_freqs.requires_grad_(False)
delta_val = ang_freqs[1] - ang_freqs[0]
delta = torch.zeros_like(ang_freqs).uniform_(-delta_val / 2, delta_val / 2)
delta.requires_grad_(False)
if delta[0] < 0:
delta[0] = -delta[0]
if delta[-1] > 0:
delta[-1] = -delta[-1]
ang_freqs = ang_freqs + delta
ang_freqs = ang_freqs.to(device)
elif strategy == "completely_randomized":
ang_freqs = (
torch.zeros((M,), device=device)
.float()
.uniform_(0.0, nyquist_rate * 2.0 * numpy.pi)
)
ang_freqs.requires_grad_(False)
ang_freqs, _ = torch.sort(ang_freqs, descending=False)
elif strategy == "fixed_for_noninteger_kernel_size":
max_freq = nyquist_rate / kernel_size * int(kernel_size)
ang_freqs = (
torch.linspace(0, max_freq * 2.0 * numpy.pi, M).float().to(device)
)
else:
raise NotImplementedError(
f"Undefined frequency sampling strategy [{strategy}]"
)
normalized_ang_freqs = ang_freqs / float(sample_rate)
if kernel_size % 2 == 1:
seq_P = torch.arange(-P, P + 1).float()[None, :].to(device)
ln_W = -normalized_ang_freqs[:, None] * seq_P # M x 2P+1
else:
seq_P = torch.arange(-(P - 1), P + 1).float()[None, :].to(device)
ln_W = -normalized_ang_freqs[:, None] * seq_P # M x 2P
ln_W = ln_W.to(device)
W = torch.cat((torch.cos(ln_W), torch.sin(ln_W)), dim=0) # 2*M x 2P
###
pinvW = torch.pinverse(W) # 2P x 2M
pinvW.requires_grad_(False)
ang_freqs.requires_grad_(False)
return ang_freqs, pinvW
def approximate_by_FIR(self, device):
"""Approximate frequency responses of analog filters with those of digital filters
Args:
device (torch.Device): Computation device
Return:
torch.Tensor: Time-reversed impulse responses of digital filters (n_filters x filter_degree (-P to P))
"""
strategy = self.frequency_sampling_strategy[0 if self.training else 1]
if strategy == "fixed" or strategy == "fixed_for_noninteger_kernel_size":
cache_tag = (self.sample_rate, self.kernel_size, self.stride)
if cache_tag in self._cache:
ang_freqs, pinvW = self._cache[cache_tag]
ang_freqs = ang_freqs.detach().to(device)
pinvW = pinvW.detach().to(device)
else:
ang_freqs, pinvW = self._compute_pinvW(device)
self._cache[cache_tag] = (
ang_freqs.detach().cpu(),
pinvW.detach().cpu(),
)
elif strategy == "randomized" or strategy == "completely_randomized":
ang_freqs, pinvW = self._compute_pinvW(device)
else:
raise NotImplementedError(
f"Undefined frequency sampling strategy [{strategy}]"
)
###
resp_r, resp_i = self.continuous_filters.get_frequency_responses(
ang_freqs
) # n_filters x M
resp = torch.cat((resp_r, resp_i), dim=1) # n_filters x 2M
###
fir_coeffs = (pinvW[None, :, :] @ resp[:, :, None])[:, :, 0] # n_filters x 2P
kernel_size = int(self.kernel_size[0] / 2) * 2
return fir_coeffs[
:, torch.arange(kernel_size - 1, -1, -1)
] # time-reversed impulse response
def extra_repr(self):
s = "{in_channels}, {out_channels}, sample_rate={sample_rate}, n_samples={n_samples}"
if hasattr(self, "kernel_size"):
s += ", kernel_size={kernel_size}"
if hasattr(self, "stride"):
s += ", stride={stride}"
if hasattr(self, "kernel_size"):
s += ", padding={padding}"
return s.format(**self.__dict__)
def get_analog_freqresp_for_visualization(self, ang_freqs):
"""Get frequency responses of analog filters for visualization
Args:
ang_freqs (torch.Tensor): Unnormalized angular frequency [rad] (n_angfreqs)
Return:
torch.Tensor[cfloat]: Frequency responses of analog filters (n_filters x n_angfreqs)
"""
resp_r, resp_i = self.continuous_filters.get_frequency_responses(ang_freqs)
resp = torch.stack((resp_r, resp_i), axis=-1)
return torch.view_as_complex(resp)
class FreqRespSampConv1d(_FreqRespSampFIRs):
def __init__(
self,
in_channels,
out_channels,
n_samples,
ContFilterType,
filter_params,
use_Hilbert_transforms=False,
frequency_sampling_strategy=["fixed", "fixed"],
):
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
n_samples=n_samples,
ContFilterType=ContFilterType,
filter_params=filter_params,
use_Hilbert_transforms=use_Hilbert_transforms,
transposed=False,
frequency_sampling_strategy=frequency_sampling_strategy,
)
def forward(
self,
input,
torchaudio_options={},
):
return self.convolution_time_interpolation(
input,
torchaudio_options=torchaudio_options,
)
@property
def is_transposed(self):
return False
class FreqRespSampConvTranspose1d(_FreqRespSampFIRs):
def __init__(
self,
in_channels,
out_channels,
n_samples,
ContFilterType,
filter_params,
use_Hilbert_transforms=False,
frequency_sampling_strategy=["fixed", "fixed"],
):
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
n_samples=n_samples,
ContFilterType=ContFilterType,
filter_params=filter_params,
use_Hilbert_transforms=use_Hilbert_transforms,
transposed=True,
frequency_sampling_strategy=frequency_sampling_strategy,
)
def forward(
self,
input,
time_interpolation=False,
interp_type="cubic",
torchaudio_options={},
other_options={},
):
if time_interpolation:
return self.new_convolution_time_interpolation_reverse(
input,
interp_type=interp_type,
torchaudio_options=torchaudio_options,
other_options=other_options,
)
else:
return F.conv_transpose1d(
input,
self.weight(),
None,
self.stride,
self.padding,
self.output_padding,
1,
_single(1),
)
@property
def is_transposed(self):
return True
## invariant impulse response method
class _InvImpRespFIRs(_FIRDesignBase):
def weight(self):
weight = self.continuous_filters.get_impulse_responses(
self.sample_rate, self.kernel_size[0]
)
return super().weight(weight)
def precompute_weight(self):
assert self.training, "This function should not be called during training."
self._precomputed_weight = self.weight()
def get_analog_impulse_resp_for_visualization(self, sample_rate, kernel_size):
return self.continuous_filters.get_impulse_responses(sample_rate, kernel_size)
def get_analog_impulse_resp_for_visualization_oversampling(
self, sample_rate, kernel_size
):
return self.continuous_filters.get_impulse_responses_oversampling(
sample_rate, kernel_size
)
class InvImpRespConv1d(_InvImpRespFIRs):
def __init__(
self,
in_channels,
out_channels,
ContFilterType,
filter_params,
use_Hilbert_transforms=False,
):
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
ContFilterType=ContFilterType,
filter_params=filter_params,
use_Hilbert_transforms=use_Hilbert_transforms,
transposed=False,
)
def forward(
self,
input,
time_interpolation=False,
interp_type="cubic",
torchaudio_options={},
other_options={},
):
if time_interpolation:
return self.convolution_time_interpolation(
input,
interp_type=interp_type,
torchaudio_options=torchaudio_options,
other_options=other_options,
)
else:
return F.conv1d(
input, self.weight(), None, self.stride, self.padding, _single(1), 1
)
@property
def is_transposed(self):
return False
class InvImpRespConvTranspose1d(_InvImpRespFIRs):
def __init__(
self,
in_channels,
out_channels,
ContFilterType,
filter_params,
use_Hilbert_transforms=False,
):
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
ContFilterType=ContFilterType,
filter_params=filter_params,
use_Hilbert_transforms=use_Hilbert_transforms,
transposed=True,
)
def forward(
self,
input,
time_interpolation=False,
interp_type="cubic",
torchaudio_options={},
other_options={},
):
if time_interpolation:
return self.new_convolution_time_interpolation_reverse(
input,
interp_type=interp_type,
torchaudio_options=torchaudio_options,
other_options=other_options,
)
else:
return F.conv_transpose1d(
input,
self.weight(),
None,
self.stride,
self.padding,
self.output_padding,
1,
_single(1),
)
@property
def is_transposed(self):
return True
##############################
class PinvConvTranspose1d(nn.Module):
def __init__(self, encoder):
super().__init__()
self.encoder = encoder
def prepare(self, sample_rate: int, kernel_size: int, stride: int):
self.sample_rate = sample_rate
self.kernel_size = kernel_size
self.stride = stride
def weight(self):
"""Computes pseudo inverse filterbank of given filters."""
scale = self.stride / self.kernel_size
filters = self.encoder.weight()
shape = filters.shape
ifilt = torch.pinverse(filters.squeeze()).transpose(-1, -2).view(shape)
# Compensate for the overlap-add.
return ifilt * scale
def forward(self, input):
return F.conv_transpose1d(
input,
self.weight(),
None,
self.stride,
(self.kernel_size - self.stride) // 2,
_single(0),
1,
_single(1),
)
class Conv1dWithHilbertTransforms(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
use_Hilbert_transforms=True,
):
super().__init__()
self.conv = nn.Conv1d(
in_channels,
out_channels // 2 if use_Hilbert_transforms else out_channels,
kernel_size=kernel_size,
stride=stride,
bias=False,
padding=(kernel_size - stride) // 2,
)
self.use_Hilbert_transforms = use_Hilbert_transforms
def forward(self, input):
if self.use_Hilbert_transforms:
weight = self.conv.weight # out_ch x in_ch x weight
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
weight = torch.cat(
(weight, compute_Hilbert_transforms_of_filters(weight)), dim=0
)
return self._forward(
input,
weight.reshape(
self.conv.weight.shape[0],
self.conv.weight.shape[1] * 2,
self.conv.weight.shape[2],
),
)
else:
return self.conv(input)
def _forward(self, input, weight):
if self.conv.padding_mode != "zeros":
return F.conv1d(
F.pad(
input,
self.conv._padding_repeated_twice,
mode=self.conv.padding_mode,
),
weight,
None,
self.conv.stride,
_single(0),
self.conv.dilation,
self.conv.groups,
)
return F.conv1d(
input,
weight,
None,
self.conv.stride,
self.conv.padding,
self.conv.dilation,
self.conv.groups,
)
class ConvTranspose1dWithHilbertTransforms(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
use_Hilbert_transforms=True,
):
super().__init__()
self.conv = nn.ConvTranspose1d(
in_channels,
out_channels // 2 if use_Hilbert_transforms else out_channels,
kernel_size=kernel_size,
stride=stride,
bias=False,
padding=(kernel_size - stride) // 2,
)
self.use_Hilbert_transforms = use_Hilbert_transforms
def forward(self, input):
if self.use_Hilbert_transforms:
weight = self.conv.weight # out_ch x in_ch x weight
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
weight = torch.cat(
(weight, compute_Hilbert_transforms_of_filters(weight)), dim=0
)
return self._forward(
input,
weight.reshape(
self.conv.weight.shape[0],
self.conv.weight.shape[1] * 2,
self.conv.weight.shape[2],
),
)
else:
return self.conv(input)
def _forward(self, input, weight, output_size=None):
if self.conv.padding_mode != "zeros":
raise ValueError(
"Only `zeros` padding mode is supported for ConvTranspose1d"
)
output_padding = self.conv._output_padding(
input,
output_size,
self.conv.stride,
self.conv.padding,
self.conv.kernel_size,
)
return F.conv_transpose1d(
input,
weight,
None,
self.conv.stride,
self.conv.padding,
output_padding,
self.conv.groups,
self.conv.dilation,
)