kokole commited on
Commit
1fd58ee
Β·
1 Parent(s): 2566adf

optimize gpu memory usage

Browse files
preprocess/tools/vocal_separation/model.py CHANGED
@@ -54,7 +54,7 @@ def build_model(args):
54
  return model, config
55
 
56
 
57
- def build_models(dict_args):
58
  args = parse_args_inference(dict_args)
59
 
60
  ########## load model ##########
@@ -65,13 +65,13 @@ def build_models(dict_args):
65
 
66
  sep_model, sep_config = build_model(args)
67
 
 
 
 
68
  args.config_path = args.der_config_path
69
  args.start_check_point = args.der_start_check_point
70
-
71
- dereverb_model, dereverb_config = build_model(args)
72
 
73
- sep_model = sep_model
74
- dereverb_model = dereverb_model
75
 
76
  return sep_model, sep_config, dereverb_model, dereverb_config, args
77
 
@@ -83,7 +83,10 @@ def main(args, sep_model=None, sep_config=None, dereverb_model=None, dereverb_co
83
 
84
  mix, _ = librosa.load(path, sr=sample_rate, mono=False)
85
  vocals = process(mix, sep_model, args, sep_config, device)
86
- dereverbed_vocals = process(vocals.mean(0), dereverb_model, args, dereverb_config, device)
 
 
 
87
  accompaniment = mix - dereverbed_vocals
88
 
89
  return mix, vocals, dereverbed_vocals, accompaniment, sample_rate
@@ -113,6 +116,7 @@ class VocalSeparator:
113
  der_model_path: str,
114
  der_config_path: str,
115
  *,
 
116
  model_type: str = "mel_band_roformer",
117
  disable_detailed_pbar: bool = True,
118
  device: str = "cuda",
@@ -127,6 +131,7 @@ class VocalSeparator:
127
  sep_start_check_point: Checkpoint path for separation model.
128
  der_config_path: Config path for dereverb model.
129
  der_start_check_point: Checkpoint path for dereverb model.
 
130
  disable_detailed_pbar: Disable detailed progress bars in underlying utils.
131
  verbose: Whether to print verbose logs.
132
  """
@@ -144,10 +149,11 @@ class VocalSeparator:
144
  if verbose:
145
  print("[vocal extraction] init: start")
146
 
147
- sep_model, sep_config, dereverb_model, dereverb_config, args = build_models(args_dict)
148
 
149
  sep_model = sep_model.to(device)
150
- dereverb_model = dereverb_model.to(device)
 
151
 
152
  self.sep_model = sep_model
153
  self.sep_config = sep_config
@@ -159,7 +165,7 @@ class VocalSeparator:
159
 
160
  if verbose:
161
  print(
162
- "[vocal extraction] init success: sep=loaded, dereverb=loaded, device=",
163
  device,
164
  )
165
 
 
54
  return model, config
55
 
56
 
57
+ def build_models(dict_args, use_der: bool = True):
58
  args = parse_args_inference(dict_args)
59
 
60
  ########## load model ##########
 
65
 
66
  sep_model, sep_config = build_model(args)
67
 
68
+ if not use_der:
69
+ return sep_model, sep_config, None, None, args
70
+
71
  args.config_path = args.der_config_path
72
  args.start_check_point = args.der_start_check_point
 
 
73
 
74
+ dereverb_model, dereverb_config = build_model(args)
 
75
 
76
  return sep_model, sep_config, dereverb_model, dereverb_config, args
77
 
 
83
 
84
  mix, _ = librosa.load(path, sr=sample_rate, mono=False)
85
  vocals = process(mix, sep_model, args, sep_config, device)
86
+ if dereverb_model is not None and dereverb_config is not None:
87
+ dereverbed_vocals = process(vocals.mean(0), dereverb_model, args, dereverb_config, device)
88
+ else:
89
+ dereverbed_vocals = vocals
90
  accompaniment = mix - dereverbed_vocals
91
 
92
  return mix, vocals, dereverbed_vocals, accompaniment, sample_rate
 
116
  der_model_path: str,
117
  der_config_path: str,
118
  *,
119
+ use_der: bool = False,
120
  model_type: str = "mel_band_roformer",
121
  disable_detailed_pbar: bool = True,
122
  device: str = "cuda",
 
131
  sep_start_check_point: Checkpoint path for separation model.
132
  der_config_path: Config path for dereverb model.
133
  der_start_check_point: Checkpoint path for dereverb model.
134
+ use_der: If False, do not load or run dereverb; vocals_dereverbed will equal vocals.
135
  disable_detailed_pbar: Disable detailed progress bars in underlying utils.
136
  verbose: Whether to print verbose logs.
137
  """
 
149
  if verbose:
150
  print("[vocal extraction] init: start")
151
 
152
+ sep_model, sep_config, dereverb_model, dereverb_config, args = build_models(args_dict, use_der=use_der)
153
 
154
  sep_model = sep_model.to(device)
155
+ if dereverb_model is not None:
156
+ dereverb_model = dereverb_model.to(device)
157
 
158
  self.sep_model = sep_model
159
  self.sep_config = sep_config
 
165
 
166
  if verbose:
167
  print(
168
+ "[vocal extraction] init success: sep=loaded, dereverb=" + ("loaded" if use_der else "skipped") + ", device=",
169
  device,
170
  )
171
 
webui_svc.py CHANGED
@@ -244,17 +244,11 @@ APP_STATE = AppState()
244
 
245
 
246
  @spaces.GPU
247
- def _start_svc(
248
- prompt_audio,
249
- target_audio,
250
- prompt_vocal_sep=False,
251
- target_vocal_sep=True,
252
- auto_shift=True,
253
- auto_mix_acc=True,
254
- pitch_shift=0,
255
- n_step=32,
256
- cfg=1.0,
257
- seed=42
258
  ):
259
  try:
260
  prompt_audio = _normalize_audio_input(prompt_audio)
@@ -288,11 +282,43 @@ def _start_svc(
288
  print(target_msg, file=sys.stderr, flush=True)
289
  return None
290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  ok, msg, generated = APP_STATE.run_svc(
292
- prompt_wav_path=prompt_wav,
293
- target_wav_path=target_wav,
294
- prompt_f0_path=prompt_f0,
295
- target_f0_path=target_f0,
296
  session_base=session_base,
297
  auto_shift=bool(auto_shift),
298
  auto_mix_acc=bool(auto_mix_acc),
@@ -306,7 +332,7 @@ def _start_svc(
306
  return None
307
  return str(generated)
308
  except Exception:
309
- _print_exception("_start_svc")
310
  return None
311
  finally:
312
  gc.collect()
@@ -314,8 +340,26 @@ def _start_svc(
314
  torch.cuda.empty_cache()
315
 
316
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
  def render_tab_content() -> None:
318
- """Render SVC tab content (for embedding in app.py). Same UI style as webui: two columns, no title."""
319
  with gr.Row(equal_height=False):
320
  # ── Left column: inputs & controls ──
321
  with gr.Column(scale=1):
@@ -351,6 +395,7 @@ def render_tab_content() -> None:
351
  # ── Right column: output ──
352
  with gr.Column(scale=1):
353
  output_audio = gr.Audio(label="Generated audio", type="filepath", interactive=False)
 
354
  gr.Examples(
355
  examples=EXAMPLE_LIST,
356
  inputs=[prompt_audio, target_audio],
@@ -361,19 +406,12 @@ def render_tab_content() -> None:
361
  )
362
 
363
  run_btn.click(
364
- fn=_start_svc,
365
- inputs=[
366
- prompt_audio,
367
- target_audio,
368
- prompt_vocal_sep,
369
- target_vocal_sep,
370
- auto_shift,
371
- auto_mix_acc,
372
- pitch_shift,
373
- n_step,
374
- cfg,
375
- seed_input,
376
- ],
377
  outputs=[output_audio],
378
  )
379
 
 
244
 
245
 
246
  @spaces.GPU
247
+ def _run_svc_preprocess(
248
+ prompt_audio,
249
+ target_audio,
250
+ prompt_vocal_sep=False,
251
+ target_vocal_sep=True,
 
 
 
 
 
 
252
  ):
253
  try:
254
  prompt_audio = _normalize_audio_input(prompt_audio)
 
282
  print(target_msg, file=sys.stderr, flush=True)
283
  return None
284
 
285
+ return (
286
+ str(session_base),
287
+ str(prompt_wav),
288
+ str(prompt_f0),
289
+ str(target_wav),
290
+ str(target_f0),
291
+ )
292
+ except Exception:
293
+ _print_exception("_run_svc_preprocess")
294
+ return None
295
+ finally:
296
+ gc.collect()
297
+ if torch.cuda.is_available():
298
+ torch.cuda.empty_cache()
299
+
300
+
301
+ @spaces.GPU
302
+ def _run_svc_convert(
303
+ preprocess_state,
304
+ auto_shift=True,
305
+ auto_mix_acc=True,
306
+ pitch_shift=0,
307
+ n_step=32,
308
+ cfg=1.0,
309
+ seed=42,
310
+ ):
311
+ try:
312
+ if preprocess_state is None or not isinstance(preprocess_state, (tuple, list)) or len(preprocess_state) != 5:
313
+ return None
314
+ session_base_str, prompt_wav, prompt_f0, target_wav, target_f0 = preprocess_state
315
+ session_base = Path(session_base_str)
316
+
317
  ok, msg, generated = APP_STATE.run_svc(
318
+ prompt_wav_path=Path(prompt_wav),
319
+ target_wav_path=Path(target_wav),
320
+ prompt_f0_path=Path(prompt_f0),
321
+ target_f0_path=Path(target_f0),
322
  session_base=session_base,
323
  auto_shift=bool(auto_shift),
324
  auto_mix_acc=bool(auto_mix_acc),
 
332
  return None
333
  return str(generated)
334
  except Exception:
335
+ _print_exception("_run_svc_convert")
336
  return None
337
  finally:
338
  gc.collect()
 
340
  torch.cuda.empty_cache()
341
 
342
 
343
+ @spaces.GPU
344
+ def _start_svc(
345
+ prompt_audio,
346
+ target_audio,
347
+ prompt_vocal_sep=False,
348
+ target_vocal_sep=True,
349
+ auto_shift=True,
350
+ auto_mix_acc=True,
351
+ pitch_shift=0,
352
+ n_step=32,
353
+ cfg=1.0,
354
+ seed=42,
355
+ ):
356
+ state = _run_svc_preprocess(prompt_audio, target_audio, prompt_vocal_sep, target_vocal_sep)
357
+ if state is None:
358
+ return None
359
+ return _run_svc_convert(state, auto_shift, auto_mix_acc, pitch_shift, n_step, cfg, seed)
360
+
361
+
362
  def render_tab_content() -> None:
 
363
  with gr.Row(equal_height=False):
364
  # ── Left column: inputs & controls ──
365
  with gr.Column(scale=1):
 
395
  # ── Right column: output ──
396
  with gr.Column(scale=1):
397
  output_audio = gr.Audio(label="Generated audio", type="filepath", interactive=False)
398
+ svc_state = gr.State(value=None)
399
  gr.Examples(
400
  examples=EXAMPLE_LIST,
401
  inputs=[prompt_audio, target_audio],
 
406
  )
407
 
408
  run_btn.click(
409
+ fn=_run_svc_preprocess,
410
+ inputs=[prompt_audio, target_audio, prompt_vocal_sep, target_vocal_sep],
411
+ outputs=[svc_state],
412
+ ).then(
413
+ fn=_run_svc_convert,
414
+ inputs=[svc_state, auto_shift, auto_mix_acc, pitch_shift, n_step, cfg, seed_input],
 
 
 
 
 
 
 
415
  outputs=[output_audio],
416
  )
417