|
|
|
|
|
|
|
|
|
|
| """Pseudo QMF modules."""
|
|
|
| import numpy as np
|
| import torch
|
| import torch.nn.functional as F
|
|
|
| from scipy.signal import kaiser
|
|
|
|
|
| def design_prototype_filter(taps=62, cutoff_ratio=0.15, beta=9.0):
|
| """Design prototype filter for PQMF.
|
|
|
| This method is based on `A Kaiser window approach for the design of prototype
|
| filters of cosine modulated filterbanks`_.
|
|
|
| Args:
|
| taps (int): The number of filter taps.
|
| cutoff_ratio (float): Cut-off frequency ratio.
|
| beta (float): Beta coefficient for kaiser window.
|
|
|
| Returns:
|
| ndarray: Impluse response of prototype filter (taps + 1,).
|
|
|
| .. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`:
|
| https://ieeexplore.ieee.org/abstract/document/681427
|
|
|
| """
|
|
|
| assert taps % 2 == 0, "The number of taps mush be even number."
|
| assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0."
|
|
|
|
|
| omega_c = np.pi * cutoff_ratio
|
| with np.errstate(invalid='ignore'):
|
| h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) \
|
| / (np.pi * (np.arange(taps + 1) - 0.5 * taps))
|
| h_i[taps // 2] = np.cos(0) * cutoff_ratio
|
|
|
|
|
| w = kaiser(taps + 1, beta)
|
| h = h_i * w
|
|
|
| return h
|
|
|
|
|
| class PQMF(torch.nn.Module):
|
| """PQMF module.
|
|
|
| This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_.
|
|
|
| .. _`Near-perfect-reconstruction pseudo-QMF banks`:
|
| https://ieeexplore.ieee.org/document/258122
|
|
|
| """
|
|
|
| def __init__(self, subbands=4, taps=62, cutoff_ratio=0.15, beta=9.0):
|
| """Initilize PQMF module.
|
|
|
| Args:
|
| subbands (int): The number of subbands.
|
| taps (int): The number of filter taps.
|
| cutoff_ratio (float): Cut-off frequency ratio.
|
| beta (float): Beta coefficient for kaiser window.
|
|
|
| """
|
| super(PQMF, self).__init__()
|
|
|
|
|
| h_proto = design_prototype_filter(taps, cutoff_ratio, beta)
|
| h_analysis = np.zeros((subbands, len(h_proto)))
|
| h_synthesis = np.zeros((subbands, len(h_proto)))
|
| for k in range(subbands):
|
| h_analysis[k] = 2 * h_proto * np.cos(
|
| (2 * k + 1) * (np.pi / (2 * subbands)) *
|
| (np.arange(taps + 1) - ((taps - 1) / 2)) +
|
| (-1) ** k * np.pi / 4)
|
| h_synthesis[k] = 2 * h_proto * np.cos(
|
| (2 * k + 1) * (np.pi / (2 * subbands)) *
|
| (np.arange(taps + 1) - ((taps - 1) / 2)) -
|
| (-1) ** k * np.pi / 4)
|
|
|
|
|
| analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1)
|
| synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0)
|
|
|
|
|
| self.register_buffer("analysis_filter", analysis_filter)
|
| self.register_buffer("synthesis_filter", synthesis_filter)
|
|
|
|
|
| updown_filter = torch.zeros((subbands, subbands, subbands)).float()
|
| for k in range(subbands):
|
| updown_filter[k, k, 0] = 1.0
|
| self.register_buffer("updown_filter", updown_filter)
|
| self.subbands = subbands
|
|
|
|
|
| self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
|
|
|
| def analysis(self, x):
|
| """Analysis with PQMF.
|
|
|
| Args:
|
| x (Tensor): Input tensor (B, 1, T).
|
|
|
| Returns:
|
| Tensor: Output tensor (B, subbands, T // subbands).
|
|
|
| """
|
| x = F.conv1d(self.pad_fn(x), self.analysis_filter)
|
| return F.conv1d(x, self.updown_filter, stride=self.subbands)
|
|
|
| def synthesis(self, x):
|
| """Synthesis with PQMF.
|
|
|
| Args:
|
| x (Tensor): Input tensor (B, subbands, T // subbands).
|
|
|
| Returns:
|
| Tensor: Output tensor (B, 1, T).
|
|
|
| """
|
| x = F.conv_transpose1d(x, self.updown_filter * self.subbands, stride=self.subbands)
|
| return F.conv1d(self.pad_fn(x), self.synthesis_filter)
|
|
|