| | |
| | |
| |
|
| | """ |
| | Implementation of a FFT based 1D convolution in PyTorch. |
| | While FFT is used in CUDNN for small kernel sizes, it is not the case for long ones, e.g. 512. |
| | This module implements efficient FFT based convolutions for such convolutions. A typical |
| | application is for evaluationg FIR filters with a long receptive field, typically |
| | evaluated with a stride of 1. |
| | """ |
| | from typing import Optional |
| |
|
| | import torch |
| | try: |
| | import torch.fft as new_fft |
| | except ImportError: |
| | new_fft = None |
| | from torch.nn import functional as F |
| |
|
| | from .core import pad_to, unfold |
| | from .utils import simple_repr |
| |
|
| |
|
| | |
| | def _new_rfft(x: torch.Tensor): |
| | z = new_fft.rfft(x, dim=-1) |
| | return torch.view_as_real(z) |
| |
|
| |
|
| | def _old_rfft(x: torch.Tensor): |
| | return torch.rfft(x, 1) |
| |
|
| |
|
| | def _old_irfft(x: torch.Tensor, length: int): |
| | result = torch.irfft(x, 1, signal_sizes=(length,)) |
| | return result |
| |
|
| |
|
| | def _new_irfft(x: torch.Tensor, length: int): |
| | x = torch.view_as_complex(x) |
| | return new_fft.irfft(x, length, dim=-1) |
| |
|
| |
|
| | if new_fft is None: |
| | _rfft = _old_rfft |
| | _irfft = _old_irfft |
| | else: |
| | _rfft = _new_rfft |
| | _irfft = _new_irfft |
| |
|
| |
|
| | def _compl_mul_conjugate(a: torch.Tensor, b: torch.Tensor): |
| | """ |
| | Given a and b two tensors of dimension 4 |
| | with the last dimension being the real and imaginary part, |
| | returns a multiplied by the conjugate of b, the multiplication |
| | being with respect to the second dimension. |
| | |
| | """ |
| | |
| | |
| |
|
| | op = "bcft,dct->bdft" |
| | return torch.stack([ |
| | torch.einsum(op, a[..., 0], b[..., 0]) + torch.einsum(op, a[..., 1], b[..., 1]), |
| | torch.einsum(op, a[..., 1], b[..., 0]) - torch.einsum(op, a[..., 0], b[..., 1]) |
| | ], |
| | dim=-1) |
| |
|
| |
|
| | def fft_conv1d( |
| | input: torch.Tensor, weight: torch.Tensor, |
| | bias: Optional[torch.Tensor] = None, stride: int = 1, padding: int = 0, |
| | block_ratio: float = 5): |
| | """ |
| | Same as `torch.nn.functional.conv1d` but using FFT for the convolution. |
| | Please check PyTorch documentation for more information. |
| | |
| | Args: |
| | input (Tensor): input signal of shape `[B, C, T]`. |
| | weight (Tensor): weight of the convolution `[D, C, K]` with `D` the number |
| | of output channels. |
| | bias (Tensor or None): if not None, bias term for the convolution. |
| | stride (int): stride of convolution. |
| | padding (int): padding to apply to the input. |
| | block_ratio (float): can be tuned for speed. The input is splitted in chunks |
| | with a size of `int(block_ratio * kernel_size)`. |
| | |
| | Shape: |
| | |
| | - Inputs: `input` is `[B, C, T]`, `weight` is `[D, C, K]` and bias is `[D]`. |
| | - Output: `(*, T)` |
| | |
| | |
| | ..note:: |
| | This function is faster than `torch.nn.functional.conv1d` only in specific cases. |
| | Typically, the kernel size should be of the order of 256 to see any real gain, |
| | for a stride of 1. |
| | |
| | ..Warning:: |
| | Dilation and groups are not supported at the moment. This function might use |
| | more memory than the default Conv1d implementation. |
| | """ |
| | input = F.pad(input, (padding, padding)) |
| | batch, channels, length = input.shape |
| | out_channels, _, kernel_size = weight.shape |
| |
|
| | if length < kernel_size: |
| | raise RuntimeError(f"Input should be at least as large as the kernel size {kernel_size}, " |
| | f"but it is only {length} samples long.") |
| | if block_ratio < 1: |
| | raise RuntimeError("Block ratio must be greater than 1.") |
| |
|
| | |
| | |
| | block_size: int = min(int(kernel_size * block_ratio), length) |
| | fold_stride = block_size - kernel_size + 1 |
| | weight = pad_to(weight, block_size) |
| | weight_z = _rfft(weight) |
| |
|
| | |
| | frames = unfold(input, block_size, fold_stride) |
| |
|
| | frames_z = _rfft(frames) |
| | out_z = _compl_mul_conjugate(frames_z, weight_z) |
| | out = _irfft(out_z, block_size) |
| | |
| | out = out[..., :-kernel_size + 1] |
| | out = out.reshape(batch, out_channels, -1) |
| | out = out[..., ::stride] |
| | target_length = (length - kernel_size) // stride + 1 |
| | out = out[..., :target_length] |
| | if bias is not None: |
| | out += bias[:, None] |
| | return out |
| |
|
| |
|
| | class FFTConv1d(torch.nn.Module): |
| | """ |
| | Same as `torch.nn.Conv1d` but based on `fft_conv1d`. |
| | Please check PyTorch documentation for more information. |
| | |
| | Args: |
| | in_channels (int): number of input channels. |
| | out_channels (int): number of output channels. |
| | kernel_size (int): kernel size of convolution. |
| | stride (int): stride of convolution. |
| | padding (int): padding to apply to the input. |
| | bias (bool): if True, use a bias term. |
| | |
| | ..note:: |
| | This module is faster than `torch.nn.Conv1d` only in specific cases. |
| | Typically, `kernel_size` should be of the order of 256 to see any real gain, |
| | for a stride of 1. |
| | |
| | ..warning:: |
| | Dilation and groups are not supported at the moment. This module might use |
| | more memory than the default Conv1d implementation. |
| | |
| | >>> fftconv = FFTConv1d(12, 24, 128, 4) |
| | >>> x = torch.randn(4, 12, 1024) |
| | >>> print(list(fftconv(x).shape)) |
| | [4, 24, 225] |
| | """ |
| | def __init__(self, in_channels: int, out_channels: int, kernel_size: int, |
| | stride: int = 1, padding: int = 0, bias: bool = True): |
| | super().__init__() |
| | self.in_channels = in_channels |
| | self.out_channels = out_channels |
| | self.kernel_size = kernel_size |
| | self.stride = stride |
| | self.padding = padding |
| |
|
| | conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size, bias=bias) |
| | self.weight = conv.weight |
| | self.bias = conv.bias |
| |
|
| | def forward(self, input: torch.Tensor): |
| | return fft_conv1d( |
| | input, self.weight, self.bias, self.stride, self.padding) |
| |
|
| | def __repr__(self): |
| | return simple_repr(self, overrides={"bias": self.bias is not None}) |
| |
|