File size: 2,490 Bytes
d8b5a80 5d3bb47 df5578b d8b5a80 90ba39e d8b5a80 90ba39e df5578b d8b5a80 df5578b d8b5a80 90ba39e d8b5a80 857ba28 5d3bb47 90ba39e df5578b 857ba28 df5578b 90ba39e 857ba28 df5578b d8b5a80 857ba28 d8b5a80 df5578b d8b5a80 90ba39e df5578b d8b5a80 90ba39e d8b5a80 df5578b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import gradio as gr
import os
import shutil
import logging
from huggingface_hub import hf_hub_download
from audio_separator.separator import Separator
# Import the module where models are actually defined
from audio_separator.separator.architectures import bs_roformer_separator
# --- Configuration ---
REPO_ID = "anvuew/dereverb_room"
MODEL_FILENAME = "dereverb_room_anvuew_sdr_13.7432.ckpt"
CONFIG_FILENAME = "dereverb_room_anvuew.yaml"
# ---------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def inference(audio_path):
if not audio_path: return None
local_models_dir = os.path.abspath("models")
os.makedirs(local_models_dir, exist_ok=True)
logger.info(f"Downloading model files...")
model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME, local_dir=local_models_dir)
config_path = hf_hub_download(repo_id=REPO_ID, filename=CONFIG_FILENAME, local_dir=local_models_dir)
expected_config_path = os.path.splitext(model_path)[0] + ".yaml"
if config_path != expected_config_path:
shutil.copyfile(config_path, expected_config_path)
logger.info("Registering custom model...")
# --- FIX: Register directly with the architecture module ---
bs_roformer_separator.BS_ROFORMER_MODELS[MODEL_FILENAME] = {
"model_type": "bs_roformer",
"config_filename": os.path.basename(expected_config_path),
"model_filename": MODEL_FILENAME,
"model_friendly_name": "Custom Dereverb",
"domain": "dereverb",
"source": "local"
}
# -----------------------------------------------------------
logger.info("Initializing separator...")
separator = Separator(
model_file_dir=local_models_dir,
output_dir=".",
output_format="FLAC",
log_level=logging.INFO
)
logger.info(f"Loading model: {MODEL_FILENAME}...")
separator.load_model(model_filename=MODEL_FILENAME)
logger.info("Starting separation...")
output_files = separator.separate(audio_path)
return output_files[0]
with gr.Blocks(title="Dereverb Room Web UI") as demo:
gr.Markdown("# Dereverb Room Inference")
with gr.Row():
input_audio = gr.Audio(label="Input", type="filepath")
output_audio = gr.Audio(label="Output (Dereverbed)", type="filepath", interactive=False)
gr.Button("Remove Reverb").click(fn=inference, inputs=input_audio, outputs=output_audio)
if __name__ == "__main__":
demo.launch() |