Sentimon commited on
Commit
857ba28
·
verified ·
1 Parent(s): 5d3bb47

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -25
app.py CHANGED
@@ -14,38 +14,46 @@ def inference(audio_path):
14
  if not audio_path:
15
  return None
16
 
17
- print("Downloading model files...")
18
- # Download to a specific local directory to ensure they are together
19
- os.makedirs("models", exist_ok=True)
20
- model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME, local_dir="models")
21
- config_path = hf_hub_download(repo_id=REPO_ID, filename=CONFIG_FILENAME, local_dir="models")
22
 
23
- # Rename config to match model filename if they differ significantly,
24
- # though standard loading might just work if they are in the same dir.
25
- # Let's try standard loading first by just pointing to the checkpoint.
26
 
27
- print("Initializing separator...")
 
 
 
 
 
 
 
 
 
 
28
  separator = Separator(
 
29
  output_dir=".",
30
  output_format="FLAC"
31
  )
32
 
33
- print(f"Loading model from {model_path}...")
34
- # Simply pass the model path. The library usually looks for a .yaml
35
- # with a matching name or the specific config it needs in the same dir.
36
- # If it strictly needs the config passed, we might need to use a different method
37
- # but standard usage often just takes the model file if packaged correctly.
38
- # Let's try to pass JUST the model path first as per common usage.
39
- #
40
- # NOTE: If this still fails because it can't find the config, we will need
41
- # to rename 'dereverb_room_anvuew.yaml' to 'dereverb_room_anvuew_sdr_13.7432.yaml'
42
-
43
- expected_config_path = os.path.splitext(model_path)[0] + ".yaml"
44
- if not os.path.exists(expected_config_path):
45
- print(f"Renaming config to {expected_config_path} for auto-discovery...")
46
- shutil.copy(config_path, expected_config_path)
47
-
48
- separator.load_model(model_filename=model_path)
49
 
50
  print("Starting inference...")
51
  output_files = separator.separate(audio_path)
 
14
  if not audio_path:
15
  return None
16
 
17
+ # 1. Setup local models directory
18
+ local_models_dir = os.path.abspath("models")
19
+ os.makedirs(local_models_dir, exist_ok=True)
 
 
20
 
21
+ print(f"Downloading model files to {local_models_dir}...")
22
+ hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME, local_dir=local_models_dir)
23
+ hf_hub_download(repo_id=REPO_ID, filename=CONFIG_FILENAME, local_dir=local_models_dir)
24
 
25
+ # 2. Ensure config has the exact same base name as the model
26
+ model_full_path = os.path.join(local_models_dir, MODEL_FILENAME)
27
+ original_config_path = os.path.join(local_models_dir, CONFIG_FILENAME)
28
+ new_config_path = os.path.splitext(model_full_path)[0] + ".yaml"
29
+
30
+ if original_config_path != new_config_path:
31
+ print(f"Renaming config to {new_config_path} for auto-discovery...")
32
+ shutil.copyfile(original_config_path, new_config_path)
33
+
34
+ print("Initializing separator with custom model directory...")
35
+ # IMPORTANT: Set model_file_dir to our local directory
36
  separator = Separator(
37
+ model_file_dir=local_models_dir,
38
  output_dir=".",
39
  output_format="FLAC"
40
  )
41
 
42
+ print(f"Loading model: {MODEL_FILENAME}...")
43
+ # Now we just pass the FILENAME, not the full path.
44
+ # Because we set model_file_dir, it should find it there.
45
+ # It might still complain about "not found in supported", if so we have one more trick.
46
+ try:
47
+ separator.load_model(model_filename=MODEL_FILENAME)
48
+ except ValueError as e:
49
+ if "not found in supported model files" in str(e):
50
+ print("Standard load failed, attempting raw load (might fail if unsupported by this version)...")
51
+ # Fallback: Some versions might allow direct path loading if specifically formatted,
52
+ # or we might have to stick to standard models.
53
+ # Let's try passing the full absolute path as a last resort if above fails,
54
+ # but usually setting model_file_dir and passing just filename works for custom models
55
+ # IF they don't strictly enforce the catalog.
56
+ raise e
 
57
 
58
  print("Starting inference...")
59
  output_files = separator.separate(audio_path)