| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Feature extractor class for CLAP.""" |
|
|
|
|
| import copy |
| from typing import Any, Dict, List, Optional, Union |
|
|
| import numpy as np |
| import torch |
|
|
| from ...audio_utils import mel_filter_bank, spectrogram, window_function |
| from ...feature_extraction_sequence_utils import SequenceFeatureExtractor |
| from ...feature_extraction_utils import BatchFeature |
| from ...utils import TensorType, logging |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class ClapFeatureExtractor(SequenceFeatureExtractor): |
| r""" |
| Constructs a CLAP feature extractor. |
| |
| This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains |
| most of the main methods. Users should refer to this superclass for more information regarding those methods. |
| |
| This class extracts mel-filter bank features from raw speech using a custom numpy implementation of the *Short Time |
| Fourier Transform* (STFT) which should match pytorch's `torch.stft` equivalent. |
| |
| Args: |
| feature_size (`int`, *optional*, defaults to 64): |
| The feature dimension of the extracted Mel spectrograms. This corresponds to the number of mel filters |
| (`n_mels`). |
| sampling_rate (`int`, *optional*, defaults to 48000): |
| The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). This only serves |
| to warn users if the audio fed to the feature extractor does not have the same sampling rate. |
| hop_length (`int`,*optional*, defaults to 480): |
| Length of the overlaping windows for the STFT used to obtain the Mel Spectrogram. The audio will be split |
| in smaller `frames` with a step of `hop_length` between each frame. |
| max_length_s (`int`, *optional*, defaults to 10): |
| The maximum input length of the model in seconds. This is used to pad the audio. |
| fft_window_size (`int`, *optional*, defaults to 1024): |
| Size of the window (in samples) on which the Fourier transform is applied. This controls the frequency |
| resolution of the spectrogram. 400 means that the fourrier transform is computed on windows of 400 samples. |
| padding_value (`float`, *optional*, defaults to 0.0): |
| Padding value used to pad the audio. Should correspond to silences. |
| return_attention_mask (`bool`, *optional*, defaults to `False`): |
| Whether or not the model should return the attention masks coresponding to the input. |
| frequency_min (`float`, *optional*, defaults to 0): |
| The lowest frequency of interest. The STFT will not be computed for values below this. |
| frequency_max (`float`, *optional*, defaults to 14000): |
| The highest frequency of interest. The STFT will not be computed for values above this. |
| top_db (`float`, *optional*): |
| The highest decibel value used to convert the mel spectrogram to the log scale. For more details see the |
| `audio_utils.power_to_db` function |
| truncation (`str`, *optional*, defaults to `"fusion"`): |
| Truncation pattern for long audio inputs. Two patterns are available: |
| - `fusion` will use `_random_mel_fusion`, which stacks 3 random crops from the mel spectrogram and a |
| downsampled version of the entire mel spectrogram. |
| If `config.fusion` is set to True, shorter audios also need to to return 4 mels, which will just be a copy |
| of the original mel obtained from the padded audio. |
| - `rand_trunc` will select a random crop of the mel spectrogram. |
| padding (`str`, *optional*, defaults to `"repeatpad"`): |
| Padding pattern for shorter audio inputs. Three patterns were originally implemented: |
| - `repeatpad`: the audio is repeated, and then padded to fit the `max_length`. |
| - `repeat`: the audio is repeated and then cut to fit the `max_length` |
| - `pad`: the audio is padded. |
| """ |
|
|
| model_input_names = ["input_features", "is_longer"] |
|
|
| def __init__( |
| self, |
| feature_size=64, |
| sampling_rate=48_000, |
| hop_length=480, |
| max_length_s=10, |
| fft_window_size=1024, |
| padding_value=0.0, |
| return_attention_mask=False, |
| frequency_min: float = 0, |
| frequency_max: float = 14_000, |
| top_db: int = None, |
| truncation: str = "fusion", |
| padding: str = "repeatpad", |
| **kwargs, |
| ): |
| super().__init__( |
| feature_size=feature_size, |
| sampling_rate=sampling_rate, |
| padding_value=padding_value, |
| return_attention_mask=return_attention_mask, |
| **kwargs, |
| ) |
| self.top_db = top_db |
| self.truncation = truncation |
| self.padding = padding |
| self.fft_window_size = fft_window_size |
| self.nb_frequency_bins = (fft_window_size >> 1) + 1 |
| self.hop_length = hop_length |
| self.max_length_s = max_length_s |
| self.nb_max_samples = max_length_s * sampling_rate |
| self.sampling_rate = sampling_rate |
| self.frequency_min = frequency_min |
| self.frequency_max = frequency_max |
| self.mel_filters = mel_filter_bank( |
| num_frequency_bins=self.nb_frequency_bins, |
| num_mel_filters=feature_size, |
| min_frequency=frequency_min, |
| max_frequency=frequency_max, |
| sampling_rate=sampling_rate, |
| norm=None, |
| mel_scale="htk", |
| ) |
| self.mel_filters_slaney = mel_filter_bank( |
| num_frequency_bins=self.nb_frequency_bins, |
| num_mel_filters=feature_size, |
| min_frequency=frequency_min, |
| max_frequency=frequency_max, |
| sampling_rate=sampling_rate, |
| norm="slaney", |
| mel_scale="slaney", |
| ) |
|
|
| def to_dict(self) -> Dict[str, Any]: |
| """ |
| Serializes this instance to a Python dictionary. |
| |
| Returns: |
| `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance, excpet for the |
| mel filter banks, which do not need to be saved or printed as they are too long. |
| """ |
| output = copy.deepcopy(self.__dict__) |
| output["feature_extractor_type"] = self.__class__.__name__ |
| if "mel_filters" in output: |
| del output["mel_filters"] |
| if "mel_filters_slaney" in output: |
| del output["mel_filters_slaney"] |
| return output |
|
|
| def _np_extract_fbank_features(self, waveform: np.array, mel_filters: Optional[np.array] = None) -> np.ndarray: |
| """ |
| Compute the log-mel spectrogram of the provided `waveform` using the Hann window. In CLAP, two different filter |
| banks are used depending on the truncation pattern: |
| - `self.mel_filters`: they correspond to the default parameters of `torchaudio` which can be obtained from |
| calling `torchaudio.transforms.MelSpectrogram().mel_scale.fb`. These filters are used when `truncation` |
| is set to `"fusion"`. |
| - `self.mel_filteres_slaney` : they correspond to the default parameters of `librosa` which used |
| `librosa.filters.mel` when computing the mel spectrogram. These filters were only used in the original |
| implementation when the truncation mode is not `"fusion"`. |
| """ |
| log_mel_spectrogram = spectrogram( |
| waveform, |
| window_function(self.fft_window_size, "hann"), |
| frame_length=self.fft_window_size, |
| hop_length=self.hop_length, |
| power=2.0, |
| mel_filters=mel_filters, |
| log_mel="dB", |
| ) |
| return log_mel_spectrogram.T |
|
|
| def _random_mel_fusion(self, mel, total_frames, chunk_frames): |
| ranges = np.array_split(list(range(0, total_frames - chunk_frames + 1)), 3) |
| if len(ranges[1]) == 0: |
| |
| ranges[1] = [0] |
| if len(ranges[2]) == 0: |
| |
| ranges[2] = [0] |
| |
| idx_front = np.random.choice(ranges[0]) |
| idx_middle = np.random.choice(ranges[1]) |
| idx_back = np.random.choice(ranges[2]) |
|
|
| mel_chunk_front = mel[idx_front : idx_front + chunk_frames, :] |
| mel_chunk_middle = mel[idx_middle : idx_middle + chunk_frames, :] |
| mel_chunk_back = mel[idx_back : idx_back + chunk_frames, :] |
|
|
| mel = torch.tensor(mel[None, None, :]) |
| mel_shrink = torch.nn.functional.interpolate( |
| mel, size=[chunk_frames, 64], mode="bilinear", align_corners=False |
| ) |
| mel_shrink = mel_shrink[0][0].numpy() |
| mel_fusion = np.stack([mel_shrink, mel_chunk_front, mel_chunk_middle, mel_chunk_back], axis=0) |
| return mel_fusion |
|
|
| def _get_input_mel(self, waveform: np.array, max_length, truncation, padding) -> np.array: |
| """ |
| Extracts the mel spectrogram and prepares it for the mode based on the `truncation` and `padding` arguments. |
| Four different path are possible: |
| - `truncation="fusion"` and the length of the waveform is greater than the max length: the mel spectrogram |
| will be computed on the entire audio. 3 random crops and a dowsampled version of the full mel spectrogram |
| are then stacked together. They will later be used for `feature_fusion`. |
| - `truncation="rand_trunc"` and the length of the waveform is smaller than the max length: the audio is |
| padded based on `padding`. |
| - `truncation="fusion"` and the length of the waveform is smaller than the max length: the audio is padded |
| based on `padding`, and is repeated `4` times. |
| - `truncation="rand_trunc"` and the length of the waveform is greater than the max length: the mel |
| spectrogram will be computed on a random crop of the waveform. |
| |
| """ |
| if waveform.shape[0] > max_length: |
| if truncation == "rand_trunc": |
| longer = True |
| |
| overflow = len(waveform) - max_length |
| idx = np.random.randint(0, overflow + 1) |
| waveform = waveform[idx : idx + max_length] |
| input_mel = self._np_extract_fbank_features(waveform, self.mel_filters_slaney)[None, :] |
| elif truncation == "fusion": |
| mel = self._np_extract_fbank_features(waveform, self.mel_filters) |
| chunk_frames = max_length // self.hop_length + 1 |
| total_frames = mel.shape[0] |
| if chunk_frames == total_frames: |
| |
| |
| input_mel = np.stack([mel, mel, mel, mel], axis=0) |
| longer = False |
| else: |
| input_mel = self._random_mel_fusion(mel, total_frames, chunk_frames) |
| longer = True |
| else: |
| raise NotImplementedError(f"data_truncating {truncation} not implemented") |
|
|
| else: |
| longer = False |
| |
| if waveform.shape[0] < max_length: |
| if padding == "repeat": |
| n_repeat = int(max_length / len(waveform)) |
| waveform = np.stack(np.tile(waveform, n_repeat + 1))[:max_length] |
| if padding == "repeatpad": |
| n_repeat = int(max_length / len(waveform)) |
| waveform = np.stack(np.tile(waveform, n_repeat)) |
| waveform = np.pad(waveform, (0, max_length - waveform.shape[0]), mode="constant", constant_values=0) |
|
|
| if truncation == "fusion": |
| input_mel = self._np_extract_fbank_features(waveform, self.mel_filters) |
| input_mel = np.stack([input_mel, input_mel, input_mel, input_mel], axis=0) |
| else: |
| input_mel = self._np_extract_fbank_features(waveform, self.mel_filters_slaney)[None, :] |
|
|
| return input_mel, longer |
|
|
| def __call__( |
| self, |
| raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]], |
| truncation: str = None, |
| padding: Optional[str] = None, |
| max_length: Optional[int] = None, |
| sampling_rate: Optional[int] = None, |
| return_tensors: Optional[Union[str, TensorType]] = None, |
| **kwargs, |
| ) -> BatchFeature: |
| """ |
| Main method to featurize and prepare for the model one or several sequence(s). |
| |
| Args: |
| raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`): |
| The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float |
| values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not |
| stereo, i.e. single float per timestep. |
| truncation (`str`, *optional*): |
| Truncation pattern for long audio inputs. Two patterns are available: |
| - `fusion` will use `_random_mel_fusion`, which stacks 3 random crops from the mel spectrogram and |
| a downsampled version of the entire mel spectrogram. |
| If `config.fusion` is set to True, shorter audios also need to to return 4 mels, which will just be a |
| copy of the original mel obtained from the padded audio. |
| - `rand_trunc` will select a random crop of the mel spectrogram. |
| padding (`str`, *optional*): |
| Padding pattern for shorter audio inputs. Three patterns were originally implemented: |
| - `repeatpad`: the audio is repeated, and then padded to fit the `max_length`. |
| - `repeat`: the audio is repeated and then cut to fit the `max_length` |
| - `pad`: the audio is padded. |
| return_tensors (`str` or [`~utils.TensorType`], *optional*): |
| If set, will return tensors instead of list of python integers. Acceptable values are: |
| |
| - `'tf'`: Return TensorFlow `tf.constant` objects. |
| - `'pt'`: Return PyTorch `torch.np.array` objects. |
| - `'np'`: Return Numpy `np.ndarray` objects. |
| sampling_rate (`int`, *optional*): |
| The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass |
| `sampling_rate` at the forward call to prevent silent errors and allow automatic speech recognition |
| pipeline. |
| """ |
| truncation = truncation if truncation is not None else self.truncation |
| padding = padding if padding else self.padding |
|
|
| if sampling_rate is not None: |
| if sampling_rate != self.sampling_rate: |
| raise ValueError( |
| f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a" |
| f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input" |
| f" was sampled with {self.sampling_rate} and not {sampling_rate}." |
| ) |
| else: |
| logger.warning( |
| "It is strongly recommended to pass the `sampling_rate` argument to this function. " |
| "Failing to do so can result in silent errors that might be hard to debug." |
| ) |
|
|
| is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 |
| if is_batched_numpy and len(raw_speech.shape) > 2: |
| raise ValueError(f"Only mono-channel audio is supported for input to {self}") |
| is_batched = is_batched_numpy or ( |
| isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) |
| ) |
|
|
| if is_batched: |
| raw_speech = [np.asarray(speech, dtype=np.float64) for speech in raw_speech] |
| elif not is_batched and not isinstance(raw_speech, np.ndarray): |
| raw_speech = np.asarray(raw_speech, dtype=np.float64) |
| elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64): |
| raw_speech = raw_speech.astype(np.float64) |
|
|
| |
| if not is_batched: |
| raw_speech = [np.asarray(raw_speech)] |
|
|
| |
| padded_inputs = [ |
| self._get_input_mel(waveform, max_length if max_length else self.nb_max_samples, truncation, padding) |
| for waveform in raw_speech |
| ] |
|
|
| input_mel = [] |
| is_longer = [] |
| for mel, longer in padded_inputs: |
| input_mel.append(mel) |
| is_longer.append(longer) |
|
|
| if truncation == "fusion" and sum(is_longer) == 0: |
| |
| rand_idx = np.random.randint(0, len(input_mel)) |
| is_longer[rand_idx] = True |
|
|
| if isinstance(input_mel[0], List): |
| input_mel = [np.asarray(feature, dtype=np.float64) for feature in input_mel] |
|
|
| |
| is_longer = [[longer] for longer in is_longer] |
|
|
| input_features = {"input_features": input_mel, "is_longer": is_longer} |
| input_features = BatchFeature(input_features) |
|
|
| if return_tensors is not None: |
| input_features = input_features.convert_to_tensors(return_tensors) |
|
|
| return input_features |
|
|