Sentimon commited on
Commit
90ba39e
·
verified ·
1 Parent(s): df5578b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -41
app.py CHANGED
@@ -4,42 +4,45 @@ import shutil
4
  import logging
5
  from huggingface_hub import hf_hub_download
6
  from audio_separator.separator import Separator
 
 
7
 
8
  # --- Configuration ---
9
  REPO_ID = "anvuew/dereverb_room"
10
- MODEL_FILENAME = "dereverb_room_anvuew_sdr_13.7432.ckpt"
11
  CONFIG_FILENAME = "dereverb_room_anvuew.yaml"
12
  # ---------------------
13
 
14
- # Setup basic logging
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
18
  def inference(audio_path):
19
- if not audio_path:
20
- return None
21
 
22
- # 1. Setup strictly defined local models directory
23
  local_models_dir = os.path.abspath("models")
24
  os.makedirs(local_models_dir, exist_ok=True)
25
 
26
- logger.info(f"Downloading model files to {local_models_dir}...")
27
  model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME, local_dir=local_models_dir)
28
  config_path = hf_hub_download(repo_id=REPO_ID, filename=CONFIG_FILENAME, local_dir=local_models_dir)
29
 
30
- # 2. Ensure config has the exact same base name as the model for auto-discovery
31
- # (Required by some underlying loaders even if we register it)
32
  expected_config_path = os.path.splitext(model_path)[0] + ".yaml"
33
  if config_path != expected_config_path:
34
- logger.info(f"Renaming config to {expected_config_path}...")
35
  shutil.copyfile(config_path, expected_config_path)
36
- # Update config filename to match what we just created
37
- actual_config_filename = os.path.basename(expected_config_path)
38
- else:
39
- actual_config_filename = CONFIG_FILENAME
 
 
 
 
 
 
 
 
40
 
41
  logger.info("Initializing separator...")
42
- # Initialize with our custom model directory
43
  separator = Separator(
44
  model_file_dir=local_models_dir,
45
  output_dir=".",
@@ -47,42 +50,19 @@ def inference(audio_path):
47
  log_level=logging.INFO
48
  )
49
 
50
- # --- CRITICAL FIX: MANUAL REGISTRATION ---
51
- # We manually inject the model metadata into the separator's internal registry.
52
- # This tricks the library into thinking it's an officially supported model.
53
- logger.info("Registering custom model at runtime...")
54
- separator.model_constants.BS_ROFORMER_MODELS[MODEL_FILENAME] = {
55
- "model_type": "bs_roformer",
56
- "config_filename": actual_config_filename,
57
- "model_filename": MODEL_FILENAME,
58
- "model_friendly_name": "Custom Dereverb Room",
59
- "domain": "dereverb",
60
- "source": "local" # strictly local, don't try to download
61
- }
62
- # -----------------------------------------
63
-
64
- logger.info(f"Loading validated model: {MODEL_FILENAME}...")
65
- # Now we can load it normally by filename, as it exists in the registry
66
  separator.load_model(model_filename=MODEL_FILENAME)
67
 
68
  logger.info("Starting separation...")
69
  output_files = separator.separate(audio_path)
70
-
71
- logger.info(f"Separation complete: {output_files}")
72
  return output_files[0]
73
 
74
- # --- Gradio UI ---
75
  with gr.Blocks(title="Dereverb Room Web UI") as demo:
76
  gr.Markdown("# Dereverb Room Inference")
77
- gr.Markdown(f"**Model:** `{REPO_ID}` (BS-RoFormer)")
78
-
79
  with gr.Row():
80
- input_audio = gr.Audio(label="Input Audio", type="filepath")
81
- # Using interactive=False prevents user from editing output
82
- output_audio = gr.Audio(label="Cleaned Audio (No Reverb)", type="filepath", interactive=False)
83
-
84
- process_btn = gr.Button("Remove Reverb", variant="primary")
85
- process_btn.click(fn=inference, inputs=input_audio, outputs=output_audio)
86
 
87
  if __name__ == "__main__":
88
  demo.launch()
 
4
  import logging
5
  from huggingface_hub import hf_hub_download
6
  from audio_separator.separator import Separator
7
+ # Import the module where models are actually defined
8
+ from audio_separator.separator.architectures import bs_roformer_separator
9
 
10
  # --- Configuration ---
11
  REPO_ID = "anvuew/dereverb_room"
12
+ MODEL_FILENAME = "dereverb_room_anvuew_sdr_13.7432.ckpt"
13
  CONFIG_FILENAME = "dereverb_room_anvuew.yaml"
14
  # ---------------------
15
 
 
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
19
  def inference(audio_path):
20
+ if not audio_path: return None
 
21
 
 
22
  local_models_dir = os.path.abspath("models")
23
  os.makedirs(local_models_dir, exist_ok=True)
24
 
25
+ logger.info(f"Downloading model files...")
26
  model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME, local_dir=local_models_dir)
27
  config_path = hf_hub_download(repo_id=REPO_ID, filename=CONFIG_FILENAME, local_dir=local_models_dir)
28
 
 
 
29
  expected_config_path = os.path.splitext(model_path)[0] + ".yaml"
30
  if config_path != expected_config_path:
 
31
  shutil.copyfile(config_path, expected_config_path)
32
+
33
+ logger.info("Registering custom model...")
34
+ # --- FIX: Register directly with the architecture module ---
35
+ bs_roformer_separator.BS_ROFORMER_MODELS[MODEL_FILENAME] = {
36
+ "model_type": "bs_roformer",
37
+ "config_filename": os.path.basename(expected_config_path),
38
+ "model_filename": MODEL_FILENAME,
39
+ "model_friendly_name": "Custom Dereverb",
40
+ "domain": "dereverb",
41
+ "source": "local"
42
+ }
43
+ # -----------------------------------------------------------
44
 
45
  logger.info("Initializing separator...")
 
46
  separator = Separator(
47
  model_file_dir=local_models_dir,
48
  output_dir=".",
 
50
  log_level=logging.INFO
51
  )
52
 
53
+ logger.info(f"Loading model: {MODEL_FILENAME}...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  separator.load_model(model_filename=MODEL_FILENAME)
55
 
56
  logger.info("Starting separation...")
57
  output_files = separator.separate(audio_path)
 
 
58
  return output_files[0]
59
 
 
60
  with gr.Blocks(title="Dereverb Room Web UI") as demo:
61
  gr.Markdown("# Dereverb Room Inference")
 
 
62
  with gr.Row():
63
+ input_audio = gr.Audio(label="Input", type="filepath")
64
+ output_audio = gr.Audio(label="Output (Dereverbed)", type="filepath", interactive=False)
65
+ gr.Button("Remove Reverb").click(fn=inference, inputs=input_audio, outputs=output_audio)
 
 
 
66
 
67
  if __name__ == "__main__":
68
  demo.launch()