Spaces:
Configuration error
Configuration error
| # pylint: disable=C0301 | |
| ''' | |
| This module contains the AudioProcessor class and related functions for processing audio data. | |
| It utilizes various libraries and models to perform tasks such as preprocessing, feature extraction, | |
| and audio separation. The class is initialized with configuration parameters and can process | |
| audio files using the provided models. | |
| ''' | |
| import math | |
| import os | |
| import librosa | |
| import numpy as np | |
| import torch | |
| from audio_separator.separator import Separator | |
| from einops import rearrange | |
| from transformers import Wav2Vec2FeatureExtractor | |
| from hallo.models.wav2vec import Wav2VecModel | |
| from hallo.utils.util import resample_audio | |
| class AudioProcessor: | |
| """ | |
| AudioProcessor is a class that handles the processing of audio files. | |
| It takes care of preprocessing the audio files, extracting features | |
| using wav2vec models, and separating audio signals if needed. | |
| :param sample_rate: Sampling rate of the audio file | |
| :param fps: Frames per second for the extracted features | |
| :param wav2vec_model_path: Path to the wav2vec model | |
| :param only_last_features: Whether to only use the last features | |
| :param audio_separator_model_path: Path to the audio separator model | |
| :param audio_separator_model_name: Name of the audio separator model | |
| :param cache_dir: Directory to cache the intermediate results | |
| :param device: Device to run the processing on | |
| """ | |
| def __init__( | |
| self, | |
| sample_rate, | |
| fps, | |
| wav2vec_model_path, | |
| only_last_features, | |
| audio_separator_model_path:str=None, | |
| audio_separator_model_name:str=None, | |
| cache_dir:str='', | |
| device="cuda:0", | |
| ) -> None: | |
| self.sample_rate = sample_rate | |
| self.fps = fps | |
| self.device = device | |
| self.audio_encoder = Wav2VecModel.from_pretrained(wav2vec_model_path, local_files_only=True).to(device=device) | |
| self.audio_encoder.feature_extractor._freeze_parameters() | |
| self.only_last_features = only_last_features | |
| if audio_separator_model_name is not None: | |
| try: | |
| os.makedirs(cache_dir, exist_ok=True) | |
| except OSError as _: | |
| print("Fail to create the output cache dir.") | |
| self.audio_separator = Separator( | |
| output_dir=cache_dir, | |
| output_single_stem="vocals", | |
| model_file_dir=audio_separator_model_path, | |
| ) | |
| self.audio_separator.load_model(audio_separator_model_name) | |
| assert self.audio_separator.model_instance is not None, "Fail to load audio separate model." | |
| else: | |
| self.audio_separator=None | |
| print("Use audio directly without vocals seperator.") | |
| self.wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_model_path, local_files_only=True) | |
| def preprocess(self, wav_file: str, clip_length: int): | |
| """ | |
| Preprocess a WAV audio file by separating the vocals from the background and resampling it to a 16 kHz sample rate. | |
| The separated vocal track is then converted into wav2vec2 for further processing or analysis. | |
| Args: | |
| wav_file (str): The path to the WAV file to be processed. This file should be accessible and in WAV format. | |
| Raises: | |
| RuntimeError: Raises an exception if the WAV file cannot be processed. This could be due to issues | |
| such as file not found, unsupported file format, or errors during the audio processing steps. | |
| Returns: | |
| torch.tensor: Returns an audio embedding as a torch.tensor | |
| """ | |
| if self.audio_separator is not None: | |
| # 1. separate vocals | |
| # TODO: process in memory | |
| outputs = self.audio_separator.separate(wav_file) | |
| if len(outputs) <= 0: | |
| raise RuntimeError("Audio separate failed.") | |
| vocal_audio_file = outputs[0] | |
| vocal_audio_name, _ = os.path.splitext(vocal_audio_file) | |
| vocal_audio_file = os.path.join(self.audio_separator.output_dir, vocal_audio_file) | |
| vocal_audio_file = resample_audio(vocal_audio_file, os.path.join(self.audio_separator.output_dir, f"{vocal_audio_name}-16k.wav"), self.sample_rate) | |
| else: | |
| vocal_audio_file=wav_file | |
| # 2. extract wav2vec features | |
| speech_array, sampling_rate = librosa.load(vocal_audio_file, sr=self.sample_rate) | |
| audio_feature = np.squeeze(self.wav2vec_feature_extractor(speech_array, sampling_rate=sampling_rate).input_values) | |
| seq_len = math.ceil(len(audio_feature) / self.sample_rate * self.fps) | |
| audio_length = seq_len | |
| audio_feature = torch.from_numpy(audio_feature).float().to(device=self.device) | |
| if seq_len % clip_length != 0: | |
| audio_feature = torch.nn.functional.pad(audio_feature, (0, (clip_length - seq_len % clip_length) * (self.sample_rate // self.fps)), 'constant', 0.0) | |
| seq_len += clip_length - seq_len % clip_length | |
| audio_feature = audio_feature.unsqueeze(0) | |
| with torch.no_grad(): | |
| embeddings = self.audio_encoder(audio_feature, seq_len=seq_len, output_hidden_states=True) | |
| assert len(embeddings) > 0, "Fail to extract audio embedding" | |
| if self.only_last_features: | |
| audio_emb = embeddings.last_hidden_state.squeeze() | |
| else: | |
| audio_emb = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0) | |
| audio_emb = rearrange(audio_emb, "b s d -> s b d") | |
| audio_emb = audio_emb.cpu().detach() | |
| return audio_emb, audio_length | |
| def get_embedding(self, wav_file: str): | |
| """preprocess wav audio file convert to embeddings | |
| Args: | |
| wav_file (str): The path to the WAV file to be processed. This file should be accessible and in WAV format. | |
| Returns: | |
| torch.tensor: Returns an audio embedding as a torch.tensor | |
| """ | |
| speech_array, sampling_rate = librosa.load( | |
| wav_file, sr=self.sample_rate) | |
| assert sampling_rate == 16000, "The audio sample rate must be 16000" | |
| audio_feature = np.squeeze(self.wav2vec_feature_extractor( | |
| speech_array, sampling_rate=sampling_rate).input_values) | |
| seq_len = math.ceil(len(audio_feature) / self.sample_rate * self.fps) | |
| audio_feature = torch.from_numpy( | |
| audio_feature).float().to(device=self.device) | |
| audio_feature = audio_feature.unsqueeze(0) | |
| with torch.no_grad(): | |
| embeddings = self.audio_encoder( | |
| audio_feature, seq_len=seq_len, output_hidden_states=True) | |
| assert len(embeddings) > 0, "Fail to extract audio embedding" | |
| if self.only_last_features: | |
| audio_emb = embeddings.last_hidden_state.squeeze() | |
| else: | |
| audio_emb = torch.stack( | |
| embeddings.hidden_states[1:], dim=1).squeeze(0) | |
| audio_emb = rearrange(audio_emb, "b s d -> s b d") | |
| audio_emb = audio_emb.cpu().detach() | |
| return audio_emb | |
| def close(self): | |
| """ | |
| TODO: to be implemented | |
| """ | |
| return self | |
| def __enter__(self): | |
| return self | |
| def __exit__(self, _exc_type, _exc_val, _exc_tb): | |
| self.close() | |