File size: 2,490 Bytes
d8b5a80
 
5d3bb47
df5578b
d8b5a80
 
90ba39e
 
d8b5a80
 
 
90ba39e
df5578b
d8b5a80
 
df5578b
 
 
d8b5a80
90ba39e
d8b5a80
857ba28
 
5d3bb47
90ba39e
df5578b
 
857ba28
df5578b
 
 
90ba39e
 
 
 
 
 
 
 
 
 
 
 
857ba28
df5578b
d8b5a80
857ba28
d8b5a80
df5578b
 
d8b5a80
 
90ba39e
df5578b
 
 
d8b5a80
 
 
 
 
 
90ba39e
 
 
d8b5a80
df5578b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gradio as gr
import os
import shutil
import logging
from huggingface_hub import hf_hub_download
from audio_separator.separator import Separator
# Import the module where models are actually defined
from audio_separator.separator.architectures import bs_roformer_separator

# --- Configuration ---
REPO_ID = "anvuew/dereverb_room"
MODEL_FILENAME = "dereverb_room_anvuew_sdr_13.7432.ckpt"
CONFIG_FILENAME = "dereverb_room_anvuew.yaml"
# ---------------------

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def inference(audio_path):
    if not audio_path: return None

    local_models_dir = os.path.abspath("models")
    os.makedirs(local_models_dir, exist_ok=True)

    logger.info(f"Downloading model files...")
    model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME, local_dir=local_models_dir)
    config_path = hf_hub_download(repo_id=REPO_ID, filename=CONFIG_FILENAME, local_dir=local_models_dir)

    expected_config_path = os.path.splitext(model_path)[0] + ".yaml"
    if config_path != expected_config_path:
        shutil.copyfile(config_path, expected_config_path)

    logger.info("Registering custom model...")
    # --- FIX: Register directly with the architecture module ---
    bs_roformer_separator.BS_ROFORMER_MODELS[MODEL_FILENAME] = {
        "model_type": "bs_roformer",
        "config_filename": os.path.basename(expected_config_path),
        "model_filename": MODEL_FILENAME,
        "model_friendly_name": "Custom Dereverb",
        "domain": "dereverb",
        "source": "local"
    }
    # -----------------------------------------------------------

    logger.info("Initializing separator...")
    separator = Separator(
        model_file_dir=local_models_dir,
        output_dir=".",
        output_format="FLAC",
        log_level=logging.INFO
    )

    logger.info(f"Loading model: {MODEL_FILENAME}...")
    separator.load_model(model_filename=MODEL_FILENAME)

    logger.info("Starting separation...")
    output_files = separator.separate(audio_path)
    return output_files[0]

with gr.Blocks(title="Dereverb Room Web UI") as demo:
    gr.Markdown("# Dereverb Room Inference")
    with gr.Row():
        input_audio = gr.Audio(label="Input", type="filepath")
        output_audio = gr.Audio(label="Output (Dereverbed)", type="filepath", interactive=False)
    gr.Button("Remove Reverb").click(fn=inference, inputs=input_audio, outputs=output_audio)

if __name__ == "__main__":
    demo.launch()