File size: 5,018 Bytes
e1c2051 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | import functools
from typing import Sequence
import torch
import torchaudio
from torch import nn
class RFFTimeDomainImplicitFilter(nn.Module):
def __init__(
self,
n_filters: int,
init_kernel_size: int,
init_sample_rate: int,
ch_list: Sequence[int] = [32, 32],
n_RFFs: int = 32,
nonlinearity: str = "relu",
train_RFF: bool = False,
use_layer_norm: bool = False,
) -> None:
"""n_filters: Number of filters.
init_kernel_size: Initial kernel size.
init_sample_rate: Initial sample rate.
ch_list: Channel list of MLP.
n_RFFs: Number of RFFs. If n_RFFs <= 0, do not use random Fourier feature inputs (i.e., directly input normalized time).
nonlinearity (str): Nonlinearity
train_RFF (bool): If True, train RFFs.
use_layer_norm (bool): If True, use layer norm.
"""
super().__init__()
self.n_filters = n_filters
self.register_buffer("init_kernel_size", torch.tensor(init_kernel_size).float())
self.register_buffer("init_sample_rate", torch.tensor(init_sample_rate).float())
# nonlinearity
if nonlinearity == "relu":
NonlinearityClass = functools.partial(nn.ReLU, inplace=True)
else:
raise NotImplementedError
# MLP
layers = []
in_ch_list = [n_RFFs * 2 if n_RFFs > 0 else 1, *list(ch_list)]
out_ch_list = [*list(ch_list), n_filters]
for (i, in_ch), out_ch in zip(enumerate(in_ch_list), out_ch_list):
layers.append(nn.Conv1d(in_ch, out_ch, 1))
if i < len(in_ch_list) - 1:
if use_layer_norm:
layers.append(nn.GroupNorm(1, out_ch))
layers.append(NonlinearityClass())
self.implicit_filter = nn.Sequential(*layers)
if n_RFFs > 0:
self.RFF_param = nn.Parameter(
torch.zeros((n_RFFs,), dtype=torch.float).normal_(
0.0,
2.0 * torch.pi * 10.0,
),
requires_grad=train_RFF,
)
else:
self.RFF_param = None
def set_zero_bias(m) -> None:
if isinstance(m, nn.Conv1d):
if m.bias is None:
msg = "bias cannot be none"
raise ValueError(msg)
m.bias.data.fill_(0.0)
self.implicit_filter.apply(set_zero_bias)
@property
def device(self):
return self.implicit_filter[0].weight.device
def _get_ir(self, normalized_time):
"""Return discrete-time impulse responses (corresponding to the weights of the convolutional layer) from MLP.
Args:
normalized_time (torch.Tensor): Normalized time (time).
Return:
torch.Tensor: Discrete-time impulse responses (n_filters x time)
"""
if self.RFF_param is not None:
RFF = self.RFF_param[:, None] @ normalized_time[None, :] # n_RFFs x time
RFF = torch.cat((RFF.sin(), RFF.cos()), dim=0) # n_RFFs*2 x time
ir = self.implicit_filter(RFF[None, :, :]) # 1 x n_filters x time
else:
ir = self.implicit_filter(
normalized_time[None, None, :],
) # 1 x n_filters x time
return ir.view(*(ir.shape[1:]))
def get_impulse_responses(self, sample_rate: int, kernel_size):
"""Calculating discrete-time impulse responses (corresponding to the weights of the convolutional layer) from MLP."""
use_oversampling = False
if not self.training and hasattr(self, "use_oversampling"):
use_oversampling = self.use_oversampling
if use_oversampling:
ir = self.get_impulse_responses_oversampling(sample_rate)
else:
normalized_time = torch.linspace(
-1.0,
1.0,
kernel_size,
device=self.device,
requires_grad=False,
) # time
ir = self._get_ir(normalized_time)
return ir
def get_impulse_responses_oversampling(self, sample_rate: int):
"""Calculating discrete-time impulse responses (corresponding to the weights of the convolutional layer) from MLP with oversampling for anti-aliasing.
First, calculate the discrete-time impulse responses with the trained sample rate.
Then, resample the calculated discrete-time impulse responses at the input sample rate.
"""
normalized_time = torch.linspace(
-1.0,
1.0,
self.init_kernel_size.item(),
device=self.device,
requires_grad=False,
) # time
ir = self._get_ir(normalized_time)
resampled_ir = torchaudio.functional.resample(
ir,
int(self.init_sample_rate.item()),
int(sample_rate),
) # resampling
return resampled_ir.float().to(self.device)
|