| | import math |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | from einops import rearrange |
| | from scipy.optimize import fmin |
| | from scipy.signal import firwin, kaiser, kaiser_beta, kaiserord |
| |
|
| | class PQMF(nn.Module): |
| | """ |
| | Pseudo Quadrature Mirror Filter (PQMF) for multiband signal decomposition and reconstruction. |
| | Uses polyphase representation which is computationally more efficient for real-time. |
| | |
| | Parameters: |
| | - attenuation (int): Desired attenuation of the rejected frequency bands, usually between 80 and 120 dB. |
| | - num_bands (int): Number of desired frequency bands. It must be a power of 2. |
| | """ |
| |
|
| | def __init__(self, attenuation, num_bands): |
| | super(PQMF, self).__init__() |
| | |
| | |
| | is_power_of_2 = (math.log2(num_bands) == int(math.log2(num_bands))) |
| | assert is_power_of_2, "'num_bands' must be a power of 2." |
| | |
| | |
| | prototype_filter = design_prototype_filter(attenuation, num_bands) |
| | filter_bank = generate_modulated_filter_bank(prototype_filter, num_bands) |
| | padded_filter_bank = pad_to_nearest_power_of_two(filter_bank) |
| | |
| | |
| | self.register_buffer("filter_bank", padded_filter_bank) |
| | self.register_buffer("prototype", prototype_filter) |
| | self.num_bands = num_bands |
| |
|
| | def forward(self, signal): |
| | """Decompose the signal into multiple frequency bands.""" |
| | |
| | signal = prepare_signal_dimensions(signal) |
| | |
| | signal = pad_signal(signal, self.num_bands) |
| | |
| | signal = polyphase_analysis(signal, self.filter_bank) |
| | return apply_alias_cancellation(signal) |
| |
|
| | def inverse(self, bands): |
| | """Reconstruct the original signal from the frequency bands.""" |
| | bands = apply_alias_cancellation(bands) |
| | return polyphase_synthesis(bands, self.filter_bank) |
| |
|
| |
|
| | def prepare_signal_dimensions(signal): |
| | """ |
| | Rearrange signal into Batch x Channels x Length. |
| | |
| | Parameters |
| | ---------- |
| | signal : torch.Tensor or numpy.ndarray |
| | The input signal. |
| | |
| | Returns |
| | ------- |
| | torch.Tensor |
| | Preprocessed signal tensor. |
| | """ |
| | |
| | if isinstance(signal, np.ndarray): |
| | signal = torch.from_numpy(signal) |
| | |
| | |
| | if not isinstance(signal, torch.Tensor): |
| | raise ValueError("Input should be either a numpy array or a PyTorch tensor.") |
| | |
| | |
| | if signal.dim() == 1: |
| | |
| | signal = signal.unsqueeze(0).unsqueeze(0) |
| | elif signal.dim() == 2: |
| | |
| | |
| | if signal.shape[0] > signal.shape[1]: |
| | signal = signal.T |
| | |
| | signal = signal.unsqueeze(0) |
| | return signal |
| | |
| | def pad_signal(signal, num_bands): |
| | """ |
| | Pads the signal to make its length divisible by the given number of bands. |
| | |
| | Parameters |
| | ---------- |
| | signal : torch.Tensor |
| | The input signal tensor, where the last dimension represents the signal length. |
| | |
| | num_bands : int |
| | The number of bands by which the signal length should be divisible. |
| | |
| | Returns |
| | ------- |
| | torch.Tensor |
| | The padded signal tensor. If the original signal length was already divisible |
| | by num_bands, returns the original signal unchanged. |
| | """ |
| | remainder = signal.shape[-1] % num_bands |
| | if remainder > 0: |
| | padding_size = num_bands - remainder |
| | signal = nn.functional.pad(signal, (0, padding_size)) |
| | return signal |
| |
|
| | def generate_modulated_filter_bank(prototype_filter, num_bands): |
| | """ |
| | Generate a QMF bank of cosine modulated filters based on a given prototype filter. |
| | |
| | Parameters |
| | ---------- |
| | prototype_filter : torch.Tensor |
| | The prototype filter used as the basis for modulation. |
| | num_bands : int |
| | The number of desired subbands or filters. |
| | |
| | Returns |
| | ------- |
| | torch.Tensor |
| | A bank of cosine modulated filters. |
| | """ |
| | |
| | |
| | subband_indices = torch.arange(num_bands).reshape(-1, 1) |
| | |
| | |
| | filter_length = prototype_filter.shape[-1] |
| | |
| | |
| | time_indices = torch.arange(-(filter_length // 2), (filter_length // 2) + 1) |
| | |
| | |
| | phase_offsets = (-1)**subband_indices * np.pi / 4 |
| | |
| | |
| | modulation = torch.cos( |
| | (2 * subband_indices + 1) * np.pi / (2 * num_bands) * time_indices + phase_offsets |
| | ) |
| | |
| | |
| | modulated_filters = 2 * prototype_filter * modulation |
| | |
| | return modulated_filters |
| |
|
| |
|
| | def design_kaiser_lowpass(angular_cutoff, attenuation, filter_length=None): |
| | """ |
| | Design a lowpass filter using the Kaiser window. |
| | |
| | Parameters |
| | ---------- |
| | angular_cutoff : float |
| | The angular frequency cutoff of the filter. |
| | attenuation : float |
| | The desired stopband attenuation in decibels (dB). |
| | filter_length : int, optional |
| | Desired length of the filter. If not provided, it's computed based on the given specs. |
| | |
| | Returns |
| | ------- |
| | ndarray |
| | The designed lowpass filter coefficients. |
| | """ |
| | |
| | estimated_length, beta = kaiserord(attenuation, angular_cutoff / np.pi) |
| | |
| | |
| | estimated_length = 2 * (estimated_length // 2) + 1 |
| | |
| | if filter_length is None: |
| | filter_length = estimated_length |
| | |
| | return firwin(filter_length, angular_cutoff, window=('kaiser', beta), scale=False, nyq=np.pi) |
| |
|
| |
|
| | def evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length): |
| | """ |
| | Evaluate the filter's objective value based on the criteria from https://ieeexplore.ieee.org/document/681427 |
| | |
| | Parameters |
| | ---------- |
| | angular_cutoff : float |
| | Angular frequency cutoff of the filter. |
| | attenuation : float |
| | Desired stopband attenuation in dB. |
| | num_bands : int |
| | Number of bands for the multiband filter system. |
| | filter_length : int, optional |
| | Desired length of the filter. |
| | |
| | Returns |
| | ------- |
| | float |
| | The computed objective (loss) value for the given filter specs. |
| | """ |
| | |
| | filter_coeffs = design_kaiser_lowpass(angular_cutoff, attenuation, filter_length) |
| | convolved_filter = np.convolve(filter_coeffs, filter_coeffs[::-1], "full") |
| | |
| | return np.max(np.abs(convolved_filter[convolved_filter.shape[-1] // 2::2 * num_bands][1:])) |
| |
|
| |
|
| | def design_prototype_filter(attenuation, num_bands, filter_length=None): |
| | """ |
| | Design the optimal prototype filter for a multiband system given the desired specs. |
| | |
| | Parameters |
| | ---------- |
| | attenuation : float |
| | The desired stopband attenuation in dB. |
| | num_bands : int |
| | Number of bands for the multiband filter system. |
| | filter_length : int, optional |
| | Desired length of the filter. If not provided, it's computed based on the given specs. |
| | |
| | Returns |
| | ------- |
| | ndarray |
| | The optimal prototype filter coefficients. |
| | """ |
| | |
| | optimal_angular_cutoff = fmin(lambda angular_cutoff: evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length), |
| | 1 / num_bands, disp=0)[0] |
| | |
| | prototype_filter = design_kaiser_lowpass(optimal_angular_cutoff, attenuation, filter_length) |
| | return torch.tensor(prototype_filter, dtype=torch.float32) |
| |
|
| | def pad_to_nearest_power_of_two(x): |
| | """ |
| | Pads the input tensor 'x' on both sides such that its last dimension |
| | becomes the nearest larger power of two. |
| | |
| | Parameters: |
| | ----------- |
| | x : torch.Tensor |
| | The input tensor to be padded. |
| | |
| | Returns: |
| | -------- |
| | torch.Tensor |
| | The padded tensor. |
| | """ |
| | current_length = x.shape[-1] |
| | target_length = 2**math.ceil(math.log2(current_length)) |
| | |
| | total_padding = target_length - current_length |
| | left_padding = total_padding // 2 |
| | right_padding = total_padding - left_padding |
| | |
| | return nn.functional.pad(x, (left_padding, right_padding)) |
| |
|
| | def apply_alias_cancellation(x): |
| | """ |
| | Applies alias cancellation by inverting the sign of every |
| | second element of every second row, starting from the second |
| | row's first element in a tensor. |
| | |
| | This operation helps ensure that the aliasing introduced in |
| | each band during the decomposition will be counteracted during |
| | the reconstruction. |
| | |
| | Parameters: |
| | ----------- |
| | x : torch.Tensor |
| | The input tensor. |
| | |
| | Returns: |
| | -------- |
| | torch.Tensor |
| | Tensor with specific elements' sign inverted for alias cancellation. |
| | """ |
| | |
| | |
| | mask = torch.ones_like(x) |
| | |
| | |
| | mask[..., 1::2, ::2] = -1 |
| | |
| | |
| | return x * mask |
| |
|
| | def ensure_odd_length(tensor): |
| | """ |
| | Pads the last dimension of a tensor to ensure its size is odd. |
| | |
| | Parameters: |
| | ----------- |
| | tensor : torch.Tensor |
| | Input tensor whose last dimension might need padding. |
| | |
| | Returns: |
| | -------- |
| | torch.Tensor |
| | The original tensor if its last dimension was already odd, |
| | or the padded tensor with an odd-sized last dimension. |
| | """ |
| | |
| | last_dim_size = tensor.shape[-1] |
| | |
| | if last_dim_size % 2 == 0: |
| | tensor = nn.functional.pad(tensor, (0, 1)) |
| | |
| | return tensor |
| |
|
| | def polyphase_analysis(signal, filter_bank): |
| | """ |
| | Applies the polyphase method to efficiently analyze the signal using a filter bank. |
| | |
| | Parameters: |
| | ----------- |
| | signal : torch.Tensor |
| | Input signal tensor with shape (Batch x Channels x Length). |
| | |
| | filter_bank : torch.Tensor |
| | Filter bank tensor with shape (Bands x Length). |
| | |
| | Returns: |
| | -------- |
| | torch.Tensor |
| | Signal split into sub-bands. (Batch x Channels x Bands x Length) |
| | """ |
| | |
| | num_bands = filter_bank.shape[0] |
| | num_channels = signal.shape[1] |
| | |
| | |
| | |
| | |
| | signal = rearrange(signal, "b c (t n) -> (b c) n t", n=num_bands) |
| | |
| | |
| | filter_bank = rearrange(filter_bank, "c (t n) -> c n t", n=num_bands) |
| | |
| | |
| | padding = filter_bank.shape[-1] // 2 |
| | filtered_signal = nn.functional.conv1d(signal, filter_bank, padding=padding) |
| | |
| | |
| | filtered_signal = filtered_signal[..., :-1] |
| | |
| | filtered_signal = rearrange(filtered_signal, "(b c) n t -> b c n t", c=num_channels) |
| |
|
| | return filtered_signal |
| |
|
| | def polyphase_synthesis(signal, filter_bank): |
| | """ |
| | Polyphase Inverse: Apply polyphase filter bank synthesis to reconstruct a signal. |
| | |
| | Parameters |
| | ---------- |
| | signal : torch.Tensor |
| | Decomposed signal to be reconstructed (shape: Batch x Channels x Bands x Length). |
| | |
| | filter_bank : torch.Tensor |
| | Analysis filter bank (shape: Bands x Length). |
| | |
| | should_rearrange : bool, optional |
| | Flag to determine if the filters should be rearranged for polyphase synthesis. Default is True. |
| | |
| | Returns |
| | ------- |
| | torch.Tensor |
| | Reconstructed signal (shape: Batch x Channels X Length) |
| | """ |
| | |
| | num_bands = filter_bank.shape[0] |
| | num_channels = signal.shape[1] |
| |
|
| | |
| | filter_bank = filter_bank.flip(-1) |
| | filter_bank = rearrange(filter_bank, "c (t n) -> n c t", n=num_bands) |
| |
|
| | |
| | signal = rearrange(signal, "b c n t -> (b c) n t") |
| |
|
| | |
| | padding_amount = filter_bank.shape[-1] // 2 + 1 |
| | reconstructed_signal = nn.functional.conv1d(signal, filter_bank, padding=int(padding_amount)) |
| | |
| | |
| | reconstructed_signal = reconstructed_signal[..., :-1] * num_bands |
| |
|
| | |
| | reconstructed_signal = reconstructed_signal.flip(1) |
| | reconstructed_signal = rearrange(reconstructed_signal, "(b c) n t -> b c (t n)", c=num_channels, n=num_bands) |
| | reconstructed_signal = reconstructed_signal[..., 2 * filter_bank.shape[1]:] |
| | |
| | return reconstructed_signal |