cwitkowitz commited on
Commit
c34090f
·
1 Parent(s): 1dad963

Removed choice and fixed to a single model.

Browse files
Files changed (1) hide show
  1. app.py +16 -13
app.py CHANGED
@@ -8,7 +8,7 @@ from typing import Dict
8
  from pyharp import *
9
 
10
 
11
- DEMUX_MODELS = ["mdx_extra_q", "mdx_extra", "htdemucs", "mdx_q"]
12
 
13
  STEM_CHOICES = {
14
  "Vocals": 3,
@@ -19,10 +19,13 @@ STEM_CHOICES = {
19
  }
20
 
21
 
22
- models = dict(zip(DEMUX_MODELS, [pretrained.get_model(m) for m in DEMUX_MODELS]))
 
 
 
 
 
23
 
24
- for model in models.values():
25
- model.eval()
26
 
27
  def separate_stem(audio_file_path: str, model_name: str, stem_choice: str) -> AudioSignal:
28
  waveform, sr = torchaudio.load(audio_file_path)
@@ -32,7 +35,7 @@ def separate_stem(audio_file_path: str, model_name: str, stem_choice: str) -> Au
32
 
33
  with torch.no_grad():
34
  stems_batch = apply_model(
35
- models[model_name],
36
  waveform.unsqueeze(0),
37
  overlap=0.2,
38
  shifts=1,
@@ -56,13 +59,13 @@ def separate_stem(audio_file_path: str, model_name: str, stem_choice: str) -> Au
56
 
57
  # Gradio Callback Function
58
 
59
- def process_fn_stem(audio_file_path: str, demucs_model: str, stem_choice: str):
60
  """
61
  PyHARP process function:
62
  - Separates the chosen stem using Demucs.
63
  - Saves the stem as a .wav file.
64
  """
65
- stem_signal = separate_stem(audio_file_path, model_name=demucs_model, stem_choice=stem_choice)
66
  stem_path = save_audio(stem_signal, f"{stem_choice.lower().replace(' ', '_')}.wav")
67
  return stem_path, LabelList(labels=[])
68
 
@@ -78,11 +81,11 @@ model_card = ModelCard(
78
  # Gradio UI
79
  with gr.Blocks() as demo:
80
 
81
- dropdown_model = gr.Dropdown(
82
- label="Demucs Model",
83
- choices=DEMUX_MODELS,
84
- value="mdx_extra_q"
85
- )
86
 
87
  dropdown_stem = gr.Dropdown(
88
  label="Stem to Separate",
@@ -92,7 +95,7 @@ with gr.Blocks() as demo:
92
 
93
  app = build_endpoint(
94
  model_card=model_card,
95
- components=[dropdown_model, dropdown_stem],
96
  process_fn=process_fn_stem
97
  )
98
 
 
8
  from pyharp import *
9
 
10
 
11
+ #DEMUX_MODELS = ["mdx_extra_q", "mdx_extra", "htdemucs", "mdx_q"]
12
 
13
  STEM_CHOICES = {
14
  "Vocals": 3,
 
19
  }
20
 
21
 
22
+ #models = dict(zip(DEMUX_MODELS, [pretrained.get_model(m) for m in DEMUX_MODELS]))
23
+
24
+ #for model in models.values():
25
+ #model.eval()
26
+
27
+ model = pretrained.get_model('mdx_extra_q')
28
 
 
 
29
 
30
  def separate_stem(audio_file_path: str, model_name: str, stem_choice: str) -> AudioSignal:
31
  waveform, sr = torchaudio.load(audio_file_path)
 
35
 
36
  with torch.no_grad():
37
  stems_batch = apply_model(
38
+ model,
39
  waveform.unsqueeze(0),
40
  overlap=0.2,
41
  shifts=1,
 
59
 
60
  # Gradio Callback Function
61
 
62
+ def process_fn_stem(audio_file_path: str, stem_choice: str):
63
  """
64
  PyHARP process function:
65
  - Separates the chosen stem using Demucs.
66
  - Saves the stem as a .wav file.
67
  """
68
+ stem_signal = separate_stem(audio_file_path, model_name='', stem_choice=stem_choice)
69
  stem_path = save_audio(stem_signal, f"{stem_choice.lower().replace(' ', '_')}.wav")
70
  return stem_path, LabelList(labels=[])
71
 
 
81
  # Gradio UI
82
  with gr.Blocks() as demo:
83
 
84
+ #dropdown_model = gr.Dropdown(
85
+ # label="Demucs Model",
86
+ # choices=DEMUX_MODELS,
87
+ # value="mdx_extra_q"
88
+ #)
89
 
90
  dropdown_stem = gr.Dropdown(
91
  label="Stem to Separate",
 
95
 
96
  app = build_endpoint(
97
  model_card=model_card,
98
+ components=[dropdown_stem],
99
  process_fn=process_fn_stem
100
  )
101