S-KEY / skey /hcqt.py
2cylu2's picture
Upload 13 files
77ac75d verified
import math
import nnAudio.features.vqt
import torch
import torchaudio
# relies on nnAudio v0.3.2
# pip install git+https://github.com/KinWaiCheuk/nnAudio.git#subdirectory=Installation
class VQT(nnAudio.features.vqt.VQT):
r"""Harmonic VQT: A collection of VQTs with different shifts.
Inspired by Bittner, McFee, Salamon, Li, Bello.
"Deep Salience Representations for F0 Estimation in Polyphonic Music". ISMIR 2017
Args:
harmonics (Collection[float]): Harmonics to be included.
fmin (float): Minimum frequency to be included.
n_bins (int): Number of bins in the output spectrogram.
bins_per_octave (int, optional): Number of bins per octave. Defaults to 12.
"""
def __init__(self, *, harmonics, fmin, n_bins, bins_per_octave=12, **kwargs):
self.harmonics = harmonics
self.bin_shifts = []
self.n_bins_per_slice = n_bins
self.fmin = fmin
for harmonic in harmonics:
shift = round(bins_per_octave * math.log2(harmonic))
self.bin_shifts.append(shift)
low_octave_shift = min([0] + self.bin_shifts) / bins_per_octave
fmin = fmin * (2**low_octave_shift)
n_bins = n_bins + max([0] + self.bin_shifts) - min([0] + self.bin_shifts)
super().__init__(fmin=fmin, n_bins=n_bins, bins_per_octave=bins_per_octave, **kwargs)
def forward(self, x, output_format="Magnitude", normalization_type="librosa"):
vqt = super().forward(x, output_format, normalization_type)
hvqt = []
for shift in self.bin_shifts:
bin_start = shift - min([0] + self.bin_shifts)
bin_stop = bin_start + self.n_bins_per_slice
vqt_slice = vqt[:, bin_start:bin_stop, ...]
hvqt.append(vqt_slice)
hvqt = torch.stack(hvqt, dim=1)
log_hcqt = ((1.0 / 80.0) * torchaudio.transforms.AmplitudeToDB(top_db=80)(hvqt)) + 1.0
return log_hcqt
class CropCQT(torch.nn.Module):
"""
A PyTorch module for cropping Constant-Q Transform (CQT) spectrograms.
Args:
height (int): The height of the cropped spectrogram.
Methods:
forward(spectrograms, transpose):
Crops the input spectrograms based on the provided transpose values.
Args:
spectrograms (torch.Tensor): A batch of spectrograms to be cropped.
transpose (torch.Tensor): A tensor containing the starting indices for cropping.
Returns:
torch.Tensor: A batch of cropped spectrograms.
"""
def __init__(self, height: int):
super(CropCQT, self).__init__()
self.height = height
def forward(self, spectrograms: torch.Tensor, transpose: torch.Tensor) -> torch.Tensor:
return torch.stack(
[s[:, int(start_idx) : int(start_idx) + self.height, :] for s, start_idx in zip(spectrograms, transpose)]
)