masszhou commited on
Commit
1218764
·
1 Parent(s): 1e9f689

Add application file

Browse files
Files changed (1) hide show
  1. app.py +104 -146
app.py CHANGED
@@ -71,21 +71,19 @@ def convert_to_stereo_and_wav(audio_path: Path) -> Path:
71
  return stereo_path
72
  else:
73
  return Path(audio_path)
74
-
75
 
76
  class MDXModel:
77
- def __init__(
78
- self,
79
- device,
80
- dim_f,
81
- dim_t,
82
- n_fft,
83
- hop=1024,
84
- stem_name=None,
85
- compensation=1.000,
86
- ):
87
- self.dim_f = dim_f
88
- self.dim_t = dim_t
89
  self.dim_c = 4
90
  self.n_fft = n_fft
91
  self.hop = hop
@@ -105,6 +103,9 @@ class MDXModel:
105
  ).to(device)
106
 
107
  def stft(self, x):
 
 
 
108
  x = x.reshape([-1, self.chunk_size])
109
  x = torch.stft(
110
  x,
@@ -122,6 +123,9 @@ class MDXModel:
122
  return x[:, :, : self.dim_f]
123
 
124
  def istft(self, x, freq_pad=None):
 
 
 
125
  freq_pad = (
126
  self.freq_pad.repeat([x.shape[0], 1, 1, 1])
127
  if freq_pad is None
@@ -143,17 +147,15 @@ class MDXModel:
143
  center=True,
144
  )
145
  return x.reshape([-1, 2, self.chunk_size])
146
-
147
 
148
  class MDX:
149
- DEFAULT_SR = 44100
150
  # Unit: seconds
151
  DEFAULT_CHUNK_SIZE = 0 * DEFAULT_SR
152
  DEFAULT_MARGIN_SIZE = 1 * DEFAULT_SR
153
 
154
- def __init__(
155
- self, model_path: str, params: MDXModel, processor=0
156
- ):
157
  # Set the device and the provider (CPU or CUDA)
158
  self.device = (
159
  torch.device(f"cuda:{processor}")
@@ -182,7 +184,7 @@ class MDX:
182
  self.prog = None
183
 
184
  @staticmethod
185
- def get_hash(model_path):
186
  try:
187
  with open(model_path, "rb") as f:
188
  f.seek(-10000 * 1024, 2)
@@ -193,20 +195,21 @@ class MDX:
193
  return model_hash
194
 
195
  @staticmethod
196
- def segment(
197
- wave,
198
- combine=True,
199
- chunk_size=DEFAULT_CHUNK_SIZE,
200
- margin_size=DEFAULT_MARGIN_SIZE,
201
- ):
202
  """
203
  Segment or join segmented wave array
 
204
  Args:
205
  wave: (np.array) Wave array to be segmented or joined
206
  combine: (bool) If True, combines segmented wave array.
207
  If False, segments wave array.
208
  chunk_size: (int) Size of each segment (in samples)
209
  margin_size: (int) Size of margin between segments (in samples)
 
210
  Returns:
211
  numpy array: Segmented or joined wave array
212
  """
@@ -251,11 +254,13 @@ class MDX:
251
 
252
  return processed_wave
253
 
254
- def pad_wave(self, wave):
255
  """
256
  Pad the wave array to match the required chunk size
 
257
  Args:
258
  wave: (np.array) Wave array to be padded
 
259
  Returns:
260
  tuple: (padded_wave, pad, trim)
261
  - padded_wave: Padded wave array
@@ -283,21 +288,21 @@ class MDX:
283
  waves = np.array(wave_p[:, i:i + self.model.chunk_size])
284
  mix_waves.append(waves)
285
 
286
- mix_waves = torch.tensor(mix_waves, dtype=torch.float32).to(
287
- self.device
288
- )
289
 
290
  return mix_waves, pad, trim
291
 
292
- def _process_wave(self, mix_waves, trim, pad, q: queue.Queue, _id: int):
293
  """
294
  Process each wave segment in a multi-threaded environment
 
295
  Args:
296
  mix_waves: (torch.Tensor) Wave segments to be processed
297
  trim: (int) Number of samples trimmed during padding
298
  pad: (int) Number of samples padded during padding
299
  q: (queue.Queue) Queue to hold the processed wave segments
300
  _id: (int) Identifier of the processed wave segment
 
301
  Returns:
302
  numpy array: Processed wave segment
303
  """
@@ -323,12 +328,14 @@ class MDX:
323
  q.put({_id: processed_signal})
324
  return processed_signal
325
 
326
- def process_wave(self, wave: np.array, mt_threads=1):
327
  """
328
  Process the wave array in a multi-threaded environment
 
329
  Args:
330
  wave: (np.array) Wave array to be processed
331
  mt_threads: (int) Number of threads to be used for processing
 
332
  Returns:
333
  numpy array: Processed wave array
334
  """
@@ -367,21 +374,17 @@ class MDX:
367
 
368
 
369
  @spaces.GPU()
370
- def run_mdx(
371
- model_params,
372
- output_dir,
373
- model_path,
374
- filename,
375
- exclude_main=False,
376
- exclude_inversion=False,
377
- suffix=None,
378
- invert_suffix=None,
379
- denoise=False,
380
- keep_orig=True,
381
- m_threads=2,
382
- device_base="cuda",
383
- ):
384
-
385
  if device_base == "cuda":
386
  device = torch.device("cuda:0")
387
  processor_num = 0
@@ -392,8 +395,9 @@ def run_mdx(
392
  device = torch.device("cpu")
393
  processor_num = -1
394
  m_threads = 1
 
395
 
396
- model_hash = MDX.get_hash(model_path)
397
  mp = model_params.get(model_hash)
398
  model = MDXModel(
399
  device,
@@ -405,51 +409,26 @@ def run_mdx(
405
  )
406
 
407
  mdx_sess = MDX(model_path, model, processor=processor_num)
408
- wave, sr = librosa.load(filename, mono=False, sr=44100)
409
  # normalizing input wave gives better output
410
  peak = max(np.max(wave), abs(np.min(wave)))
411
  wave /= peak
412
  if denoise:
413
- wave_processed = -(mdx_sess.process_wave(-wave, m_threads)) + (
414
- mdx_sess.process_wave(wave, m_threads)
415
- )
416
  wave_processed *= 0.5
417
  else:
418
  wave_processed = mdx_sess.process_wave(wave, m_threads)
419
  # return to previous peak
420
  wave_processed *= peak
421
- stem_name = model.stem_name if suffix is None else suffix
422
 
423
- main_filepath = None
424
- if not exclude_main:
425
- main_filepath = os.path.join(
426
- output_dir,
427
- f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav",
428
- )
429
- sf.write(main_filepath, wave_processed.T, sr)
430
-
431
- invert_filepath = None
432
- if not exclude_inversion:
433
- diff_stem_name = (
434
- stem_naming.get(stem_name)
435
- if invert_suffix is None
436
- else invert_suffix
437
- )
438
- stem_name = (
439
- f"{stem_name}_diff" if diff_stem_name is None else diff_stem_name
440
- )
441
- invert_filepath = os.path.join(
442
- output_dir,
443
- f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav",
444
- )
445
- sf.write(
446
- invert_filepath,
447
- (-wave_processed.T * model.compensation) + wave.T,
448
- sr,
449
- )
450
 
451
- if not keep_orig:
452
- os.remove(filename)
 
453
 
454
  del mdx_sess, wave_processed, wave
455
  gc.collect()
@@ -457,31 +436,30 @@ def run_mdx(
457
  return main_filepath, invert_filepath
458
 
459
 
460
- def run_mdx_beta(
461
- model_params,
462
- output_dir,
463
- model_path,
464
- filename,
465
- exclude_main=False,
466
- exclude_inversion=False,
467
- suffix=None,
468
- invert_suffix=None,
469
- denoise=False,
470
- keep_orig=True,
471
- m_threads=2,
472
- device_base="",
473
- ):
474
-
475
- m_threads = 1
476
- duration = librosa.get_duration(filename=filename)
477
- if duration >= 60 and duration <= 120:
478
- m_threads = 8
479
- elif duration > 120:
480
- m_threads = 16
481
-
482
- model_hash = MDX.get_hash(model_path)
483
- device = torch.device("cpu")
484
- processor_num = -1
485
  mp = model_params.get(model_hash)
486
  model = MDXModel(
487
  device,
@@ -493,56 +471,26 @@ def run_mdx_beta(
493
  )
494
 
495
  mdx_sess = MDX(model_path, model, processor=processor_num)
496
- wave, sr = librosa.load(filename, mono=False, sr=44100)
497
  # normalizing input wave gives better output
498
  peak = max(np.max(wave), abs(np.min(wave)))
499
  wave /= peak
500
  if denoise:
501
- wave_processed = -(mdx_sess.process_wave(-wave, m_threads)) + (
502
- mdx_sess.process_wave(wave, m_threads)
503
- )
504
  wave_processed *= 0.5
505
  else:
506
  wave_processed = mdx_sess.process_wave(wave, m_threads)
507
  # return to previous peak
508
  wave_processed *= peak
509
- stem_name = model.stem_name if suffix is None else suffix
510
 
511
- main_filepath = None
512
- if not exclude_main:
513
- main_filepath = os.path.join(
514
- output_dir,
515
- f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav",
516
- )
517
- sf.write(main_filepath, wave_processed.T, sr)
518
-
519
- invert_filepath = None
520
- if not exclude_inversion:
521
- diff_stem_name = (
522
- stem_naming.get(stem_name)
523
- if invert_suffix is None
524
- else invert_suffix
525
- )
526
- stem_name = (
527
- f"{stem_name}_diff" if diff_stem_name is None else diff_stem_name
528
- )
529
- invert_filepath = os.path.join(
530
- output_dir,
531
- f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav",
532
- )
533
- sf.write(
534
- invert_filepath,
535
- (-wave_processed.T * model.compensation) + wave.T,
536
- sr,
537
- )
538
 
539
- if not keep_orig:
540
- os.remove(filename)
541
 
542
- del mdx_sess, wave_processed, wave
543
- gc.collect()
544
- torch.cuda.empty_cache()
545
- return main_filepath, invert_filepath
546
 
547
 
548
  def extract_bgm(mdx_model_params: Dict,
@@ -592,10 +540,20 @@ def extract_vocal(mdx_model_params: Dict,
592
  device_base=device_base,
593
  )
594
  vocals_path = main_vocals_path
595
-
 
 
 
 
 
 
 
 
 
 
 
596
  return vocals_path
597
 
598
-
599
  def process_uvr_task(input_file_path: Path,
600
  output_dir: Path,
601
  models_path: Dict[str, Path],
 
71
  return stereo_path
72
  else:
73
  return Path(audio_path)
74
+
75
 
76
  class MDXModel:
77
+ def __init__(self,
78
+ device: torch.device,
79
+ dim_f: int,
80
+ dim_t: int,
81
+ n_fft: int,
82
+ hop: int = 1024,
83
+ stem_name: str = "Vocals",
84
+ compensation: float = 1.000,):
85
+ self.dim_f = dim_f # frequency bins
86
+ self.dim_t = dim_t
 
 
87
  self.dim_c = 4
88
  self.n_fft = n_fft
89
  self.hop = hop
 
103
  ).to(device)
104
 
105
  def stft(self, x):
106
+ """
107
+ computes the Fourier transform of short overlapping windows of the input
108
+ """
109
  x = x.reshape([-1, self.chunk_size])
110
  x = torch.stft(
111
  x,
 
123
  return x[:, :, : self.dim_f]
124
 
125
  def istft(self, x, freq_pad=None):
126
+ """
127
+ computes the inverse Fourier transform of short overlapping windows of the input
128
+ """
129
  freq_pad = (
130
  self.freq_pad.repeat([x.shape[0], 1, 1, 1])
131
  if freq_pad is None
 
147
  center=True,
148
  )
149
  return x.reshape([-1, 2, self.chunk_size])
150
+
151
 
152
  class MDX:
153
+ DEFAULT_SR = 44100 # unit: Hz
154
  # Unit: seconds
155
  DEFAULT_CHUNK_SIZE = 0 * DEFAULT_SR
156
  DEFAULT_MARGIN_SIZE = 1 * DEFAULT_SR
157
 
158
+ def __init__(self, model_path: Path, params: MDXModel, processor: int = 0):
 
 
159
  # Set the device and the provider (CPU or CUDA)
160
  self.device = (
161
  torch.device(f"cuda:{processor}")
 
184
  self.prog = None
185
 
186
  @staticmethod
187
+ def get_hash(model_path: Path) -> str:
188
  try:
189
  with open(model_path, "rb") as f:
190
  f.seek(-10000 * 1024, 2)
 
195
  return model_hash
196
 
197
  @staticmethod
198
+ def segment(wave: np.array,
199
+ combine: bool = True,
200
+ chunk_size: int = DEFAULT_CHUNK_SIZE,
201
+ margin_size: int = DEFAULT_MARGIN_SIZE,
202
+ ) -> np.array:
 
203
  """
204
  Segment or join segmented wave array
205
+
206
  Args:
207
  wave: (np.array) Wave array to be segmented or joined
208
  combine: (bool) If True, combines segmented wave array.
209
  If False, segments wave array.
210
  chunk_size: (int) Size of each segment (in samples)
211
  margin_size: (int) Size of margin between segments (in samples)
212
+
213
  Returns:
214
  numpy array: Segmented or joined wave array
215
  """
 
254
 
255
  return processed_wave
256
 
257
+ def pad_wave(self, wave: np.array) -> Tuple[np.array, int, int]:
258
  """
259
  Pad the wave array to match the required chunk size
260
+
261
  Args:
262
  wave: (np.array) Wave array to be padded
263
+
264
  Returns:
265
  tuple: (padded_wave, pad, trim)
266
  - padded_wave: Padded wave array
 
288
  waves = np.array(wave_p[:, i:i + self.model.chunk_size])
289
  mix_waves.append(waves)
290
 
291
+ mix_waves = torch.tensor(np.array(mix_waves), dtype=torch.float32).to(self.device)
 
 
292
 
293
  return mix_waves, pad, trim
294
 
295
+ def _process_wave(self, mix_waves, trim, pad, q: queue.Queue, _id: int) -> np.array:
296
  """
297
  Process each wave segment in a multi-threaded environment
298
+
299
  Args:
300
  mix_waves: (torch.Tensor) Wave segments to be processed
301
  trim: (int) Number of samples trimmed during padding
302
  pad: (int) Number of samples padded during padding
303
  q: (queue.Queue) Queue to hold the processed wave segments
304
  _id: (int) Identifier of the processed wave segment
305
+
306
  Returns:
307
  numpy array: Processed wave segment
308
  """
 
328
  q.put({_id: processed_signal})
329
  return processed_signal
330
 
331
+ def process_wave(self, wave: np.array, mt_threads=1) -> np.array:
332
  """
333
  Process the wave array in a multi-threaded environment
334
+
335
  Args:
336
  wave: (np.array) Wave array to be processed
337
  mt_threads: (int) Number of threads to be used for processing
338
+
339
  Returns:
340
  numpy array: Processed wave array
341
  """
 
374
 
375
 
376
  @spaces.GPU()
377
+ def run_mdx(model_params: Dict,
378
+ input_filename: Path,
379
+ output_dir: Path,
380
+ model_path: Path,
381
+ denoise: bool = False,
382
+ m_threads: int = 2,
383
+ device_base: str = "cuda",
384
+ ) -> Tuple[str, str]:
385
+ """
386
+ Separate vocals using MDX model
387
+ """
 
 
 
 
388
  if device_base == "cuda":
389
  device = torch.device("cuda:0")
390
  processor_num = 0
 
395
  device = torch.device("cpu")
396
  processor_num = -1
397
  m_threads = 1
398
+ print(f"device: {device}")
399
 
400
+ model_hash = MDX.get_hash(model_path) # type: str
401
  mp = model_params.get(model_hash)
402
  model = MDXModel(
403
  device,
 
409
  )
410
 
411
  mdx_sess = MDX(model_path, model, processor=processor_num)
412
+ wave, sr = librosa.load(input_filename, mono=False, sr=44100)
413
  # normalizing input wave gives better output
414
  peak = max(np.max(wave), abs(np.min(wave)))
415
  wave /= peak
416
  if denoise:
417
+ wave_processed = -(mdx_sess.process_wave(-wave, m_threads)) + (mdx_sess.process_wave(wave, m_threads)) # type: np.array
 
 
418
  wave_processed *= 0.5
419
  else:
420
  wave_processed = mdx_sess.process_wave(wave, m_threads)
421
  # return to previous peak
422
  wave_processed *= peak
423
+ stem_name = model.stem_name
424
 
425
+ # output main track
426
+ main_filepath = output_dir / input_filename.with_name(f"{input_filename.stem}_{stem_name}.wav")
427
+ sf.write(main_filepath, wave_processed.T, sr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
 
429
+ # output reverse track
430
+ invert_filepath = output_dir / input_filename.with_name(f"{input_filename.stem}_{stem_name}_reverse.wav")
431
+ sf.write(invert_filepath, (-wave_processed.T * model.compensation) + wave.T, sr)
432
 
433
  del mdx_sess, wave_processed, wave
434
  gc.collect()
 
436
  return main_filepath, invert_filepath
437
 
438
 
439
+ @spaces.GPU()
440
+ def run_mdx_return_np(model_params: Dict,
441
+ input_filename: Path,
442
+ model_path: Path,
443
+ denoise: bool = False,
444
+ m_threads: int = 2,
445
+ device_base: str = "cuda",
446
+ ) -> Tuple[np.ndarray, np.ndarray]:
447
+ """
448
+ Separate vocals using MDX model
449
+ """
450
+ if device_base == "cuda":
451
+ device = torch.device("cuda:0")
452
+ processor_num = 0
453
+ device_properties = torch.cuda.get_device_properties(device)
454
+ vram_gb = device_properties.total_memory / 1024**3
455
+ m_threads = 1 if vram_gb < 8 else (8 if vram_gb > 32 else 2)
456
+ else:
457
+ device = torch.device("cpu")
458
+ processor_num = -1
459
+ m_threads = 1
460
+ print(f"device: {device}")
461
+
462
+ model_hash = MDX.get_hash(model_path) # type: str
 
463
  mp = model_params.get(model_hash)
464
  model = MDXModel(
465
  device,
 
471
  )
472
 
473
  mdx_sess = MDX(model_path, model, processor=processor_num)
474
+ wave, sr = librosa.load(input_filename, mono=False, sr=44100)
475
  # normalizing input wave gives better output
476
  peak = max(np.max(wave), abs(np.min(wave)))
477
  wave /= peak
478
  if denoise:
479
+ wave_processed = -(mdx_sess.process_wave(-wave, m_threads)) + (mdx_sess.process_wave(wave, m_threads)) # type: np.array
 
 
480
  wave_processed *= 0.5
481
  else:
482
  wave_processed = mdx_sess.process_wave(wave, m_threads)
483
  # return to previous peak
484
  wave_processed *= peak
485
+ stem_name = model.stem_name
486
 
487
+ # output main track
488
+ main_track = wave_processed.T
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
 
490
+ # output reverse track
491
+ invert_track = (-wave_processed.T * model.compensation) + wave.T
492
 
493
+ return main_track, invert_track
 
 
 
494
 
495
 
496
  def extract_bgm(mdx_model_params: Dict,
 
540
  device_base=device_base,
541
  )
542
  vocals_path = main_vocals_path
543
+ # If "dereverb_flag" is enabled, use Reverb_HQ_By_FoxJoy.onnx for dereverberation
544
+ # deactived since Model license unknown
545
+ # if dereverb_flag:
546
+ # time.sleep(2)
547
+ # _, vocals_dereverb_path = run_mdx(mdx_model_params,
548
+ # output_dir,
549
+ # mdxnet_models_dir/"Reverb_HQ_By_FoxJoy.onnx",
550
+ # vocals_path,
551
+ # denoise=True,
552
+ # device_base=device_base,
553
+ # )
554
+ # vocals_path = vocals_dereverb_path
555
  return vocals_path
556
 
 
557
  def process_uvr_task(input_file_path: Path,
558
  output_dir: Path,
559
  models_path: Dict[str, Path],