|
|
import functools |
|
|
from typing import Sequence |
|
|
|
|
|
import torch |
|
|
import torchaudio |
|
|
from torch import nn |
|
|
|
|
|
|
|
|
class RFFTimeDomainImplicitFilter(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
n_filters: int, |
|
|
init_kernel_size: int, |
|
|
init_sample_rate: int, |
|
|
ch_list: Sequence[int] = [32, 32], |
|
|
n_RFFs: int = 32, |
|
|
nonlinearity: str = "relu", |
|
|
train_RFF: bool = False, |
|
|
use_layer_norm: bool = False, |
|
|
) -> None: |
|
|
"""n_filters: Number of filters. |
|
|
init_kernel_size: Initial kernel size. |
|
|
init_sample_rate: Initial sample rate. |
|
|
ch_list: Channel list of MLP. |
|
|
n_RFFs: Number of RFFs. If n_RFFs <= 0, do not use random Fourier feature inputs (i.e., directly input normalized time). |
|
|
nonlinearity (str): Nonlinearity |
|
|
train_RFF (bool): If True, train RFFs. |
|
|
use_layer_norm (bool): If True, use layer norm. |
|
|
""" |
|
|
super().__init__() |
|
|
self.n_filters = n_filters |
|
|
self.register_buffer("init_kernel_size", torch.tensor(init_kernel_size).float()) |
|
|
self.register_buffer("init_sample_rate", torch.tensor(init_sample_rate).float()) |
|
|
|
|
|
|
|
|
if nonlinearity == "relu": |
|
|
NonlinearityClass = functools.partial(nn.ReLU, inplace=True) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
layers = [] |
|
|
in_ch_list = [n_RFFs * 2 if n_RFFs > 0 else 1, *list(ch_list)] |
|
|
out_ch_list = [*list(ch_list), n_filters] |
|
|
for (i, in_ch), out_ch in zip(enumerate(in_ch_list), out_ch_list): |
|
|
layers.append(nn.Conv1d(in_ch, out_ch, 1)) |
|
|
if i < len(in_ch_list) - 1: |
|
|
if use_layer_norm: |
|
|
layers.append(nn.GroupNorm(1, out_ch)) |
|
|
layers.append(NonlinearityClass()) |
|
|
self.implicit_filter = nn.Sequential(*layers) |
|
|
|
|
|
if n_RFFs > 0: |
|
|
self.RFF_param = nn.Parameter( |
|
|
torch.zeros((n_RFFs,), dtype=torch.float).normal_( |
|
|
0.0, |
|
|
2.0 * torch.pi * 10.0, |
|
|
), |
|
|
requires_grad=train_RFF, |
|
|
) |
|
|
else: |
|
|
self.RFF_param = None |
|
|
|
|
|
def set_zero_bias(m) -> None: |
|
|
if isinstance(m, nn.Conv1d): |
|
|
if m.bias is None: |
|
|
msg = "bias cannot be none" |
|
|
raise ValueError(msg) |
|
|
m.bias.data.fill_(0.0) |
|
|
|
|
|
self.implicit_filter.apply(set_zero_bias) |
|
|
|
|
|
@property |
|
|
def device(self): |
|
|
return self.implicit_filter[0].weight.device |
|
|
|
|
|
def _get_ir(self, normalized_time): |
|
|
"""Return discrete-time impulse responses (corresponding to the weights of the convolutional layer) from MLP. |
|
|
|
|
|
Args: |
|
|
normalized_time (torch.Tensor): Normalized time (time). |
|
|
|
|
|
Return: |
|
|
torch.Tensor: Discrete-time impulse responses (n_filters x time) |
|
|
|
|
|
""" |
|
|
if self.RFF_param is not None: |
|
|
RFF = self.RFF_param[:, None] @ normalized_time[None, :] |
|
|
RFF = torch.cat((RFF.sin(), RFF.cos()), dim=0) |
|
|
ir = self.implicit_filter(RFF[None, :, :]) |
|
|
else: |
|
|
ir = self.implicit_filter( |
|
|
normalized_time[None, None, :], |
|
|
) |
|
|
return ir.view(*(ir.shape[1:])) |
|
|
|
|
|
def get_impulse_responses(self, sample_rate: int, kernel_size): |
|
|
"""Calculating discrete-time impulse responses (corresponding to the weights of the convolutional layer) from MLP.""" |
|
|
use_oversampling = False |
|
|
if not self.training and hasattr(self, "use_oversampling"): |
|
|
use_oversampling = self.use_oversampling |
|
|
|
|
|
if use_oversampling: |
|
|
ir = self.get_impulse_responses_oversampling(sample_rate) |
|
|
else: |
|
|
normalized_time = torch.linspace( |
|
|
-1.0, |
|
|
1.0, |
|
|
kernel_size, |
|
|
device=self.device, |
|
|
requires_grad=False, |
|
|
) |
|
|
ir = self._get_ir(normalized_time) |
|
|
return ir |
|
|
|
|
|
def get_impulse_responses_oversampling(self, sample_rate: int): |
|
|
"""Calculating discrete-time impulse responses (corresponding to the weights of the convolutional layer) from MLP with oversampling for anti-aliasing. |
|
|
First, calculate the discrete-time impulse responses with the trained sample rate. |
|
|
Then, resample the calculated discrete-time impulse responses at the input sample rate. |
|
|
""" |
|
|
normalized_time = torch.linspace( |
|
|
-1.0, |
|
|
1.0, |
|
|
self.init_kernel_size.item(), |
|
|
device=self.device, |
|
|
requires_grad=False, |
|
|
) |
|
|
ir = self._get_ir(normalized_time) |
|
|
resampled_ir = torchaudio.functional.resample( |
|
|
ir, |
|
|
int(self.init_sample_rate.item()), |
|
|
int(sample_rate), |
|
|
) |
|
|
return resampled_ir.float().to(self.device) |
|
|
|