sfi_hubert / rff_filter.py
Wataru's picture
Upload model
e1c2051 verified
import functools
from typing import Sequence
import torch
import torchaudio
from torch import nn
class RFFTimeDomainImplicitFilter(nn.Module):
def __init__(
self,
n_filters: int,
init_kernel_size: int,
init_sample_rate: int,
ch_list: Sequence[int] = [32, 32],
n_RFFs: int = 32,
nonlinearity: str = "relu",
train_RFF: bool = False,
use_layer_norm: bool = False,
) -> None:
"""n_filters: Number of filters.
init_kernel_size: Initial kernel size.
init_sample_rate: Initial sample rate.
ch_list: Channel list of MLP.
n_RFFs: Number of RFFs. If n_RFFs <= 0, do not use random Fourier feature inputs (i.e., directly input normalized time).
nonlinearity (str): Nonlinearity
train_RFF (bool): If True, train RFFs.
use_layer_norm (bool): If True, use layer norm.
"""
super().__init__()
self.n_filters = n_filters
self.register_buffer("init_kernel_size", torch.tensor(init_kernel_size).float())
self.register_buffer("init_sample_rate", torch.tensor(init_sample_rate).float())
# nonlinearity
if nonlinearity == "relu":
NonlinearityClass = functools.partial(nn.ReLU, inplace=True)
else:
raise NotImplementedError
# MLP
layers = []
in_ch_list = [n_RFFs * 2 if n_RFFs > 0 else 1, *list(ch_list)]
out_ch_list = [*list(ch_list), n_filters]
for (i, in_ch), out_ch in zip(enumerate(in_ch_list), out_ch_list):
layers.append(nn.Conv1d(in_ch, out_ch, 1))
if i < len(in_ch_list) - 1:
if use_layer_norm:
layers.append(nn.GroupNorm(1, out_ch))
layers.append(NonlinearityClass())
self.implicit_filter = nn.Sequential(*layers)
if n_RFFs > 0:
self.RFF_param = nn.Parameter(
torch.zeros((n_RFFs,), dtype=torch.float).normal_(
0.0,
2.0 * torch.pi * 10.0,
),
requires_grad=train_RFF,
)
else:
self.RFF_param = None
def set_zero_bias(m) -> None:
if isinstance(m, nn.Conv1d):
if m.bias is None:
msg = "bias cannot be none"
raise ValueError(msg)
m.bias.data.fill_(0.0)
self.implicit_filter.apply(set_zero_bias)
@property
def device(self):
return self.implicit_filter[0].weight.device
def _get_ir(self, normalized_time):
"""Return discrete-time impulse responses (corresponding to the weights of the convolutional layer) from MLP.
Args:
normalized_time (torch.Tensor): Normalized time (time).
Return:
torch.Tensor: Discrete-time impulse responses (n_filters x time)
"""
if self.RFF_param is not None:
RFF = self.RFF_param[:, None] @ normalized_time[None, :] # n_RFFs x time
RFF = torch.cat((RFF.sin(), RFF.cos()), dim=0) # n_RFFs*2 x time
ir = self.implicit_filter(RFF[None, :, :]) # 1 x n_filters x time
else:
ir = self.implicit_filter(
normalized_time[None, None, :],
) # 1 x n_filters x time
return ir.view(*(ir.shape[1:]))
def get_impulse_responses(self, sample_rate: int, kernel_size):
"""Calculating discrete-time impulse responses (corresponding to the weights of the convolutional layer) from MLP."""
use_oversampling = False
if not self.training and hasattr(self, "use_oversampling"):
use_oversampling = self.use_oversampling
if use_oversampling:
ir = self.get_impulse_responses_oversampling(sample_rate)
else:
normalized_time = torch.linspace(
-1.0,
1.0,
kernel_size,
device=self.device,
requires_grad=False,
) # time
ir = self._get_ir(normalized_time)
return ir
def get_impulse_responses_oversampling(self, sample_rate: int):
"""Calculating discrete-time impulse responses (corresponding to the weights of the convolutional layer) from MLP with oversampling for anti-aliasing.
First, calculate the discrete-time impulse responses with the trained sample rate.
Then, resample the calculated discrete-time impulse responses at the input sample rate.
"""
normalized_time = torch.linspace(
-1.0,
1.0,
self.init_kernel_size.item(),
device=self.device,
requires_grad=False,
) # time
ir = self._get_ir(normalized_time)
resampled_ir = torchaudio.functional.resample(
ir,
int(self.init_sample_rate.item()),
int(sample_rate),
) # resampling
return resampled_ir.float().to(self.device)