RoomTone / app.py
Sentimon's picture
Update app.py
90ba39e verified
import gradio as gr
import os
import shutil
import logging
from huggingface_hub import hf_hub_download
from audio_separator.separator import Separator
# Import the module where models are actually defined
from audio_separator.separator.architectures import bs_roformer_separator
# --- Configuration ---
REPO_ID = "anvuew/dereverb_room"
MODEL_FILENAME = "dereverb_room_anvuew_sdr_13.7432.ckpt"
CONFIG_FILENAME = "dereverb_room_anvuew.yaml"
# ---------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def inference(audio_path):
if not audio_path: return None
local_models_dir = os.path.abspath("models")
os.makedirs(local_models_dir, exist_ok=True)
logger.info(f"Downloading model files...")
model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME, local_dir=local_models_dir)
config_path = hf_hub_download(repo_id=REPO_ID, filename=CONFIG_FILENAME, local_dir=local_models_dir)
expected_config_path = os.path.splitext(model_path)[0] + ".yaml"
if config_path != expected_config_path:
shutil.copyfile(config_path, expected_config_path)
logger.info("Registering custom model...")
# --- FIX: Register directly with the architecture module ---
bs_roformer_separator.BS_ROFORMER_MODELS[MODEL_FILENAME] = {
"model_type": "bs_roformer",
"config_filename": os.path.basename(expected_config_path),
"model_filename": MODEL_FILENAME,
"model_friendly_name": "Custom Dereverb",
"domain": "dereverb",
"source": "local"
}
# -----------------------------------------------------------
logger.info("Initializing separator...")
separator = Separator(
model_file_dir=local_models_dir,
output_dir=".",
output_format="FLAC",
log_level=logging.INFO
)
logger.info(f"Loading model: {MODEL_FILENAME}...")
separator.load_model(model_filename=MODEL_FILENAME)
logger.info("Starting separation...")
output_files = separator.separate(audio_path)
return output_files[0]
with gr.Blocks(title="Dereverb Room Web UI") as demo:
gr.Markdown("# Dereverb Room Inference")
with gr.Row():
input_audio = gr.Audio(label="Input", type="filepath")
output_audio = gr.Audio(label="Output (Dereverbed)", type="filepath", interactive=False)
gr.Button("Remove Reverb").click(fn=inference, inputs=input_audio, outputs=output_audio)
if __name__ == "__main__":
demo.launch()