Spaces:
Running
Running
File size: 7,128 Bytes
ca97aa9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import { FeatureExtractor, validate_audio_inputs } from '../../base/feature_extraction_utils.js';
import { Tensor } from '../../utils/tensor.js';
import { mel_filter_bank, spectrogram, window_function } from '../../utils/audio.js';
export class ClapFeatureExtractor extends FeatureExtractor {
constructor(config) {
super(config);
this.mel_filters = mel_filter_bank(
this.config.nb_frequency_bins, // num_frequency_bins
this.config.feature_size, // num_mel_filters
this.config.frequency_min, // min_frequency
this.config.frequency_max, // max_frequency
this.config.sampling_rate, // sampling_rate
null, // norm
"htk", // mel_scale
);
this.mel_filters_slaney = mel_filter_bank(
this.config.nb_frequency_bins, // num_frequency_bins
this.config.feature_size, // num_mel_filters
this.config.frequency_min, // min_frequency
this.config.frequency_max, // max_frequency
this.config.sampling_rate, // sampling_rate
"slaney", // norm
"slaney", // mel_scale
);
this.window = window_function(this.config.fft_window_size, 'hann')
}
/**
* Extracts the mel spectrogram and prepares it for the mode based on the `truncation` and `padding` arguments.
*
* Four different path are possible:
* - `truncation="fusion"` and the length of the waveform is greater than the max length: the mel spectrogram
* will be computed on the entire audio. 3 random crops and a dowsampled version of the full mel spectrogram
* are then stacked together. They will later be used for `feature_fusion`.
* - `truncation="rand_trunc"` and the length of the waveform is smaller than the max length: the audio is
* padded based on `padding`.
* - `truncation="fusion"` and the length of the waveform is smaller than the max length: the audio is padded
* based on `padding`, and is repeated `4` times.
* - `truncation="rand_trunc"` and the length of the waveform is greater than the max length: the mel
* spectrogram will be computed on a random crop of the waveform.
*
* @param {Float32Array|Float64Array} waveform The input waveform.
* @param {number} max_length The maximum length of the waveform.
* @param {string} truncation The truncation strategy to use.
* @param {string} padding The padding strategy to use.
* @returns {Promise<Tensor>} An object containing the mel spectrogram data as a Float32Array, its dimensions as an array of numbers, and a boolean indicating whether the waveform was longer than the max length.
* @private
*/
async _get_input_mel(waveform, max_length, truncation, padding) {
/** @type {Tensor} */
let input_mel;
let longer = false;
const diff = waveform.length - max_length;
if (diff > 0) {
if (truncation === 'rand_trunc') {
longer = true;
const idx = Math.floor(Math.random() * (diff + 1));
waveform = waveform.subarray(idx, idx + max_length);
input_mel = await this._extract_fbank_features(waveform, this.mel_filters_slaney, this.config.nb_max_samples);
} else {
// TODO implement fusion strategy
throw new Error(`Truncation strategy "${truncation}" not implemented`)
}
} else {
if (diff < 0) {
let padded = new Float64Array(max_length); // already padded with zeros
padded.set(waveform);
if (padding === 'repeat') {
for (let i = waveform.length; i < max_length; i += waveform.length) {
padded.set(waveform.subarray(0, Math.min(waveform.length, max_length - i)), i);
}
} else if (padding === 'repeatpad') {
for (let i = waveform.length; i < -diff; i += waveform.length) {
padded.set(waveform, i);
}
}
waveform = padded;
}
if (truncation === 'fusion') {
throw new Error(`Truncation strategy "${truncation}" not implemented`)
}
input_mel = await this._extract_fbank_features(waveform, this.mel_filters_slaney, this.config.nb_max_samples);
}
return input_mel.unsqueeze_(0);
}
/**
* Compute the log-mel spectrogram of the provided `waveform` using the Hann window.
* In CLAP, two different filter banks are used depending on the truncation pattern:
* - `self.mel_filters`: they correspond to the default parameters of `torchaudio` which can be obtained from
* calling `torchaudio.transforms.MelSpectrogram().mel_scale.fb`. These filters are used when `truncation`
* is set to `"fusion"`.
* - `self.mel_filteres_slaney` : they correspond to the default parameters of `librosa` which used
* `librosa.filters.mel` when computing the mel spectrogram. These filters were only used in the original
* implementation when the truncation mode is not `"fusion"`.
*
* @param {Float32Array|Float64Array} waveform The audio waveform to process.
* @param {number[][]} mel_filters The mel filters to use.
* @param {number} [max_length=null] The maximum number of frames to return.
* @returns {Promise<Tensor>} An object containing the log-Mel spectrogram data as a Float32Array and its dimensions as an array of numbers.
*/
async _extract_fbank_features(waveform, mel_filters, max_length = null) {
// NOTE: We don't pad/truncate since that is passed in as `max_num_frames`
return spectrogram(
waveform,
this.window, // window
this.config.fft_window_size, // frame_length
this.config.hop_length, // hop_length
{
power: 2.0,
mel_filters,
log_mel: 'dB',
// Custom
max_num_frames: max_length,
do_pad: false,
transpose: true,
}
)
}
/**
* Asynchronously extracts features from a given audio using the provided configuration.
* @param {Float32Array|Float64Array} audio The audio data as a Float32Array/Float64Array.
* @returns {Promise<{ input_features: Tensor }>} A Promise resolving to an object containing the extracted input features as a Tensor.
*/
async _call(audio, {
max_length = null,
} = {}) {
validate_audio_inputs(audio, 'ClapFeatureExtractor');
// convert to mel spectrogram, truncate and pad if needed.
const padded_inputs = await this._get_input_mel(
audio,
max_length ?? this.config.nb_max_samples,
this.config.truncation,
this.config.padding,
);
return {
input_features: padded_inputs.unsqueeze_(0),
}
}
}
|