import gradio as gr import os import shutil import logging from huggingface_hub import hf_hub_download from audio_separator.separator import Separator # Import the module where models are actually defined from audio_separator.separator.architectures import bs_roformer_separator # --- Configuration --- REPO_ID = "anvuew/dereverb_room" MODEL_FILENAME = "dereverb_room_anvuew_sdr_13.7432.ckpt" CONFIG_FILENAME = "dereverb_room_anvuew.yaml" # --------------------- logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def inference(audio_path): if not audio_path: return None local_models_dir = os.path.abspath("models") os.makedirs(local_models_dir, exist_ok=True) logger.info(f"Downloading model files...") model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME, local_dir=local_models_dir) config_path = hf_hub_download(repo_id=REPO_ID, filename=CONFIG_FILENAME, local_dir=local_models_dir) expected_config_path = os.path.splitext(model_path)[0] + ".yaml" if config_path != expected_config_path: shutil.copyfile(config_path, expected_config_path) logger.info("Registering custom model...") # --- FIX: Register directly with the architecture module --- bs_roformer_separator.BS_ROFORMER_MODELS[MODEL_FILENAME] = { "model_type": "bs_roformer", "config_filename": os.path.basename(expected_config_path), "model_filename": MODEL_FILENAME, "model_friendly_name": "Custom Dereverb", "domain": "dereverb", "source": "local" } # ----------------------------------------------------------- logger.info("Initializing separator...") separator = Separator( model_file_dir=local_models_dir, output_dir=".", output_format="FLAC", log_level=logging.INFO ) logger.info(f"Loading model: {MODEL_FILENAME}...") separator.load_model(model_filename=MODEL_FILENAME) logger.info("Starting separation...") output_files = separator.separate(audio_path) return output_files[0] with gr.Blocks(title="Dereverb Room Web UI") as demo: gr.Markdown("# Dereverb Room Inference") with gr.Row(): input_audio = gr.Audio(label="Input", type="filepath") output_audio = gr.Audio(label="Output (Dereverbed)", type="filepath", interactive=False) gr.Button("Remove Reverb").click(fn=inference, inputs=input_audio, outputs=output_audio) if __name__ == "__main__": demo.launch()