| |
|
|
| |
| |
|
|
| """ |
| |
| This module aims to be an entrance that integrates all the functions for extracting features from raw audio. |
| |
| The common audio features include: |
| 1. Acoustic features such as Mel Spectrogram, F0, Energy, etc. |
| 2. Content features such as phonetic posteriorgrams (PPG) and bottleneck features (BNF) from pretrained models |
| |
| Note: |
| All the features extraction are designed to utilize GPU to the maximum extent, which can ease the on-the-fly extraction for large-scale dataset. |
| |
| """ |
|
|
| import torch |
| from torch.nn.utils.rnn import pad_sequence |
|
|
| from utils.mel import extract_mel_features |
| from utils.f0 import get_f0 as extract_f0_features |
| from processors.content_extractor import ( |
| WhisperExtractor, |
| ContentvecExtractor, |
| WenetExtractor, |
| ) |
|
|
|
|
| class AudioFeaturesExtractor: |
| def __init__(self, cfg): |
| """ |
| Args: |
| cfg: Amphion config that would be used to specify the processing parameters |
| """ |
| self.cfg = cfg |
|
|
| def get_mel_spectrogram(self, wavs): |
| """Get Mel Spectrogram Features |
| |
| Args: |
| wavs: Tensor whose shape is (B, T) |
| |
| Returns: |
| Tensor whose shape is (B, n_mels, n_frames) |
| """ |
| return extract_mel_features(y=wavs, cfg=self.cfg.preprocess) |
|
|
| def get_f0(self, wavs, wav_lens=None, use_interpolate=False, return_uv=False): |
| """Get F0 Features |
| |
| Args: |
| wavs: Tensor whose shape is (B, T) |
| |
| Returns: |
| Tensor whose shape is (B, n_frames) |
| """ |
| device = wavs.device |
|
|
| f0s = [] |
| uvs = [] |
| for i, w in enumerate(wavs): |
| if wav_lens is not None: |
| w = w[: wav_lens[i]] |
|
|
| f0, uv = extract_f0_features( |
| |
| w.cpu().numpy(), |
| self.cfg.preprocess, |
| use_interpolate=use_interpolate, |
| return_uv=True, |
| ) |
| f0s.append(torch.as_tensor(f0, device=device)) |
| uvs.append(torch.as_tensor(uv, device=device, dtype=torch.long)) |
|
|
| |
| f0s = pad_sequence(f0s, batch_first=True, padding_value=0) |
| uvs = pad_sequence(uvs, batch_first=True, padding_value=0) |
|
|
| if return_uv: |
| return f0s, uvs |
|
|
| return f0s |
|
|
| def get_energy(self, wavs, mel_spec=None): |
| """Get Energy Features |
| |
| Args: |
| wavs: Tensor whose shape is (B, T) |
| mel_spec: Tensor whose shape is (B, n_mels, n_frames) |
| |
| Returns: |
| Tensor whose shape is (B, n_frames) |
| """ |
| if mel_spec is None: |
| mel_spec = self.get_mel_spectrogram(wavs) |
|
|
| energies = (mel_spec.exp() ** 2).sum(dim=1).sqrt() |
| return energies |
|
|
| def get_whisper_features(self, wavs, target_frame_len): |
| """Get Whisper Features |
| |
| Args: |
| wavs: Tensor whose shape is (B, T) |
| target_frame_len: int |
| |
| Returns: |
| Tensor whose shape is (B, target_frame_len, D) |
| """ |
| if not hasattr(self, "whisper_extractor"): |
| self.whisper_extractor = WhisperExtractor(self.cfg) |
| self.whisper_extractor.load_model() |
|
|
| whisper_feats = self.whisper_extractor.extract_content_features(wavs) |
| whisper_feats = self.whisper_extractor.ReTrans(whisper_feats, target_frame_len) |
| return whisper_feats |
|
|
| def get_contentvec_features(self, wavs, target_frame_len): |
| """Get ContentVec Features |
| |
| Args: |
| wavs: Tensor whose shape is (B, T) |
| target_frame_len: int |
| |
| Returns: |
| Tensor whose shape is (B, target_frame_len, D) |
| """ |
| if not hasattr(self, "contentvec_extractor"): |
| self.contentvec_extractor = ContentvecExtractor(self.cfg) |
| self.contentvec_extractor.load_model() |
|
|
| contentvec_feats = self.contentvec_extractor.extract_content_features(wavs) |
| contentvec_feats = self.contentvec_extractor.ReTrans( |
| contentvec_feats, target_frame_len |
| ) |
| return contentvec_feats |
|
|
| def get_wenet_features(self, wavs, target_frame_len, wav_lens=None): |
| """Get WeNet Features |
| |
| Args: |
| wavs: Tensor whose shape is (B, T) |
| target_frame_len: int |
| wav_lens: Tensor whose shape is (B) |
| |
| Returns: |
| Tensor whose shape is (B, target_frame_len, D) |
| """ |
| if not hasattr(self, "wenet_extractor"): |
| self.wenet_extractor = WenetExtractor(self.cfg) |
| self.wenet_extractor.load_model() |
|
|
| wenet_feats = self.wenet_extractor.extract_content_features(wavs, lens=wav_lens) |
| wenet_feats = self.wenet_extractor.ReTrans(wenet_feats, target_frame_len) |
| return wenet_feats |
|
|