sfi_hubert / continuous_filters.py
Wataru's picture
add non_integer stride
16293fd verified
"""Implementations of latent analog filters.
Copyright (c) Tomohiko Nakamura
All rights reserved.
"""
import functools
from typing import Sequence
import numpy as np
import torch
import torchaudio
from torch import nn
def erb_to_hz(x):
"""Convert ERB to Hz.
Args:
x (numpy.ndarray or float): Frequency in ERB scale
Return:
numpy.ndarray or float: Frequency in Hz
"""
return (np.exp(x / 9.265) - 1) * 24.7 * 9.265
def hz_to_erb(x):
"""Convert Hz to ERB.
Args:
x (numpy.ndarray or float): Frequency in Hz
Return:
numpy.ndarray or float: Frequency in ERB scale
"""
return np.log(1 + x / (24.7 * 9.265)) * 9.265
#############################################
class ModulatedGaussianFilters(nn.Module):
r"""Modulated Gaussian filters.
The frequency response of this filter is given by
[
H(\omega) = e^{-(\omega-\omega_{c})^2/(2\sigma^2)} + e^{-(\omega+\omega_{c})^2/(2\sigma^2)}.
]
If one_sided is True, this frequency response is changed as
[
H(\omega) = e^{-(\omega-\omega_{c})^2/(2\sigma^2)}.
]
"""
def __init__(
self,
n_filters,
init_type="erb",
min_bw=1.0 * 2.0 * np.pi,
initial_freq_range=None,
one_sided=False,
init_sigma=100.0 * 2.0 * np.pi,
trainable=True,
) -> None:
"""Args:
n_filters (int): Number of filters
init_type (str): Initialization type of center frequencies.
If "erb", set them from initial_freq_range[0] to initial_freq_range[1] with an equal interval in the ERB scale.
If "linear", set them from initial_freq_range[0] to initial_freq_range[1] with an equal interval in the linear frequency scale.
min_bw (float): Minimum bandwidth in radian
initial_freq_range ([float,float]): Initial frequency ranges in Hz, as tuple of minimum (typically 50) and maximum values (typically, half of Nyquist frequency)
one_sided (bool): If True, ignore the term in the negative frequency region. If False, the corresponding impulse response is modulated Gaussian window.
init_sigma (float): Initial value for sigma
trainable (bool): Whether filter parameters are trainable or not.
"""
if initial_freq_range is None:
initial_freq_range = [50.0, 32000 / 2]
super().__init__()
lf, hf = initial_freq_range
if init_type == "linear":
mus = np.linspace(lf, hf, n_filters) * 2.0 * np.pi
sigma2s = init_sigma**2 * np.ones((n_filters,), dtype="f")
elif init_type == "erb":
erb_mus = np.linspace(hz_to_erb(lf), hz_to_erb(hf), n_filters)
mus = erb_to_hz(erb_mus) * 2.0 * np.pi
sigma2s = init_sigma**2 * np.ones((n_filters,), dtype="f")
else:
raise ValueError
self.min_ln_sigma2s = np.log(min_bw**2)
self.mus = nn.Parameter(torch.from_numpy(mus).float(), requires_grad=trainable)
self._ln_sigma2s = nn.Parameter(
torch.from_numpy(np.log(sigma2s)).float().clamp(min=self.min_ln_sigma2s),
requires_grad=trainable,
)
self.phase = nn.Parameter(
torch.zeros((n_filters,), dtype=torch.float),
requires_grad=trainable,
)
self.phase.data.uniform_(0.0, np.pi)
self.one_sided = one_sided
@property
def sigma2s(self):
return self._ln_sigma2s.clamp(min=self.min_ln_sigma2s).exp()
def get_frequency_responses(self, omega: torch.Tensor):
"""Sample frequency responses at omega.
Args:
omega (torch.Tensor): Angular frequencies (n_angs)
Return:
tuple[torch.Tensor]: Real and imaginary parts of frequency responses sampled at omega.
"""
if self.one_sided:
resp_abs = torch.exp(
-(omega[None, :] - self.mus[:, None]).pow(2.0)
/ (2.0 * self.sigma2s[:, None]),
) # n_filters x n_angfreqs
resp_r = resp_abs * self.phase.cos()[:, None]
resp_i = resp_abs * self.phase.sin()[:, None]
else:
resp_abs = torch.exp(
-(omega[None, :] - self.mus[:, None]).pow(2.0)
/ (2.0 * self.sigma2s[:, None]),
) # n_filters x n_angfreqs
resp_abs2 = torch.exp(
-(omega[None, :] + self.mus[:, None]).pow(2.0)
/ (2.0 * self.sigma2s[:, None]),
) # to ensure filters whose impulse responses are real.
resp_r = (
resp_abs * self.phase.cos()[:, None]
+ resp_abs2 * ((-self.phase).cos()[:, None])
)
resp_i = (
resp_abs * self.phase.sin()[:, None]
+ resp_abs2 * ((-self.phase).sin()[:, None])
)
return resp_r, resp_i
def extra_repr(self):
s = f"n_filters={int(self.mus.shape[0])}, one_sided={self.one_sided}"
return s.format(**self.__dict__)
@property
def device(self):
return self.mus.device
class TDModulatedGaussianFilters(ModulatedGaussianFilters):
def __init__(
self,
n_filters,
train_sample_rate,
init_type="erb",
min_bw=1.0 * 2.0 * np.pi,
initial_freq_range=None,
one_sided=False,
init_sigma=100.0 * 2.0 * np.pi,
trainable=True,
) -> None:
"""Args:
n_filters (int): Number of filters
train_sample_rate (float): Trained sampling frequency
init_type (str): Initialization type of center frequencies.
If "erb", set them from initial_freq_range[0] to initial_freq_range[1] with an equal interval in the ERB scale.
If "linear", set them from initial_freq_range[0] to initial_freq_range[1] with an equal interval in the linear frequency scale.
min_bw (float): Minimum bandwidth in radian
initial_freq_range ([float,float]): Initial frequency ranges in Hz, as tuple of minimum (typically 50) and maximum values (typically, half of Nyquist frequency)
one_sided (bool): If True, ignore the term in the negative frequency region. If False, the corresponding impulse response is modulated Gaussian window.
init_sigma (float): Initial value for sigma
trainable (bool): Whether filter parameters are trainable or not.
"""
if initial_freq_range is None:
initial_freq_range = [50.0, 32000 / 2]
super().__init__(
n_filters=n_filters,
init_type=init_type,
min_bw=min_bw,
initial_freq_range=initial_freq_range,
one_sided=one_sided,
init_sigma=init_sigma,
trainable=trainable,
)
self.register_buffer(
"train_sample_rate",
torch.tensor(float(train_sample_rate)),
)
def get_impulse_responses(self, sample_rate: int, tap_size: int):
"""Sample impulse responses.
Args:
sample_rate (int): Target sampling frequency
tap_size (int): Tap size
Return:
torch.Tensor: Sampled impulse responses (n_filters x tap_size)
"""
center_freqs_in_hz = self.mus / (2.0 * np.pi)
# check whether the center frequencies are below Nyquist rate
if self.train_sample_rate > sample_rate:
mask = center_freqs_in_hz <= sample_rate / 2
###
t = torch.arange(0.0, tap_size, 1).type_as(center_freqs_in_hz) / sample_rate
t = (t - t.mean())[None, :]
###
if self.one_sided:
raise NotImplementedError
c = (
2.0
* (2.0 * np.pi * self.sigma2s[:, None]).sqrt()
* (-self.sigma2s[:, None] * (t**2) / 2.0).exp()
)
filter_coeffs = (
c * (self.mus[:, None] @ t + self.phase[:, None]).cos()
) # n_filters x tap_size
if self.train_sample_rate > sample_rate:
filter_coeffs = filter_coeffs * mask[:, None]
return filter_coeffs[:, torch.arange(tap_size - 1, -1, -1)]
#############################################
class MultiPhaseGammaToneFilters(nn.Module):
"""Multiphase gamma tone filters.
Remark:
This class includes the creation of Hilbert transform pairs.
[2] D. Ditter and T. Gerkmann, ``A multi-phase gammatone filterbank for speech separation via TasNet,'' in Proceedings of IEEE International Conference on Acoustics, Speech, and Signal Processing, 2020, pp. 36--40.
"""
def __init__(
self,
n_filters,
train_sample_rate,
initial_freq_range=None,
n_center_freqs=24,
trainable=False,
) -> None:
"""Args:
n_filters (int): Number of filters
train_sample_rate (float): Trained sampling frequency
initial_freq_range ([float,float]): Initial frequency ranges in Hz, as tuple of minimum (typically 50) and maximum values (typically, half of Nyquist frequency)
n_center_freqs (int): Number of center frequencies
trainable (bool): Whether filter parameters are trainable or not.
"""
if initial_freq_range is None:
initial_freq_range = [100.0, 16000 / 2]
super().__init__()
self.register_buffer(
"train_sample_rate",
torch.tensor(float(train_sample_rate)),
)
self.n_filters = n_filters
assert n_filters // 2 >= n_center_freqs
## Ditter's initialization method
if trainable:
self.center_freqs_in_hz = nn.Parameter(
torch.from_numpy(
erb_to_hz(
np.linspace(
hz_to_erb(initial_freq_range[0]),
hz_to_erb(initial_freq_range[1]),
n_center_freqs,
),
).astype("f"),
).float(), # [Hz]
requires_grad=trainable,
)
else:
self.register_buffer(
"center_freqs_in_hz",
torch.from_numpy(
erb_to_hz(
np.linspace(
hz_to_erb(initial_freq_range[0]),
hz_to_erb(initial_freq_range[1]),
n_center_freqs,
),
).astype("f"),
).float(),
)
###
n_phase_variations_list = (
np.ones(n_center_freqs) * np.floor(self.n_filters / 2 / n_center_freqs)
).astype("i")
remaining_phases = int(self.n_filters // 2 - n_phase_variations_list.sum())
if remaining_phases > 0:
n_phase_variations_list[:remaining_phases] += 1
n_phase_variations_list = [int(_) for _ in n_phase_variations_list]
self.register_buffer(
"n_phase_variations",
torch.tensor(n_phase_variations_list),
)
###
phases = []
for N in n_phase_variations_list:
phases.append(np.linspace(0.0, np.pi, N))
phases = np.concatenate(phases, axis=0)
##
if trainable:
self.phases = nn.Parameter(
torch.from_numpy(phases).float(),
requires_grad=trainable,
) # n_filters//2
else:
self.register_buffer("phases", torch.from_numpy(phases).float())
def compute_gammatone_impulse_response(self, center_freqs_in_hz, phases, t):
"""Comptue gammatone impulse responses.
Args:
center_freqs_in_hz (torch.Tensor): Center frequencies in Hz
phases (torch.Tensor): Phases
sample_rate (float): Sampling frequency
Return:
torch.Tensor: Sampled impulse response (n_center_freqs x tap_size)
"""
center_freqs_in_hz = center_freqs_in_hz[:, None]
n = 2
b = (24.7 + center_freqs_in_hz / 9.265) / (
(np.pi * np.math.factorial(2 * n - 2) * np.power(2, float(-(2 * n - 2))))
/ np.square(np.math.factorial(n - 1))
) # equiavalent rectangular bandwidth
a = 1.0
return (
a
* (t ** (n - 1))
* torch.exp(-2 * np.pi * b * t)
* torch.cos(2 * np.pi * center_freqs_in_hz * t + phases[:, None])
) # n_center_freqs x tap_size
def normalize_filters(self, filter_coeffs):
"""Normalize filter coefficients.
Args:
filter_coeffs (torch.Tensor): Filter coefficients (n_filters x tap_size)
Return:
torch.Tensor: Normalized filter coefficients (n_filters x tap_size)
"""
rms_per_filter = (filter_coeffs**2).mean(dim=1).sqrt()
C = 1.0 / (rms_per_filter / rms_per_filter.max())
return filter_coeffs * C[:, None]
def get_impulse_responses(self, sample_rate: int, tap_size: int):
"""Sample impulse responses.
Args:
sample_rate (int): Target sampling frequency
tap_size (int): Tap size
Return:
torch.Tensor: Sampled impulse responses (n_filters x tap_size)
"""
phases = torch.cat((self.phases, self.phases + np.pi), dim=0) # n_filters
center_freqs_in_hz = self.center_freqs_in_hz.repeat_interleave(
self.n_phase_variations,
dim=0,
)
center_freqs_in_hz = center_freqs_in_hz.repeat(2) # doubles for Hilbert pairs
# check whether the center frequencies are below Nyquist rate
if self.train_sample_rate > sample_rate:
mask = center_freqs_in_hz <= sample_rate / 2
###
if tap_size % 2 == 0:
# even: exclude the origin
t = (
torch.arange(1.0, tap_size + 1, 1).type_as(center_freqs_in_hz)
/ sample_rate
)[None, :]
else:
# odd: include the origin
t = (
torch.arange(0.0, tap_size, 1).type_as(center_freqs_in_hz) / sample_rate
)[None, :]
filter_coeffs = self.compute_gammatone_impulse_response(
center_freqs_in_hz,
phases,
t,
).type_as(center_freqs_in_hz) # n_center_freqs x tap_size
filter_coeffs = self.normalize_filters(filter_coeffs).type_as(
center_freqs_in_hz,
)
if self.train_sample_rate > sample_rate:
filter_coeffs = filter_coeffs * mask[:, None]
return filter_coeffs[:, torch.arange(tap_size - 1, -1, -1)]
class RFFTimeDomainImplicitFilter(nn.Module):
def __init__(
self,
n_filters: int,
init_kernel_size: int,
init_sample_rate: int,
ch_list: Sequence[int] = [32, 32],
n_RFFs: int = 32,
nonlinearity: str = "relu",
train_RFF: bool = False,
use_layer_norm: bool = False,
) -> None:
"""n_filters: Number of filters.
init_kernel_size: Initial kernel size.
init_sample_rate: Initial sample rate.
ch_list: Channel list of MLP.
n_RFFs: Number of RFFs. If n_RFFs <= 0, do not use random Fourier feature inputs (i.e., directly input normalized time).
nonlinearity (str): Nonlinearity
train_RFF (bool): If True, train RFFs.
use_layer_norm (bool): If True, use layer norm.
"""
super().__init__()
self.n_filters = n_filters
self.register_buffer("init_kernel_size", torch.tensor(init_kernel_size).float())
self.register_buffer("init_sample_rate", torch.tensor(init_sample_rate).float())
# nonlinearity
if nonlinearity == "relu":
NonlinearityClass = functools.partial(nn.ReLU, inplace=True)
else:
raise NotImplementedError
# MLP
layers = []
in_ch_list = [n_RFFs * 2 if n_RFFs > 0 else 1, *list(ch_list)]
out_ch_list = [*list(ch_list), n_filters]
for (i, in_ch), out_ch in zip(enumerate(in_ch_list), out_ch_list):
layers.append(nn.Conv1d(in_ch, out_ch, 1))
if i < len(in_ch_list) - 1:
if use_layer_norm:
layers.append(nn.GroupNorm(1, out_ch))
layers.append(NonlinearityClass())
self.implicit_filter = nn.Sequential(*layers)
def init_weights(m) -> None:
if isinstance(m, nn.Conv1d):
torch.nn.init.xavier_uniform_(m.weight, gain=1e-3)
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
self.implicit_filter.apply(init_weights)
if n_RFFs > 0:
self.RFF_param = nn.Parameter(
torch.zeros((n_RFFs,), dtype=torch.float).normal_(
0.0,
2.0 * torch.pi * 10.0,
),
requires_grad=train_RFF,
)
else:
self.RFF_param = None
def set_zero_bias(m) -> None:
if isinstance(m, nn.Conv1d):
if m.bias is None:
msg = "bias cannot be none"
raise ValueError(msg)
m.bias.data.fill_(0.0)
self.implicit_filter.apply(set_zero_bias)
@staticmethod
def normalize_filters(filter_coeffs):
rms_per_filter = (filter_coeffs**2).mean(dim=1).sqrt()
# rms_per_filter = (filter_coeffs**2).mean(dim=1).clamp(min=1.0e-16)
# rms_per_filter = rms_per_filter.sqrt()
C = 1.0 / (rms_per_filter / rms_per_filter.max())
return filter_coeffs * C[:, None]
@property
def device(self):
return self.implicit_filter[0].weight.device
def _get_ir(self, normalized_time):
"""Get impulse response.
Args:
normalized_time (torch.Tensor): Normalized time (time).
Return:
torch.Tensor: Discrete-time impulse responses (n_filters x time)
"""
if self.RFF_param is not None:
RFF = self.RFF_param[:, None] @ normalized_time[None, :] # n_RFFs x time
RFF = torch.cat((RFF.sin(), RFF.cos()), dim=0) # n_RFFs*2 x time
ir = self.implicit_filter(RFF[None, :, :]) # 1 x n_filters x time
else:
ir = self.implicit_filter(
normalized_time[None, None, :],
) # 1 x n_filters x time
return ir.view(*(ir.shape[1:]))
def get_impulse_responses(self, sample_rate: int, kernel_size):
"""Calculate discrete-time impulse responses.
Corresponding to the weights of the convolutional layer from MLP.
"""
use_oversampling = False
if not self.training and hasattr(self, "use_oversampling"):
use_oversampling = self.use_oversampling
if use_oversampling:
ir = self.get_impulse_responses_oversampling(sample_rate)
else:
normalized_time = torch.linspace(
-1.0,
1.0,
kernel_size,
device=self.device,
requires_grad=False,
) # time
ir = self._get_ir(normalized_time)
return ir
def get_impulse_responses_oversampling(self, sample_rate: int):
"""Calculate discrete-time impulse responses from MLP with oversampling for anti-aliasing.
First, calculate the discrete-time impulse responses with the trained sample
rate.
Then, resample the calculated discrete-time impulse responses at the input
sample rate.
"""
normalized_time = torch.linspace(
-1.0,
1.0,
self.init_kernel_size.item(),
device=self.device,
requires_grad=False,
) # time
ir = self._get_ir(normalized_time)
resampled_ir = torchaudio.functional.resample(
ir,
int(self.init_sample_rate.item()),
int(sample_rate),
) # resampling
return resampled_ir.float().to(self.device)
class FrequencyDomainRFFImplicitFilter(nn.Module):
"""Nueral analog filter (NAF) for frequency-domain sampling-frequency-independent convolutional layer in [1].
[1] Kanami Imamura, Tomohiko Nakamura, Kohei Yatabe, and Hiroshi Saruwatari, ``Neural analog filter for sampling-frequency-independent convolutional layer," APSIPA Transactions on Signal and Information Processing, vol. 13, no. 1, e28, Nov. 2024.
"""
def __init__(
self,
n_filters: int,
max_freq: int,
ch_list: list[int] = [224, 224],
n_rffs: int = 128,
nonlinearity: str = "relu",
train_rff: bool = True,
use_layer_norm: bool = True,
):
"""Initialize FrequencyDomainRFFImplicitFilter.
Args:
n_filters (int): Number of filters
max_freq (float): Max. of frequency (i.e., Nyquist frequency of training data)
ch_list (list[int]): Channel list of MLP
n_rffs (int): # of RFFs. If equal to or less than 0, RFFs are not used.
nonlinearity (str): Nonlinearity
train_rff (bool): If True, train RFFs.
use_layer_norm (bool): If True, use layer norm.
"""
super().__init__()
self.use_RFFs = n_rffs > 0
# nonlinearity
if nonlinearity == "relu":
nonlinearity = functools.partial(nn.ReLU, inplace=True)
elif nonlinearity == "none":
nonlinearity = functools.partial(nn.Identity, inplace=True)
else:
raise NotImplementedError
self.n_filters = n_filters
self.register_buffer("max_ang_freq", torch.tensor(max_freq * 2.0 * np.pi))
layers = []
in_ch_list = [n_rffs * 2 if self.use_RFFs else 1] + [i for i in ch_list]
out_ch_list = [i for i in ch_list] + [n_filters * 2]
for (i, in_ch), out_ch in zip(enumerate(in_ch_list), out_ch_list):
layers.append(nn.Conv1d(in_ch, out_ch, 1))
if i < len(in_ch_list) - 1:
if use_layer_norm:
layers.append(nn.GroupNorm(1, out_ch))
layers.append(nonlinearity())
self.implicit_filter = nn.Sequential(*layers)
if self.use_RFFs:
self.RFF_param = nn.Parameter(
torch.zeros((n_rffs,), dtype=torch.float).normal_(
0.0, 2.0 * np.pi * 10.0
),
requires_grad=train_rff,
)
def set_zero_bias(m):
if isinstance(m, nn.Conv1d):
m.bias.data.fill_(0.0)
self.implicit_filter.apply(set_zero_bias)
self.use_ideal_low_pass_filter = True
@property
def device(self):
"""Device."""
return self.implicit_filter[0].weight.device
def get_frequency_responses(self, omega: torch.Tensor):
"""Calculating frequency responses from MLP.
Args:
omega (torch.Tensor): (Unnormalized) angular frequencies (n_angfreqs)
Return:
Tuple[torch.Tensor,torch.Tensor]: Real and imaginary parts of frequency characteristics (pair of n_filters x n_angfreqs as tuple)
"""
omega = omega / self.max_ang_freq # n_angfreqs
if self.use_RFFs:
x = self.RFF_param[:, None] @ omega[None, :] # n_RFFs x n_angfreqs
x = torch.cat((x.cos(), x.sin()), dim=0) # n_RFFs*2 x n_angfreqs
else:
x = omega[None, :] # 1 x n_angfreqs
freq_resps = self.implicit_filter(
x[None, :, :]
) # 1 x n_RFFs*2 (or 1 (ang. freq.)) x n_angfreqs -> 1 x n_filters*2 x n_angfreqs
# Apply ideal low pass filter
if not self.training and omega.max() > 1.0 and self.use_ideal_low_pass_filter:
freq_resps *= (omega <= 1.0).float()[None, None, :]
return freq_resps[0, : self.n_filters, :], freq_resps[0, self.n_filters :, :]