Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| # Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) | |
| import logging | |
| import os | |
| from functools import lru_cache | |
| import numpy as np | |
| import sherpa_onnx | |
| import soundfile as sf | |
| from huggingface_hub import hf_hub_download | |
| import uuid | |
| def convert_to_wav(in_filename: str) -> str: | |
| """Convert the input audio file to a wave file""" | |
| out_filename = str(uuid.uuid4()) | |
| out_filename = f"{in_filename}.wav" | |
| logging.info(f"Converting '{in_filename}' to '{out_filename}'") | |
| _ = os.system( | |
| f"ffmpeg -hide_banner -loglevel error -i '{in_filename}' -ar 44100 -ac 2 '{out_filename}' -y" | |
| ) | |
| return out_filename | |
| def load_audio(filename): | |
| filename = convert_to_wav(filename) | |
| samples, sample_rate = sf.read(filename, dtype="float32", always_2d=True) | |
| samples = np.transpose(samples) | |
| # now samples is of shape (num_channels, num_samples) | |
| assert ( | |
| samples.shape[1] > samples.shape[0] | |
| ), f"You should use (num_channels, num_samples). {samples.shape}" | |
| assert ( | |
| samples.dtype == np.float32 | |
| ), f"Expect np.float32 as dtype. Given: {samples.dtype}" | |
| return samples, sample_rate | |
| def get_file( | |
| repo_id: str, | |
| filename: str, | |
| subfolder: str = ".", | |
| ) -> str: | |
| nn_model_filename = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| subfolder=subfolder, | |
| ) | |
| return nn_model_filename | |
| def load_model(name: str): | |
| name = name.split("|")[0] | |
| if "spleeter" in name: | |
| return load_spleeter_model(name) | |
| elif "UVR" in name: | |
| return load_uvr_model(name) | |
| raise ValueError(f"Unsupported model name {name}") | |
| def load_uvr_model(name: str): | |
| model = get_file( | |
| repo_id="k2-fsa/sherpa-onnx-models", | |
| subfolder="source-separation-models", | |
| filename=name, | |
| ) | |
| config = sherpa_onnx.OfflineSourceSeparationConfig( | |
| model=sherpa_onnx.OfflineSourceSeparationModelConfig( | |
| uvr=sherpa_onnx.OfflineSourceSeparationUvrModelConfig( | |
| model=model, | |
| ), | |
| num_threads=2, | |
| debug=False, | |
| provider="cpu", | |
| ) | |
| ) | |
| if not config.validate(): | |
| raise ValueError("Please check your config.") | |
| return sherpa_onnx.OfflineSourceSeparation(config) | |
| def load_spleeter_model(name: str): | |
| if "fp16" in name: | |
| suffix = "fp16.onnx" | |
| elif "int8" in name: | |
| suffix = "int8.onnx" | |
| else: | |
| suffix = "onnx" | |
| vocals = get_file(repo_id=f"csukuangfj/{name}", filename=f"vocals.{suffix}") | |
| accompaniment = get_file( | |
| repo_id=f"csukuangfj/{name}", filename=f"accompaniment.{suffix}" | |
| ) | |
| config = sherpa_onnx.OfflineSourceSeparationConfig( | |
| model=sherpa_onnx.OfflineSourceSeparationModelConfig( | |
| spleeter=sherpa_onnx.OfflineSourceSeparationSpleeterModelConfig( | |
| vocals=vocals, | |
| accompaniment=accompaniment, | |
| ), | |
| num_threads=2, | |
| debug=False, | |
| provider="cpu", | |
| ) | |
| ) | |
| if not config.validate(): | |
| raise ValueError("Please check your config.") | |
| return sherpa_onnx.OfflineSourceSeparation(config) | |
| model_list = [ | |
| "sherpa-onnx-spleeter-2stems|fastest", | |
| "sherpa-onnx-spleeter-2stems-fp16|fastest", | |
| "sherpa-onnx-spleeter-2stems-int8|fastest", | |
| "UVR_MDXNET_1_9703.onnx|slow", | |
| "UVR_MDXNET_2_9682.onnx|slow", | |
| "UVR_MDXNET_3_9662.onnx|slow", | |
| "UVR_MDXNET_9482.onnx|slow", | |
| "UVR_MDXNET_KARA.onnx|slow", | |
| "UVR_MDXNET_KARA_2.onnx|slowest", | |
| "UVR_MDXNET_Main.onnx|slowest", | |
| "UVR-MDX-NET-Inst_1.onnx|slowest", | |
| "UVR-MDX-NET-Inst_2.onnx|slowest", | |
| "UVR-MDX-NET-Inst_3.onnx|slowest", | |
| "UVR-MDX-NET-Inst_HQ_1.onnx|slowest", | |
| "UVR-MDX-NET-Inst_HQ_2.onnx|slowest", | |
| "UVR-MDX-NET-Inst_HQ_3.onnx|slowest", | |
| "UVR-MDX-NET-Inst_HQ_4.onnx|slowest", | |
| "UVR-MDX-NET-Inst_HQ_5.onnx|slowest", | |
| "UVR-MDX-NET-Inst_Main.onnx|slowest", | |
| "UVR-MDX-NET-Voc_FT.onnx|slowest", | |
| "UVR-MDX-NET_Crowd_HQ_1.onnx|slowest", | |
| ] | |