import functools from typing import Sequence import torch import torchaudio from torch import nn class RFFTimeDomainImplicitFilter(nn.Module): def __init__( self, n_filters: int, init_kernel_size: int, init_sample_rate: int, ch_list: Sequence[int] = [32, 32], n_RFFs: int = 32, nonlinearity: str = "relu", train_RFF: bool = False, use_layer_norm: bool = False, ) -> None: """n_filters: Number of filters. init_kernel_size: Initial kernel size. init_sample_rate: Initial sample rate. ch_list: Channel list of MLP. n_RFFs: Number of RFFs. If n_RFFs <= 0, do not use random Fourier feature inputs (i.e., directly input normalized time). nonlinearity (str): Nonlinearity train_RFF (bool): If True, train RFFs. use_layer_norm (bool): If True, use layer norm. """ super().__init__() self.n_filters = n_filters self.register_buffer("init_kernel_size", torch.tensor(init_kernel_size).float()) self.register_buffer("init_sample_rate", torch.tensor(init_sample_rate).float()) # nonlinearity if nonlinearity == "relu": NonlinearityClass = functools.partial(nn.ReLU, inplace=True) else: raise NotImplementedError # MLP layers = [] in_ch_list = [n_RFFs * 2 if n_RFFs > 0 else 1, *list(ch_list)] out_ch_list = [*list(ch_list), n_filters] for (i, in_ch), out_ch in zip(enumerate(in_ch_list), out_ch_list): layers.append(nn.Conv1d(in_ch, out_ch, 1)) if i < len(in_ch_list) - 1: if use_layer_norm: layers.append(nn.GroupNorm(1, out_ch)) layers.append(NonlinearityClass()) self.implicit_filter = nn.Sequential(*layers) if n_RFFs > 0: self.RFF_param = nn.Parameter( torch.zeros((n_RFFs,), dtype=torch.float).normal_( 0.0, 2.0 * torch.pi * 10.0, ), requires_grad=train_RFF, ) else: self.RFF_param = None def set_zero_bias(m) -> None: if isinstance(m, nn.Conv1d): if m.bias is None: msg = "bias cannot be none" raise ValueError(msg) m.bias.data.fill_(0.0) self.implicit_filter.apply(set_zero_bias) @property def device(self): return self.implicit_filter[0].weight.device def _get_ir(self, normalized_time): """Return discrete-time impulse responses (corresponding to the weights of the convolutional layer) from MLP. Args: normalized_time (torch.Tensor): Normalized time (time). Return: torch.Tensor: Discrete-time impulse responses (n_filters x time) """ if self.RFF_param is not None: RFF = self.RFF_param[:, None] @ normalized_time[None, :] # n_RFFs x time RFF = torch.cat((RFF.sin(), RFF.cos()), dim=0) # n_RFFs*2 x time ir = self.implicit_filter(RFF[None, :, :]) # 1 x n_filters x time else: ir = self.implicit_filter( normalized_time[None, None, :], ) # 1 x n_filters x time return ir.view(*(ir.shape[1:])) def get_impulse_responses(self, sample_rate: int, kernel_size): """Calculating discrete-time impulse responses (corresponding to the weights of the convolutional layer) from MLP.""" use_oversampling = False if not self.training and hasattr(self, "use_oversampling"): use_oversampling = self.use_oversampling if use_oversampling: ir = self.get_impulse_responses_oversampling(sample_rate) else: normalized_time = torch.linspace( -1.0, 1.0, kernel_size, device=self.device, requires_grad=False, ) # time ir = self._get_ir(normalized_time) return ir def get_impulse_responses_oversampling(self, sample_rate: int): """Calculating discrete-time impulse responses (corresponding to the weights of the convolutional layer) from MLP with oversampling for anti-aliasing. First, calculate the discrete-time impulse responses with the trained sample rate. Then, resample the calculated discrete-time impulse responses at the input sample rate. """ normalized_time = torch.linspace( -1.0, 1.0, self.init_kernel_size.item(), device=self.device, requires_grad=False, ) # time ir = self._get_ir(normalized_time) resampled_ir = torchaudio.functional.resample( ir, int(self.init_sample_rate.item()), int(sample_rate), ) # resampling return resampled_ir.float().to(self.device)