| | """Implementations of latent analog filters. |
| | |
| | Copyright (c) Tomohiko Nakamura |
| | All rights reserved. |
| | """ |
| |
|
| | import functools |
| | from typing import Sequence |
| |
|
| | import numpy as np |
| | import torch |
| | import torchaudio |
| | from torch import nn |
| |
|
| |
|
| | def erb_to_hz(x): |
| | """Convert ERB to Hz. |
| | |
| | Args: |
| | x (numpy.ndarray or float): Frequency in ERB scale |
| | |
| | Return: |
| | numpy.ndarray or float: Frequency in Hz |
| | |
| | """ |
| | return (np.exp(x / 9.265) - 1) * 24.7 * 9.265 |
| |
|
| |
|
| | def hz_to_erb(x): |
| | """Convert Hz to ERB. |
| | |
| | Args: |
| | x (numpy.ndarray or float): Frequency in Hz |
| | |
| | Return: |
| | numpy.ndarray or float: Frequency in ERB scale |
| | |
| | """ |
| | return np.log(1 + x / (24.7 * 9.265)) * 9.265 |
| |
|
| |
|
| | |
| | class ModulatedGaussianFilters(nn.Module): |
| | r"""Modulated Gaussian filters. |
| | |
| | The frequency response of this filter is given by |
| | |
| | [ |
| | H(\omega) = e^{-(\omega-\omega_{c})^2/(2\sigma^2)} + e^{-(\omega+\omega_{c})^2/(2\sigma^2)}. |
| | ] |
| | |
| | If one_sided is True, this frequency response is changed as |
| | |
| | [ |
| | H(\omega) = e^{-(\omega-\omega_{c})^2/(2\sigma^2)}. |
| | ] |
| | |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | n_filters, |
| | init_type="erb", |
| | min_bw=1.0 * 2.0 * np.pi, |
| | initial_freq_range=None, |
| | one_sided=False, |
| | init_sigma=100.0 * 2.0 * np.pi, |
| | trainable=True, |
| | ) -> None: |
| | """Args: |
| | n_filters (int): Number of filters |
| | init_type (str): Initialization type of center frequencies. |
| | If "erb", set them from initial_freq_range[0] to initial_freq_range[1] with an equal interval in the ERB scale. |
| | If "linear", set them from initial_freq_range[0] to initial_freq_range[1] with an equal interval in the linear frequency scale. |
| | min_bw (float): Minimum bandwidth in radian |
| | initial_freq_range ([float,float]): Initial frequency ranges in Hz, as tuple of minimum (typically 50) and maximum values (typically, half of Nyquist frequency) |
| | one_sided (bool): If True, ignore the term in the negative frequency region. If False, the corresponding impulse response is modulated Gaussian window. |
| | init_sigma (float): Initial value for sigma |
| | trainable (bool): Whether filter parameters are trainable or not. |
| | |
| | """ |
| | if initial_freq_range is None: |
| | initial_freq_range = [50.0, 32000 / 2] |
| | super().__init__() |
| | lf, hf = initial_freq_range |
| | if init_type == "linear": |
| | mus = np.linspace(lf, hf, n_filters) * 2.0 * np.pi |
| | sigma2s = init_sigma**2 * np.ones((n_filters,), dtype="f") |
| | elif init_type == "erb": |
| | erb_mus = np.linspace(hz_to_erb(lf), hz_to_erb(hf), n_filters) |
| | mus = erb_to_hz(erb_mus) * 2.0 * np.pi |
| | sigma2s = init_sigma**2 * np.ones((n_filters,), dtype="f") |
| | else: |
| | raise ValueError |
| | self.min_ln_sigma2s = np.log(min_bw**2) |
| |
|
| | self.mus = nn.Parameter(torch.from_numpy(mus).float(), requires_grad=trainable) |
| | self._ln_sigma2s = nn.Parameter( |
| | torch.from_numpy(np.log(sigma2s)).float().clamp(min=self.min_ln_sigma2s), |
| | requires_grad=trainable, |
| | ) |
| | self.phase = nn.Parameter( |
| | torch.zeros((n_filters,), dtype=torch.float), |
| | requires_grad=trainable, |
| | ) |
| | self.phase.data.uniform_(0.0, np.pi) |
| | self.one_sided = one_sided |
| |
|
| | @property |
| | def sigma2s(self): |
| | return self._ln_sigma2s.clamp(min=self.min_ln_sigma2s).exp() |
| |
|
| | def get_frequency_responses(self, omega: torch.Tensor): |
| | """Sample frequency responses at omega. |
| | |
| | Args: |
| | omega (torch.Tensor): Angular frequencies (n_angs) |
| | |
| | Return: |
| | tuple[torch.Tensor]: Real and imaginary parts of frequency responses sampled at omega. |
| | |
| | """ |
| | if self.one_sided: |
| | resp_abs = torch.exp( |
| | -(omega[None, :] - self.mus[:, None]).pow(2.0) |
| | / (2.0 * self.sigma2s[:, None]), |
| | ) |
| | resp_r = resp_abs * self.phase.cos()[:, None] |
| | resp_i = resp_abs * self.phase.sin()[:, None] |
| | else: |
| | resp_abs = torch.exp( |
| | -(omega[None, :] - self.mus[:, None]).pow(2.0) |
| | / (2.0 * self.sigma2s[:, None]), |
| | ) |
| | resp_abs2 = torch.exp( |
| | -(omega[None, :] + self.mus[:, None]).pow(2.0) |
| | / (2.0 * self.sigma2s[:, None]), |
| | ) |
| | resp_r = ( |
| | resp_abs * self.phase.cos()[:, None] |
| | + resp_abs2 * ((-self.phase).cos()[:, None]) |
| | ) |
| | resp_i = ( |
| | resp_abs * self.phase.sin()[:, None] |
| | + resp_abs2 * ((-self.phase).sin()[:, None]) |
| | ) |
| | return resp_r, resp_i |
| |
|
| | def extra_repr(self): |
| | s = f"n_filters={int(self.mus.shape[0])}, one_sided={self.one_sided}" |
| | return s.format(**self.__dict__) |
| |
|
| | @property |
| | def device(self): |
| | return self.mus.device |
| |
|
| |
|
| | class TDModulatedGaussianFilters(ModulatedGaussianFilters): |
| | def __init__( |
| | self, |
| | n_filters, |
| | train_sample_rate, |
| | init_type="erb", |
| | min_bw=1.0 * 2.0 * np.pi, |
| | initial_freq_range=None, |
| | one_sided=False, |
| | init_sigma=100.0 * 2.0 * np.pi, |
| | trainable=True, |
| | ) -> None: |
| | """Args: |
| | n_filters (int): Number of filters |
| | train_sample_rate (float): Trained sampling frequency |
| | init_type (str): Initialization type of center frequencies. |
| | If "erb", set them from initial_freq_range[0] to initial_freq_range[1] with an equal interval in the ERB scale. |
| | If "linear", set them from initial_freq_range[0] to initial_freq_range[1] with an equal interval in the linear frequency scale. |
| | min_bw (float): Minimum bandwidth in radian |
| | initial_freq_range ([float,float]): Initial frequency ranges in Hz, as tuple of minimum (typically 50) and maximum values (typically, half of Nyquist frequency) |
| | one_sided (bool): If True, ignore the term in the negative frequency region. If False, the corresponding impulse response is modulated Gaussian window. |
| | init_sigma (float): Initial value for sigma |
| | trainable (bool): Whether filter parameters are trainable or not. |
| | |
| | """ |
| | if initial_freq_range is None: |
| | initial_freq_range = [50.0, 32000 / 2] |
| | super().__init__( |
| | n_filters=n_filters, |
| | init_type=init_type, |
| | min_bw=min_bw, |
| | initial_freq_range=initial_freq_range, |
| | one_sided=one_sided, |
| | init_sigma=init_sigma, |
| | trainable=trainable, |
| | ) |
| | self.register_buffer( |
| | "train_sample_rate", |
| | torch.tensor(float(train_sample_rate)), |
| | ) |
| |
|
| | def get_impulse_responses(self, sample_rate: int, tap_size: int): |
| | """Sample impulse responses. |
| | |
| | Args: |
| | sample_rate (int): Target sampling frequency |
| | tap_size (int): Tap size |
| | |
| | Return: |
| | torch.Tensor: Sampled impulse responses (n_filters x tap_size) |
| | |
| | """ |
| | center_freqs_in_hz = self.mus / (2.0 * np.pi) |
| | |
| | if self.train_sample_rate > sample_rate: |
| | mask = center_freqs_in_hz <= sample_rate / 2 |
| | |
| | t = torch.arange(0.0, tap_size, 1).type_as(center_freqs_in_hz) / sample_rate |
| | t = (t - t.mean())[None, :] |
| | |
| | if self.one_sided: |
| | raise NotImplementedError |
| | c = ( |
| | 2.0 |
| | * (2.0 * np.pi * self.sigma2s[:, None]).sqrt() |
| | * (-self.sigma2s[:, None] * (t**2) / 2.0).exp() |
| | ) |
| | filter_coeffs = ( |
| | c * (self.mus[:, None] @ t + self.phase[:, None]).cos() |
| | ) |
| | if self.train_sample_rate > sample_rate: |
| | filter_coeffs = filter_coeffs * mask[:, None] |
| | return filter_coeffs[:, torch.arange(tap_size - 1, -1, -1)] |
| |
|
| |
|
| | |
| | class MultiPhaseGammaToneFilters(nn.Module): |
| | """Multiphase gamma tone filters. |
| | |
| | Remark: |
| | This class includes the creation of Hilbert transform pairs. |
| | |
| | [2] D. Ditter and T. Gerkmann, ``A multi-phase gammatone filterbank for speech separation via TasNet,'' in Proceedings of IEEE International Conference on Acoustics, Speech, and Signal Processing, 2020, pp. 36--40. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | n_filters, |
| | train_sample_rate, |
| | initial_freq_range=None, |
| | n_center_freqs=24, |
| | trainable=False, |
| | ) -> None: |
| | """Args: |
| | n_filters (int): Number of filters |
| | train_sample_rate (float): Trained sampling frequency |
| | initial_freq_range ([float,float]): Initial frequency ranges in Hz, as tuple of minimum (typically 50) and maximum values (typically, half of Nyquist frequency) |
| | n_center_freqs (int): Number of center frequencies |
| | trainable (bool): Whether filter parameters are trainable or not. |
| | |
| | """ |
| | if initial_freq_range is None: |
| | initial_freq_range = [100.0, 16000 / 2] |
| | super().__init__() |
| | self.register_buffer( |
| | "train_sample_rate", |
| | torch.tensor(float(train_sample_rate)), |
| | ) |
| | self.n_filters = n_filters |
| | assert n_filters // 2 >= n_center_freqs |
| | |
| | if trainable: |
| | self.center_freqs_in_hz = nn.Parameter( |
| | torch.from_numpy( |
| | erb_to_hz( |
| | np.linspace( |
| | hz_to_erb(initial_freq_range[0]), |
| | hz_to_erb(initial_freq_range[1]), |
| | n_center_freqs, |
| | ), |
| | ).astype("f"), |
| | ).float(), |
| | requires_grad=trainable, |
| | ) |
| | else: |
| | self.register_buffer( |
| | "center_freqs_in_hz", |
| | torch.from_numpy( |
| | erb_to_hz( |
| | np.linspace( |
| | hz_to_erb(initial_freq_range[0]), |
| | hz_to_erb(initial_freq_range[1]), |
| | n_center_freqs, |
| | ), |
| | ).astype("f"), |
| | ).float(), |
| | ) |
| | |
| | n_phase_variations_list = ( |
| | np.ones(n_center_freqs) * np.floor(self.n_filters / 2 / n_center_freqs) |
| | ).astype("i") |
| | remaining_phases = int(self.n_filters // 2 - n_phase_variations_list.sum()) |
| | if remaining_phases > 0: |
| | n_phase_variations_list[:remaining_phases] += 1 |
| | n_phase_variations_list = [int(_) for _ in n_phase_variations_list] |
| | self.register_buffer( |
| | "n_phase_variations", |
| | torch.tensor(n_phase_variations_list), |
| | ) |
| | |
| | phases = [] |
| | for N in n_phase_variations_list: |
| | phases.append(np.linspace(0.0, np.pi, N)) |
| | phases = np.concatenate(phases, axis=0) |
| | |
| | if trainable: |
| | self.phases = nn.Parameter( |
| | torch.from_numpy(phases).float(), |
| | requires_grad=trainable, |
| | ) |
| | else: |
| | self.register_buffer("phases", torch.from_numpy(phases).float()) |
| |
|
| | def compute_gammatone_impulse_response(self, center_freqs_in_hz, phases, t): |
| | """Comptue gammatone impulse responses. |
| | |
| | Args: |
| | center_freqs_in_hz (torch.Tensor): Center frequencies in Hz |
| | phases (torch.Tensor): Phases |
| | sample_rate (float): Sampling frequency |
| | |
| | Return: |
| | torch.Tensor: Sampled impulse response (n_center_freqs x tap_size) |
| | |
| | """ |
| | center_freqs_in_hz = center_freqs_in_hz[:, None] |
| | n = 2 |
| | b = (24.7 + center_freqs_in_hz / 9.265) / ( |
| | (np.pi * np.math.factorial(2 * n - 2) * np.power(2, float(-(2 * n - 2)))) |
| | / np.square(np.math.factorial(n - 1)) |
| | ) |
| | a = 1.0 |
| | return ( |
| | a |
| | * (t ** (n - 1)) |
| | * torch.exp(-2 * np.pi * b * t) |
| | * torch.cos(2 * np.pi * center_freqs_in_hz * t + phases[:, None]) |
| | ) |
| |
|
| | def normalize_filters(self, filter_coeffs): |
| | """Normalize filter coefficients. |
| | |
| | Args: |
| | filter_coeffs (torch.Tensor): Filter coefficients (n_filters x tap_size) |
| | |
| | Return: |
| | torch.Tensor: Normalized filter coefficients (n_filters x tap_size) |
| | |
| | """ |
| | rms_per_filter = (filter_coeffs**2).mean(dim=1).sqrt() |
| | C = 1.0 / (rms_per_filter / rms_per_filter.max()) |
| | return filter_coeffs * C[:, None] |
| |
|
| | def get_impulse_responses(self, sample_rate: int, tap_size: int): |
| | """Sample impulse responses. |
| | |
| | Args: |
| | sample_rate (int): Target sampling frequency |
| | tap_size (int): Tap size |
| | |
| | Return: |
| | torch.Tensor: Sampled impulse responses (n_filters x tap_size) |
| | |
| | """ |
| | phases = torch.cat((self.phases, self.phases + np.pi), dim=0) |
| | center_freqs_in_hz = self.center_freqs_in_hz.repeat_interleave( |
| | self.n_phase_variations, |
| | dim=0, |
| | ) |
| | center_freqs_in_hz = center_freqs_in_hz.repeat(2) |
| | |
| | if self.train_sample_rate > sample_rate: |
| | mask = center_freqs_in_hz <= sample_rate / 2 |
| | |
| | if tap_size % 2 == 0: |
| | |
| | t = ( |
| | torch.arange(1.0, tap_size + 1, 1).type_as(center_freqs_in_hz) |
| | / sample_rate |
| | )[None, :] |
| | else: |
| | |
| | t = ( |
| | torch.arange(0.0, tap_size, 1).type_as(center_freqs_in_hz) / sample_rate |
| | )[None, :] |
| | filter_coeffs = self.compute_gammatone_impulse_response( |
| | center_freqs_in_hz, |
| | phases, |
| | t, |
| | ).type_as(center_freqs_in_hz) |
| | filter_coeffs = self.normalize_filters(filter_coeffs).type_as( |
| | center_freqs_in_hz, |
| | ) |
| | if self.train_sample_rate > sample_rate: |
| | filter_coeffs = filter_coeffs * mask[:, None] |
| | return filter_coeffs[:, torch.arange(tap_size - 1, -1, -1)] |
| |
|
| |
|
| | class RFFTimeDomainImplicitFilter(nn.Module): |
| | def __init__( |
| | self, |
| | n_filters: int, |
| | init_kernel_size: int, |
| | init_sample_rate: int, |
| | ch_list: Sequence[int] = [32, 32], |
| | n_RFFs: int = 32, |
| | nonlinearity: str = "relu", |
| | train_RFF: bool = False, |
| | use_layer_norm: bool = False, |
| | ) -> None: |
| | """n_filters: Number of filters. |
| | init_kernel_size: Initial kernel size. |
| | init_sample_rate: Initial sample rate. |
| | ch_list: Channel list of MLP. |
| | n_RFFs: Number of RFFs. If n_RFFs <= 0, do not use random Fourier feature inputs (i.e., directly input normalized time). |
| | nonlinearity (str): Nonlinearity |
| | train_RFF (bool): If True, train RFFs. |
| | use_layer_norm (bool): If True, use layer norm. |
| | """ |
| | super().__init__() |
| | self.n_filters = n_filters |
| | self.register_buffer("init_kernel_size", torch.tensor(init_kernel_size).float()) |
| | self.register_buffer("init_sample_rate", torch.tensor(init_sample_rate).float()) |
| |
|
| | |
| | if nonlinearity == "relu": |
| | NonlinearityClass = functools.partial(nn.ReLU, inplace=True) |
| | else: |
| | raise NotImplementedError |
| |
|
| | |
| | layers = [] |
| | in_ch_list = [n_RFFs * 2 if n_RFFs > 0 else 1, *list(ch_list)] |
| | out_ch_list = [*list(ch_list), n_filters] |
| | for (i, in_ch), out_ch in zip(enumerate(in_ch_list), out_ch_list): |
| | layers.append(nn.Conv1d(in_ch, out_ch, 1)) |
| | if i < len(in_ch_list) - 1: |
| | if use_layer_norm: |
| | layers.append(nn.GroupNorm(1, out_ch)) |
| | layers.append(NonlinearityClass()) |
| | self.implicit_filter = nn.Sequential(*layers) |
| |
|
| | def init_weights(m) -> None: |
| | if isinstance(m, nn.Conv1d): |
| | torch.nn.init.xavier_uniform_(m.weight, gain=1e-3) |
| | if m.bias is not None: |
| | torch.nn.init.zeros_(m.bias) |
| |
|
| | self.implicit_filter.apply(init_weights) |
| |
|
| | if n_RFFs > 0: |
| | self.RFF_param = nn.Parameter( |
| | torch.zeros((n_RFFs,), dtype=torch.float).normal_( |
| | 0.0, |
| | 2.0 * torch.pi * 10.0, |
| | ), |
| | requires_grad=train_RFF, |
| | ) |
| | else: |
| | self.RFF_param = None |
| |
|
| | def set_zero_bias(m) -> None: |
| | if isinstance(m, nn.Conv1d): |
| | if m.bias is None: |
| | msg = "bias cannot be none" |
| | raise ValueError(msg) |
| | m.bias.data.fill_(0.0) |
| |
|
| | self.implicit_filter.apply(set_zero_bias) |
| |
|
| | @staticmethod |
| | def normalize_filters(filter_coeffs): |
| | rms_per_filter = (filter_coeffs**2).mean(dim=1).sqrt() |
| | |
| | |
| | C = 1.0 / (rms_per_filter / rms_per_filter.max()) |
| | return filter_coeffs * C[:, None] |
| |
|
| | @property |
| | def device(self): |
| | return self.implicit_filter[0].weight.device |
| |
|
| | def _get_ir(self, normalized_time): |
| | """Get impulse response. |
| | |
| | Args: |
| | normalized_time (torch.Tensor): Normalized time (time). |
| | |
| | Return: |
| | torch.Tensor: Discrete-time impulse responses (n_filters x time) |
| | |
| | """ |
| | if self.RFF_param is not None: |
| | RFF = self.RFF_param[:, None] @ normalized_time[None, :] |
| | RFF = torch.cat((RFF.sin(), RFF.cos()), dim=0) |
| | ir = self.implicit_filter(RFF[None, :, :]) |
| | else: |
| | ir = self.implicit_filter( |
| | normalized_time[None, None, :], |
| | ) |
| | return ir.view(*(ir.shape[1:])) |
| |
|
| | def get_impulse_responses(self, sample_rate: int, kernel_size): |
| | """Calculate discrete-time impulse responses. |
| | |
| | Corresponding to the weights of the convolutional layer from MLP. |
| | """ |
| | use_oversampling = False |
| | if not self.training and hasattr(self, "use_oversampling"): |
| | use_oversampling = self.use_oversampling |
| |
|
| | if use_oversampling: |
| | ir = self.get_impulse_responses_oversampling(sample_rate) |
| | else: |
| | normalized_time = torch.linspace( |
| | -1.0, |
| | 1.0, |
| | kernel_size, |
| | device=self.device, |
| | requires_grad=False, |
| | ) |
| | ir = self._get_ir(normalized_time) |
| | return ir |
| |
|
| | def get_impulse_responses_oversampling(self, sample_rate: int): |
| | """Calculate discrete-time impulse responses from MLP with oversampling for anti-aliasing. |
| | |
| | First, calculate the discrete-time impulse responses with the trained sample |
| | rate. |
| | |
| | Then, resample the calculated discrete-time impulse responses at the input |
| | sample rate. |
| | """ |
| | normalized_time = torch.linspace( |
| | -1.0, |
| | 1.0, |
| | self.init_kernel_size.item(), |
| | device=self.device, |
| | requires_grad=False, |
| | ) |
| | ir = self._get_ir(normalized_time) |
| | resampled_ir = torchaudio.functional.resample( |
| | ir, |
| | int(self.init_sample_rate.item()), |
| | int(sample_rate), |
| | ) |
| | return resampled_ir.float().to(self.device) |
| |
|
| |
|
| | class FrequencyDomainRFFImplicitFilter(nn.Module): |
| | """Nueral analog filter (NAF) for frequency-domain sampling-frequency-independent convolutional layer in [1]. |
| | |
| | [1] Kanami Imamura, Tomohiko Nakamura, Kohei Yatabe, and Hiroshi Saruwatari, ``Neural analog filter for sampling-frequency-independent convolutional layer," APSIPA Transactions on Signal and Information Processing, vol. 13, no. 1, e28, Nov. 2024. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | n_filters: int, |
| | max_freq: int, |
| | ch_list: list[int] = [224, 224], |
| | n_rffs: int = 128, |
| | nonlinearity: str = "relu", |
| | train_rff: bool = True, |
| | use_layer_norm: bool = True, |
| | ): |
| | """Initialize FrequencyDomainRFFImplicitFilter. |
| | |
| | Args: |
| | n_filters (int): Number of filters |
| | max_freq (float): Max. of frequency (i.e., Nyquist frequency of training data) |
| | ch_list (list[int]): Channel list of MLP |
| | n_rffs (int): # of RFFs. If equal to or less than 0, RFFs are not used. |
| | nonlinearity (str): Nonlinearity |
| | train_rff (bool): If True, train RFFs. |
| | use_layer_norm (bool): If True, use layer norm. |
| | """ |
| | super().__init__() |
| | self.use_RFFs = n_rffs > 0 |
| |
|
| | |
| | if nonlinearity == "relu": |
| | nonlinearity = functools.partial(nn.ReLU, inplace=True) |
| | elif nonlinearity == "none": |
| | nonlinearity = functools.partial(nn.Identity, inplace=True) |
| | else: |
| | raise NotImplementedError |
| |
|
| | self.n_filters = n_filters |
| | self.register_buffer("max_ang_freq", torch.tensor(max_freq * 2.0 * np.pi)) |
| | layers = [] |
| | in_ch_list = [n_rffs * 2 if self.use_RFFs else 1] + [i for i in ch_list] |
| | out_ch_list = [i for i in ch_list] + [n_filters * 2] |
| | for (i, in_ch), out_ch in zip(enumerate(in_ch_list), out_ch_list): |
| | layers.append(nn.Conv1d(in_ch, out_ch, 1)) |
| | if i < len(in_ch_list) - 1: |
| | if use_layer_norm: |
| | layers.append(nn.GroupNorm(1, out_ch)) |
| | layers.append(nonlinearity()) |
| | self.implicit_filter = nn.Sequential(*layers) |
| |
|
| | if self.use_RFFs: |
| | self.RFF_param = nn.Parameter( |
| | torch.zeros((n_rffs,), dtype=torch.float).normal_( |
| | 0.0, 2.0 * np.pi * 10.0 |
| | ), |
| | requires_grad=train_rff, |
| | ) |
| |
|
| | def set_zero_bias(m): |
| | if isinstance(m, nn.Conv1d): |
| | m.bias.data.fill_(0.0) |
| |
|
| | self.implicit_filter.apply(set_zero_bias) |
| | self.use_ideal_low_pass_filter = True |
| |
|
| | @property |
| | def device(self): |
| | """Device.""" |
| | return self.implicit_filter[0].weight.device |
| |
|
| | def get_frequency_responses(self, omega: torch.Tensor): |
| | """Calculating frequency responses from MLP. |
| | |
| | Args: |
| | omega (torch.Tensor): (Unnormalized) angular frequencies (n_angfreqs) |
| | |
| | Return: |
| | Tuple[torch.Tensor,torch.Tensor]: Real and imaginary parts of frequency characteristics (pair of n_filters x n_angfreqs as tuple) |
| | """ |
| | omega = omega / self.max_ang_freq |
| | if self.use_RFFs: |
| | x = self.RFF_param[:, None] @ omega[None, :] |
| | x = torch.cat((x.cos(), x.sin()), dim=0) |
| | else: |
| | x = omega[None, :] |
| | freq_resps = self.implicit_filter( |
| | x[None, :, :] |
| | ) |
| |
|
| | |
| | if not self.training and omega.max() > 1.0 and self.use_ideal_low_pass_filter: |
| | freq_resps *= (omega <= 1.0).float()[None, None, :] |
| |
|
| | return freq_resps[0, : self.n_filters, :], freq_resps[0, self.n_filters :, :] |
| |
|