cwitkowitz commited on
Commit
1dad963
·
1 Parent(s): 0910fa9

Abstracted model loading, added progress bar and increased threads.

Browse files
Files changed (1) hide show
  1. app.py +10 -7
app.py CHANGED
@@ -19,11 +19,12 @@ STEM_CHOICES = {
19
  }
20
 
21
 
22
- def separate_stem(audio_file_path: str, model_name: str, stem_choice: str) -> AudioSignal:
23
- model = pretrained.get_model(model_name)
24
- model.to('cuda' if torch.cuda.is_available() else 'cpu')
25
  model.eval()
26
 
 
27
  waveform, sr = torchaudio.load(audio_file_path)
28
  is_mono = waveform.shape[0] == 1
29
  if is_mono:
@@ -31,11 +32,13 @@ def separate_stem(audio_file_path: str, model_name: str, stem_choice: str) -> Au
31
 
32
  with torch.no_grad():
33
  stems_batch = apply_model(
34
- model,
35
  waveform.unsqueeze(0),
36
  overlap=0.2,
37
  shifts=1,
38
- split=True
 
 
39
  )
40
 
41
  stems = stems_batch[0]
@@ -76,13 +79,13 @@ model_card = ModelCard(
76
  with gr.Blocks() as demo:
77
 
78
  dropdown_model = gr.Dropdown(
79
- label="Select Demucs Model",
80
  choices=DEMUX_MODELS,
81
  value="mdx_extra_q"
82
  )
83
 
84
  dropdown_stem = gr.Dropdown(
85
- label="Select Stem to Separate",
86
  choices=list(STEM_CHOICES.keys()),
87
  value="Vocals"
88
  )
 
19
  }
20
 
21
 
22
+ models = dict(zip(DEMUX_MODELS, [pretrained.get_model(m) for m in DEMUX_MODELS]))
23
+
24
+ for model in models.values():
25
  model.eval()
26
 
27
+ def separate_stem(audio_file_path: str, model_name: str, stem_choice: str) -> AudioSignal:
28
  waveform, sr = torchaudio.load(audio_file_path)
29
  is_mono = waveform.shape[0] == 1
30
  if is_mono:
 
32
 
33
  with torch.no_grad():
34
  stems_batch = apply_model(
35
+ models[model_name],
36
  waveform.unsqueeze(0),
37
  overlap=0.2,
38
  shifts=1,
39
+ split=True,
40
+ progress=True,
41
+ num_workers=4
42
  )
43
 
44
  stems = stems_batch[0]
 
79
  with gr.Blocks() as demo:
80
 
81
  dropdown_model = gr.Dropdown(
82
+ label="Demucs Model",
83
  choices=DEMUX_MODELS,
84
  value="mdx_extra_q"
85
  )
86
 
87
  dropdown_stem = gr.Dropdown(
88
+ label="Stem to Separate",
89
  choices=list(STEM_CHOICES.keys()),
90
  value="Vocals"
91
  )