Spaces:
Paused
Paused
| #!/usr/bin/python3 | |
| # -*- coding: utf-8 -*- | |
| from joblib import Parallel, delayed | |
| import numpy as np | |
| from pesq import pesq | |
| from typing import List | |
| from pesq import cypesq | |
| def run_pesq(clean_audio: np.ndarray, | |
| noisy_audio: np.ndarray, | |
| sample_rate: int = 16000, | |
| mode: str = "wb", | |
| ) -> float: | |
| if sample_rate == 8000 and mode == "wb": | |
| raise AssertionError(f"mode should be `nb` when sample_rate is 8000") | |
| try: | |
| pesq_score = pesq(sample_rate, clean_audio, noisy_audio, mode) | |
| except cypesq.NoUtterancesError as e: | |
| pesq_score = -1 | |
| except Exception as e: | |
| print(f"pesq failed. error type: {type(e)}, error text: {str(e)}") | |
| pesq_score = -1 | |
| return pesq_score | |
| def run_batch_pesq(clean_audio_list: List[np.ndarray], | |
| noisy_audio_list: List[np.ndarray], | |
| sample_rate: int = 16000, | |
| mode: str = "wb", | |
| n_jobs: int = 4, | |
| ) -> List[float]: | |
| parallel = Parallel(n_jobs=n_jobs) | |
| parallel_tasks = list() | |
| for clean_audio, noisy_audio in zip(clean_audio_list, noisy_audio_list): | |
| parallel_task = delayed(run_pesq)(clean_audio, noisy_audio, sample_rate, mode) | |
| parallel_tasks.append(parallel_task) | |
| pesq_score_list = parallel.__call__(parallel_tasks) | |
| return pesq_score_list | |
| def run_pesq_score(clean_audio_list: List[np.ndarray], | |
| noisy_audio_list: List[np.ndarray], | |
| sample_rate: int = 16000, | |
| mode: str = "wb", | |
| n_jobs: int = 4, | |
| ) -> List[float]: | |
| pesq_score_list = run_batch_pesq(clean_audio_list=clean_audio_list, | |
| noisy_audio_list=noisy_audio_list, | |
| sample_rate=sample_rate, | |
| mode=mode, | |
| n_jobs=n_jobs, | |
| ) | |
| pesq_score = np.mean(pesq_score_list) | |
| return pesq_score | |
| def main(): | |
| clean_audio = np.random.uniform(low=0, high=1, size=(2, 160000,)) | |
| noisy_audio = np.random.uniform(low=0, high=1, size=(2, 160000,)) | |
| clean_audio_list = list(clean_audio) | |
| noisy_audio_list = list(noisy_audio) | |
| pesq_score_list = run_batch_pesq(clean_audio_list, noisy_audio_list) | |
| print(pesq_score_list) | |
| pesq_score = run_pesq_score(clean_audio_list, noisy_audio_list) | |
| print(pesq_score) | |
| return | |
| if __name__ == "__main__": | |
| main() | |