Sentimon commited on
Commit
df5578b
·
verified ·
1 Parent(s): 857ba28

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -37
app.py CHANGED
@@ -1,73 +1,88 @@
1
  import gradio as gr
2
  import os
3
  import shutil
 
4
  from huggingface_hub import hf_hub_download
5
  from audio_separator.separator import Separator
6
 
7
  # --- Configuration ---
8
  REPO_ID = "anvuew/dereverb_room"
9
  MODEL_FILENAME = "dereverb_room_anvuew_sdr_13.7432.ckpt"
10
- CONFIG_FILENAME = "dereverb_room_anvuew.yaml"
11
  # ---------------------
12
 
 
 
 
 
13
  def inference(audio_path):
14
  if not audio_path:
15
  return None
16
 
17
- # 1. Setup local models directory
18
  local_models_dir = os.path.abspath("models")
19
  os.makedirs(local_models_dir, exist_ok=True)
20
 
21
- print(f"Downloading model files to {local_models_dir}...")
22
- hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME, local_dir=local_models_dir)
23
- hf_hub_download(repo_id=REPO_ID, filename=CONFIG_FILENAME, local_dir=local_models_dir)
24
-
25
- # 2. Ensure config has the exact same base name as the model
26
- model_full_path = os.path.join(local_models_dir, MODEL_FILENAME)
27
- original_config_path = os.path.join(local_models_dir, CONFIG_FILENAME)
28
- new_config_path = os.path.splitext(model_full_path)[0] + ".yaml"
29
 
30
- if original_config_path != new_config_path:
31
- print(f"Renaming config to {new_config_path} for auto-discovery...")
32
- shutil.copyfile(original_config_path, new_config_path)
 
 
 
 
 
 
 
33
 
34
- print("Initializing separator with custom model directory...")
35
- # IMPORTANT: Set model_file_dir to our local directory
36
  separator = Separator(
37
  model_file_dir=local_models_dir,
38
  output_dir=".",
39
- output_format="FLAC"
 
40
  )
41
 
42
- print(f"Loading model: {MODEL_FILENAME}...")
43
- # Now we just pass the FILENAME, not the full path.
44
- # Because we set model_file_dir, it should find it there.
45
- # It might still complain about "not found in supported", if so we have one more trick.
46
- try:
47
- separator.load_model(model_filename=MODEL_FILENAME)
48
- except ValueError as e:
49
- if "not found in supported model files" in str(e):
50
- print("Standard load failed, attempting raw load (might fail if unsupported by this version)...")
51
- # Fallback: Some versions might allow direct path loading if specifically formatted,
52
- # or we might have to stick to standard models.
53
- # Let's try passing the full absolute path as a last resort if above fails,
54
- # but usually setting model_file_dir and passing just filename works for custom models
55
- # IF they don't strictly enforce the catalog.
56
- raise e
57
 
58
- print("Starting inference...")
 
 
 
 
59
  output_files = separator.separate(audio_path)
60
 
61
- print(f"Separation complete. Outputs: {output_files}")
62
  return output_files[0]
63
 
 
64
  with gr.Blocks(title="Dereverb Room Web UI") as demo:
65
  gr.Markdown("# Dereverb Room Inference")
 
 
66
  with gr.Row():
67
- input_audio = gr.Audio(label="Upload Audio", type="filepath")
68
- output_audio = gr.Audio(label="Dereverbed Audio", type="filepath", interactive=False)
 
69
 
70
- process_btn = gr.Button("Remove Reverb")
71
  process_btn.click(fn=inference, inputs=input_audio, outputs=output_audio)
72
 
73
- demo.launch()
 
 
1
  import gradio as gr
2
  import os
3
  import shutil
4
+ import logging
5
  from huggingface_hub import hf_hub_download
6
  from audio_separator.separator import Separator
7
 
8
  # --- Configuration ---
9
  REPO_ID = "anvuew/dereverb_room"
10
  MODEL_FILENAME = "dereverb_room_anvuew_sdr_13.7432.ckpt"
11
+ CONFIG_FILENAME = "dereverb_room_anvuew.yaml"
12
  # ---------------------
13
 
14
+ # Setup basic logging
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
  def inference(audio_path):
19
  if not audio_path:
20
  return None
21
 
22
+ # 1. Setup strictly defined local models directory
23
  local_models_dir = os.path.abspath("models")
24
  os.makedirs(local_models_dir, exist_ok=True)
25
 
26
+ logger.info(f"Downloading model files to {local_models_dir}...")
27
+ model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME, local_dir=local_models_dir)
28
+ config_path = hf_hub_download(repo_id=REPO_ID, filename=CONFIG_FILENAME, local_dir=local_models_dir)
 
 
 
 
 
29
 
30
+ # 2. Ensure config has the exact same base name as the model for auto-discovery
31
+ # (Required by some underlying loaders even if we register it)
32
+ expected_config_path = os.path.splitext(model_path)[0] + ".yaml"
33
+ if config_path != expected_config_path:
34
+ logger.info(f"Renaming config to {expected_config_path}...")
35
+ shutil.copyfile(config_path, expected_config_path)
36
+ # Update config filename to match what we just created
37
+ actual_config_filename = os.path.basename(expected_config_path)
38
+ else:
39
+ actual_config_filename = CONFIG_FILENAME
40
 
41
+ logger.info("Initializing separator...")
42
+ # Initialize with our custom model directory
43
  separator = Separator(
44
  model_file_dir=local_models_dir,
45
  output_dir=".",
46
+ output_format="FLAC",
47
+ log_level=logging.INFO
48
  )
49
 
50
+ # --- CRITICAL FIX: MANUAL REGISTRATION ---
51
+ # We manually inject the model metadata into the separator's internal registry.
52
+ # This tricks the library into thinking it's an officially supported model.
53
+ logger.info("Registering custom model at runtime...")
54
+ separator.model_constants.BS_ROFORMER_MODELS[MODEL_FILENAME] = {
55
+ "model_type": "bs_roformer",
56
+ "config_filename": actual_config_filename,
57
+ "model_filename": MODEL_FILENAME,
58
+ "model_friendly_name": "Custom Dereverb Room",
59
+ "domain": "dereverb",
60
+ "source": "local" # strictly local, don't try to download
61
+ }
62
+ # -----------------------------------------
 
 
63
 
64
+ logger.info(f"Loading validated model: {MODEL_FILENAME}...")
65
+ # Now we can load it normally by filename, as it exists in the registry
66
+ separator.load_model(model_filename=MODEL_FILENAME)
67
+
68
+ logger.info("Starting separation...")
69
  output_files = separator.separate(audio_path)
70
 
71
+ logger.info(f"Separation complete: {output_files}")
72
  return output_files[0]
73
 
74
+ # --- Gradio UI ---
75
  with gr.Blocks(title="Dereverb Room Web UI") as demo:
76
  gr.Markdown("# Dereverb Room Inference")
77
+ gr.Markdown(f"**Model:** `{REPO_ID}` (BS-RoFormer)")
78
+
79
  with gr.Row():
80
+ input_audio = gr.Audio(label="Input Audio", type="filepath")
81
+ # Using interactive=False prevents user from editing output
82
+ output_audio = gr.Audio(label="Cleaned Audio (No Reverb)", type="filepath", interactive=False)
83
 
84
+ process_btn = gr.Button("Remove Reverb", variant="primary")
85
  process_btn.click(fn=inference, inputs=input_audio, outputs=output_audio)
86
 
87
+ if __name__ == "__main__":
88
+ demo.launch()