|
|
import gradio as gr |
|
|
import os |
|
|
import shutil |
|
|
import logging |
|
|
from huggingface_hub import hf_hub_download |
|
|
from audio_separator.separator import Separator |
|
|
|
|
|
from audio_separator.separator.architectures import bs_roformer_separator |
|
|
|
|
|
|
|
|
REPO_ID = "anvuew/dereverb_room" |
|
|
MODEL_FILENAME = "dereverb_room_anvuew_sdr_13.7432.ckpt" |
|
|
CONFIG_FILENAME = "dereverb_room_anvuew.yaml" |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def inference(audio_path): |
|
|
if not audio_path: return None |
|
|
|
|
|
local_models_dir = os.path.abspath("models") |
|
|
os.makedirs(local_models_dir, exist_ok=True) |
|
|
|
|
|
logger.info(f"Downloading model files...") |
|
|
model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME, local_dir=local_models_dir) |
|
|
config_path = hf_hub_download(repo_id=REPO_ID, filename=CONFIG_FILENAME, local_dir=local_models_dir) |
|
|
|
|
|
expected_config_path = os.path.splitext(model_path)[0] + ".yaml" |
|
|
if config_path != expected_config_path: |
|
|
shutil.copyfile(config_path, expected_config_path) |
|
|
|
|
|
logger.info("Registering custom model...") |
|
|
|
|
|
bs_roformer_separator.BS_ROFORMER_MODELS[MODEL_FILENAME] = { |
|
|
"model_type": "bs_roformer", |
|
|
"config_filename": os.path.basename(expected_config_path), |
|
|
"model_filename": MODEL_FILENAME, |
|
|
"model_friendly_name": "Custom Dereverb", |
|
|
"domain": "dereverb", |
|
|
"source": "local" |
|
|
} |
|
|
|
|
|
|
|
|
logger.info("Initializing separator...") |
|
|
separator = Separator( |
|
|
model_file_dir=local_models_dir, |
|
|
output_dir=".", |
|
|
output_format="FLAC", |
|
|
log_level=logging.INFO |
|
|
) |
|
|
|
|
|
logger.info(f"Loading model: {MODEL_FILENAME}...") |
|
|
separator.load_model(model_filename=MODEL_FILENAME) |
|
|
|
|
|
logger.info("Starting separation...") |
|
|
output_files = separator.separate(audio_path) |
|
|
return output_files[0] |
|
|
|
|
|
with gr.Blocks(title="Dereverb Room Web UI") as demo: |
|
|
gr.Markdown("# Dereverb Room Inference") |
|
|
with gr.Row(): |
|
|
input_audio = gr.Audio(label="Input", type="filepath") |
|
|
output_audio = gr.Audio(label="Output (Dereverbed)", type="filepath", interactive=False) |
|
|
gr.Button("Remove Reverb").click(fn=inference, inputs=input_audio, outputs=output_audio) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |