| import os |
| from abc import abstractmethod |
| from typing import Callable |
|
|
| import numpy as np |
| import torch |
| from librosa import hz_to_midi, midi_to_hz |
| from torchaudio import functional as taF |
|
|
| |
| |
|
|
|
|
| def band_widths_from_specs(band_specs): |
| return [e - i for i, e in band_specs] |
|
|
|
|
| def check_nonzero_bandwidth(band_specs): |
| |
| for fstart, fend in band_specs: |
| if fend - fstart <= 0: |
| raise ValueError("Bands cannot be zero-width") |
|
|
|
|
| def check_no_overlap(band_specs): |
| fend_prev = -1 |
| for fstart_curr, fend_curr in band_specs: |
| if fstart_curr <= fend_prev: |
| raise ValueError("Bands cannot overlap") |
|
|
|
|
| def check_no_gap(band_specs): |
| fstart, _ = band_specs[0] |
| assert fstart == 0 |
|
|
| fend_prev = -1 |
| for fstart_curr, fend_curr in band_specs: |
| if fstart_curr - fend_prev > 1: |
| raise ValueError("Bands cannot leave gap") |
| fend_prev = fend_curr |
|
|
|
|
| class BandsplitSpecification: |
| def __init__(self, nfft: int, fs: int) -> None: |
| self.fs = fs |
| self.nfft = nfft |
| self.nyquist = fs / 2 |
| self.max_index = nfft // 2 + 1 |
|
|
| self.split500 = self.hertz_to_index(500) |
| self.split1k = self.hertz_to_index(1000) |
| self.split2k = self.hertz_to_index(2000) |
| self.split4k = self.hertz_to_index(4000) |
| self.split8k = self.hertz_to_index(8000) |
| self.split16k = self.hertz_to_index(16000) |
| self.split20k = self.hertz_to_index(20000) |
|
|
| self.above20k = [(self.split20k, self.max_index)] |
| self.above16k = [(self.split16k, self.split20k)] + self.above20k |
|
|
| def index_to_hertz(self, index: int): |
| return index * self.fs / self.nfft |
|
|
| def hertz_to_index(self, hz: float, round: bool = True): |
| index = hz * self.nfft / self.fs |
|
|
| if round: |
| index = int(np.round(index)) |
|
|
| return index |
|
|
| def get_band_specs_with_bandwidth(self, start_index, end_index, bandwidth_hz): |
| band_specs = [] |
| lower = start_index |
|
|
| while lower < end_index: |
| upper = int(np.floor(lower + self.hertz_to_index(bandwidth_hz))) |
| upper = min(upper, end_index) |
|
|
| band_specs.append((lower, upper)) |
| lower = upper |
|
|
| return band_specs |
|
|
| @abstractmethod |
| def get_band_specs(self): |
| raise NotImplementedError |
|
|
|
|
| class VocalBandsplitSpecification(BandsplitSpecification): |
| def __init__(self, nfft: int, fs: int, version: str = "7") -> None: |
| super().__init__(nfft=nfft, fs=fs) |
|
|
| self.version = version |
|
|
| def get_band_specs(self): |
| return getattr(self, f"version{self.version}")() |
|
|
| @property |
| def version1(self): |
| return self.get_band_specs_with_bandwidth( |
| start_index=0, end_index=self.max_index, bandwidth_hz=1000 |
| ) |
|
|
| def version2(self): |
| below16k = self.get_band_specs_with_bandwidth( |
| start_index=0, end_index=self.split16k, bandwidth_hz=1000 |
| ) |
| below20k = self.get_band_specs_with_bandwidth( |
| start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000 |
| ) |
|
|
| return below16k + below20k + self.above20k |
|
|
| def version3(self): |
| below8k = self.get_band_specs_with_bandwidth( |
| start_index=0, end_index=self.split8k, bandwidth_hz=1000 |
| ) |
| below16k = self.get_band_specs_with_bandwidth( |
| start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000 |
| ) |
|
|
| return below8k + below16k + self.above16k |
|
|
| def version4(self): |
| below1k = self.get_band_specs_with_bandwidth( |
| start_index=0, end_index=self.split1k, bandwidth_hz=100 |
| ) |
| below8k = self.get_band_specs_with_bandwidth( |
| start_index=self.split1k, end_index=self.split8k, bandwidth_hz=1000 |
| ) |
| below16k = self.get_band_specs_with_bandwidth( |
| start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000 |
| ) |
|
|
| return below1k + below8k + below16k + self.above16k |
|
|
| def version5(self): |
| below1k = self.get_band_specs_with_bandwidth( |
| start_index=0, end_index=self.split1k, bandwidth_hz=100 |
| ) |
| below16k = self.get_band_specs_with_bandwidth( |
| start_index=self.split1k, end_index=self.split16k, bandwidth_hz=1000 |
| ) |
| below20k = self.get_band_specs_with_bandwidth( |
| start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000 |
| ) |
| return below1k + below16k + below20k + self.above20k |
|
|
| def version6(self): |
| below1k = self.get_band_specs_with_bandwidth( |
| start_index=0, end_index=self.split1k, bandwidth_hz=100 |
| ) |
| below4k = self.get_band_specs_with_bandwidth( |
| start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500 |
| ) |
| below8k = self.get_band_specs_with_bandwidth( |
| start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000 |
| ) |
| below16k = self.get_band_specs_with_bandwidth( |
| start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000 |
| ) |
| return below1k + below4k + below8k + below16k + self.above16k |
|
|
| def version7(self): |
| below1k = self.get_band_specs_with_bandwidth( |
| start_index=0, end_index=self.split1k, bandwidth_hz=100 |
| ) |
| below4k = self.get_band_specs_with_bandwidth( |
| start_index=self.split1k, end_index=self.split4k, bandwidth_hz=250 |
| ) |
| below8k = self.get_band_specs_with_bandwidth( |
| start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500 |
| ) |
| below16k = self.get_band_specs_with_bandwidth( |
| start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000 |
| ) |
| below20k = self.get_band_specs_with_bandwidth( |
| start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000 |
| ) |
| return below1k + below4k + below8k + below16k + below20k + self.above20k |
|
|
|
|
| class OtherBandsplitSpecification(VocalBandsplitSpecification): |
| def __init__(self, nfft: int, fs: int) -> None: |
| super().__init__(nfft=nfft, fs=fs, version="7") |
|
|
|
|
| class BassBandsplitSpecification(BandsplitSpecification): |
| def __init__(self, nfft: int, fs: int, version: str = "7") -> None: |
| super().__init__(nfft=nfft, fs=fs) |
|
|
| def get_band_specs(self): |
| below500 = self.get_band_specs_with_bandwidth( |
| start_index=0, end_index=self.split500, bandwidth_hz=50 |
| ) |
| below1k = self.get_band_specs_with_bandwidth( |
| start_index=self.split500, end_index=self.split1k, bandwidth_hz=100 |
| ) |
| below4k = self.get_band_specs_with_bandwidth( |
| start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500 |
| ) |
| below8k = self.get_band_specs_with_bandwidth( |
| start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000 |
| ) |
| below16k = self.get_band_specs_with_bandwidth( |
| start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000 |
| ) |
| above16k = [(self.split16k, self.max_index)] |
|
|
| return below500 + below1k + below4k + below8k + below16k + above16k |
|
|
|
|
| class DrumBandsplitSpecification(BandsplitSpecification): |
| def __init__(self, nfft: int, fs: int) -> None: |
| super().__init__(nfft=nfft, fs=fs) |
|
|
| def get_band_specs(self): |
| below1k = self.get_band_specs_with_bandwidth( |
| start_index=0, end_index=self.split1k, bandwidth_hz=50 |
| ) |
| below2k = self.get_band_specs_with_bandwidth( |
| start_index=self.split1k, end_index=self.split2k, bandwidth_hz=100 |
| ) |
| below4k = self.get_band_specs_with_bandwidth( |
| start_index=self.split2k, end_index=self.split4k, bandwidth_hz=250 |
| ) |
| below8k = self.get_band_specs_with_bandwidth( |
| start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500 |
| ) |
| below16k = self.get_band_specs_with_bandwidth( |
| start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000 |
| ) |
| above16k = [(self.split16k, self.max_index)] |
|
|
| return below1k + below2k + below4k + below8k + below16k + above16k |
|
|
|
|
| class PerceptualBandsplitSpecification(BandsplitSpecification): |
| def __init__( |
| self, |
| nfft: int, |
| fs: int, |
| fbank_fn: Callable[[int, int, float, float, int], torch.Tensor], |
| n_bands: int, |
| f_min: float = 0.0, |
| f_max: float = None, |
| ) -> None: |
| super().__init__(nfft=nfft, fs=fs) |
| self.n_bands = n_bands |
| if f_max is None: |
| f_max = fs / 2 |
|
|
| self.filterbank = fbank_fn(n_bands, fs, f_min, f_max, self.max_index) |
|
|
| weight_per_bin = torch.sum(self.filterbank, dim=0, keepdim=True) |
| normalized_mel_fb = self.filterbank / weight_per_bin |
|
|
| freq_weights = [] |
| band_specs = [] |
| for i in range(self.n_bands): |
| active_bins = torch.nonzero(self.filterbank[i, :]).squeeze().tolist() |
| if isinstance(active_bins, int): |
| active_bins = (active_bins, active_bins) |
| if len(active_bins) == 0: |
| continue |
| start_index = active_bins[0] |
| end_index = active_bins[-1] + 1 |
| band_specs.append((start_index, end_index)) |
| freq_weights.append(normalized_mel_fb[i, start_index:end_index]) |
|
|
| self.freq_weights = freq_weights |
| self.band_specs = band_specs |
|
|
| def get_band_specs(self): |
| return self.band_specs |
|
|
| def get_freq_weights(self): |
| return self.freq_weights |
|
|
| def save_to_file(self, dir_path: str) -> None: |
| os.makedirs(dir_path, exist_ok=True) |
|
|
| import pickle |
|
|
| with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f: |
| pickle.dump( |
| { |
| "band_specs": self.band_specs, |
| "freq_weights": self.freq_weights, |
| "filterbank": self.filterbank, |
| }, |
| f, |
| ) |
|
|
|
|
| def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs): |
| fb = taF.melscale_fbanks( |
| n_mels=n_bands, |
| sample_rate=fs, |
| f_min=f_min, |
| f_max=f_max, |
| n_freqs=n_freqs, |
| ).T |
|
|
| fb[0, 0] = 1.0 |
|
|
| return fb |
|
|
|
|
| class MelBandsplitSpecification(PerceptualBandsplitSpecification): |
| def __init__( |
| self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None |
| ) -> None: |
| super().__init__( |
| fbank_fn=mel_filterbank, |
| nfft=nfft, |
| fs=fs, |
| n_bands=n_bands, |
| f_min=f_min, |
| f_max=f_max, |
| ) |
|
|
|
|
| def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs, scale="constant"): |
| nfft = 2 * (n_freqs - 1) |
| df = fs / nfft |
| |
| f_max = f_max or fs / 2 |
| f_min = f_min or 0 |
| f_min = fs / nfft |
|
|
| n_octaves = np.log2(f_max / f_min) |
| n_octaves_per_band = n_octaves / n_bands |
| bandwidth_mult = np.power(2.0, n_octaves_per_band) |
|
|
| low_midi = max(0, hz_to_midi(f_min)) |
| high_midi = hz_to_midi(f_max) |
| midi_points = np.linspace(low_midi, high_midi, n_bands) |
| hz_pts = midi_to_hz(midi_points) |
|
|
| low_pts = hz_pts / bandwidth_mult |
| high_pts = hz_pts * bandwidth_mult |
|
|
| low_bins = np.floor(low_pts / df).astype(int) |
| high_bins = np.ceil(high_pts / df).astype(int) |
|
|
| fb = np.zeros((n_bands, n_freqs)) |
|
|
| for i in range(n_bands): |
| fb[i, low_bins[i] : high_bins[i] + 1] = 1.0 |
|
|
| fb[0, : low_bins[0]] = 1.0 |
| fb[-1, high_bins[-1] + 1 :] = 1.0 |
|
|
| return torch.as_tensor(fb) |
|
|
|
|
| class MusicalBandsplitSpecification(PerceptualBandsplitSpecification): |
| def __init__( |
| self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None |
| ) -> None: |
| super().__init__( |
| fbank_fn=musical_filterbank, |
| nfft=nfft, |
| fs=fs, |
| n_bands=n_bands, |
| f_min=f_min, |
| f_max=f_max, |
| ) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
|
|
| |
| |
|
|
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
|
|
|
|
| |
| |
|
|
| |
|
|
| |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| if __name__ == "__main__": |
| import pandas as pd |
|
|
| band_defs = [] |
|
|
| for bands in [VocalBandsplitSpecification]: |
| band_name = bands.__name__.replace("BandsplitSpecification", "") |
|
|
| mbs = bands(nfft=2048, fs=44100).get_band_specs() |
|
|
| for i, (f_min, f_max) in enumerate(mbs): |
| band_defs.append( |
| {"band": band_name, "band_index": i, "f_min": f_min, "f_max": f_max} |
| ) |
|
|
| df = pd.DataFrame(band_defs) |
| df.to_csv("vox7bands.csv", index=False) |
|
|