noblebarkrr commited on
Commit
8a06b33
·
1 Parent(s): e4bb613

Code actualized

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. mvsepless/additional_app.py +15 -4
  2. mvsepless/app.py +11 -13
  3. mvsepless/i18n.py +4 -4
  4. mvsepless/infer_utils.py +18 -4
  5. mvsepless/install.py +6 -15
  6. mvsepless/models/bandit/core/__init__.py +669 -669
  7. mvsepless/models/bandit/core/data/__init__.py +2 -2
  8. mvsepless/models/bandit/core/data/_types.py +17 -17
  9. mvsepless/models/bandit/core/data/augmentation.py +102 -102
  10. mvsepless/models/bandit/core/data/augmented.py +34 -34
  11. mvsepless/models/bandit/core/data/base.py +60 -60
  12. mvsepless/models/bandit/core/data/dnr/datamodule.py +64 -64
  13. mvsepless/models/bandit/core/data/dnr/dataset.py +360 -360
  14. mvsepless/models/bandit/core/data/dnr/preprocess.py +51 -51
  15. mvsepless/models/bandit/core/data/musdb/datamodule.py +75 -75
  16. mvsepless/models/bandit/core/data/musdb/dataset.py +241 -241
  17. mvsepless/models/bandit/core/data/musdb/preprocess.py +223 -223
  18. mvsepless/models/bandit/core/data/musdb/validation.yaml +14 -14
  19. mvsepless/models/bandit/core/loss/__init__.py +8 -8
  20. mvsepless/models/bandit/core/loss/_complex.py +27 -27
  21. mvsepless/models/bandit/core/loss/_multistem.py +43 -43
  22. mvsepless/models/bandit/core/loss/_timefreq.py +94 -94
  23. mvsepless/models/bandit/core/loss/snr.py +131 -131
  24. mvsepless/models/bandit/core/metrics/__init__.py +7 -7
  25. mvsepless/models/bandit/core/metrics/_squim.py +350 -350
  26. mvsepless/models/bandit/core/metrics/snr.py +124 -124
  27. mvsepless/models/bandit/core/model/__init__.py +3 -3
  28. mvsepless/models/bandit/core/model/_spectral.py +54 -54
  29. mvsepless/models/bandit/core/model/bsrnn/__init__.py +23 -23
  30. mvsepless/models/bandit/core/model/bsrnn/bandsplit.py +119 -119
  31. mvsepless/models/bandit/core/model/bsrnn/core.py +619 -619
  32. mvsepless/models/bandit/core/model/bsrnn/maskestim.py +327 -327
  33. mvsepless/models/bandit/core/model/bsrnn/tfmodel.py +287 -287
  34. mvsepless/models/bandit/core/model/bsrnn/utils.py +518 -518
  35. mvsepless/models/bandit/core/model/bsrnn/wrapper.py +828 -828
  36. mvsepless/models/bandit/core/utils/audio.py +324 -324
  37. mvsepless/models/bandit/model_from_config.py +26 -26
  38. mvsepless/models/bandit_v2/bandit.py +360 -360
  39. mvsepless/models/bandit_v2/bandsplit.py +127 -127
  40. mvsepless/models/bandit_v2/film.py +23 -23
  41. mvsepless/models/bandit_v2/maskestim.py +269 -269
  42. mvsepless/models/bandit_v2/tfmodel.py +141 -141
  43. mvsepless/models/bandit_v2/utils.py +384 -384
  44. mvsepless/models/bs_roformer/__init__.py +14 -14
  45. mvsepless/models/bs_roformer/attend.py +127 -127
  46. mvsepless/models/bs_roformer/attend_sage.py +146 -146
  47. mvsepless/models/bs_roformer/attend_sw.py +88 -88
  48. mvsepless/models/bs_roformer/bs_roformer.py +696 -696
  49. mvsepless/models/bs_roformer/bs_roformer_fno.py +704 -704
  50. mvsepless/models/bs_roformer/bs_roformer_hyperace.py +1122 -1122
mvsepless/additional_app.py CHANGED
@@ -279,6 +279,7 @@ class CustomSeparator(Separator, GradioHelper):
279
  use_spec_invert: bool = False,
280
  econom_mode: Optional[bool] = None,
281
  chunk_duration: float = 300,
 
282
  progress: gr.Progress = gr.Progress(track_tqdm=True),
283
  ) -> List:
284
  """
@@ -295,7 +296,7 @@ class CustomSeparator(Separator, GradioHelper):
295
  vr_aggr: Агрессивность для VR
296
  vr_post_process: Постобработка для VR
297
  vr_high_end_process: Обработка высоких частот для VR
298
- mdx_denoise: Шумоподавление для MDX
299
  use_spec_invert: Использовать инверсию спектрограммы
300
  econom_mode: Эконом-режим
301
  chunk_duration: Длительность чанка
@@ -309,7 +310,8 @@ class CustomSeparator(Separator, GradioHelper):
309
 
310
  add_settings: Dict[str, Any] = {
311
  "add_single_sep_text_progress": None,
312
- "single_mode": False
 
313
  }
314
 
315
  if econom_mode is not None:
@@ -412,6 +414,11 @@ class CustomSeparator(Separator, GradioHelper):
412
  value=False,
413
  interactive=True
414
  )
 
 
 
 
 
415
 
416
  gr.Markdown(f"<h4>{_i18n('economy_settings')}</h4>", container=True)
417
  econom_mode = gr.Checkbox(
@@ -497,7 +504,8 @@ class CustomSeparator(Separator, GradioHelper):
497
  stems,
498
  use_spec_for_extract_instrumental,
499
  econom_mode,
500
- chunk_dur_slider
 
501
  ],
502
  outputs=[sep_state, status],
503
  show_progress="full",
@@ -514,6 +522,7 @@ class CustomSeparator(Separator, GradioHelper):
514
  u_spec: bool,
515
  ec_mode: bool,
516
  ch_dur: float,
 
517
  progress: gr.Progress = gr.Progress(track_tqdm=True),
518
  ) -> Tuple[gr.update, gr.update]:
519
  results = self._separate_batch(
@@ -527,6 +536,7 @@ class CustomSeparator(Separator, GradioHelper):
527
  u_spec,
528
  ec_mode,
529
  ch_dur * 60,
 
530
  progress=progress,
531
  )
532
  return gr.update(value=str(results)), gr.update(visible=False)
@@ -801,10 +811,11 @@ class CustomSeparator(Separator, GradioHelper):
801
  def refresh_all_models() -> Tuple[gr.update, gr.update, gr.update]:
802
  models = self.get_mn()
803
  first_model = models[0] if models else None
 
804
  return (
805
  gr.update(choices=models, value=first_model),
806
  gr.update(choices=models, value=first_model),
807
- gr.update(choices=models, value=[first_model]),
808
  )
809
 
810
  class AutoEnsembless(Separator, GradioHelper):
 
279
  use_spec_invert: bool = False,
280
  econom_mode: Optional[bool] = None,
281
  chunk_duration: float = 300,
282
+ denoise: bool = False,
283
  progress: gr.Progress = gr.Progress(track_tqdm=True),
284
  ) -> List:
285
  """
 
296
  vr_aggr: Агрессивность для VR
297
  vr_post_process: Постобработка для VR
298
  vr_high_end_process: Обработка высоких частот для VR
299
+ denoise: Шумоподавление для MDX
300
  use_spec_invert: Использовать инверсию спектрограммы
301
  econom_mode: Эконом-режим
302
  chunk_duration: Длительность чанка
 
310
 
311
  add_settings: Dict[str, Any] = {
312
  "add_single_sep_text_progress": None,
313
+ "single_mode": False,
314
+ "denoise": denoise
315
  }
316
 
317
  if econom_mode is not None:
 
414
  value=False,
415
  interactive=True
416
  )
417
+ denoise = gr.Checkbox(
418
+ label=_i18n("denoise"),
419
+ value=False,
420
+ interactive=True,
421
+ )
422
 
423
  gr.Markdown(f"<h4>{_i18n('economy_settings')}</h4>", container=True)
424
  econom_mode = gr.Checkbox(
 
504
  stems,
505
  use_spec_for_extract_instrumental,
506
  econom_mode,
507
+ chunk_dur_slider,
508
+ denoise
509
  ],
510
  outputs=[sep_state, status],
511
  show_progress="full",
 
522
  u_spec: bool,
523
  ec_mode: bool,
524
  ch_dur: float,
525
+ den: bool,
526
  progress: gr.Progress = gr.Progress(track_tqdm=True),
527
  ) -> Tuple[gr.update, gr.update]:
528
  results = self._separate_batch(
 
536
  u_spec,
537
  ec_mode,
538
  ch_dur * 60,
539
+ den,
540
  progress=progress,
541
  )
542
  return gr.update(value=str(results)), gr.update(visible=False)
 
811
  def refresh_all_models() -> Tuple[gr.update, gr.update, gr.update]:
812
  models = self.get_mn()
813
  first_model = models[0] if models else None
814
+ first_model2 = [models[0]] if models else []
815
  return (
816
  gr.update(choices=models, value=first_model),
817
  gr.update(choices=models, value=first_model),
818
+ gr.update(choices=models, value=first_model2),
819
  )
820
 
821
  class AutoEnsembless(Separator, GradioHelper):
mvsepless/app.py CHANGED
@@ -278,7 +278,7 @@ class SeparatorGradio(GradioHelper, DownloadModelManager):
278
  vr_aggr: int = 5,
279
  vr_post_process: bool = False,
280
  vr_high_end_process: bool = False,
281
- mdx_denoise: bool = False,
282
  use_spec_invert: bool = False,
283
  econom_mode: Optional[bool] = None,
284
  chunk_duration: float = 300,
@@ -298,7 +298,7 @@ class SeparatorGradio(GradioHelper, DownloadModelManager):
298
  vr_aggr: Агрессивность для VR
299
  vr_post_process: Постобработка для VR
300
  vr_high_end_process: Обработка высоких частот для VR
301
- mdx_denoise: Шумоподавление для MDX
302
  use_spec_invert: Использовать инверсию спектрограммы
303
  econom_mode: Эконом-режим
304
  chunk_duration: Длительность чанка
@@ -311,7 +311,7 @@ class SeparatorGradio(GradioHelper, DownloadModelManager):
311
  self.chunk_duration = chunk_duration
312
 
313
  add_settings: Dict[str, Any] = {
314
- "mdx_denoise": mdx_denoise,
315
  "vr_aggr": vr_aggr,
316
  "vr_post_process": vr_post_process,
317
  "vr_high_end_process": vr_high_end_process,
@@ -487,19 +487,17 @@ class SeparatorGradio(GradioHelper, DownloadModelManager):
487
  interactive=True
488
  )
489
 
490
- gr.Markdown(f"<h4>{_i18n('mdx_settings')}</h4>", container=True)
491
- mdx_denoise = gr.Checkbox(
492
- label=_i18n("mdx_denoise"),
493
- value=False,
494
- interactive=True,
495
- )
496
-
497
  gr.Markdown(f"<h4>{_i18n('invert_settings')}</h4>", container=True)
498
  use_spec_for_extract_instrumental = gr.Checkbox(
499
  label=_i18n("use_spectrogram_invert"),
500
  value=False,
501
  interactive=True
502
  )
 
 
 
 
 
503
 
504
  gr.Markdown(f"<h4>{_i18n('economy_settings')}</h4>", container=True)
505
  econom_mode = gr.Checkbox(
@@ -583,7 +581,7 @@ class SeparatorGradio(GradioHelper, DownloadModelManager):
583
  output_bitrate,
584
  template,
585
  stems,
586
- mdx_denoise,
587
  vr_aggr,
588
  vr_enable_post_process,
589
  vr_enable_high_end_process,
@@ -603,7 +601,7 @@ class SeparatorGradio(GradioHelper, DownloadModelManager):
603
  output_bitrate: int,
604
  template: str,
605
  stems: List[str],
606
- mdx_denoise: bool,
607
  vr_aggr: int,
608
  vr_pp: bool,
609
  vr_hip: bool,
@@ -623,7 +621,7 @@ class SeparatorGradio(GradioHelper, DownloadModelManager):
623
  vr_aggr,
624
  vr_pp,
625
  vr_hip,
626
- mdx_denoise,
627
  u_spec,
628
  ec_mode,
629
  ch_dur * 60,
 
278
  vr_aggr: int = 5,
279
  vr_post_process: bool = False,
280
  vr_high_end_process: bool = False,
281
+ denoise: bool = False,
282
  use_spec_invert: bool = False,
283
  econom_mode: Optional[bool] = None,
284
  chunk_duration: float = 300,
 
298
  vr_aggr: Агрессивность для VR
299
  vr_post_process: Постобработка для VR
300
  vr_high_end_process: Обработка высоких частот для VR
301
+ denoise: Шумоподавление для MDX
302
  use_spec_invert: Использовать инверсию спектрограммы
303
  econom_mode: Эконом-режим
304
  chunk_duration: Длительность чанка
 
311
  self.chunk_duration = chunk_duration
312
 
313
  add_settings: Dict[str, Any] = {
314
+ "denoise": denoise,
315
  "vr_aggr": vr_aggr,
316
  "vr_post_process": vr_post_process,
317
  "vr_high_end_process": vr_high_end_process,
 
487
  interactive=True
488
  )
489
 
 
 
 
 
 
 
 
490
  gr.Markdown(f"<h4>{_i18n('invert_settings')}</h4>", container=True)
491
  use_spec_for_extract_instrumental = gr.Checkbox(
492
  label=_i18n("use_spectrogram_invert"),
493
  value=False,
494
  interactive=True
495
  )
496
+ denoise = gr.Checkbox(
497
+ label=_i18n("denoise"),
498
+ value=False,
499
+ interactive=True,
500
+ )
501
 
502
  gr.Markdown(f"<h4>{_i18n('economy_settings')}</h4>", container=True)
503
  econom_mode = gr.Checkbox(
 
581
  output_bitrate,
582
  template,
583
  stems,
584
+ denoise,
585
  vr_aggr,
586
  vr_enable_post_process,
587
  vr_enable_high_end_process,
 
601
  output_bitrate: int,
602
  template: str,
603
  stems: List[str],
604
+ denoise: bool,
605
  vr_aggr: int,
606
  vr_pp: bool,
607
  vr_hip: bool,
 
621
  vr_aggr,
622
  vr_pp,
623
  vr_hip,
624
+ denoise,
625
  u_spec,
626
  ec_mode,
627
  ch_dur * 60,
mvsepless/i18n.py CHANGED
@@ -128,7 +128,7 @@ TRANSLATIONS: Dict[Language, Dict[str, str]] = {
128
  "vr_post_process": "Дополнительная обработка для улучшения качества разделения",
129
  "vr_high_end_process": "Восстановление недостающих высоких частот",
130
  "mdx_settings": "MDX-NET",
131
- "mdx_denoise": "Шумоподавление",
132
  "invert_settings": "Инвертирование результата",
133
  "use_spectrogram_invert": "При извлечении инструментала/второго стема/остатка использовать спектрограмму",
134
  "economy_settings": "Экономия",
@@ -441,7 +441,7 @@ TRANSLATIONS: Dict[Language, Dict[str, str]] = {
441
  "ext_inst_help": "Извлечь инструментал вычитанием",
442
  "use_spec_invert_help": "Инверсия спектрограммы для вторичного стема",
443
  "install_only_help": "Только установка модели",
444
- "mdx_denoise_help": "Шумоподавление для MDX-NET моделей",
445
  "vr_aggression_help": "Агрессивность для VR моделей (по умолчанию: 5)",
446
  "vr_high_end_help": "Восстановление недостающих высоких частот на VR моделях",
447
  "vr_post_process_help": "Дополнительная обработка для улучшения качества разделения VR модели",
@@ -856,7 +856,7 @@ TRANSLATIONS: Dict[Language, Dict[str, str]] = {
856
  "vr_post_process": "Additional post-processing to improve separation quality",
857
  "vr_high_end_process": "Restore missing high frequencies",
858
  "mdx_settings": "MDX-NET",
859
- "mdx_denoise": "Denoise",
860
  "invert_settings": "Result Inversion",
861
  "use_spectrogram_invert": "Use spectrogram for extracting instrumental/second stem/remainder",
862
  "economy_settings": "Economy",
@@ -1169,7 +1169,7 @@ TRANSLATIONS: Dict[Language, Dict[str, str]] = {
1169
  "ext_inst_help": "Extract instrumental by subtraction",
1170
  "use_spec_invert_help": "Use spectrogram inversion for secondary stem",
1171
  "install_only_help": "Download model only",
1172
- "mdx_denoise_help": "Denoise for MDX-NET models",
1173
  "vr_aggression_help": "Aggression for VR models (default: 5)",
1174
  "vr_high_end_help": "Restore missing high frequencies on VR models",
1175
  "vr_post_process_help": "Additional post-processing for VR models",
 
128
  "vr_post_process": "Дополнительная обработка для улучшения качества разделения",
129
  "vr_high_end_process": "Восстановление недостающих высоких частот",
130
  "mdx_settings": "MDX-NET",
131
+ "denoise": "Шумоподавление",
132
  "invert_settings": "Инвертирование результата",
133
  "use_spectrogram_invert": "При извлечении инструментала/второго стема/остатка использовать спектрограмму",
134
  "economy_settings": "Экономия",
 
441
  "ext_inst_help": "Извлечь инструментал вычитанием",
442
  "use_spec_invert_help": "Инверсия спектрограммы для вторичного стема",
443
  "install_only_help": "Только установка модели",
444
+ "denoise_help": "Удаление шума из вывода",
445
  "vr_aggression_help": "Агрессивность для VR моделей (по умолчанию: 5)",
446
  "vr_high_end_help": "Восстановление недостающих высоких частот на VR моделях",
447
  "vr_post_process_help": "Дополнительная обработка для улучшения качества разделения VR модели",
 
856
  "vr_post_process": "Additional post-processing to improve separation quality",
857
  "vr_high_end_process": "Restore missing high frequencies",
858
  "mdx_settings": "MDX-NET",
859
+ "denoise": "Denoise",
860
  "invert_settings": "Result Inversion",
861
  "use_spectrogram_invert": "Use spectrogram for extracting instrumental/second stem/remainder",
862
  "economy_settings": "Economy",
 
1169
  "ext_inst_help": "Extract instrumental by subtraction",
1170
  "use_spec_invert_help": "Use spectrogram inversion for secondary stem",
1171
  "install_only_help": "Download model only",
1172
+ "denoise_help": "Remove noise from output",
1173
  "vr_aggression_help": "Aggression for VR models (default: 5)",
1174
  "vr_high_end_help": "Restore missing high frequencies on VR models",
1175
  "vr_post_process_help": "Additional post-processing for VR models",
mvsepless/infer_utils.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import sys
2
  sys.stdout.reconfigure(encoding='utf-8')
3
  sys.stderr.reconfigure(encoding='utf-8')
@@ -70,7 +71,9 @@ def get_model_from_config(model_type: str, config_path: str) -> Tuple[Any, Any]:
70
  from models.vr_arch import VRNet
71
  model = VRNet(**dict(config.model))
72
  elif model_type == "htdemucs":
73
- from models.demucs4ht import get_model
 
 
74
  model = get_model(config)
75
  elif model_type == "mel_band_roformer":
76
  if hasattr(config, "windowed"):
@@ -419,6 +422,7 @@ def demix_demucs(
419
  mix = torch.tensor(mix, dtype=torch.float32)
420
  chunk_size = config.training.samplerate * config.training.segment
421
  num_instruments = len(config.training.instruments)
 
422
  num_overlap = config.inference.num_overlap
423
  step = chunk_size // num_overlap
424
  fade_size = chunk_size // 10
@@ -451,8 +455,12 @@ def demix_demucs(
451
 
452
  if len(batch_data) >= batch_size or i >= mix.shape[1]:
453
  arr = torch.stack(batch_data, dim=0)
454
- x = model(arr)
455
-
 
 
 
 
456
  window = windowing_array.clone()
457
  if i - step == 0:
458
  window[:fade_size] = 1
@@ -511,6 +519,7 @@ def demix_generic(
511
  chunk_size = config.audio.chunk_size
512
  instruments = prefer_target_instrument(config)
513
  num_instruments = len(instruments)
 
514
  num_overlap = config.inference.num_overlap
515
 
516
  fade_size = chunk_size // 10
@@ -550,7 +559,12 @@ def demix_generic(
550
 
551
  if len(batch_data) >= batch_size or i >= mix.shape[1]:
552
  arr = torch.stack(batch_data, dim=0)
553
- x = model(arr)
 
 
 
 
 
554
 
555
  window = windowing_array.clone()
556
  if i - step == 0:
 
1
+ import os
2
  import sys
3
  sys.stdout.reconfigure(encoding='utf-8')
4
  sys.stderr.reconfigure(encoding='utf-8')
 
71
  from models.vr_arch import VRNet
72
  model = VRNet(**dict(config.model))
73
  elif model_type == "htdemucs":
74
+ models_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models')
75
+ sys.path.append(models_path)
76
+ from demucs import get_model
77
  model = get_model(config)
78
  elif model_type == "mel_band_roformer":
79
  if hasattr(config, "windowed"):
 
422
  mix = torch.tensor(mix, dtype=torch.float32)
423
  chunk_size = config.training.samplerate * config.training.segment
424
  num_instruments = len(config.training.instruments)
425
+ denoise = config.inference.denoise
426
  num_overlap = config.inference.num_overlap
427
  step = chunk_size // num_overlap
428
  fade_size = chunk_size // 10
 
455
 
456
  if len(batch_data) >= batch_size or i >= mix.shape[1]:
457
  arr = torch.stack(batch_data, dim=0)
458
+ if denoise:
459
+ x1 = model(arr)
460
+ x2 = model(-arr)
461
+ x = (x1 + -x2) * 0.5
462
+ else:
463
+ x = model(arr)
464
  window = windowing_array.clone()
465
  if i - step == 0:
466
  window[:fade_size] = 1
 
519
  chunk_size = config.audio.chunk_size
520
  instruments = prefer_target_instrument(config)
521
  num_instruments = len(instruments)
522
+ denoise = config.inference.denoise
523
  num_overlap = config.inference.num_overlap
524
 
525
  fade_size = chunk_size // 10
 
559
 
560
  if len(batch_data) >= batch_size or i >= mix.shape[1]:
561
  arr = torch.stack(batch_data, dim=0)
562
+ if denoise:
563
+ x1 = model(arr)
564
+ x2 = model(-arr)
565
+ x = (x1 + -x2) * 0.5
566
+ else:
567
+ x = model(arr)
568
 
569
  window = windowing_array.clone()
570
  if i - step == 0:
mvsepless/install.py CHANGED
@@ -221,33 +221,28 @@ universal_requirements: List[str] = [
221
  "matplotlib",
222
  "tqdm",
223
  "einops",
224
- "protobuf",
225
  "soundfile",
226
  "pydub",
227
  "webrtcvad",
228
  "audiomentations",
229
  "pedalboard",
230
  "ml_collections",
231
- "timm",
232
  "wandb",
233
- "accelerate",
234
  "bitsandbytes",
235
  "tokenizers",
236
  "huggingface-hub",
237
  "transformers",
238
- "torchseg",
239
- "demucs==4.0.0",
 
240
  "asteroid>=0.6.0",
241
  "pyloudnorm",
242
- "prodigyopt",
243
- "torch_log_wmse",
244
  "rotary_embedding_torch",
245
  "gradio<6.0",
246
  "omegaconf",
247
  "beartype",
248
  "spafe",
249
  "torch_audiomentations",
250
- "auraloss",
251
  "onnx>=1.17",
252
  "onnx2torch>=0.3.0",
253
  "onnxruntime-gpu>=1.17" if cuda_available else "onnxruntime>=1.17",
@@ -279,32 +274,28 @@ old_requirements: List[str] = [
279
  "matplotlib==3.10.8",
280
  "tqdm==4.67.1",
281
  "einops==0.8.1",
282
- "protobuf==6.33.4",
283
  "soundfile==0.13.1",
284
  "pydub==0.25.1",
285
  "webrtcvad==2.0.10",
286
  "audiomentations==0.43.1",
287
  "pedalboard==0.8.2",
288
  "ml_collections==1.1.0",
289
- "timm==1.0.24",
290
  "wandb==0.24.0",
291
- "accelerate==1.2.1",
292
  "bitsandbytes==0.45.0",
293
  "tokenizers==0.15.2",
294
  "huggingface-hub==0.34.2",
295
  "transformers==4.39.3",
296
- "torchseg==0.0.1a4",
297
- "demucs==4.0.0",
 
298
  "asteroid==0.6.0",
299
  "pyloudnorm",
300
- "prodigyopt==1.1.2",
301
  "rotary_embedding_torch==0.3.6",
302
  "gradio<6.0.0",
303
  "omegaconf==2.3.0",
304
  "beartype==0.22.9",
305
  "spafe==0.3.3",
306
  "torch_audiomentations==0.12.0",
307
- "auraloss==0.4.0",
308
  "onnx>=1.17",
309
  "onnx2torch>=0.3.0",
310
  "onnxruntime-gpu>=1.17" if cuda_available else "onnxruntime>=1.17",
 
221
  "matplotlib",
222
  "tqdm",
223
  "einops",
 
224
  "soundfile",
225
  "pydub",
226
  "webrtcvad",
227
  "audiomentations",
228
  "pedalboard",
229
  "ml_collections",
 
230
  "wandb",
 
231
  "bitsandbytes",
232
  "tokenizers",
233
  "huggingface-hub",
234
  "transformers",
235
+ "diffq>=0.2.1",
236
+ "julius>=0.2.3",
237
+ "openunmix",
238
  "asteroid>=0.6.0",
239
  "pyloudnorm",
 
 
240
  "rotary_embedding_torch",
241
  "gradio<6.0",
242
  "omegaconf",
243
  "beartype",
244
  "spafe",
245
  "torch_audiomentations",
 
246
  "onnx>=1.17",
247
  "onnx2torch>=0.3.0",
248
  "onnxruntime-gpu>=1.17" if cuda_available else "onnxruntime>=1.17",
 
274
  "matplotlib==3.10.8",
275
  "tqdm==4.67.1",
276
  "einops==0.8.1",
 
277
  "soundfile==0.13.1",
278
  "pydub==0.25.1",
279
  "webrtcvad==2.0.10",
280
  "audiomentations==0.43.1",
281
  "pedalboard==0.8.2",
282
  "ml_collections==1.1.0",
 
283
  "wandb==0.24.0",
 
284
  "bitsandbytes==0.45.0",
285
  "tokenizers==0.15.2",
286
  "huggingface-hub==0.34.2",
287
  "transformers==4.39.3",
288
+ "diffq>=0.2.1",
289
+ "julius>=0.2.3",
290
+ "openunmix",
291
  "asteroid==0.6.0",
292
  "pyloudnorm",
 
293
  "rotary_embedding_torch==0.3.6",
294
  "gradio<6.0.0",
295
  "omegaconf==2.3.0",
296
  "beartype==0.22.9",
297
  "spafe==0.3.3",
298
  "torch_audiomentations==0.12.0",
 
299
  "onnx>=1.17",
300
  "onnx2torch>=0.3.0",
301
  "onnxruntime-gpu>=1.17" if cuda_available else "onnxruntime>=1.17",
mvsepless/models/bandit/core/__init__.py CHANGED
@@ -1,669 +1,669 @@
1
- import os.path
2
- from collections import defaultdict
3
- from itertools import chain, combinations
4
- from typing import Any, Dict, Iterator, Mapping, Optional, Tuple, Type, TypedDict
5
-
6
- import pytorch_lightning as pl
7
- import torch
8
- import torchaudio as ta
9
- import torchmetrics as tm
10
- from asteroid import losses as asteroid_losses
11
-
12
- from pytorch_lightning.utilities.types import STEP_OUTPUT
13
- from torch import nn, optim
14
- from torch.optim import lr_scheduler
15
- from torch.optim.lr_scheduler import LRScheduler
16
-
17
- from . import loss, metrics as metrics_, model
18
- from .data._types import BatchedDataDict
19
- from .data.augmentation import BaseAugmentor, StemAugmentor
20
- from .utils import audio as audio_
21
- from .utils.audio import BaseFader
22
-
23
-
24
- ConfigDict = TypedDict("ConfigDict", {"name": str, "kwargs": Dict[str, Any]})
25
-
26
-
27
- class SchedulerConfigDict(ConfigDict):
28
- monitor: str
29
-
30
-
31
- OptimizerSchedulerConfigDict = TypedDict(
32
- "OptimizerSchedulerConfigDict",
33
- {"optimizer": ConfigDict, "scheduler": SchedulerConfigDict},
34
- total=False,
35
- )
36
-
37
-
38
- class LRSchedulerReturnDict(TypedDict, total=False):
39
- scheduler: LRScheduler
40
- monitor: str
41
-
42
-
43
- class ConfigureOptimizerReturnDict(TypedDict, total=False):
44
- optimizer: torch.optim.Optimizer
45
- lr_scheduler: LRSchedulerReturnDict
46
-
47
-
48
- OutputType = Dict[str, Any]
49
- MetricsType = Dict[str, torch.Tensor]
50
-
51
-
52
- def get_optimizer_class(name: str) -> Type[optim.Optimizer]:
53
-
54
- if name == "DeepSpeedCPUAdam":
55
- return DeepSpeedCPUAdam
56
-
57
- for module in [optim, gooptim]:
58
- if name in module.__dict__:
59
- return module.__dict__[name]
60
-
61
- raise NameError
62
-
63
-
64
- def parse_optimizer_config(
65
- config: OptimizerSchedulerConfigDict, parameters: Iterator[nn.Parameter]
66
- ) -> ConfigureOptimizerReturnDict:
67
- optim_class = get_optimizer_class(config["optimizer"]["name"])
68
- optimizer = optim_class(parameters, **config["optimizer"]["kwargs"])
69
-
70
- optim_dict: ConfigureOptimizerReturnDict = {
71
- "optimizer": optimizer,
72
- }
73
-
74
- if "scheduler" in config:
75
-
76
- lr_scheduler_class_ = config["scheduler"]["name"]
77
- lr_scheduler_class = lr_scheduler.__dict__[lr_scheduler_class_]
78
- lr_scheduler_dict: LRSchedulerReturnDict = {
79
- "scheduler": lr_scheduler_class(optimizer, **config["scheduler"]["kwargs"])
80
- }
81
-
82
- if lr_scheduler_class_ == "ReduceLROnPlateau":
83
- lr_scheduler_dict["monitor"] = config["scheduler"]["monitor"]
84
-
85
- optim_dict["lr_scheduler"] = lr_scheduler_dict
86
-
87
- return optim_dict
88
-
89
-
90
- def parse_model_config(config: ConfigDict) -> Any:
91
- name = config["name"]
92
-
93
- for module in [model]:
94
- if name in module.__dict__:
95
- return module.__dict__[name](**config["kwargs"])
96
-
97
- raise NameError
98
-
99
-
100
- _LEGACY_LOSS_NAMES = ["HybridL1Loss"]
101
-
102
-
103
- def _parse_legacy_loss_config(config: ConfigDict) -> nn.Module:
104
- name = config["name"]
105
-
106
- if name == "HybridL1Loss":
107
- return loss.TimeFreqL1Loss(**config["kwargs"])
108
-
109
- raise NameError
110
-
111
-
112
- def parse_loss_config(config: ConfigDict) -> nn.Module:
113
- name = config["name"]
114
-
115
- if name in _LEGACY_LOSS_NAMES:
116
- return _parse_legacy_loss_config(config)
117
-
118
- for module in [loss, nn.modules.loss, asteroid_losses]:
119
- if name in module.__dict__:
120
- return module.__dict__[name](**config["kwargs"])
121
-
122
- raise NameError
123
-
124
-
125
- def get_metric(config: ConfigDict) -> tm.Metric:
126
- name = config["name"]
127
-
128
- for module in [tm, metrics_]:
129
- if name in module.__dict__:
130
- return module.__dict__[name](**config["kwargs"])
131
- raise NameError
132
-
133
-
134
- def parse_metric_config(config: Dict[str, ConfigDict]) -> tm.MetricCollection:
135
- metrics = {}
136
-
137
- for metric in config:
138
- metrics[metric] = get_metric(config[metric])
139
-
140
- return tm.MetricCollection(metrics)
141
-
142
-
143
- def parse_fader_config(config: ConfigDict) -> BaseFader:
144
- name = config["name"]
145
-
146
- for module in [audio_]:
147
- if name in module.__dict__:
148
- return module.__dict__[name](**config["kwargs"])
149
-
150
- raise NameError
151
-
152
-
153
- class LightningSystem(pl.LightningModule):
154
- _VOX_STEMS = ["speech", "vocals"]
155
- _BG_STEMS = ["background", "effects", "mne"]
156
-
157
- def __init__(
158
- self, config: Dict, loss_adjustment: float = 1.0, attach_fader: bool = False
159
- ) -> None:
160
- super().__init__()
161
- self.optimizer_config = config["optimizer"]
162
- self.model = parse_model_config(config["model"])
163
- self.loss = parse_loss_config(config["loss"])
164
- self.metrics = nn.ModuleDict(
165
- {
166
- stem: parse_metric_config(config["metrics"]["dev"])
167
- for stem in self.model.stems
168
- }
169
- )
170
-
171
- self.metrics.disallow_fsdp = True
172
-
173
- self.test_metrics = nn.ModuleDict(
174
- {
175
- stem: parse_metric_config(config["metrics"]["test"])
176
- for stem in self.model.stems
177
- }
178
- )
179
-
180
- self.test_metrics.disallow_fsdp = True
181
-
182
- self.fs = config["model"]["kwargs"]["fs"]
183
-
184
- self.fader_config = config["inference"]["fader"]
185
- if attach_fader:
186
- self.fader = parse_fader_config(config["inference"]["fader"])
187
- else:
188
- self.fader = None
189
-
190
- self.augmentation: Optional[BaseAugmentor]
191
- if config.get("augmentation", None) is not None:
192
- self.augmentation = StemAugmentor(**config["augmentation"])
193
- else:
194
- self.augmentation = None
195
-
196
- self.predict_output_path: Optional[str] = None
197
- self.loss_adjustment = loss_adjustment
198
-
199
- self.val_prefix = None
200
- self.test_prefix = None
201
-
202
- def configure_optimizers(self) -> Any:
203
- return parse_optimizer_config(
204
- self.optimizer_config, self.trainer.model.parameters()
205
- )
206
-
207
- def compute_loss(
208
- self, batch: BatchedDataDict, output: OutputType
209
- ) -> Dict[str, torch.Tensor]:
210
- return {"loss": self.loss(output, batch)}
211
-
212
- def update_metrics(
213
- self, batch: BatchedDataDict, output: OutputType, mode: str
214
- ) -> None:
215
-
216
- if mode == "test":
217
- metrics = self.test_metrics
218
- else:
219
- metrics = self.metrics
220
-
221
- for stem, metric in metrics.items():
222
-
223
- if stem == "mne:+":
224
- stem = "mne"
225
-
226
- if mode == "train":
227
- metric.update(
228
- output["audio"][stem],
229
- batch["audio"][stem],
230
- )
231
- else:
232
- if stem not in batch["audio"]:
233
- matched = False
234
- if stem in self._VOX_STEMS:
235
- for bstem in self._VOX_STEMS:
236
- if bstem in batch["audio"]:
237
- batch["audio"][stem] = batch["audio"][bstem]
238
- matched = True
239
- break
240
- elif stem in self._BG_STEMS:
241
- for bstem in self._BG_STEMS:
242
- if bstem in batch["audio"]:
243
- batch["audio"][stem] = batch["audio"][bstem]
244
- matched = True
245
- break
246
- else:
247
- matched = True
248
-
249
- if matched:
250
- if stem == "mne" and "mne" not in output["audio"]:
251
- output["audio"]["mne"] = (
252
- output["audio"]["music"] + output["audio"]["effects"]
253
- )
254
-
255
- metric.update(
256
- output["audio"][stem],
257
- batch["audio"][stem],
258
- )
259
-
260
- def compute_metrics(self, mode: str = "dev") -> Dict[str, torch.Tensor]:
261
-
262
- if mode == "test":
263
- metrics = self.test_metrics
264
- else:
265
- metrics = self.metrics
266
-
267
- metric_dict = {}
268
-
269
- for stem, metric in metrics.items():
270
- md = metric.compute()
271
- metric_dict.update({f"{stem}/{k}": v for k, v in md.items()})
272
-
273
- self.log_dict(metric_dict, prog_bar=True, logger=False)
274
-
275
- return metric_dict
276
-
277
- def reset_metrics(self, test_mode: bool = False) -> None:
278
-
279
- if test_mode:
280
- metrics = self.test_metrics
281
- else:
282
- metrics = self.metrics
283
-
284
- for _, metric in metrics.items():
285
- metric.reset()
286
-
287
- def forward(self, batch: BatchedDataDict) -> Any:
288
- batch, output = self.model(batch)
289
-
290
- return batch, output
291
-
292
- def common_step(self, batch: BatchedDataDict, mode: str) -> Any:
293
- batch, output = self.forward(batch)
294
- loss_dict = self.compute_loss(batch, output)
295
-
296
- with torch.no_grad():
297
- self.update_metrics(batch, output, mode=mode)
298
-
299
- if mode == "train":
300
- self.log("loss", loss_dict["loss"], prog_bar=True)
301
-
302
- return output, loss_dict
303
-
304
- def training_step(self, batch: BatchedDataDict) -> Dict[str, Any]:
305
-
306
- if self.augmentation is not None:
307
- with torch.no_grad():
308
- batch = self.augmentation(batch)
309
-
310
- _, loss_dict = self.common_step(batch, mode="train")
311
-
312
- with torch.inference_mode():
313
- self.log_dict_with_prefix(
314
- loss_dict, "train", batch_size=batch["audio"]["mixture"].shape[0]
315
- )
316
-
317
- loss_dict["loss"] *= self.loss_adjustment
318
-
319
- return loss_dict
320
-
321
- def on_train_batch_end(
322
- self, outputs: STEP_OUTPUT, batch: BatchedDataDict, batch_idx: int
323
- ) -> None:
324
-
325
- metric_dict = self.compute_metrics()
326
- self.log_dict_with_prefix(metric_dict, "train")
327
- self.reset_metrics()
328
-
329
- def validation_step(
330
- self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0
331
- ) -> Dict[str, Any]:
332
-
333
- with torch.inference_mode():
334
- curr_val_prefix = f"val{dataloader_idx}" if dataloader_idx > 0 else "val"
335
-
336
- if curr_val_prefix != self.val_prefix:
337
- if self.val_prefix is not None:
338
- self._on_validation_epoch_end()
339
- self.val_prefix = curr_val_prefix
340
- _, loss_dict = self.common_step(batch, mode="val")
341
-
342
- self.log_dict_with_prefix(
343
- loss_dict,
344
- self.val_prefix,
345
- batch_size=batch["audio"]["mixture"].shape[0],
346
- prog_bar=True,
347
- add_dataloader_idx=False,
348
- )
349
-
350
- return loss_dict
351
-
352
- def on_validation_epoch_end(self) -> None:
353
- self._on_validation_epoch_end()
354
-
355
- def _on_validation_epoch_end(self) -> None:
356
- metric_dict = self.compute_metrics()
357
- self.log_dict_with_prefix(
358
- metric_dict, self.val_prefix, prog_bar=True, add_dataloader_idx=False
359
- )
360
- self.reset_metrics()
361
-
362
- def old_predtest_step(
363
- self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0
364
- ) -> Tuple[BatchedDataDict, OutputType]:
365
-
366
- audio_batch = batch["audio"]["mixture"]
367
- track_batch = batch.get("track", ["" for _ in range(len(audio_batch))])
368
-
369
- output_list_of_dicts = [
370
- self.fader(audio[None, ...], lambda a: self.test_forward(a, track))
371
- for audio, track in zip(audio_batch, track_batch)
372
- ]
373
-
374
- output_dict_of_lists = defaultdict(list)
375
-
376
- for output_dict in output_list_of_dicts:
377
- for stem, audio in output_dict.items():
378
- output_dict_of_lists[stem].append(audio)
379
-
380
- output = {
381
- "audio": {
382
- stem: torch.concat(output_list, dim=0)
383
- for stem, output_list in output_dict_of_lists.items()
384
- }
385
- }
386
-
387
- return batch, output
388
-
389
- def predtest_step(
390
- self, batch: BatchedDataDict, batch_idx: int = -1, dataloader_idx: int = 0
391
- ) -> Tuple[BatchedDataDict, OutputType]:
392
-
393
- if getattr(self.model, "bypass_fader", False):
394
- batch, output = self.model(batch)
395
- else:
396
- audio_batch = batch["audio"]["mixture"]
397
- output = self.fader(
398
- audio_batch, lambda a: self.test_forward(a, "", batch=batch)
399
- )
400
-
401
- return batch, output
402
-
403
- def test_forward(
404
- self, audio: torch.Tensor, track: str = "", batch: BatchedDataDict = None
405
- ) -> torch.Tensor:
406
-
407
- if self.fader is None:
408
- self.attach_fader()
409
-
410
- cond = batch.get("condition", None)
411
-
412
- if cond is not None and cond.shape[0] == 1:
413
- cond = cond.repeat(audio.shape[0], 1)
414
-
415
- _, output = self.forward(
416
- {
417
- "audio": {"mixture": audio},
418
- "track": track,
419
- "condition": cond,
420
- }
421
- )
422
-
423
- return output["audio"]
424
-
425
- def on_test_epoch_start(self) -> None:
426
- self.attach_fader(force_reattach=True)
427
-
428
- def test_step(
429
- self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0
430
- ) -> Any:
431
- curr_test_prefix = f"test{dataloader_idx}"
432
-
433
- if curr_test_prefix != self.test_prefix:
434
- if self.test_prefix is not None:
435
- self._on_test_epoch_end()
436
- self.test_prefix = curr_test_prefix
437
-
438
- with torch.inference_mode():
439
- _, output = self.predtest_step(batch, batch_idx, dataloader_idx)
440
- self.update_metrics(batch, output, mode="test")
441
-
442
- return output
443
-
444
- def on_test_epoch_end(self) -> None:
445
- self._on_test_epoch_end()
446
-
447
- def _on_test_epoch_end(self) -> None:
448
- metric_dict = self.compute_metrics(mode="test")
449
- self.log_dict_with_prefix(
450
- metric_dict, self.test_prefix, prog_bar=True, add_dataloader_idx=False
451
- )
452
- self.reset_metrics()
453
-
454
- def predict_step(
455
- self,
456
- batch: BatchedDataDict,
457
- batch_idx: int = 0,
458
- dataloader_idx: int = 0,
459
- include_track_name: Optional[bool] = None,
460
- get_no_vox_combinations: bool = True,
461
- get_residual: bool = False,
462
- treat_batch_as_channels: bool = False,
463
- fs: Optional[int] = None,
464
- ) -> Any:
465
- assert self.predict_output_path is not None
466
-
467
- batch_size = batch["audio"]["mixture"].shape[0]
468
-
469
- if include_track_name is None:
470
- include_track_name = batch_size > 1
471
-
472
- with torch.inference_mode():
473
- batch, output = self.predtest_step(batch, batch_idx, dataloader_idx)
474
- print("Pred test finished...")
475
- torch.cuda.empty_cache()
476
- metric_dict = {}
477
-
478
- if get_residual:
479
- mixture = batch["audio"]["mixture"]
480
- extracted = sum([output["audio"][stem] for stem in output["audio"]])
481
- residual = mixture - extracted
482
- print(extracted.shape, mixture.shape, residual.shape)
483
-
484
- output["audio"]["residual"] = residual
485
-
486
- if get_no_vox_combinations:
487
- no_vox_stems = [
488
- stem for stem in output["audio"] if stem not in self._VOX_STEMS
489
- ]
490
- no_vox_combinations = chain.from_iterable(
491
- combinations(no_vox_stems, r) for r in range(2, len(no_vox_stems) + 1)
492
- )
493
-
494
- for combination in no_vox_combinations:
495
- combination_ = list(combination)
496
- output["audio"]["+".join(combination_)] = sum(
497
- [output["audio"][stem] for stem in combination_]
498
- )
499
-
500
- if treat_batch_as_channels:
501
- for stem in output["audio"]:
502
- output["audio"][stem] = output["audio"][stem].reshape(
503
- 1, -1, output["audio"][stem].shape[-1]
504
- )
505
- batch_size = 1
506
-
507
- for b in range(batch_size):
508
- print("!!", b)
509
- for stem in output["audio"]:
510
- print(f"Saving audio for {stem} to {self.predict_output_path}")
511
- track_name = batch["track"][b].split("/")[-1]
512
-
513
- if batch.get("audio", {}).get(stem, None) is not None:
514
- self.test_metrics[stem].reset()
515
- metrics = self.test_metrics[stem](
516
- batch["audio"][stem][[b], ...], output["audio"][stem][[b], ...]
517
- )
518
- snr = metrics["snr"]
519
- sisnr = metrics["sisnr"]
520
- sdr = metrics["sdr"]
521
- metric_dict[stem] = metrics
522
- print(
523
- track_name,
524
- f"snr={snr:2.2f} dB",
525
- f"sisnr={sisnr:2.2f}",
526
- f"sdr={sdr:2.2f} dB",
527
- )
528
- filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav"
529
- else:
530
- filename = f"{stem}.wav"
531
-
532
- if include_track_name:
533
- output_dir = os.path.join(self.predict_output_path, track_name)
534
- else:
535
- output_dir = self.predict_output_path
536
-
537
- os.makedirs(output_dir, exist_ok=True)
538
-
539
- if fs is None:
540
- fs = self.fs
541
-
542
- ta.save(
543
- os.path.join(output_dir, filename),
544
- output["audio"][stem][b, ...].cpu(),
545
- fs,
546
- )
547
-
548
- return metric_dict
549
-
550
- def get_stems(
551
- self,
552
- batch: BatchedDataDict,
553
- batch_idx: int = 0,
554
- dataloader_idx: int = 0,
555
- include_track_name: Optional[bool] = None,
556
- get_no_vox_combinations: bool = True,
557
- get_residual: bool = False,
558
- treat_batch_as_channels: bool = False,
559
- fs: Optional[int] = None,
560
- ) -> Any:
561
- assert self.predict_output_path is not None
562
-
563
- batch_size = batch["audio"]["mixture"].shape[0]
564
-
565
- if include_track_name is None:
566
- include_track_name = batch_size > 1
567
-
568
- with torch.inference_mode():
569
- batch, output = self.predtest_step(batch, batch_idx, dataloader_idx)
570
- torch.cuda.empty_cache()
571
- metric_dict = {}
572
-
573
- if get_residual:
574
- mixture = batch["audio"]["mixture"]
575
- extracted = sum([output["audio"][stem] for stem in output["audio"]])
576
- residual = mixture - extracted
577
-
578
- output["audio"]["residual"] = residual
579
-
580
- if get_no_vox_combinations:
581
- no_vox_stems = [
582
- stem for stem in output["audio"] if stem not in self._VOX_STEMS
583
- ]
584
- no_vox_combinations = chain.from_iterable(
585
- combinations(no_vox_stems, r) for r in range(2, len(no_vox_stems) + 1)
586
- )
587
-
588
- for combination in no_vox_combinations:
589
- combination_ = list(combination)
590
- output["audio"]["+".join(combination_)] = sum(
591
- [output["audio"][stem] for stem in combination_]
592
- )
593
-
594
- if treat_batch_as_channels:
595
- for stem in output["audio"]:
596
- output["audio"][stem] = output["audio"][stem].reshape(
597
- 1, -1, output["audio"][stem].shape[-1]
598
- )
599
- batch_size = 1
600
-
601
- result = {}
602
- for b in range(batch_size):
603
- for stem in output["audio"]:
604
- track_name = batch["track"][b].split("/")[-1]
605
-
606
- if batch.get("audio", {}).get(stem, None) is not None:
607
- self.test_metrics[stem].reset()
608
- metrics = self.test_metrics[stem](
609
- batch["audio"][stem][[b], ...], output["audio"][stem][[b], ...]
610
- )
611
- snr = metrics["snr"]
612
- sisnr = metrics["sisnr"]
613
- sdr = metrics["sdr"]
614
- metric_dict[stem] = metrics
615
- print(
616
- track_name,
617
- f"snr={snr:2.2f} dB",
618
- f"sisnr={sisnr:2.2f}",
619
- f"sdr={sdr:2.2f} dB",
620
- )
621
- filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav"
622
- else:
623
- filename = f"{stem}.wav"
624
-
625
- if include_track_name:
626
- output_dir = os.path.join(self.predict_output_path, track_name)
627
- else:
628
- output_dir = self.predict_output_path
629
-
630
- os.makedirs(output_dir, exist_ok=True)
631
-
632
- if fs is None:
633
- fs = self.fs
634
-
635
- result[stem] = output["audio"][stem][b, ...].cpu().numpy()
636
-
637
- return result
638
-
639
- def load_state_dict(
640
- self, state_dict: Mapping[str, Any], strict: bool = False
641
- ) -> Any:
642
-
643
- return super().load_state_dict(state_dict, strict=False)
644
-
645
- def set_predict_output_path(self, path: str) -> None:
646
- self.predict_output_path = path
647
- os.makedirs(self.predict_output_path, exist_ok=True)
648
-
649
- self.attach_fader()
650
-
651
- def attach_fader(self, force_reattach=False) -> None:
652
- if self.fader is None or force_reattach:
653
- self.fader = parse_fader_config(self.fader_config)
654
- self.fader.to(self.device)
655
-
656
- def log_dict_with_prefix(
657
- self,
658
- dict_: Dict[str, torch.Tensor],
659
- prefix: str,
660
- batch_size: Optional[int] = None,
661
- **kwargs: Any,
662
- ) -> None:
663
- self.log_dict(
664
- {f"{prefix}/{k}": v for k, v in dict_.items()},
665
- batch_size=batch_size,
666
- logger=True,
667
- sync_dist=True,
668
- **kwargs,
669
- )
 
1
+ import os.path
2
+ from collections import defaultdict
3
+ from itertools import chain, combinations
4
+ from typing import Any, Dict, Iterator, Mapping, Optional, Tuple, Type, TypedDict
5
+
6
+ import pytorch_lightning as pl
7
+ import torch
8
+ import torchaudio as ta
9
+ import torchmetrics as tm
10
+ from asteroid import losses as asteroid_losses
11
+
12
+ from pytorch_lightning.utilities.types import STEP_OUTPUT
13
+ from torch import nn, optim
14
+ from torch.optim import lr_scheduler
15
+ from torch.optim.lr_scheduler import LRScheduler
16
+
17
+ from . import loss, metrics as metrics_, model
18
+ from .data._types import BatchedDataDict
19
+ from .data.augmentation import BaseAugmentor, StemAugmentor
20
+ from .utils import audio as audio_
21
+ from .utils.audio import BaseFader
22
+
23
+
24
+ ConfigDict = TypedDict("ConfigDict", {"name": str, "kwargs": Dict[str, Any]})
25
+
26
+
27
+ class SchedulerConfigDict(ConfigDict):
28
+ monitor: str
29
+
30
+
31
+ OptimizerSchedulerConfigDict = TypedDict(
32
+ "OptimizerSchedulerConfigDict",
33
+ {"optimizer": ConfigDict, "scheduler": SchedulerConfigDict},
34
+ total=False,
35
+ )
36
+
37
+
38
+ class LRSchedulerReturnDict(TypedDict, total=False):
39
+ scheduler: LRScheduler
40
+ monitor: str
41
+
42
+
43
+ class ConfigureOptimizerReturnDict(TypedDict, total=False):
44
+ optimizer: torch.optim.Optimizer
45
+ lr_scheduler: LRSchedulerReturnDict
46
+
47
+
48
+ OutputType = Dict[str, Any]
49
+ MetricsType = Dict[str, torch.Tensor]
50
+
51
+
52
+ def get_optimizer_class(name: str) -> Type[optim.Optimizer]:
53
+
54
+ if name == "DeepSpeedCPUAdam":
55
+ return DeepSpeedCPUAdam
56
+
57
+ for module in [optim, gooptim]:
58
+ if name in module.__dict__:
59
+ return module.__dict__[name]
60
+
61
+ raise NameError
62
+
63
+
64
+ def parse_optimizer_config(
65
+ config: OptimizerSchedulerConfigDict, parameters: Iterator[nn.Parameter]
66
+ ) -> ConfigureOptimizerReturnDict:
67
+ optim_class = get_optimizer_class(config["optimizer"]["name"])
68
+ optimizer = optim_class(parameters, **config["optimizer"]["kwargs"])
69
+
70
+ optim_dict: ConfigureOptimizerReturnDict = {
71
+ "optimizer": optimizer,
72
+ }
73
+
74
+ if "scheduler" in config:
75
+
76
+ lr_scheduler_class_ = config["scheduler"]["name"]
77
+ lr_scheduler_class = lr_scheduler.__dict__[lr_scheduler_class_]
78
+ lr_scheduler_dict: LRSchedulerReturnDict = {
79
+ "scheduler": lr_scheduler_class(optimizer, **config["scheduler"]["kwargs"])
80
+ }
81
+
82
+ if lr_scheduler_class_ == "ReduceLROnPlateau":
83
+ lr_scheduler_dict["monitor"] = config["scheduler"]["monitor"]
84
+
85
+ optim_dict["lr_scheduler"] = lr_scheduler_dict
86
+
87
+ return optim_dict
88
+
89
+
90
+ def parse_model_config(config: ConfigDict) -> Any:
91
+ name = config["name"]
92
+
93
+ for module in [model]:
94
+ if name in module.__dict__:
95
+ return module.__dict__[name](**config["kwargs"])
96
+
97
+ raise NameError
98
+
99
+
100
+ _LEGACY_LOSS_NAMES = ["HybridL1Loss"]
101
+
102
+
103
+ def _parse_legacy_loss_config(config: ConfigDict) -> nn.Module:
104
+ name = config["name"]
105
+
106
+ if name == "HybridL1Loss":
107
+ return loss.TimeFreqL1Loss(**config["kwargs"])
108
+
109
+ raise NameError
110
+
111
+
112
+ def parse_loss_config(config: ConfigDict) -> nn.Module:
113
+ name = config["name"]
114
+
115
+ if name in _LEGACY_LOSS_NAMES:
116
+ return _parse_legacy_loss_config(config)
117
+
118
+ for module in [loss, nn.modules.loss, asteroid_losses]:
119
+ if name in module.__dict__:
120
+ return module.__dict__[name](**config["kwargs"])
121
+
122
+ raise NameError
123
+
124
+
125
+ def get_metric(config: ConfigDict) -> tm.Metric:
126
+ name = config["name"]
127
+
128
+ for module in [tm, metrics_]:
129
+ if name in module.__dict__:
130
+ return module.__dict__[name](**config["kwargs"])
131
+ raise NameError
132
+
133
+
134
+ def parse_metric_config(config: Dict[str, ConfigDict]) -> tm.MetricCollection:
135
+ metrics = {}
136
+
137
+ for metric in config:
138
+ metrics[metric] = get_metric(config[metric])
139
+
140
+ return tm.MetricCollection(metrics)
141
+
142
+
143
+ def parse_fader_config(config: ConfigDict) -> BaseFader:
144
+ name = config["name"]
145
+
146
+ for module in [audio_]:
147
+ if name in module.__dict__:
148
+ return module.__dict__[name](**config["kwargs"])
149
+
150
+ raise NameError
151
+
152
+
153
+ class LightningSystem(pl.LightningModule):
154
+ _VOX_STEMS = ["speech", "vocals"]
155
+ _BG_STEMS = ["background", "effects", "mne"]
156
+
157
+ def __init__(
158
+ self, config: Dict, loss_adjustment: float = 1.0, attach_fader: bool = False
159
+ ) -> None:
160
+ super().__init__()
161
+ self.optimizer_config = config["optimizer"]
162
+ self.model = parse_model_config(config["model"])
163
+ self.loss = parse_loss_config(config["loss"])
164
+ self.metrics = nn.ModuleDict(
165
+ {
166
+ stem: parse_metric_config(config["metrics"]["dev"])
167
+ for stem in self.model.stems
168
+ }
169
+ )
170
+
171
+ self.metrics.disallow_fsdp = True
172
+
173
+ self.test_metrics = nn.ModuleDict(
174
+ {
175
+ stem: parse_metric_config(config["metrics"]["test"])
176
+ for stem in self.model.stems
177
+ }
178
+ )
179
+
180
+ self.test_metrics.disallow_fsdp = True
181
+
182
+ self.fs = config["model"]["kwargs"]["fs"]
183
+
184
+ self.fader_config = config["inference"]["fader"]
185
+ if attach_fader:
186
+ self.fader = parse_fader_config(config["inference"]["fader"])
187
+ else:
188
+ self.fader = None
189
+
190
+ self.augmentation: Optional[BaseAugmentor]
191
+ if config.get("augmentation", None) is not None:
192
+ self.augmentation = StemAugmentor(**config["augmentation"])
193
+ else:
194
+ self.augmentation = None
195
+
196
+ self.predict_output_path: Optional[str] = None
197
+ self.loss_adjustment = loss_adjustment
198
+
199
+ self.val_prefix = None
200
+ self.test_prefix = None
201
+
202
+ def configure_optimizers(self) -> Any:
203
+ return parse_optimizer_config(
204
+ self.optimizer_config, self.trainer.model.parameters()
205
+ )
206
+
207
+ def compute_loss(
208
+ self, batch: BatchedDataDict, output: OutputType
209
+ ) -> Dict[str, torch.Tensor]:
210
+ return {"loss": self.loss(output, batch)}
211
+
212
+ def update_metrics(
213
+ self, batch: BatchedDataDict, output: OutputType, mode: str
214
+ ) -> None:
215
+
216
+ if mode == "test":
217
+ metrics = self.test_metrics
218
+ else:
219
+ metrics = self.metrics
220
+
221
+ for stem, metric in metrics.items():
222
+
223
+ if stem == "mne:+":
224
+ stem = "mne"
225
+
226
+ if mode == "train":
227
+ metric.update(
228
+ output["audio"][stem],
229
+ batch["audio"][stem],
230
+ )
231
+ else:
232
+ if stem not in batch["audio"]:
233
+ matched = False
234
+ if stem in self._VOX_STEMS:
235
+ for bstem in self._VOX_STEMS:
236
+ if bstem in batch["audio"]:
237
+ batch["audio"][stem] = batch["audio"][bstem]
238
+ matched = True
239
+ break
240
+ elif stem in self._BG_STEMS:
241
+ for bstem in self._BG_STEMS:
242
+ if bstem in batch["audio"]:
243
+ batch["audio"][stem] = batch["audio"][bstem]
244
+ matched = True
245
+ break
246
+ else:
247
+ matched = True
248
+
249
+ if matched:
250
+ if stem == "mne" and "mne" not in output["audio"]:
251
+ output["audio"]["mne"] = (
252
+ output["audio"]["music"] + output["audio"]["effects"]
253
+ )
254
+
255
+ metric.update(
256
+ output["audio"][stem],
257
+ batch["audio"][stem],
258
+ )
259
+
260
+ def compute_metrics(self, mode: str = "dev") -> Dict[str, torch.Tensor]:
261
+
262
+ if mode == "test":
263
+ metrics = self.test_metrics
264
+ else:
265
+ metrics = self.metrics
266
+
267
+ metric_dict = {}
268
+
269
+ for stem, metric in metrics.items():
270
+ md = metric.compute()
271
+ metric_dict.update({f"{stem}/{k}": v for k, v in md.items()})
272
+
273
+ self.log_dict(metric_dict, prog_bar=True, logger=False)
274
+
275
+ return metric_dict
276
+
277
+ def reset_metrics(self, test_mode: bool = False) -> None:
278
+
279
+ if test_mode:
280
+ metrics = self.test_metrics
281
+ else:
282
+ metrics = self.metrics
283
+
284
+ for _, metric in metrics.items():
285
+ metric.reset()
286
+
287
+ def forward(self, batch: BatchedDataDict) -> Any:
288
+ batch, output = self.model(batch)
289
+
290
+ return batch, output
291
+
292
+ def common_step(self, batch: BatchedDataDict, mode: str) -> Any:
293
+ batch, output = self.forward(batch)
294
+ loss_dict = self.compute_loss(batch, output)
295
+
296
+ with torch.no_grad():
297
+ self.update_metrics(batch, output, mode=mode)
298
+
299
+ if mode == "train":
300
+ self.log("loss", loss_dict["loss"], prog_bar=True)
301
+
302
+ return output, loss_dict
303
+
304
+ def training_step(self, batch: BatchedDataDict) -> Dict[str, Any]:
305
+
306
+ if self.augmentation is not None:
307
+ with torch.no_grad():
308
+ batch = self.augmentation(batch)
309
+
310
+ _, loss_dict = self.common_step(batch, mode="train")
311
+
312
+ with torch.inference_mode():
313
+ self.log_dict_with_prefix(
314
+ loss_dict, "train", batch_size=batch["audio"]["mixture"].shape[0]
315
+ )
316
+
317
+ loss_dict["loss"] *= self.loss_adjustment
318
+
319
+ return loss_dict
320
+
321
+ def on_train_batch_end(
322
+ self, outputs: STEP_OUTPUT, batch: BatchedDataDict, batch_idx: int
323
+ ) -> None:
324
+
325
+ metric_dict = self.compute_metrics()
326
+ self.log_dict_with_prefix(metric_dict, "train")
327
+ self.reset_metrics()
328
+
329
+ def validation_step(
330
+ self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0
331
+ ) -> Dict[str, Any]:
332
+
333
+ with torch.inference_mode():
334
+ curr_val_prefix = f"val{dataloader_idx}" if dataloader_idx > 0 else "val"
335
+
336
+ if curr_val_prefix != self.val_prefix:
337
+ if self.val_prefix is not None:
338
+ self._on_validation_epoch_end()
339
+ self.val_prefix = curr_val_prefix
340
+ _, loss_dict = self.common_step(batch, mode="val")
341
+
342
+ self.log_dict_with_prefix(
343
+ loss_dict,
344
+ self.val_prefix,
345
+ batch_size=batch["audio"]["mixture"].shape[0],
346
+ prog_bar=True,
347
+ add_dataloader_idx=False,
348
+ )
349
+
350
+ return loss_dict
351
+
352
+ def on_validation_epoch_end(self) -> None:
353
+ self._on_validation_epoch_end()
354
+
355
+ def _on_validation_epoch_end(self) -> None:
356
+ metric_dict = self.compute_metrics()
357
+ self.log_dict_with_prefix(
358
+ metric_dict, self.val_prefix, prog_bar=True, add_dataloader_idx=False
359
+ )
360
+ self.reset_metrics()
361
+
362
+ def old_predtest_step(
363
+ self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0
364
+ ) -> Tuple[BatchedDataDict, OutputType]:
365
+
366
+ audio_batch = batch["audio"]["mixture"]
367
+ track_batch = batch.get("track", ["" for _ in range(len(audio_batch))])
368
+
369
+ output_list_of_dicts = [
370
+ self.fader(audio[None, ...], lambda a: self.test_forward(a, track))
371
+ for audio, track in zip(audio_batch, track_batch)
372
+ ]
373
+
374
+ output_dict_of_lists = defaultdict(list)
375
+
376
+ for output_dict in output_list_of_dicts:
377
+ for stem, audio in output_dict.items():
378
+ output_dict_of_lists[stem].append(audio)
379
+
380
+ output = {
381
+ "audio": {
382
+ stem: torch.concat(output_list, dim=0)
383
+ for stem, output_list in output_dict_of_lists.items()
384
+ }
385
+ }
386
+
387
+ return batch, output
388
+
389
+ def predtest_step(
390
+ self, batch: BatchedDataDict, batch_idx: int = -1, dataloader_idx: int = 0
391
+ ) -> Tuple[BatchedDataDict, OutputType]:
392
+
393
+ if getattr(self.model, "bypass_fader", False):
394
+ batch, output = self.model(batch)
395
+ else:
396
+ audio_batch = batch["audio"]["mixture"]
397
+ output = self.fader(
398
+ audio_batch, lambda a: self.test_forward(a, "", batch=batch)
399
+ )
400
+
401
+ return batch, output
402
+
403
+ def test_forward(
404
+ self, audio: torch.Tensor, track: str = "", batch: BatchedDataDict = None
405
+ ) -> torch.Tensor:
406
+
407
+ if self.fader is None:
408
+ self.attach_fader()
409
+
410
+ cond = batch.get("condition", None)
411
+
412
+ if cond is not None and cond.shape[0] == 1:
413
+ cond = cond.repeat(audio.shape[0], 1)
414
+
415
+ _, output = self.forward(
416
+ {
417
+ "audio": {"mixture": audio},
418
+ "track": track,
419
+ "condition": cond,
420
+ }
421
+ )
422
+
423
+ return output["audio"]
424
+
425
+ def on_test_epoch_start(self) -> None:
426
+ self.attach_fader(force_reattach=True)
427
+
428
+ def test_step(
429
+ self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0
430
+ ) -> Any:
431
+ curr_test_prefix = f"test{dataloader_idx}"
432
+
433
+ if curr_test_prefix != self.test_prefix:
434
+ if self.test_prefix is not None:
435
+ self._on_test_epoch_end()
436
+ self.test_prefix = curr_test_prefix
437
+
438
+ with torch.inference_mode():
439
+ _, output = self.predtest_step(batch, batch_idx, dataloader_idx)
440
+ self.update_metrics(batch, output, mode="test")
441
+
442
+ return output
443
+
444
+ def on_test_epoch_end(self) -> None:
445
+ self._on_test_epoch_end()
446
+
447
+ def _on_test_epoch_end(self) -> None:
448
+ metric_dict = self.compute_metrics(mode="test")
449
+ self.log_dict_with_prefix(
450
+ metric_dict, self.test_prefix, prog_bar=True, add_dataloader_idx=False
451
+ )
452
+ self.reset_metrics()
453
+
454
+ def predict_step(
455
+ self,
456
+ batch: BatchedDataDict,
457
+ batch_idx: int = 0,
458
+ dataloader_idx: int = 0,
459
+ include_track_name: Optional[bool] = None,
460
+ get_no_vox_combinations: bool = True,
461
+ get_residual: bool = False,
462
+ treat_batch_as_channels: bool = False,
463
+ fs: Optional[int] = None,
464
+ ) -> Any:
465
+ assert self.predict_output_path is not None
466
+
467
+ batch_size = batch["audio"]["mixture"].shape[0]
468
+
469
+ if include_track_name is None:
470
+ include_track_name = batch_size > 1
471
+
472
+ with torch.inference_mode():
473
+ batch, output = self.predtest_step(batch, batch_idx, dataloader_idx)
474
+ print("Pred test finished...")
475
+ torch.cuda.empty_cache()
476
+ metric_dict = {}
477
+
478
+ if get_residual:
479
+ mixture = batch["audio"]["mixture"]
480
+ extracted = sum([output["audio"][stem] for stem in output["audio"]])
481
+ residual = mixture - extracted
482
+ print(extracted.shape, mixture.shape, residual.shape)
483
+
484
+ output["audio"]["residual"] = residual
485
+
486
+ if get_no_vox_combinations:
487
+ no_vox_stems = [
488
+ stem for stem in output["audio"] if stem not in self._VOX_STEMS
489
+ ]
490
+ no_vox_combinations = chain.from_iterable(
491
+ combinations(no_vox_stems, r) for r in range(2, len(no_vox_stems) + 1)
492
+ )
493
+
494
+ for combination in no_vox_combinations:
495
+ combination_ = list(combination)
496
+ output["audio"]["+".join(combination_)] = sum(
497
+ [output["audio"][stem] for stem in combination_]
498
+ )
499
+
500
+ if treat_batch_as_channels:
501
+ for stem in output["audio"]:
502
+ output["audio"][stem] = output["audio"][stem].reshape(
503
+ 1, -1, output["audio"][stem].shape[-1]
504
+ )
505
+ batch_size = 1
506
+
507
+ for b in range(batch_size):
508
+ print("!!", b)
509
+ for stem in output["audio"]:
510
+ print(f"Saving audio for {stem} to {self.predict_output_path}")
511
+ track_name = batch["track"][b].split("/")[-1]
512
+
513
+ if batch.get("audio", {}).get(stem, None) is not None:
514
+ self.test_metrics[stem].reset()
515
+ metrics = self.test_metrics[stem](
516
+ batch["audio"][stem][[b], ...], output["audio"][stem][[b], ...]
517
+ )
518
+ snr = metrics["snr"]
519
+ sisnr = metrics["sisnr"]
520
+ sdr = metrics["sdr"]
521
+ metric_dict[stem] = metrics
522
+ print(
523
+ track_name,
524
+ f"snr={snr:2.2f} dB",
525
+ f"sisnr={sisnr:2.2f}",
526
+ f"sdr={sdr:2.2f} dB",
527
+ )
528
+ filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav"
529
+ else:
530
+ filename = f"{stem}.wav"
531
+
532
+ if include_track_name:
533
+ output_dir = os.path.join(self.predict_output_path, track_name)
534
+ else:
535
+ output_dir = self.predict_output_path
536
+
537
+ os.makedirs(output_dir, exist_ok=True)
538
+
539
+ if fs is None:
540
+ fs = self.fs
541
+
542
+ ta.save(
543
+ os.path.join(output_dir, filename),
544
+ output["audio"][stem][b, ...].cpu(),
545
+ fs,
546
+ )
547
+
548
+ return metric_dict
549
+
550
+ def get_stems(
551
+ self,
552
+ batch: BatchedDataDict,
553
+ batch_idx: int = 0,
554
+ dataloader_idx: int = 0,
555
+ include_track_name: Optional[bool] = None,
556
+ get_no_vox_combinations: bool = True,
557
+ get_residual: bool = False,
558
+ treat_batch_as_channels: bool = False,
559
+ fs: Optional[int] = None,
560
+ ) -> Any:
561
+ assert self.predict_output_path is not None
562
+
563
+ batch_size = batch["audio"]["mixture"].shape[0]
564
+
565
+ if include_track_name is None:
566
+ include_track_name = batch_size > 1
567
+
568
+ with torch.inference_mode():
569
+ batch, output = self.predtest_step(batch, batch_idx, dataloader_idx)
570
+ torch.cuda.empty_cache()
571
+ metric_dict = {}
572
+
573
+ if get_residual:
574
+ mixture = batch["audio"]["mixture"]
575
+ extracted = sum([output["audio"][stem] for stem in output["audio"]])
576
+ residual = mixture - extracted
577
+
578
+ output["audio"]["residual"] = residual
579
+
580
+ if get_no_vox_combinations:
581
+ no_vox_stems = [
582
+ stem for stem in output["audio"] if stem not in self._VOX_STEMS
583
+ ]
584
+ no_vox_combinations = chain.from_iterable(
585
+ combinations(no_vox_stems, r) for r in range(2, len(no_vox_stems) + 1)
586
+ )
587
+
588
+ for combination in no_vox_combinations:
589
+ combination_ = list(combination)
590
+ output["audio"]["+".join(combination_)] = sum(
591
+ [output["audio"][stem] for stem in combination_]
592
+ )
593
+
594
+ if treat_batch_as_channels:
595
+ for stem in output["audio"]:
596
+ output["audio"][stem] = output["audio"][stem].reshape(
597
+ 1, -1, output["audio"][stem].shape[-1]
598
+ )
599
+ batch_size = 1
600
+
601
+ result = {}
602
+ for b in range(batch_size):
603
+ for stem in output["audio"]:
604
+ track_name = batch["track"][b].split("/")[-1]
605
+
606
+ if batch.get("audio", {}).get(stem, None) is not None:
607
+ self.test_metrics[stem].reset()
608
+ metrics = self.test_metrics[stem](
609
+ batch["audio"][stem][[b], ...], output["audio"][stem][[b], ...]
610
+ )
611
+ snr = metrics["snr"]
612
+ sisnr = metrics["sisnr"]
613
+ sdr = metrics["sdr"]
614
+ metric_dict[stem] = metrics
615
+ print(
616
+ track_name,
617
+ f"snr={snr:2.2f} dB",
618
+ f"sisnr={sisnr:2.2f}",
619
+ f"sdr={sdr:2.2f} dB",
620
+ )
621
+ filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav"
622
+ else:
623
+ filename = f"{stem}.wav"
624
+
625
+ if include_track_name:
626
+ output_dir = os.path.join(self.predict_output_path, track_name)
627
+ else:
628
+ output_dir = self.predict_output_path
629
+
630
+ os.makedirs(output_dir, exist_ok=True)
631
+
632
+ if fs is None:
633
+ fs = self.fs
634
+
635
+ result[stem] = output["audio"][stem][b, ...].cpu().numpy()
636
+
637
+ return result
638
+
639
+ def load_state_dict(
640
+ self, state_dict: Mapping[str, Any], strict: bool = False
641
+ ) -> Any:
642
+
643
+ return super().load_state_dict(state_dict, strict=False)
644
+
645
+ def set_predict_output_path(self, path: str) -> None:
646
+ self.predict_output_path = path
647
+ os.makedirs(self.predict_output_path, exist_ok=True)
648
+
649
+ self.attach_fader()
650
+
651
+ def attach_fader(self, force_reattach=False) -> None:
652
+ if self.fader is None or force_reattach:
653
+ self.fader = parse_fader_config(self.fader_config)
654
+ self.fader.to(self.device)
655
+
656
+ def log_dict_with_prefix(
657
+ self,
658
+ dict_: Dict[str, torch.Tensor],
659
+ prefix: str,
660
+ batch_size: Optional[int] = None,
661
+ **kwargs: Any,
662
+ ) -> None:
663
+ self.log_dict(
664
+ {f"{prefix}/{k}": v for k, v in dict_.items()},
665
+ batch_size=batch_size,
666
+ logger=True,
667
+ sync_dist=True,
668
+ **kwargs,
669
+ )
mvsepless/models/bandit/core/data/__init__.py CHANGED
@@ -1,2 +1,2 @@
1
- from .dnr.datamodule import DivideAndRemasterDataModule
2
- from .musdb.datamodule import MUSDB18DataModule
 
1
+ from .dnr.datamodule import DivideAndRemasterDataModule
2
+ from .musdb.datamodule import MUSDB18DataModule
mvsepless/models/bandit/core/data/_types.py CHANGED
@@ -1,17 +1,17 @@
1
- from typing import Dict, Sequence, TypedDict
2
-
3
- import torch
4
-
5
- AudioDict = Dict[str, torch.Tensor]
6
-
7
- DataDict = TypedDict("DataDict", {"audio": AudioDict, "track": str})
8
-
9
- BatchedDataDict = TypedDict(
10
- "BatchedDataDict", {"audio": AudioDict, "track": Sequence[str]}
11
- )
12
-
13
-
14
- class DataDictWithLanguage(TypedDict):
15
- audio: AudioDict
16
- track: str
17
- language: str
 
1
+ from typing import Dict, Sequence, TypedDict
2
+
3
+ import torch
4
+
5
+ AudioDict = Dict[str, torch.Tensor]
6
+
7
+ DataDict = TypedDict("DataDict", {"audio": AudioDict, "track": str})
8
+
9
+ BatchedDataDict = TypedDict(
10
+ "BatchedDataDict", {"audio": AudioDict, "track": Sequence[str]}
11
+ )
12
+
13
+
14
+ class DataDictWithLanguage(TypedDict):
15
+ audio: AudioDict
16
+ track: str
17
+ language: str
mvsepless/models/bandit/core/data/augmentation.py CHANGED
@@ -1,102 +1,102 @@
1
- from abc import ABC
2
- from typing import Any, Dict, Union
3
-
4
- import torch
5
- import torch_audiomentations as tam
6
- from torch import nn
7
-
8
- from ._types import BatchedDataDict, DataDict
9
-
10
-
11
- class BaseAugmentor(nn.Module, ABC):
12
- def forward(
13
- self, item: Union[DataDict, BatchedDataDict]
14
- ) -> Union[DataDict, BatchedDataDict]:
15
- raise NotImplementedError
16
-
17
-
18
- class StemAugmentor(BaseAugmentor):
19
- def __init__(
20
- self,
21
- audiomentations: Dict[str, Dict[str, Any]],
22
- fix_clipping: bool = True,
23
- scaler_margin: float = 0.5,
24
- apply_both_default_and_common: bool = False,
25
- ) -> None:
26
- super().__init__()
27
-
28
- augmentations = {}
29
-
30
- self.has_default = "[default]" in audiomentations
31
- self.has_common = "[common]" in audiomentations
32
- self.apply_both_default_and_common = apply_both_default_and_common
33
-
34
- for stem in audiomentations:
35
- if audiomentations[stem]["name"] == "Compose":
36
- augmentations[stem] = getattr(tam, audiomentations[stem]["name"])(
37
- [
38
- getattr(tam, aug["name"])(**aug["kwargs"])
39
- for aug in audiomentations[stem]["kwargs"]["transforms"]
40
- ],
41
- **audiomentations[stem]["kwargs"]["kwargs"],
42
- )
43
- else:
44
- augmentations[stem] = getattr(tam, audiomentations[stem]["name"])(
45
- **audiomentations[stem]["kwargs"]
46
- )
47
-
48
- self.augmentations = nn.ModuleDict(augmentations)
49
- self.fix_clipping = fix_clipping
50
- self.scaler_margin = scaler_margin
51
-
52
- def check_and_fix_clipping(
53
- self, item: Union[DataDict, BatchedDataDict]
54
- ) -> Union[DataDict, BatchedDataDict]:
55
- max_abs = []
56
-
57
- for stem in item["audio"]:
58
- max_abs.append(item["audio"][stem].abs().max().item())
59
-
60
- if max(max_abs) > 1.0:
61
- scaler = 1.0 / (
62
- max(max_abs)
63
- + torch.rand((1,), device=item["audio"]["mixture"].device)
64
- * self.scaler_margin
65
- )
66
-
67
- for stem in item["audio"]:
68
- item["audio"][stem] *= scaler
69
-
70
- return item
71
-
72
- def forward(
73
- self, item: Union[DataDict, BatchedDataDict]
74
- ) -> Union[DataDict, BatchedDataDict]:
75
-
76
- for stem in item["audio"]:
77
- if stem == "mixture":
78
- continue
79
-
80
- if self.has_common:
81
- item["audio"][stem] = self.augmentations["[common]"](
82
- item["audio"][stem]
83
- ).samples
84
-
85
- if stem in self.augmentations:
86
- item["audio"][stem] = self.augmentations[stem](
87
- item["audio"][stem]
88
- ).samples
89
- elif self.has_default:
90
- if not self.has_common or self.apply_both_default_and_common:
91
- item["audio"][stem] = self.augmentations["[default]"](
92
- item["audio"][stem]
93
- ).samples
94
-
95
- item["audio"]["mixture"] = sum(
96
- [item["audio"][stem] for stem in item["audio"] if stem != "mixture"]
97
- ) # type: ignore[call-overload, assignment]
98
-
99
- if self.fix_clipping:
100
- item = self.check_and_fix_clipping(item)
101
-
102
- return item
 
1
+ from abc import ABC
2
+ from typing import Any, Dict, Union
3
+
4
+ import torch
5
+ import torch_audiomentations as tam
6
+ from torch import nn
7
+
8
+ from ._types import BatchedDataDict, DataDict
9
+
10
+
11
+ class BaseAugmentor(nn.Module, ABC):
12
+ def forward(
13
+ self, item: Union[DataDict, BatchedDataDict]
14
+ ) -> Union[DataDict, BatchedDataDict]:
15
+ raise NotImplementedError
16
+
17
+
18
+ class StemAugmentor(BaseAugmentor):
19
+ def __init__(
20
+ self,
21
+ audiomentations: Dict[str, Dict[str, Any]],
22
+ fix_clipping: bool = True,
23
+ scaler_margin: float = 0.5,
24
+ apply_both_default_and_common: bool = False,
25
+ ) -> None:
26
+ super().__init__()
27
+
28
+ augmentations = {}
29
+
30
+ self.has_default = "[default]" in audiomentations
31
+ self.has_common = "[common]" in audiomentations
32
+ self.apply_both_default_and_common = apply_both_default_and_common
33
+
34
+ for stem in audiomentations:
35
+ if audiomentations[stem]["name"] == "Compose":
36
+ augmentations[stem] = getattr(tam, audiomentations[stem]["name"])(
37
+ [
38
+ getattr(tam, aug["name"])(**aug["kwargs"])
39
+ for aug in audiomentations[stem]["kwargs"]["transforms"]
40
+ ],
41
+ **audiomentations[stem]["kwargs"]["kwargs"],
42
+ )
43
+ else:
44
+ augmentations[stem] = getattr(tam, audiomentations[stem]["name"])(
45
+ **audiomentations[stem]["kwargs"]
46
+ )
47
+
48
+ self.augmentations = nn.ModuleDict(augmentations)
49
+ self.fix_clipping = fix_clipping
50
+ self.scaler_margin = scaler_margin
51
+
52
+ def check_and_fix_clipping(
53
+ self, item: Union[DataDict, BatchedDataDict]
54
+ ) -> Union[DataDict, BatchedDataDict]:
55
+ max_abs = []
56
+
57
+ for stem in item["audio"]:
58
+ max_abs.append(item["audio"][stem].abs().max().item())
59
+
60
+ if max(max_abs) > 1.0:
61
+ scaler = 1.0 / (
62
+ max(max_abs)
63
+ + torch.rand((1,), device=item["audio"]["mixture"].device)
64
+ * self.scaler_margin
65
+ )
66
+
67
+ for stem in item["audio"]:
68
+ item["audio"][stem] *= scaler
69
+
70
+ return item
71
+
72
+ def forward(
73
+ self, item: Union[DataDict, BatchedDataDict]
74
+ ) -> Union[DataDict, BatchedDataDict]:
75
+
76
+ for stem in item["audio"]:
77
+ if stem == "mixture":
78
+ continue
79
+
80
+ if self.has_common:
81
+ item["audio"][stem] = self.augmentations["[common]"](
82
+ item["audio"][stem]
83
+ ).samples
84
+
85
+ if stem in self.augmentations:
86
+ item["audio"][stem] = self.augmentations[stem](
87
+ item["audio"][stem]
88
+ ).samples
89
+ elif self.has_default:
90
+ if not self.has_common or self.apply_both_default_and_common:
91
+ item["audio"][stem] = self.augmentations["[default]"](
92
+ item["audio"][stem]
93
+ ).samples
94
+
95
+ item["audio"]["mixture"] = sum(
96
+ [item["audio"][stem] for stem in item["audio"] if stem != "mixture"]
97
+ ) # type: ignore[call-overload, assignment]
98
+
99
+ if self.fix_clipping:
100
+ item = self.check_and_fix_clipping(item)
101
+
102
+ return item
mvsepless/models/bandit/core/data/augmented.py CHANGED
@@ -1,34 +1,34 @@
1
- import warnings
2
- from typing import Dict, Optional, Union
3
-
4
- import torch
5
- from torch import nn
6
- from torch.utils import data
7
-
8
-
9
- class AugmentedDataset(data.Dataset):
10
- def __init__(
11
- self,
12
- dataset: data.Dataset,
13
- augmentation: nn.Module = nn.Identity(),
14
- target_length: Optional[int] = None,
15
- ) -> None:
16
- warnings.warn(
17
- "This class is no longer used. Attach augmentation to "
18
- "the LightningSystem instead.",
19
- DeprecationWarning,
20
- )
21
-
22
- self.dataset = dataset
23
- self.augmentation = augmentation
24
-
25
- self.ds_length: int = len(dataset) # type: ignore[arg-type]
26
- self.length = target_length if target_length is not None else self.ds_length
27
-
28
- def __getitem__(self, index: int) -> Dict[str, Union[str, Dict[str, torch.Tensor]]]:
29
- item = self.dataset[index % self.ds_length]
30
- item = self.augmentation(item)
31
- return item
32
-
33
- def __len__(self) -> int:
34
- return self.length
 
1
+ import warnings
2
+ from typing import Dict, Optional, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.utils import data
7
+
8
+
9
+ class AugmentedDataset(data.Dataset):
10
+ def __init__(
11
+ self,
12
+ dataset: data.Dataset,
13
+ augmentation: nn.Module = nn.Identity(),
14
+ target_length: Optional[int] = None,
15
+ ) -> None:
16
+ warnings.warn(
17
+ "This class is no longer used. Attach augmentation to "
18
+ "the LightningSystem instead.",
19
+ DeprecationWarning,
20
+ )
21
+
22
+ self.dataset = dataset
23
+ self.augmentation = augmentation
24
+
25
+ self.ds_length: int = len(dataset) # type: ignore[arg-type]
26
+ self.length = target_length if target_length is not None else self.ds_length
27
+
28
+ def __getitem__(self, index: int) -> Dict[str, Union[str, Dict[str, torch.Tensor]]]:
29
+ item = self.dataset[index % self.ds_length]
30
+ item = self.augmentation(item)
31
+ return item
32
+
33
+ def __len__(self) -> int:
34
+ return self.length
mvsepless/models/bandit/core/data/base.py CHANGED
@@ -1,60 +1,60 @@
1
- import os
2
- from abc import ABC, abstractmethod
3
- from typing import Any, Dict, List, Optional
4
-
5
- import numpy as np
6
- import pedalboard as pb
7
- import torch
8
- import torchaudio as ta
9
- from torch.utils import data
10
-
11
- from ._types import AudioDict, DataDict
12
-
13
-
14
- class BaseSourceSeparationDataset(data.Dataset, ABC):
15
- def __init__(
16
- self,
17
- split: str,
18
- stems: List[str],
19
- files: List[str],
20
- data_path: str,
21
- fs: int,
22
- npy_memmap: bool,
23
- recompute_mixture: bool,
24
- ):
25
- self.split = split
26
- self.stems = stems
27
- self.stems_no_mixture = [s for s in stems if s != "mixture"]
28
- self.files = files
29
- self.data_path = data_path
30
- self.fs = fs
31
- self.npy_memmap = npy_memmap
32
- self.recompute_mixture = recompute_mixture
33
-
34
- @abstractmethod
35
- def get_stem(self, *, stem: str, identifier: Dict[str, Any]) -> torch.Tensor:
36
- raise NotImplementedError
37
-
38
- def _get_audio(self, stems, identifier: Dict[str, Any]):
39
- audio = {}
40
- for stem in stems:
41
- audio[stem] = self.get_stem(stem=stem, identifier=identifier)
42
-
43
- return audio
44
-
45
- def get_audio(self, identifier: Dict[str, Any]) -> AudioDict:
46
-
47
- if self.recompute_mixture:
48
- audio = self._get_audio(self.stems_no_mixture, identifier=identifier)
49
- audio["mixture"] = self.compute_mixture(audio)
50
- return audio
51
- else:
52
- return self._get_audio(self.stems, identifier=identifier)
53
-
54
- @abstractmethod
55
- def get_identifier(self, index: int) -> Dict[str, Any]:
56
- pass
57
-
58
- def compute_mixture(self, audio: AudioDict) -> torch.Tensor:
59
-
60
- return sum(audio[stem] for stem in audio if stem != "mixture")
 
1
+ import os
2
+ from abc import ABC, abstractmethod
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ import numpy as np
6
+ import pedalboard as pb
7
+ import torch
8
+ import torchaudio as ta
9
+ from torch.utils import data
10
+
11
+ from ._types import AudioDict, DataDict
12
+
13
+
14
+ class BaseSourceSeparationDataset(data.Dataset, ABC):
15
+ def __init__(
16
+ self,
17
+ split: str,
18
+ stems: List[str],
19
+ files: List[str],
20
+ data_path: str,
21
+ fs: int,
22
+ npy_memmap: bool,
23
+ recompute_mixture: bool,
24
+ ):
25
+ self.split = split
26
+ self.stems = stems
27
+ self.stems_no_mixture = [s for s in stems if s != "mixture"]
28
+ self.files = files
29
+ self.data_path = data_path
30
+ self.fs = fs
31
+ self.npy_memmap = npy_memmap
32
+ self.recompute_mixture = recompute_mixture
33
+
34
+ @abstractmethod
35
+ def get_stem(self, *, stem: str, identifier: Dict[str, Any]) -> torch.Tensor:
36
+ raise NotImplementedError
37
+
38
+ def _get_audio(self, stems, identifier: Dict[str, Any]):
39
+ audio = {}
40
+ for stem in stems:
41
+ audio[stem] = self.get_stem(stem=stem, identifier=identifier)
42
+
43
+ return audio
44
+
45
+ def get_audio(self, identifier: Dict[str, Any]) -> AudioDict:
46
+
47
+ if self.recompute_mixture:
48
+ audio = self._get_audio(self.stems_no_mixture, identifier=identifier)
49
+ audio["mixture"] = self.compute_mixture(audio)
50
+ return audio
51
+ else:
52
+ return self._get_audio(self.stems, identifier=identifier)
53
+
54
+ @abstractmethod
55
+ def get_identifier(self, index: int) -> Dict[str, Any]:
56
+ pass
57
+
58
+ def compute_mixture(self, audio: AudioDict) -> torch.Tensor:
59
+
60
+ return sum(audio[stem] for stem in audio if stem != "mixture")
mvsepless/models/bandit/core/data/dnr/datamodule.py CHANGED
@@ -1,64 +1,64 @@
1
- import os
2
- from typing import Mapping, Optional
3
-
4
- import pytorch_lightning as pl
5
-
6
- from .dataset import (
7
- DivideAndRemasterDataset,
8
- DivideAndRemasterDeterministicChunkDataset,
9
- DivideAndRemasterRandomChunkDataset,
10
- DivideAndRemasterRandomChunkDatasetWithSpeechReverb,
11
- )
12
-
13
-
14
- def DivideAndRemasterDataModule(
15
- data_root: str = "$DATA_ROOT/DnR/v2",
16
- batch_size: int = 2,
17
- num_workers: int = 8,
18
- train_kwargs: Optional[Mapping] = None,
19
- val_kwargs: Optional[Mapping] = None,
20
- test_kwargs: Optional[Mapping] = None,
21
- datamodule_kwargs: Optional[Mapping] = None,
22
- use_speech_reverb: bool = False,
23
- ) -> pl.LightningDataModule:
24
- if train_kwargs is None:
25
- train_kwargs = {}
26
-
27
- if val_kwargs is None:
28
- val_kwargs = {}
29
-
30
- if test_kwargs is None:
31
- test_kwargs = {}
32
-
33
- if datamodule_kwargs is None:
34
- datamodule_kwargs = {}
35
-
36
- if num_workers is None:
37
- num_workers = os.cpu_count()
38
-
39
- if num_workers is None:
40
- num_workers = 32
41
-
42
- num_workers = min(num_workers, 64)
43
-
44
- if use_speech_reverb:
45
- train_cls = DivideAndRemasterRandomChunkDatasetWithSpeechReverb
46
- else:
47
- train_cls = DivideAndRemasterRandomChunkDataset
48
-
49
- train_dataset = train_cls(data_root, "train", **train_kwargs)
50
-
51
- datamodule = pl.LightningDataModule.from_datasets(
52
- train_dataset=train_dataset,
53
- val_dataset=DivideAndRemasterDeterministicChunkDataset(
54
- data_root, "val", **val_kwargs
55
- ),
56
- test_dataset=DivideAndRemasterDataset(data_root, "test", **test_kwargs),
57
- batch_size=batch_size,
58
- num_workers=num_workers,
59
- **datamodule_kwargs,
60
- )
61
-
62
- datamodule.predict_dataloader = datamodule.test_dataloader # type: ignore[method-assign]
63
-
64
- return datamodule
 
1
+ import os
2
+ from typing import Mapping, Optional
3
+
4
+ import pytorch_lightning as pl
5
+
6
+ from .dataset import (
7
+ DivideAndRemasterDataset,
8
+ DivideAndRemasterDeterministicChunkDataset,
9
+ DivideAndRemasterRandomChunkDataset,
10
+ DivideAndRemasterRandomChunkDatasetWithSpeechReverb,
11
+ )
12
+
13
+
14
+ def DivideAndRemasterDataModule(
15
+ data_root: str = "$DATA_ROOT/DnR/v2",
16
+ batch_size: int = 2,
17
+ num_workers: int = 8,
18
+ train_kwargs: Optional[Mapping] = None,
19
+ val_kwargs: Optional[Mapping] = None,
20
+ test_kwargs: Optional[Mapping] = None,
21
+ datamodule_kwargs: Optional[Mapping] = None,
22
+ use_speech_reverb: bool = False,
23
+ ) -> pl.LightningDataModule:
24
+ if train_kwargs is None:
25
+ train_kwargs = {}
26
+
27
+ if val_kwargs is None:
28
+ val_kwargs = {}
29
+
30
+ if test_kwargs is None:
31
+ test_kwargs = {}
32
+
33
+ if datamodule_kwargs is None:
34
+ datamodule_kwargs = {}
35
+
36
+ if num_workers is None:
37
+ num_workers = os.cpu_count()
38
+
39
+ if num_workers is None:
40
+ num_workers = 32
41
+
42
+ num_workers = min(num_workers, 64)
43
+
44
+ if use_speech_reverb:
45
+ train_cls = DivideAndRemasterRandomChunkDatasetWithSpeechReverb
46
+ else:
47
+ train_cls = DivideAndRemasterRandomChunkDataset
48
+
49
+ train_dataset = train_cls(data_root, "train", **train_kwargs)
50
+
51
+ datamodule = pl.LightningDataModule.from_datasets(
52
+ train_dataset=train_dataset,
53
+ val_dataset=DivideAndRemasterDeterministicChunkDataset(
54
+ data_root, "val", **val_kwargs
55
+ ),
56
+ test_dataset=DivideAndRemasterDataset(data_root, "test", **test_kwargs),
57
+ batch_size=batch_size,
58
+ num_workers=num_workers,
59
+ **datamodule_kwargs,
60
+ )
61
+
62
+ datamodule.predict_dataloader = datamodule.test_dataloader # type: ignore[method-assign]
63
+
64
+ return datamodule
mvsepless/models/bandit/core/data/dnr/dataset.py CHANGED
@@ -1,360 +1,360 @@
1
- import os
2
- from abc import ABC
3
- from typing import Any, Dict, List, Optional
4
-
5
- import numpy as np
6
- import pedalboard as pb
7
- import torch
8
- import torchaudio as ta
9
- from torch.utils import data
10
-
11
- from .._types import AudioDict, DataDict
12
- from ..base import BaseSourceSeparationDataset
13
-
14
-
15
- class DivideAndRemasterBaseDataset(BaseSourceSeparationDataset, ABC):
16
- ALLOWED_STEMS = ["mixture", "speech", "music", "effects", "mne"]
17
- STEM_NAME_MAP = {
18
- "mixture": "mix",
19
- "speech": "speech",
20
- "music": "music",
21
- "effects": "sfx",
22
- }
23
- SPLIT_NAME_MAP = {"train": "tr", "val": "cv", "test": "tt"}
24
-
25
- FULL_TRACK_LENGTH_SECOND = 60
26
- FULL_TRACK_LENGTH_SAMPLES = FULL_TRACK_LENGTH_SECOND * 44100
27
-
28
- def __init__(
29
- self,
30
- split: str,
31
- stems: List[str],
32
- files: List[str],
33
- data_path: str,
34
- fs: int = 44100,
35
- npy_memmap: bool = True,
36
- recompute_mixture: bool = False,
37
- ) -> None:
38
- super().__init__(
39
- split=split,
40
- stems=stems,
41
- files=files,
42
- data_path=data_path,
43
- fs=fs,
44
- npy_memmap=npy_memmap,
45
- recompute_mixture=recompute_mixture,
46
- )
47
-
48
- def get_stem(self, *, stem: str, identifier: Dict[str, Any]) -> torch.Tensor:
49
-
50
- if stem == "mne":
51
- return self.get_stem(stem="music", identifier=identifier) + self.get_stem(
52
- stem="effects", identifier=identifier
53
- )
54
-
55
- track = identifier["track"]
56
- path = os.path.join(self.data_path, track)
57
-
58
- if self.npy_memmap:
59
- audio = np.load(
60
- os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.npy"), mmap_mode="r"
61
- )
62
- else:
63
- audio, _ = ta.load(os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.wav"))
64
-
65
- return audio
66
-
67
- def get_identifier(self, index):
68
- return dict(track=self.files[index])
69
-
70
- def __getitem__(self, index: int) -> DataDict:
71
- identifier = self.get_identifier(index)
72
- audio = self.get_audio(identifier)
73
-
74
- return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
75
-
76
-
77
- class DivideAndRemasterDataset(DivideAndRemasterBaseDataset):
78
- def __init__(
79
- self,
80
- data_root: str,
81
- split: str,
82
- stems: Optional[List[str]] = None,
83
- fs: int = 44100,
84
- npy_memmap: bool = True,
85
- ) -> None:
86
-
87
- if stems is None:
88
- stems = self.ALLOWED_STEMS
89
- self.stems = stems
90
-
91
- data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
92
-
93
- files = sorted(os.listdir(data_path))
94
- files = [
95
- f
96
- for f in files
97
- if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f))
98
- ]
99
- if split == "train":
100
- assert len(files) == 3406, len(files)
101
- elif split == "val":
102
- assert len(files) == 487, len(files)
103
- elif split == "test":
104
- assert len(files) == 973, len(files)
105
-
106
- self.n_tracks = len(files)
107
-
108
- super().__init__(
109
- data_path=data_path,
110
- split=split,
111
- stems=stems,
112
- files=files,
113
- fs=fs,
114
- npy_memmap=npy_memmap,
115
- )
116
-
117
- def __len__(self) -> int:
118
- return self.n_tracks
119
-
120
-
121
- class DivideAndRemasterRandomChunkDataset(DivideAndRemasterBaseDataset):
122
- def __init__(
123
- self,
124
- data_root: str,
125
- split: str,
126
- target_length: int,
127
- chunk_size_second: float,
128
- stems: Optional[List[str]] = None,
129
- fs: int = 44100,
130
- npy_memmap: bool = True,
131
- ) -> None:
132
-
133
- if stems is None:
134
- stems = self.ALLOWED_STEMS
135
- self.stems = stems
136
-
137
- data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
138
-
139
- files = sorted(os.listdir(data_path))
140
- files = [
141
- f
142
- for f in files
143
- if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f))
144
- ]
145
-
146
- if split == "train":
147
- assert len(files) == 3406, len(files)
148
- elif split == "val":
149
- assert len(files) == 487, len(files)
150
- elif split == "test":
151
- assert len(files) == 973, len(files)
152
-
153
- self.n_tracks = len(files)
154
-
155
- self.target_length = target_length
156
- self.chunk_size = int(chunk_size_second * fs)
157
-
158
- super().__init__(
159
- data_path=data_path,
160
- split=split,
161
- stems=stems,
162
- files=files,
163
- fs=fs,
164
- npy_memmap=npy_memmap,
165
- )
166
-
167
- def __len__(self) -> int:
168
- return self.target_length
169
-
170
- def get_identifier(self, index):
171
- return super().get_identifier(index % self.n_tracks)
172
-
173
- def get_stem(
174
- self,
175
- *,
176
- stem: str,
177
- identifier: Dict[str, Any],
178
- chunk_here: bool = False,
179
- ) -> torch.Tensor:
180
-
181
- stem = super().get_stem(stem=stem, identifier=identifier)
182
-
183
- if chunk_here:
184
- start = np.random.randint(
185
- 0, self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size
186
- )
187
- end = start + self.chunk_size
188
-
189
- stem = stem[:, start:end]
190
-
191
- return stem
192
-
193
- def __getitem__(self, index: int) -> DataDict:
194
- identifier = self.get_identifier(index)
195
- audio = self.get_audio(identifier)
196
-
197
- start = np.random.randint(0, self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size)
198
- end = start + self.chunk_size
199
-
200
- audio = {k: v[:, start:end] for k, v in audio.items()}
201
-
202
- return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
203
-
204
-
205
- class DivideAndRemasterDeterministicChunkDataset(DivideAndRemasterBaseDataset):
206
- def __init__(
207
- self,
208
- data_root: str,
209
- split: str,
210
- chunk_size_second: float,
211
- hop_size_second: float,
212
- stems: Optional[List[str]] = None,
213
- fs: int = 44100,
214
- npy_memmap: bool = True,
215
- ) -> None:
216
-
217
- if stems is None:
218
- stems = self.ALLOWED_STEMS
219
- self.stems = stems
220
-
221
- data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
222
-
223
- files = sorted(os.listdir(data_path))
224
- files = [
225
- f
226
- for f in files
227
- if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f))
228
- ]
229
- if split == "train":
230
- assert len(files) == 3406, len(files)
231
- elif split == "val":
232
- assert len(files) == 487, len(files)
233
- elif split == "test":
234
- assert len(files) == 973, len(files)
235
-
236
- self.n_tracks = len(files)
237
-
238
- self.chunk_size = int(chunk_size_second * fs)
239
- self.hop_size = int(hop_size_second * fs)
240
- self.n_chunks_per_track = int(
241
- (self.FULL_TRACK_LENGTH_SECOND - chunk_size_second) / hop_size_second
242
- )
243
-
244
- self.length = self.n_tracks * self.n_chunks_per_track
245
-
246
- super().__init__(
247
- data_path=data_path,
248
- split=split,
249
- stems=stems,
250
- files=files,
251
- fs=fs,
252
- npy_memmap=npy_memmap,
253
- )
254
-
255
- def get_identifier(self, index):
256
- return super().get_identifier(index % self.n_tracks)
257
-
258
- def __len__(self) -> int:
259
- return self.length
260
-
261
- def __getitem__(self, item: int) -> DataDict:
262
-
263
- index = item % self.n_tracks
264
- chunk = item // self.n_tracks
265
-
266
- data_ = super().__getitem__(index)
267
-
268
- audio = data_["audio"]
269
-
270
- start = chunk * self.hop_size
271
- end = start + self.chunk_size
272
-
273
- for stem in self.stems:
274
- data_["audio"][stem] = audio[stem][:, start:end]
275
-
276
- return data_
277
-
278
-
279
- class DivideAndRemasterRandomChunkDatasetWithSpeechReverb(
280
- DivideAndRemasterRandomChunkDataset
281
- ):
282
- def __init__(
283
- self,
284
- data_root: str,
285
- split: str,
286
- target_length: int,
287
- chunk_size_second: float,
288
- stems: Optional[List[str]] = None,
289
- fs: int = 44100,
290
- npy_memmap: bool = True,
291
- ) -> None:
292
-
293
- if stems is None:
294
- stems = self.ALLOWED_STEMS
295
-
296
- stems_no_mixture = [s for s in stems if s != "mixture"]
297
-
298
- super().__init__(
299
- data_root=data_root,
300
- split=split,
301
- target_length=target_length,
302
- chunk_size_second=chunk_size_second,
303
- stems=stems_no_mixture,
304
- fs=fs,
305
- npy_memmap=npy_memmap,
306
- )
307
-
308
- self.stems = stems
309
- self.stems_no_mixture = stems_no_mixture
310
-
311
- def __getitem__(self, index: int) -> DataDict:
312
-
313
- data_ = super().__getitem__(index)
314
-
315
- dry = data_["audio"]["speech"][:]
316
- n_samples = dry.shape[-1]
317
-
318
- wet_level = np.random.rand()
319
-
320
- speech = pb.Reverb(
321
- room_size=np.random.rand(),
322
- damping=np.random.rand(),
323
- wet_level=wet_level,
324
- dry_level=(1 - wet_level),
325
- width=np.random.rand(),
326
- ).process(dry, self.fs, buffer_size=8192 * 4)[..., :n_samples]
327
-
328
- data_["audio"]["speech"] = speech
329
-
330
- data_["audio"]["mixture"] = sum(
331
- [data_["audio"][s] for s in self.stems_no_mixture]
332
- )
333
-
334
- return data_
335
-
336
- def __len__(self) -> int:
337
- return super().__len__()
338
-
339
-
340
- if __name__ == "__main__":
341
-
342
- from pprint import pprint
343
- from tqdm.auto import tqdm
344
-
345
- for split_ in ["train", "val", "test"]:
346
- ds = DivideAndRemasterRandomChunkDatasetWithSpeechReverb(
347
- data_root="$DATA_ROOT/DnR/v2np",
348
- split=split_,
349
- target_length=100,
350
- chunk_size_second=6.0,
351
- )
352
-
353
- print(split_, len(ds))
354
-
355
- for track_ in tqdm(ds): # type: ignore
356
- pprint(track_)
357
- track_["audio"] = {k: v.shape for k, v in track_["audio"].items()}
358
- pprint(track_)
359
-
360
- break
 
1
+ import os
2
+ from abc import ABC
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ import numpy as np
6
+ import pedalboard as pb
7
+ import torch
8
+ import torchaudio as ta
9
+ from torch.utils import data
10
+
11
+ from .._types import AudioDict, DataDict
12
+ from ..base import BaseSourceSeparationDataset
13
+
14
+
15
+ class DivideAndRemasterBaseDataset(BaseSourceSeparationDataset, ABC):
16
+ ALLOWED_STEMS = ["mixture", "speech", "music", "effects", "mne"]
17
+ STEM_NAME_MAP = {
18
+ "mixture": "mix",
19
+ "speech": "speech",
20
+ "music": "music",
21
+ "effects": "sfx",
22
+ }
23
+ SPLIT_NAME_MAP = {"train": "tr", "val": "cv", "test": "tt"}
24
+
25
+ FULL_TRACK_LENGTH_SECOND = 60
26
+ FULL_TRACK_LENGTH_SAMPLES = FULL_TRACK_LENGTH_SECOND * 44100
27
+
28
+ def __init__(
29
+ self,
30
+ split: str,
31
+ stems: List[str],
32
+ files: List[str],
33
+ data_path: str,
34
+ fs: int = 44100,
35
+ npy_memmap: bool = True,
36
+ recompute_mixture: bool = False,
37
+ ) -> None:
38
+ super().__init__(
39
+ split=split,
40
+ stems=stems,
41
+ files=files,
42
+ data_path=data_path,
43
+ fs=fs,
44
+ npy_memmap=npy_memmap,
45
+ recompute_mixture=recompute_mixture,
46
+ )
47
+
48
+ def get_stem(self, *, stem: str, identifier: Dict[str, Any]) -> torch.Tensor:
49
+
50
+ if stem == "mne":
51
+ return self.get_stem(stem="music", identifier=identifier) + self.get_stem(
52
+ stem="effects", identifier=identifier
53
+ )
54
+
55
+ track = identifier["track"]
56
+ path = os.path.join(self.data_path, track)
57
+
58
+ if self.npy_memmap:
59
+ audio = np.load(
60
+ os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.npy"), mmap_mode="r"
61
+ )
62
+ else:
63
+ audio, _ = ta.load(os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.wav"))
64
+
65
+ return audio
66
+
67
+ def get_identifier(self, index):
68
+ return dict(track=self.files[index])
69
+
70
+ def __getitem__(self, index: int) -> DataDict:
71
+ identifier = self.get_identifier(index)
72
+ audio = self.get_audio(identifier)
73
+
74
+ return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
75
+
76
+
77
+ class DivideAndRemasterDataset(DivideAndRemasterBaseDataset):
78
+ def __init__(
79
+ self,
80
+ data_root: str,
81
+ split: str,
82
+ stems: Optional[List[str]] = None,
83
+ fs: int = 44100,
84
+ npy_memmap: bool = True,
85
+ ) -> None:
86
+
87
+ if stems is None:
88
+ stems = self.ALLOWED_STEMS
89
+ self.stems = stems
90
+
91
+ data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
92
+
93
+ files = sorted(os.listdir(data_path))
94
+ files = [
95
+ f
96
+ for f in files
97
+ if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f))
98
+ ]
99
+ if split == "train":
100
+ assert len(files) == 3406, len(files)
101
+ elif split == "val":
102
+ assert len(files) == 487, len(files)
103
+ elif split == "test":
104
+ assert len(files) == 973, len(files)
105
+
106
+ self.n_tracks = len(files)
107
+
108
+ super().__init__(
109
+ data_path=data_path,
110
+ split=split,
111
+ stems=stems,
112
+ files=files,
113
+ fs=fs,
114
+ npy_memmap=npy_memmap,
115
+ )
116
+
117
+ def __len__(self) -> int:
118
+ return self.n_tracks
119
+
120
+
121
+ class DivideAndRemasterRandomChunkDataset(DivideAndRemasterBaseDataset):
122
+ def __init__(
123
+ self,
124
+ data_root: str,
125
+ split: str,
126
+ target_length: int,
127
+ chunk_size_second: float,
128
+ stems: Optional[List[str]] = None,
129
+ fs: int = 44100,
130
+ npy_memmap: bool = True,
131
+ ) -> None:
132
+
133
+ if stems is None:
134
+ stems = self.ALLOWED_STEMS
135
+ self.stems = stems
136
+
137
+ data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
138
+
139
+ files = sorted(os.listdir(data_path))
140
+ files = [
141
+ f
142
+ for f in files
143
+ if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f))
144
+ ]
145
+
146
+ if split == "train":
147
+ assert len(files) == 3406, len(files)
148
+ elif split == "val":
149
+ assert len(files) == 487, len(files)
150
+ elif split == "test":
151
+ assert len(files) == 973, len(files)
152
+
153
+ self.n_tracks = len(files)
154
+
155
+ self.target_length = target_length
156
+ self.chunk_size = int(chunk_size_second * fs)
157
+
158
+ super().__init__(
159
+ data_path=data_path,
160
+ split=split,
161
+ stems=stems,
162
+ files=files,
163
+ fs=fs,
164
+ npy_memmap=npy_memmap,
165
+ )
166
+
167
+ def __len__(self) -> int:
168
+ return self.target_length
169
+
170
+ def get_identifier(self, index):
171
+ return super().get_identifier(index % self.n_tracks)
172
+
173
+ def get_stem(
174
+ self,
175
+ *,
176
+ stem: str,
177
+ identifier: Dict[str, Any],
178
+ chunk_here: bool = False,
179
+ ) -> torch.Tensor:
180
+
181
+ stem = super().get_stem(stem=stem, identifier=identifier)
182
+
183
+ if chunk_here:
184
+ start = np.random.randint(
185
+ 0, self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size
186
+ )
187
+ end = start + self.chunk_size
188
+
189
+ stem = stem[:, start:end]
190
+
191
+ return stem
192
+
193
+ def __getitem__(self, index: int) -> DataDict:
194
+ identifier = self.get_identifier(index)
195
+ audio = self.get_audio(identifier)
196
+
197
+ start = np.random.randint(0, self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size)
198
+ end = start + self.chunk_size
199
+
200
+ audio = {k: v[:, start:end] for k, v in audio.items()}
201
+
202
+ return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
203
+
204
+
205
+ class DivideAndRemasterDeterministicChunkDataset(DivideAndRemasterBaseDataset):
206
+ def __init__(
207
+ self,
208
+ data_root: str,
209
+ split: str,
210
+ chunk_size_second: float,
211
+ hop_size_second: float,
212
+ stems: Optional[List[str]] = None,
213
+ fs: int = 44100,
214
+ npy_memmap: bool = True,
215
+ ) -> None:
216
+
217
+ if stems is None:
218
+ stems = self.ALLOWED_STEMS
219
+ self.stems = stems
220
+
221
+ data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
222
+
223
+ files = sorted(os.listdir(data_path))
224
+ files = [
225
+ f
226
+ for f in files
227
+ if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f))
228
+ ]
229
+ if split == "train":
230
+ assert len(files) == 3406, len(files)
231
+ elif split == "val":
232
+ assert len(files) == 487, len(files)
233
+ elif split == "test":
234
+ assert len(files) == 973, len(files)
235
+
236
+ self.n_tracks = len(files)
237
+
238
+ self.chunk_size = int(chunk_size_second * fs)
239
+ self.hop_size = int(hop_size_second * fs)
240
+ self.n_chunks_per_track = int(
241
+ (self.FULL_TRACK_LENGTH_SECOND - chunk_size_second) / hop_size_second
242
+ )
243
+
244
+ self.length = self.n_tracks * self.n_chunks_per_track
245
+
246
+ super().__init__(
247
+ data_path=data_path,
248
+ split=split,
249
+ stems=stems,
250
+ files=files,
251
+ fs=fs,
252
+ npy_memmap=npy_memmap,
253
+ )
254
+
255
+ def get_identifier(self, index):
256
+ return super().get_identifier(index % self.n_tracks)
257
+
258
+ def __len__(self) -> int:
259
+ return self.length
260
+
261
+ def __getitem__(self, item: int) -> DataDict:
262
+
263
+ index = item % self.n_tracks
264
+ chunk = item // self.n_tracks
265
+
266
+ data_ = super().__getitem__(index)
267
+
268
+ audio = data_["audio"]
269
+
270
+ start = chunk * self.hop_size
271
+ end = start + self.chunk_size
272
+
273
+ for stem in self.stems:
274
+ data_["audio"][stem] = audio[stem][:, start:end]
275
+
276
+ return data_
277
+
278
+
279
+ class DivideAndRemasterRandomChunkDatasetWithSpeechReverb(
280
+ DivideAndRemasterRandomChunkDataset
281
+ ):
282
+ def __init__(
283
+ self,
284
+ data_root: str,
285
+ split: str,
286
+ target_length: int,
287
+ chunk_size_second: float,
288
+ stems: Optional[List[str]] = None,
289
+ fs: int = 44100,
290
+ npy_memmap: bool = True,
291
+ ) -> None:
292
+
293
+ if stems is None:
294
+ stems = self.ALLOWED_STEMS
295
+
296
+ stems_no_mixture = [s for s in stems if s != "mixture"]
297
+
298
+ super().__init__(
299
+ data_root=data_root,
300
+ split=split,
301
+ target_length=target_length,
302
+ chunk_size_second=chunk_size_second,
303
+ stems=stems_no_mixture,
304
+ fs=fs,
305
+ npy_memmap=npy_memmap,
306
+ )
307
+
308
+ self.stems = stems
309
+ self.stems_no_mixture = stems_no_mixture
310
+
311
+ def __getitem__(self, index: int) -> DataDict:
312
+
313
+ data_ = super().__getitem__(index)
314
+
315
+ dry = data_["audio"]["speech"][:]
316
+ n_samples = dry.shape[-1]
317
+
318
+ wet_level = np.random.rand()
319
+
320
+ speech = pb.Reverb(
321
+ room_size=np.random.rand(),
322
+ damping=np.random.rand(),
323
+ wet_level=wet_level,
324
+ dry_level=(1 - wet_level),
325
+ width=np.random.rand(),
326
+ ).process(dry, self.fs, buffer_size=8192 * 4)[..., :n_samples]
327
+
328
+ data_["audio"]["speech"] = speech
329
+
330
+ data_["audio"]["mixture"] = sum(
331
+ [data_["audio"][s] for s in self.stems_no_mixture]
332
+ )
333
+
334
+ return data_
335
+
336
+ def __len__(self) -> int:
337
+ return super().__len__()
338
+
339
+
340
+ if __name__ == "__main__":
341
+
342
+ from pprint import pprint
343
+ from tqdm.auto import tqdm
344
+
345
+ for split_ in ["train", "val", "test"]:
346
+ ds = DivideAndRemasterRandomChunkDatasetWithSpeechReverb(
347
+ data_root="$DATA_ROOT/DnR/v2np",
348
+ split=split_,
349
+ target_length=100,
350
+ chunk_size_second=6.0,
351
+ )
352
+
353
+ print(split_, len(ds))
354
+
355
+ for track_ in tqdm(ds): # type: ignore
356
+ pprint(track_)
357
+ track_["audio"] = {k: v.shape for k, v in track_["audio"].items()}
358
+ pprint(track_)
359
+
360
+ break
mvsepless/models/bandit/core/data/dnr/preprocess.py CHANGED
@@ -1,51 +1,51 @@
1
- import glob
2
- import os
3
- from typing import Tuple
4
-
5
- import numpy as np
6
- import torchaudio as ta
7
- from tqdm.contrib.concurrent import process_map
8
-
9
-
10
- def process_one(inputs: Tuple[str, str, int]) -> None:
11
- infile, outfile, target_fs = inputs
12
-
13
- dir = os.path.dirname(outfile)
14
- os.makedirs(dir, exist_ok=True)
15
-
16
- data, fs = ta.load(infile)
17
-
18
- if fs != target_fs:
19
- data = ta.functional.resample(
20
- data, fs, target_fs, resampling_method="sinc_interp_kaiser"
21
- )
22
- fs = target_fs
23
-
24
- data = data.numpy()
25
- data = data.astype(np.float32)
26
-
27
- if os.path.exists(outfile):
28
- data_ = np.load(outfile)
29
- if np.allclose(data, data_):
30
- return
31
-
32
- np.save(outfile, data)
33
-
34
-
35
- def preprocess(data_path: str, output_path: str, fs: int) -> None:
36
- files = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True)
37
- print(files)
38
- outfiles = [
39
- f.replace(data_path, output_path).replace(".wav", ".npy") for f in files
40
- ]
41
-
42
- os.makedirs(output_path, exist_ok=True)
43
- inputs = list(zip(files, outfiles, [fs] * len(files)))
44
-
45
- process_map(process_one, inputs, chunksize=32)
46
-
47
-
48
- if __name__ == "__main__":
49
- import fire
50
-
51
- fire.Fire()
 
1
+ import glob
2
+ import os
3
+ from typing import Tuple
4
+
5
+ import numpy as np
6
+ import torchaudio as ta
7
+ from tqdm.contrib.concurrent import process_map
8
+
9
+
10
+ def process_one(inputs: Tuple[str, str, int]) -> None:
11
+ infile, outfile, target_fs = inputs
12
+
13
+ dir = os.path.dirname(outfile)
14
+ os.makedirs(dir, exist_ok=True)
15
+
16
+ data, fs = ta.load(infile)
17
+
18
+ if fs != target_fs:
19
+ data = ta.functional.resample(
20
+ data, fs, target_fs, resampling_method="sinc_interp_kaiser"
21
+ )
22
+ fs = target_fs
23
+
24
+ data = data.numpy()
25
+ data = data.astype(np.float32)
26
+
27
+ if os.path.exists(outfile):
28
+ data_ = np.load(outfile)
29
+ if np.allclose(data, data_):
30
+ return
31
+
32
+ np.save(outfile, data)
33
+
34
+
35
+ def preprocess(data_path: str, output_path: str, fs: int) -> None:
36
+ files = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True)
37
+ print(files)
38
+ outfiles = [
39
+ f.replace(data_path, output_path).replace(".wav", ".npy") for f in files
40
+ ]
41
+
42
+ os.makedirs(output_path, exist_ok=True)
43
+ inputs = list(zip(files, outfiles, [fs] * len(files)))
44
+
45
+ process_map(process_one, inputs, chunksize=32)
46
+
47
+
48
+ if __name__ == "__main__":
49
+ import fire
50
+
51
+ fire.Fire()
mvsepless/models/bandit/core/data/musdb/datamodule.py CHANGED
@@ -1,75 +1,75 @@
1
- import os.path
2
- from typing import Mapping, Optional
3
-
4
- import pytorch_lightning as pl
5
-
6
- from .dataset import (
7
- MUSDB18BaseDataset,
8
- MUSDB18FullTrackDataset,
9
- MUSDB18SadDataset,
10
- MUSDB18SadOnTheFlyAugmentedDataset,
11
- )
12
-
13
-
14
- def MUSDB18DataModule(
15
- data_root: str = "$DATA_ROOT/MUSDB18/HQ",
16
- target_stem: str = "vocals",
17
- batch_size: int = 2,
18
- num_workers: int = 8,
19
- train_kwargs: Optional[Mapping] = None,
20
- val_kwargs: Optional[Mapping] = None,
21
- test_kwargs: Optional[Mapping] = None,
22
- datamodule_kwargs: Optional[Mapping] = None,
23
- use_on_the_fly: bool = True,
24
- npy_memmap: bool = True,
25
- ) -> pl.LightningDataModule:
26
- if train_kwargs is None:
27
- train_kwargs = {}
28
-
29
- if val_kwargs is None:
30
- val_kwargs = {}
31
-
32
- if test_kwargs is None:
33
- test_kwargs = {}
34
-
35
- if datamodule_kwargs is None:
36
- datamodule_kwargs = {}
37
-
38
- train_dataset: MUSDB18BaseDataset
39
-
40
- if use_on_the_fly:
41
- train_dataset = MUSDB18SadOnTheFlyAugmentedDataset(
42
- data_root=os.path.join(data_root, "saded-np"),
43
- split="train",
44
- target_stem=target_stem,
45
- **train_kwargs,
46
- )
47
- else:
48
- train_dataset = MUSDB18SadDataset(
49
- data_root=os.path.join(data_root, "saded-np"),
50
- split="train",
51
- target_stem=target_stem,
52
- **train_kwargs,
53
- )
54
-
55
- datamodule = pl.LightningDataModule.from_datasets(
56
- train_dataset=train_dataset,
57
- val_dataset=MUSDB18SadDataset(
58
- data_root=os.path.join(data_root, "saded-np"),
59
- split="val",
60
- target_stem=target_stem,
61
- **val_kwargs,
62
- ),
63
- test_dataset=MUSDB18FullTrackDataset(
64
- data_root=os.path.join(data_root, "canonical"), split="test", **test_kwargs
65
- ),
66
- batch_size=batch_size,
67
- num_workers=num_workers,
68
- **datamodule_kwargs,
69
- )
70
-
71
- datamodule.predict_dataloader = ( # type: ignore[method-assign]
72
- datamodule.test_dataloader
73
- )
74
-
75
- return datamodule
 
1
+ import os.path
2
+ from typing import Mapping, Optional
3
+
4
+ import pytorch_lightning as pl
5
+
6
+ from .dataset import (
7
+ MUSDB18BaseDataset,
8
+ MUSDB18FullTrackDataset,
9
+ MUSDB18SadDataset,
10
+ MUSDB18SadOnTheFlyAugmentedDataset,
11
+ )
12
+
13
+
14
+ def MUSDB18DataModule(
15
+ data_root: str = "$DATA_ROOT/MUSDB18/HQ",
16
+ target_stem: str = "vocals",
17
+ batch_size: int = 2,
18
+ num_workers: int = 8,
19
+ train_kwargs: Optional[Mapping] = None,
20
+ val_kwargs: Optional[Mapping] = None,
21
+ test_kwargs: Optional[Mapping] = None,
22
+ datamodule_kwargs: Optional[Mapping] = None,
23
+ use_on_the_fly: bool = True,
24
+ npy_memmap: bool = True,
25
+ ) -> pl.LightningDataModule:
26
+ if train_kwargs is None:
27
+ train_kwargs = {}
28
+
29
+ if val_kwargs is None:
30
+ val_kwargs = {}
31
+
32
+ if test_kwargs is None:
33
+ test_kwargs = {}
34
+
35
+ if datamodule_kwargs is None:
36
+ datamodule_kwargs = {}
37
+
38
+ train_dataset: MUSDB18BaseDataset
39
+
40
+ if use_on_the_fly:
41
+ train_dataset = MUSDB18SadOnTheFlyAugmentedDataset(
42
+ data_root=os.path.join(data_root, "saded-np"),
43
+ split="train",
44
+ target_stem=target_stem,
45
+ **train_kwargs,
46
+ )
47
+ else:
48
+ train_dataset = MUSDB18SadDataset(
49
+ data_root=os.path.join(data_root, "saded-np"),
50
+ split="train",
51
+ target_stem=target_stem,
52
+ **train_kwargs,
53
+ )
54
+
55
+ datamodule = pl.LightningDataModule.from_datasets(
56
+ train_dataset=train_dataset,
57
+ val_dataset=MUSDB18SadDataset(
58
+ data_root=os.path.join(data_root, "saded-np"),
59
+ split="val",
60
+ target_stem=target_stem,
61
+ **val_kwargs,
62
+ ),
63
+ test_dataset=MUSDB18FullTrackDataset(
64
+ data_root=os.path.join(data_root, "canonical"), split="test", **test_kwargs
65
+ ),
66
+ batch_size=batch_size,
67
+ num_workers=num_workers,
68
+ **datamodule_kwargs,
69
+ )
70
+
71
+ datamodule.predict_dataloader = ( # type: ignore[method-assign]
72
+ datamodule.test_dataloader
73
+ )
74
+
75
+ return datamodule
mvsepless/models/bandit/core/data/musdb/dataset.py CHANGED
@@ -1,241 +1,241 @@
1
- import os
2
- from abc import ABC
3
- from typing import List, Optional, Tuple
4
-
5
- import numpy as np
6
- import torch
7
- import torchaudio as ta
8
- from torch.utils import data
9
-
10
- from .._types import AudioDict, DataDict
11
- from ..base import BaseSourceSeparationDataset
12
-
13
-
14
- class MUSDB18BaseDataset(BaseSourceSeparationDataset, ABC):
15
-
16
- ALLOWED_STEMS = ["mixture", "vocals", "bass", "drums", "other"]
17
-
18
- def __init__(
19
- self,
20
- split: str,
21
- stems: List[str],
22
- files: List[str],
23
- data_path: str,
24
- fs: int = 44100,
25
- npy_memmap=False,
26
- ) -> None:
27
- super().__init__(
28
- split=split,
29
- stems=stems,
30
- files=files,
31
- data_path=data_path,
32
- fs=fs,
33
- npy_memmap=npy_memmap,
34
- recompute_mixture=False,
35
- )
36
-
37
- def get_stem(self, *, stem: str, identifier) -> torch.Tensor:
38
- track = identifier["track"]
39
- path = os.path.join(self.data_path, track)
40
-
41
- if self.npy_memmap:
42
- audio = np.load(os.path.join(path, f"{stem}.wav.npy"), mmap_mode="r")
43
- else:
44
- audio, _ = ta.load(os.path.join(path, f"{stem}.wav"))
45
-
46
- return audio
47
-
48
- def get_identifier(self, index):
49
- return dict(track=self.files[index])
50
-
51
- def __getitem__(self, index: int) -> DataDict:
52
- identifier = self.get_identifier(index)
53
- audio = self.get_audio(identifier)
54
-
55
- return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
56
-
57
-
58
- class MUSDB18FullTrackDataset(MUSDB18BaseDataset):
59
-
60
- N_TRAIN_TRACKS = 100
61
- N_TEST_TRACKS = 50
62
- VALIDATION_FILES = [
63
- "Actions - One Minute Smile",
64
- "Clara Berry And Wooldog - Waltz For My Victims",
65
- "Johnny Lokke - Promises & Lies",
66
- "Patrick Talbot - A Reason To Leave",
67
- "Triviul - Angelsaint",
68
- "Alexander Ross - Goodbye Bolero",
69
- "Fergessen - Nos Palpitants",
70
- "Leaf - Summerghost",
71
- "Skelpolu - Human Mistakes",
72
- "Young Griffo - Pennies",
73
- "ANiMAL - Rockshow",
74
- "James May - On The Line",
75
- "Meaxic - Take A Step",
76
- "Traffic Experiment - Sirens",
77
- ]
78
-
79
- def __init__(
80
- self, data_root: str, split: str, stems: Optional[List[str]] = None
81
- ) -> None:
82
-
83
- if stems is None:
84
- stems = self.ALLOWED_STEMS
85
- self.stems = stems
86
-
87
- if split == "test":
88
- subset = "test"
89
- elif split in ["train", "val"]:
90
- subset = "train"
91
- else:
92
- raise NameError
93
-
94
- data_path = os.path.join(data_root, subset)
95
-
96
- files = sorted(os.listdir(data_path))
97
- files = [f for f in files if not f.startswith(".")]
98
- if subset == "train":
99
- assert len(files) == 100, len(files)
100
- if split == "train":
101
- files = [f for f in files if f not in self.VALIDATION_FILES]
102
- assert len(files) == 100 - len(self.VALIDATION_FILES)
103
- else:
104
- files = [f for f in files if f in self.VALIDATION_FILES]
105
- assert len(files) == len(self.VALIDATION_FILES)
106
- else:
107
- split = "test"
108
- assert len(files) == 50
109
-
110
- self.n_tracks = len(files)
111
-
112
- super().__init__(data_path=data_path, split=split, stems=stems, files=files)
113
-
114
- def __len__(self) -> int:
115
- return self.n_tracks
116
-
117
-
118
- class MUSDB18SadDataset(MUSDB18BaseDataset):
119
- def __init__(
120
- self,
121
- data_root: str,
122
- split: str,
123
- target_stem: str,
124
- stems: Optional[List[str]] = None,
125
- target_length: Optional[int] = None,
126
- npy_memmap=False,
127
- ) -> None:
128
-
129
- if stems is None:
130
- stems = self.ALLOWED_STEMS
131
-
132
- data_path = os.path.join(data_root, target_stem, split)
133
-
134
- files = sorted(os.listdir(data_path))
135
- files = [f for f in files if not f.startswith(".")]
136
-
137
- super().__init__(
138
- data_path=data_path,
139
- split=split,
140
- stems=stems,
141
- files=files,
142
- npy_memmap=npy_memmap,
143
- )
144
- self.n_segments = len(files)
145
- self.target_stem = target_stem
146
- self.target_length = (
147
- target_length if target_length is not None else self.n_segments
148
- )
149
-
150
- def __len__(self) -> int:
151
- return self.target_length
152
-
153
- def __getitem__(self, index: int) -> DataDict:
154
-
155
- index = index % self.n_segments
156
-
157
- return super().__getitem__(index)
158
-
159
- def get_identifier(self, index):
160
- return super().get_identifier(index % self.n_segments)
161
-
162
-
163
- class MUSDB18SadOnTheFlyAugmentedDataset(MUSDB18SadDataset):
164
- def __init__(
165
- self,
166
- data_root: str,
167
- split: str,
168
- target_stem: str,
169
- stems: Optional[List[str]] = None,
170
- target_length: int = 20000,
171
- apply_probability: Optional[float] = None,
172
- chunk_size_second: float = 3.0,
173
- random_scale_range_db: Tuple[float, float] = (-10, 10),
174
- drop_probability: float = 0.1,
175
- rescale: bool = True,
176
- ) -> None:
177
- super().__init__(data_root, split, target_stem, stems)
178
-
179
- if apply_probability is None:
180
- apply_probability = (target_length - self.n_segments) / target_length
181
-
182
- self.apply_probability = apply_probability
183
- self.drop_probability = drop_probability
184
- self.chunk_size_second = chunk_size_second
185
- self.random_scale_range_db = random_scale_range_db
186
- self.rescale = rescale
187
-
188
- self.chunk_size_sample = int(self.chunk_size_second * self.fs)
189
- self.target_length = target_length
190
-
191
- def __len__(self) -> int:
192
- return self.target_length
193
-
194
- def __getitem__(self, index: int) -> DataDict:
195
-
196
- index = index % self.n_segments
197
-
198
- audio = {}
199
- identifier = self.get_identifier(index)
200
-
201
- for stem in self.stems_no_mixture:
202
- if stem == self.target_stem:
203
- identifier_ = identifier
204
- else:
205
- if np.random.rand() < self.apply_probability:
206
- index_ = np.random.randint(self.n_segments)
207
- identifier_ = self.get_identifier(index_)
208
- else:
209
- identifier_ = identifier
210
-
211
- audio[stem] = self.get_stem(stem=stem, identifier=identifier_)
212
-
213
- if self.chunk_size_sample < audio[stem].shape[-1]:
214
- chunk_start = np.random.randint(
215
- audio[stem].shape[-1] - self.chunk_size_sample
216
- )
217
- else:
218
- chunk_start = 0
219
-
220
- if np.random.rand() < self.drop_probability:
221
- linear_scale = 0.0
222
- else:
223
- db_scale = np.random.uniform(*self.random_scale_range_db)
224
- linear_scale = np.power(10, db_scale / 20)
225
- audio[stem][..., chunk_start : chunk_start + self.chunk_size_sample] = (
226
- linear_scale
227
- * audio[stem][..., chunk_start : chunk_start + self.chunk_size_sample]
228
- )
229
-
230
- audio["mixture"] = self.compute_mixture(audio)
231
-
232
- if self.rescale:
233
- max_abs_val = max(
234
- [torch.max(torch.abs(audio[stem])) for stem in self.stems]
235
- ) # type: ignore[type-var]
236
- if max_abs_val > 1:
237
- audio = {k: v / max_abs_val for k, v in audio.items()}
238
-
239
- track = identifier["track"]
240
-
241
- return {"audio": audio, "track": f"{self.split}/{track}"}
 
1
+ import os
2
+ from abc import ABC
3
+ from typing import List, Optional, Tuple
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torchaudio as ta
8
+ from torch.utils import data
9
+
10
+ from .._types import AudioDict, DataDict
11
+ from ..base import BaseSourceSeparationDataset
12
+
13
+
14
+ class MUSDB18BaseDataset(BaseSourceSeparationDataset, ABC):
15
+
16
+ ALLOWED_STEMS = ["mixture", "vocals", "bass", "drums", "other"]
17
+
18
+ def __init__(
19
+ self,
20
+ split: str,
21
+ stems: List[str],
22
+ files: List[str],
23
+ data_path: str,
24
+ fs: int = 44100,
25
+ npy_memmap=False,
26
+ ) -> None:
27
+ super().__init__(
28
+ split=split,
29
+ stems=stems,
30
+ files=files,
31
+ data_path=data_path,
32
+ fs=fs,
33
+ npy_memmap=npy_memmap,
34
+ recompute_mixture=False,
35
+ )
36
+
37
+ def get_stem(self, *, stem: str, identifier) -> torch.Tensor:
38
+ track = identifier["track"]
39
+ path = os.path.join(self.data_path, track)
40
+
41
+ if self.npy_memmap:
42
+ audio = np.load(os.path.join(path, f"{stem}.wav.npy"), mmap_mode="r")
43
+ else:
44
+ audio, _ = ta.load(os.path.join(path, f"{stem}.wav"))
45
+
46
+ return audio
47
+
48
+ def get_identifier(self, index):
49
+ return dict(track=self.files[index])
50
+
51
+ def __getitem__(self, index: int) -> DataDict:
52
+ identifier = self.get_identifier(index)
53
+ audio = self.get_audio(identifier)
54
+
55
+ return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
56
+
57
+
58
+ class MUSDB18FullTrackDataset(MUSDB18BaseDataset):
59
+
60
+ N_TRAIN_TRACKS = 100
61
+ N_TEST_TRACKS = 50
62
+ VALIDATION_FILES = [
63
+ "Actions - One Minute Smile",
64
+ "Clara Berry And Wooldog - Waltz For My Victims",
65
+ "Johnny Lokke - Promises & Lies",
66
+ "Patrick Talbot - A Reason To Leave",
67
+ "Triviul - Angelsaint",
68
+ "Alexander Ross - Goodbye Bolero",
69
+ "Fergessen - Nos Palpitants",
70
+ "Leaf - Summerghost",
71
+ "Skelpolu - Human Mistakes",
72
+ "Young Griffo - Pennies",
73
+ "ANiMAL - Rockshow",
74
+ "James May - On The Line",
75
+ "Meaxic - Take A Step",
76
+ "Traffic Experiment - Sirens",
77
+ ]
78
+
79
+ def __init__(
80
+ self, data_root: str, split: str, stems: Optional[List[str]] = None
81
+ ) -> None:
82
+
83
+ if stems is None:
84
+ stems = self.ALLOWED_STEMS
85
+ self.stems = stems
86
+
87
+ if split == "test":
88
+ subset = "test"
89
+ elif split in ["train", "val"]:
90
+ subset = "train"
91
+ else:
92
+ raise NameError
93
+
94
+ data_path = os.path.join(data_root, subset)
95
+
96
+ files = sorted(os.listdir(data_path))
97
+ files = [f for f in files if not f.startswith(".")]
98
+ if subset == "train":
99
+ assert len(files) == 100, len(files)
100
+ if split == "train":
101
+ files = [f for f in files if f not in self.VALIDATION_FILES]
102
+ assert len(files) == 100 - len(self.VALIDATION_FILES)
103
+ else:
104
+ files = [f for f in files if f in self.VALIDATION_FILES]
105
+ assert len(files) == len(self.VALIDATION_FILES)
106
+ else:
107
+ split = "test"
108
+ assert len(files) == 50
109
+
110
+ self.n_tracks = len(files)
111
+
112
+ super().__init__(data_path=data_path, split=split, stems=stems, files=files)
113
+
114
+ def __len__(self) -> int:
115
+ return self.n_tracks
116
+
117
+
118
+ class MUSDB18SadDataset(MUSDB18BaseDataset):
119
+ def __init__(
120
+ self,
121
+ data_root: str,
122
+ split: str,
123
+ target_stem: str,
124
+ stems: Optional[List[str]] = None,
125
+ target_length: Optional[int] = None,
126
+ npy_memmap=False,
127
+ ) -> None:
128
+
129
+ if stems is None:
130
+ stems = self.ALLOWED_STEMS
131
+
132
+ data_path = os.path.join(data_root, target_stem, split)
133
+
134
+ files = sorted(os.listdir(data_path))
135
+ files = [f for f in files if not f.startswith(".")]
136
+
137
+ super().__init__(
138
+ data_path=data_path,
139
+ split=split,
140
+ stems=stems,
141
+ files=files,
142
+ npy_memmap=npy_memmap,
143
+ )
144
+ self.n_segments = len(files)
145
+ self.target_stem = target_stem
146
+ self.target_length = (
147
+ target_length if target_length is not None else self.n_segments
148
+ )
149
+
150
+ def __len__(self) -> int:
151
+ return self.target_length
152
+
153
+ def __getitem__(self, index: int) -> DataDict:
154
+
155
+ index = index % self.n_segments
156
+
157
+ return super().__getitem__(index)
158
+
159
+ def get_identifier(self, index):
160
+ return super().get_identifier(index % self.n_segments)
161
+
162
+
163
+ class MUSDB18SadOnTheFlyAugmentedDataset(MUSDB18SadDataset):
164
+ def __init__(
165
+ self,
166
+ data_root: str,
167
+ split: str,
168
+ target_stem: str,
169
+ stems: Optional[List[str]] = None,
170
+ target_length: int = 20000,
171
+ apply_probability: Optional[float] = None,
172
+ chunk_size_second: float = 3.0,
173
+ random_scale_range_db: Tuple[float, float] = (-10, 10),
174
+ drop_probability: float = 0.1,
175
+ rescale: bool = True,
176
+ ) -> None:
177
+ super().__init__(data_root, split, target_stem, stems)
178
+
179
+ if apply_probability is None:
180
+ apply_probability = (target_length - self.n_segments) / target_length
181
+
182
+ self.apply_probability = apply_probability
183
+ self.drop_probability = drop_probability
184
+ self.chunk_size_second = chunk_size_second
185
+ self.random_scale_range_db = random_scale_range_db
186
+ self.rescale = rescale
187
+
188
+ self.chunk_size_sample = int(self.chunk_size_second * self.fs)
189
+ self.target_length = target_length
190
+
191
+ def __len__(self) -> int:
192
+ return self.target_length
193
+
194
+ def __getitem__(self, index: int) -> DataDict:
195
+
196
+ index = index % self.n_segments
197
+
198
+ audio = {}
199
+ identifier = self.get_identifier(index)
200
+
201
+ for stem in self.stems_no_mixture:
202
+ if stem == self.target_stem:
203
+ identifier_ = identifier
204
+ else:
205
+ if np.random.rand() < self.apply_probability:
206
+ index_ = np.random.randint(self.n_segments)
207
+ identifier_ = self.get_identifier(index_)
208
+ else:
209
+ identifier_ = identifier
210
+
211
+ audio[stem] = self.get_stem(stem=stem, identifier=identifier_)
212
+
213
+ if self.chunk_size_sample < audio[stem].shape[-1]:
214
+ chunk_start = np.random.randint(
215
+ audio[stem].shape[-1] - self.chunk_size_sample
216
+ )
217
+ else:
218
+ chunk_start = 0
219
+
220
+ if np.random.rand() < self.drop_probability:
221
+ linear_scale = 0.0
222
+ else:
223
+ db_scale = np.random.uniform(*self.random_scale_range_db)
224
+ linear_scale = np.power(10, db_scale / 20)
225
+ audio[stem][..., chunk_start : chunk_start + self.chunk_size_sample] = (
226
+ linear_scale
227
+ * audio[stem][..., chunk_start : chunk_start + self.chunk_size_sample]
228
+ )
229
+
230
+ audio["mixture"] = self.compute_mixture(audio)
231
+
232
+ if self.rescale:
233
+ max_abs_val = max(
234
+ [torch.max(torch.abs(audio[stem])) for stem in self.stems]
235
+ ) # type: ignore[type-var]
236
+ if max_abs_val > 1:
237
+ audio = {k: v / max_abs_val for k, v in audio.items()}
238
+
239
+ track = identifier["track"]
240
+
241
+ return {"audio": audio, "track": f"{self.split}/{track}"}
mvsepless/models/bandit/core/data/musdb/preprocess.py CHANGED
@@ -1,223 +1,223 @@
1
- import glob
2
- import os
3
-
4
- import numpy as np
5
- import torch
6
- import torchaudio as ta
7
- from torch import nn
8
- from torch.nn import functional as F
9
- from tqdm.contrib.concurrent import process_map
10
-
11
- from .._types import DataDict
12
- from .dataset import MUSDB18FullTrackDataset
13
- import pyloudnorm as pyln
14
-
15
-
16
- class SourceActivityDetector(nn.Module):
17
- def __init__(
18
- self,
19
- analysis_stem: str,
20
- output_path: str,
21
- fs: int = 44100,
22
- segment_length_second: float = 6.0,
23
- hop_length_second: float = 3.0,
24
- n_chunks: int = 10,
25
- chunk_epsilon: float = 1e-5,
26
- energy_threshold_quantile: float = 0.15,
27
- segment_epsilon: float = 1e-3,
28
- salient_proportion_threshold: float = 0.5,
29
- target_lufs: float = -24,
30
- ) -> None:
31
- super().__init__()
32
-
33
- self.fs = fs
34
- self.segment_length = int(segment_length_second * self.fs)
35
- self.hop_length = int(hop_length_second * self.fs)
36
- self.n_chunks = n_chunks
37
- assert self.segment_length % self.n_chunks == 0
38
- self.chunk_size = self.segment_length // self.n_chunks
39
- self.chunk_epsilon = chunk_epsilon
40
- self.energy_threshold_quantile = energy_threshold_quantile
41
- self.segment_epsilon = segment_epsilon
42
- self.salient_proportion_threshold = salient_proportion_threshold
43
- self.analysis_stem = analysis_stem
44
-
45
- self.meter = pyln.Meter(self.fs)
46
- self.target_lufs = target_lufs
47
-
48
- self.output_path = output_path
49
-
50
- def forward(self, data: DataDict) -> None:
51
-
52
- stem_ = self.analysis_stem if (self.analysis_stem != "none") else "mixture"
53
-
54
- x = data["audio"][stem_]
55
-
56
- xnp = x.numpy()
57
- loudness = self.meter.integrated_loudness(xnp.T)
58
-
59
- for stem in data["audio"]:
60
- s = data["audio"][stem]
61
- s = pyln.normalize.loudness(s.numpy().T, loudness, self.target_lufs).T
62
- s = torch.as_tensor(s)
63
- data["audio"][stem] = s
64
-
65
- if x.ndim == 3:
66
- assert x.shape[0] == 1
67
- x = x[0]
68
-
69
- n_chan, n_samples = x.shape
70
-
71
- n_segments = (
72
- int(np.ceil((n_samples - self.segment_length) / self.hop_length)) + 1
73
- )
74
-
75
- segments = torch.zeros((n_segments, n_chan, self.segment_length))
76
- for i in range(n_segments):
77
- start = i * self.hop_length
78
- end = start + self.segment_length
79
- end = min(end, n_samples)
80
-
81
- xseg = x[:, start:end]
82
-
83
- if end - start < self.segment_length:
84
- xseg = F.pad(
85
- xseg, pad=(0, self.segment_length - (end - start)), value=torch.nan
86
- )
87
-
88
- segments[i, :, :] = xseg
89
-
90
- chunks = segments.reshape((n_segments, n_chan, self.n_chunks, self.chunk_size))
91
-
92
- if self.analysis_stem != "none":
93
- chunk_energies = torch.mean(torch.square(chunks), dim=(1, 3))
94
- chunk_energies = torch.nan_to_num(chunk_energies, nan=0)
95
- chunk_energies[chunk_energies == 0] = self.chunk_epsilon
96
-
97
- energy_threshold = torch.nanquantile(
98
- chunk_energies, q=self.energy_threshold_quantile
99
- )
100
-
101
- if energy_threshold < self.segment_epsilon:
102
- energy_threshold = self.segment_epsilon # type: ignore[assignment]
103
-
104
- chunks_above_threshold = chunk_energies > energy_threshold
105
- n_chunks_above_threshold = torch.mean(
106
- chunks_above_threshold.to(torch.float), dim=-1
107
- )
108
-
109
- segment_above_threshold = (
110
- n_chunks_above_threshold > self.salient_proportion_threshold
111
- )
112
-
113
- if torch.sum(segment_above_threshold) == 0:
114
- return
115
-
116
- else:
117
- segment_above_threshold = torch.ones((n_segments,))
118
-
119
- for i in range(n_segments):
120
- if not segment_above_threshold[i]:
121
- continue
122
-
123
- outpath = os.path.join(
124
- self.output_path,
125
- self.analysis_stem,
126
- f"{data['track']} - {self.analysis_stem}{i:03d}",
127
- )
128
- os.makedirs(outpath, exist_ok=True)
129
-
130
- for stem in data["audio"]:
131
- if stem == self.analysis_stem:
132
- segment = torch.nan_to_num(segments[i, :, :], nan=0)
133
- else:
134
- start = i * self.hop_length
135
- end = start + self.segment_length
136
- end = min(n_samples, end)
137
-
138
- segment = data["audio"][stem][:, start:end]
139
-
140
- if end - start < self.segment_length:
141
- segment = F.pad(
142
- segment, (0, self.segment_length - (end - start))
143
- )
144
-
145
- assert segment.shape[-1] == self.segment_length, segment.shape
146
-
147
- np.save(os.path.join(outpath, f"{stem}.wav"), segment)
148
-
149
-
150
- def preprocess(
151
- analysis_stem: str,
152
- output_path: str = "/data/MUSDB18/HQ/saded-np",
153
- fs: int = 44100,
154
- segment_length_second: float = 6.0,
155
- hop_length_second: float = 3.0,
156
- n_chunks: int = 10,
157
- chunk_epsilon: float = 1e-5,
158
- energy_threshold_quantile: float = 0.15,
159
- segment_epsilon: float = 1e-3,
160
- salient_proportion_threshold: float = 0.5,
161
- ) -> None:
162
-
163
- sad = SourceActivityDetector(
164
- analysis_stem=analysis_stem,
165
- output_path=output_path,
166
- fs=fs,
167
- segment_length_second=segment_length_second,
168
- hop_length_second=hop_length_second,
169
- n_chunks=n_chunks,
170
- chunk_epsilon=chunk_epsilon,
171
- energy_threshold_quantile=energy_threshold_quantile,
172
- segment_epsilon=segment_epsilon,
173
- salient_proportion_threshold=salient_proportion_threshold,
174
- )
175
-
176
- for split in ["train", "val", "test"]:
177
- ds = MUSDB18FullTrackDataset(
178
- data_root="/data/MUSDB18/HQ/canonical",
179
- split=split,
180
- )
181
-
182
- tracks = []
183
- for i, track in enumerate(tqdm(ds, total=len(ds))):
184
- if i % 32 == 0 and tracks:
185
- process_map(sad, tracks, max_workers=8)
186
- tracks = []
187
- tracks.append(track)
188
- process_map(sad, tracks, max_workers=8)
189
-
190
-
191
- def loudness_norm_one(inputs):
192
- infile, outfile, target_lufs = inputs
193
-
194
- audio, fs = ta.load(infile)
195
- audio = audio.mean(dim=0, keepdim=True).numpy().T
196
-
197
- meter = pyln.Meter(fs)
198
- loudness = meter.integrated_loudness(audio)
199
- audio = pyln.normalize.loudness(audio, loudness, target_lufs)
200
-
201
- os.makedirs(os.path.dirname(outfile), exist_ok=True)
202
- np.save(outfile, audio.T)
203
-
204
-
205
- def loudness_norm(
206
- data_path: str,
207
- target_lufs=-17.0,
208
- ):
209
- files = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True)
210
-
211
- outfiles = [f.replace(".wav", ".npy").replace("saded", "saded-np") for f in files]
212
-
213
- files = [(f, o, target_lufs) for f, o in zip(files, outfiles)]
214
-
215
- process_map(loudness_norm_one, files, chunksize=2)
216
-
217
-
218
- if __name__ == "__main__":
219
-
220
- from tqdm.auto import tqdm
221
- import fire
222
-
223
- fire.Fire()
 
1
+ import glob
2
+ import os
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torchaudio as ta
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+ from tqdm.contrib.concurrent import process_map
10
+
11
+ from .._types import DataDict
12
+ from .dataset import MUSDB18FullTrackDataset
13
+ import pyloudnorm as pyln
14
+
15
+
16
+ class SourceActivityDetector(nn.Module):
17
+ def __init__(
18
+ self,
19
+ analysis_stem: str,
20
+ output_path: str,
21
+ fs: int = 44100,
22
+ segment_length_second: float = 6.0,
23
+ hop_length_second: float = 3.0,
24
+ n_chunks: int = 10,
25
+ chunk_epsilon: float = 1e-5,
26
+ energy_threshold_quantile: float = 0.15,
27
+ segment_epsilon: float = 1e-3,
28
+ salient_proportion_threshold: float = 0.5,
29
+ target_lufs: float = -24,
30
+ ) -> None:
31
+ super().__init__()
32
+
33
+ self.fs = fs
34
+ self.segment_length = int(segment_length_second * self.fs)
35
+ self.hop_length = int(hop_length_second * self.fs)
36
+ self.n_chunks = n_chunks
37
+ assert self.segment_length % self.n_chunks == 0
38
+ self.chunk_size = self.segment_length // self.n_chunks
39
+ self.chunk_epsilon = chunk_epsilon
40
+ self.energy_threshold_quantile = energy_threshold_quantile
41
+ self.segment_epsilon = segment_epsilon
42
+ self.salient_proportion_threshold = salient_proportion_threshold
43
+ self.analysis_stem = analysis_stem
44
+
45
+ self.meter = pyln.Meter(self.fs)
46
+ self.target_lufs = target_lufs
47
+
48
+ self.output_path = output_path
49
+
50
+ def forward(self, data: DataDict) -> None:
51
+
52
+ stem_ = self.analysis_stem if (self.analysis_stem != "none") else "mixture"
53
+
54
+ x = data["audio"][stem_]
55
+
56
+ xnp = x.numpy()
57
+ loudness = self.meter.integrated_loudness(xnp.T)
58
+
59
+ for stem in data["audio"]:
60
+ s = data["audio"][stem]
61
+ s = pyln.normalize.loudness(s.numpy().T, loudness, self.target_lufs).T
62
+ s = torch.as_tensor(s)
63
+ data["audio"][stem] = s
64
+
65
+ if x.ndim == 3:
66
+ assert x.shape[0] == 1
67
+ x = x[0]
68
+
69
+ n_chan, n_samples = x.shape
70
+
71
+ n_segments = (
72
+ int(np.ceil((n_samples - self.segment_length) / self.hop_length)) + 1
73
+ )
74
+
75
+ segments = torch.zeros((n_segments, n_chan, self.segment_length))
76
+ for i in range(n_segments):
77
+ start = i * self.hop_length
78
+ end = start + self.segment_length
79
+ end = min(end, n_samples)
80
+
81
+ xseg = x[:, start:end]
82
+
83
+ if end - start < self.segment_length:
84
+ xseg = F.pad(
85
+ xseg, pad=(0, self.segment_length - (end - start)), value=torch.nan
86
+ )
87
+
88
+ segments[i, :, :] = xseg
89
+
90
+ chunks = segments.reshape((n_segments, n_chan, self.n_chunks, self.chunk_size))
91
+
92
+ if self.analysis_stem != "none":
93
+ chunk_energies = torch.mean(torch.square(chunks), dim=(1, 3))
94
+ chunk_energies = torch.nan_to_num(chunk_energies, nan=0)
95
+ chunk_energies[chunk_energies == 0] = self.chunk_epsilon
96
+
97
+ energy_threshold = torch.nanquantile(
98
+ chunk_energies, q=self.energy_threshold_quantile
99
+ )
100
+
101
+ if energy_threshold < self.segment_epsilon:
102
+ energy_threshold = self.segment_epsilon # type: ignore[assignment]
103
+
104
+ chunks_above_threshold = chunk_energies > energy_threshold
105
+ n_chunks_above_threshold = torch.mean(
106
+ chunks_above_threshold.to(torch.float), dim=-1
107
+ )
108
+
109
+ segment_above_threshold = (
110
+ n_chunks_above_threshold > self.salient_proportion_threshold
111
+ )
112
+
113
+ if torch.sum(segment_above_threshold) == 0:
114
+ return
115
+
116
+ else:
117
+ segment_above_threshold = torch.ones((n_segments,))
118
+
119
+ for i in range(n_segments):
120
+ if not segment_above_threshold[i]:
121
+ continue
122
+
123
+ outpath = os.path.join(
124
+ self.output_path,
125
+ self.analysis_stem,
126
+ f"{data['track']} - {self.analysis_stem}{i:03d}",
127
+ )
128
+ os.makedirs(outpath, exist_ok=True)
129
+
130
+ for stem in data["audio"]:
131
+ if stem == self.analysis_stem:
132
+ segment = torch.nan_to_num(segments[i, :, :], nan=0)
133
+ else:
134
+ start = i * self.hop_length
135
+ end = start + self.segment_length
136
+ end = min(n_samples, end)
137
+
138
+ segment = data["audio"][stem][:, start:end]
139
+
140
+ if end - start < self.segment_length:
141
+ segment = F.pad(
142
+ segment, (0, self.segment_length - (end - start))
143
+ )
144
+
145
+ assert segment.shape[-1] == self.segment_length, segment.shape
146
+
147
+ np.save(os.path.join(outpath, f"{stem}.wav"), segment)
148
+
149
+
150
+ def preprocess(
151
+ analysis_stem: str,
152
+ output_path: str = "/data/MUSDB18/HQ/saded-np",
153
+ fs: int = 44100,
154
+ segment_length_second: float = 6.0,
155
+ hop_length_second: float = 3.0,
156
+ n_chunks: int = 10,
157
+ chunk_epsilon: float = 1e-5,
158
+ energy_threshold_quantile: float = 0.15,
159
+ segment_epsilon: float = 1e-3,
160
+ salient_proportion_threshold: float = 0.5,
161
+ ) -> None:
162
+
163
+ sad = SourceActivityDetector(
164
+ analysis_stem=analysis_stem,
165
+ output_path=output_path,
166
+ fs=fs,
167
+ segment_length_second=segment_length_second,
168
+ hop_length_second=hop_length_second,
169
+ n_chunks=n_chunks,
170
+ chunk_epsilon=chunk_epsilon,
171
+ energy_threshold_quantile=energy_threshold_quantile,
172
+ segment_epsilon=segment_epsilon,
173
+ salient_proportion_threshold=salient_proportion_threshold,
174
+ )
175
+
176
+ for split in ["train", "val", "test"]:
177
+ ds = MUSDB18FullTrackDataset(
178
+ data_root="/data/MUSDB18/HQ/canonical",
179
+ split=split,
180
+ )
181
+
182
+ tracks = []
183
+ for i, track in enumerate(tqdm(ds, total=len(ds))):
184
+ if i % 32 == 0 and tracks:
185
+ process_map(sad, tracks, max_workers=8)
186
+ tracks = []
187
+ tracks.append(track)
188
+ process_map(sad, tracks, max_workers=8)
189
+
190
+
191
+ def loudness_norm_one(inputs):
192
+ infile, outfile, target_lufs = inputs
193
+
194
+ audio, fs = ta.load(infile)
195
+ audio = audio.mean(dim=0, keepdim=True).numpy().T
196
+
197
+ meter = pyln.Meter(fs)
198
+ loudness = meter.integrated_loudness(audio)
199
+ audio = pyln.normalize.loudness(audio, loudness, target_lufs)
200
+
201
+ os.makedirs(os.path.dirname(outfile), exist_ok=True)
202
+ np.save(outfile, audio.T)
203
+
204
+
205
+ def loudness_norm(
206
+ data_path: str,
207
+ target_lufs=-17.0,
208
+ ):
209
+ files = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True)
210
+
211
+ outfiles = [f.replace(".wav", ".npy").replace("saded", "saded-np") for f in files]
212
+
213
+ files = [(f, o, target_lufs) for f, o in zip(files, outfiles)]
214
+
215
+ process_map(loudness_norm_one, files, chunksize=2)
216
+
217
+
218
+ if __name__ == "__main__":
219
+
220
+ from tqdm.auto import tqdm
221
+ import fire
222
+
223
+ fire.Fire()
mvsepless/models/bandit/core/data/musdb/validation.yaml CHANGED
@@ -1,15 +1,15 @@
1
- validation:
2
- - 'Actions - One Minute Smile'
3
- - 'Clara Berry And Wooldog - Waltz For My Victims'
4
- - 'Johnny Lokke - Promises & Lies'
5
- - 'Patrick Talbot - A Reason To Leave'
6
- - 'Triviul - Angelsaint'
7
- - 'Alexander Ross - Goodbye Bolero'
8
- - 'Fergessen - Nos Palpitants'
9
- - 'Leaf - Summerghost'
10
- - 'Skelpolu - Human Mistakes'
11
- - 'Young Griffo - Pennies'
12
- - 'ANiMAL - Rockshow'
13
- - 'James May - On The Line'
14
- - 'Meaxic - Take A Step'
15
  - 'Traffic Experiment - Sirens'
 
1
+ validation:
2
+ - 'Actions - One Minute Smile'
3
+ - 'Clara Berry And Wooldog - Waltz For My Victims'
4
+ - 'Johnny Lokke - Promises & Lies'
5
+ - 'Patrick Talbot - A Reason To Leave'
6
+ - 'Triviul - Angelsaint'
7
+ - 'Alexander Ross - Goodbye Bolero'
8
+ - 'Fergessen - Nos Palpitants'
9
+ - 'Leaf - Summerghost'
10
+ - 'Skelpolu - Human Mistakes'
11
+ - 'Young Griffo - Pennies'
12
+ - 'ANiMAL - Rockshow'
13
+ - 'James May - On The Line'
14
+ - 'Meaxic - Take A Step'
15
  - 'Traffic Experiment - Sirens'
mvsepless/models/bandit/core/loss/__init__.py CHANGED
@@ -1,8 +1,8 @@
1
- from ._multistem import MultiStemWrapperFromConfig
2
- from ._timefreq import (
3
- ReImL1Loss,
4
- ReImL2Loss,
5
- TimeFreqL1Loss,
6
- TimeFreqL2Loss,
7
- TimeFreqSignalNoisePNormRatioLoss,
8
- )
 
1
+ from ._multistem import MultiStemWrapperFromConfig
2
+ from ._timefreq import (
3
+ ReImL1Loss,
4
+ ReImL2Loss,
5
+ TimeFreqL1Loss,
6
+ TimeFreqL2Loss,
7
+ TimeFreqSignalNoisePNormRatioLoss,
8
+ )
mvsepless/models/bandit/core/loss/_complex.py CHANGED
@@ -1,27 +1,27 @@
1
- from typing import Any
2
-
3
- import torch
4
- from torch import nn
5
- from torch.nn.modules import loss as _loss
6
- from torch.nn.modules.loss import _Loss
7
-
8
-
9
- class ReImLossWrapper(_Loss):
10
- def __init__(self, module: _Loss) -> None:
11
- super().__init__()
12
- self.module = module
13
-
14
- def forward(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
15
- return self.module(torch.view_as_real(preds), torch.view_as_real(target))
16
-
17
-
18
- class ReImL1Loss(ReImLossWrapper):
19
- def __init__(self, **kwargs: Any) -> None:
20
- l1_loss = _loss.L1Loss(**kwargs)
21
- super().__init__(module=(l1_loss))
22
-
23
-
24
- class ReImL2Loss(ReImLossWrapper):
25
- def __init__(self, **kwargs: Any) -> None:
26
- l2_loss = _loss.MSELoss(**kwargs)
27
- super().__init__(module=(l2_loss))
 
1
+ from typing import Any
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn.modules import loss as _loss
6
+ from torch.nn.modules.loss import _Loss
7
+
8
+
9
+ class ReImLossWrapper(_Loss):
10
+ def __init__(self, module: _Loss) -> None:
11
+ super().__init__()
12
+ self.module = module
13
+
14
+ def forward(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
15
+ return self.module(torch.view_as_real(preds), torch.view_as_real(target))
16
+
17
+
18
+ class ReImL1Loss(ReImLossWrapper):
19
+ def __init__(self, **kwargs: Any) -> None:
20
+ l1_loss = _loss.L1Loss(**kwargs)
21
+ super().__init__(module=(l1_loss))
22
+
23
+
24
+ class ReImL2Loss(ReImLossWrapper):
25
+ def __init__(self, **kwargs: Any) -> None:
26
+ l2_loss = _loss.MSELoss(**kwargs)
27
+ super().__init__(module=(l2_loss))
mvsepless/models/bandit/core/loss/_multistem.py CHANGED
@@ -1,43 +1,43 @@
1
- from typing import Any, Dict
2
-
3
- import torch
4
- from asteroid import losses as asteroid_losses
5
- from torch import nn
6
- from torch.nn.modules.loss import _Loss
7
-
8
- from . import snr
9
-
10
-
11
- def parse_loss(name: str, kwargs: Dict[str, Any]) -> _Loss:
12
-
13
- for module in [nn.modules.loss, snr, asteroid_losses, asteroid_losses.sdr]:
14
- if name in module.__dict__:
15
- return module.__dict__[name](**kwargs)
16
-
17
- raise NameError
18
-
19
-
20
- class MultiStemWrapper(_Loss):
21
- def __init__(self, module: _Loss, modality: str = "audio") -> None:
22
- super().__init__()
23
- self.loss = module
24
- self.modality = modality
25
-
26
- def forward(
27
- self,
28
- preds: Dict[str, Dict[str, torch.Tensor]],
29
- target: Dict[str, Dict[str, torch.Tensor]],
30
- ) -> torch.Tensor:
31
- loss = {
32
- stem: self.loss(preds[self.modality][stem], target[self.modality][stem])
33
- for stem in preds[self.modality]
34
- if stem in target[self.modality]
35
- }
36
-
37
- return sum(list(loss.values()))
38
-
39
-
40
- class MultiStemWrapperFromConfig(MultiStemWrapper):
41
- def __init__(self, name: str, kwargs: Any, modality: str = "audio") -> None:
42
- loss = parse_loss(name, kwargs)
43
- super().__init__(module=loss, modality=modality)
 
1
+ from typing import Any, Dict
2
+
3
+ import torch
4
+ from asteroid import losses as asteroid_losses
5
+ from torch import nn
6
+ from torch.nn.modules.loss import _Loss
7
+
8
+ from . import snr
9
+
10
+
11
+ def parse_loss(name: str, kwargs: Dict[str, Any]) -> _Loss:
12
+
13
+ for module in [nn.modules.loss, snr, asteroid_losses, asteroid_losses.sdr]:
14
+ if name in module.__dict__:
15
+ return module.__dict__[name](**kwargs)
16
+
17
+ raise NameError
18
+
19
+
20
+ class MultiStemWrapper(_Loss):
21
+ def __init__(self, module: _Loss, modality: str = "audio") -> None:
22
+ super().__init__()
23
+ self.loss = module
24
+ self.modality = modality
25
+
26
+ def forward(
27
+ self,
28
+ preds: Dict[str, Dict[str, torch.Tensor]],
29
+ target: Dict[str, Dict[str, torch.Tensor]],
30
+ ) -> torch.Tensor:
31
+ loss = {
32
+ stem: self.loss(preds[self.modality][stem], target[self.modality][stem])
33
+ for stem in preds[self.modality]
34
+ if stem in target[self.modality]
35
+ }
36
+
37
+ return sum(list(loss.values()))
38
+
39
+
40
+ class MultiStemWrapperFromConfig(MultiStemWrapper):
41
+ def __init__(self, name: str, kwargs: Any, modality: str = "audio") -> None:
42
+ loss = parse_loss(name, kwargs)
43
+ super().__init__(module=loss, modality=modality)
mvsepless/models/bandit/core/loss/_timefreq.py CHANGED
@@ -1,94 +1,94 @@
1
- from typing import Any, Dict, Optional
2
-
3
- import torch
4
- from torch import nn
5
- from torch.nn.modules.loss import _Loss
6
-
7
- from ._multistem import MultiStemWrapper
8
- from ._complex import ReImL1Loss, ReImL2Loss, ReImLossWrapper
9
- from .snr import SignalNoisePNormRatio
10
-
11
-
12
- class TimeFreqWrapper(_Loss):
13
- def __init__(
14
- self,
15
- time_module: _Loss,
16
- freq_module: Optional[_Loss] = None,
17
- time_weight: float = 1.0,
18
- freq_weight: float = 1.0,
19
- multistem: bool = True,
20
- ) -> None:
21
- super().__init__()
22
-
23
- if freq_module is None:
24
- freq_module = time_module
25
-
26
- if multistem:
27
- time_module = MultiStemWrapper(time_module, modality="audio")
28
- freq_module = MultiStemWrapper(freq_module, modality="spectrogram")
29
-
30
- self.time_module = time_module
31
- self.freq_module = freq_module
32
-
33
- self.time_weight = time_weight
34
- self.freq_weight = freq_weight
35
-
36
- def forward(self, preds: Any, target: Any) -> torch.Tensor:
37
-
38
- return self.time_weight * self.time_module(
39
- preds, target
40
- ) + self.freq_weight * self.freq_module(preds, target)
41
-
42
-
43
- class TimeFreqL1Loss(TimeFreqWrapper):
44
- def __init__(
45
- self,
46
- time_weight: float = 1.0,
47
- freq_weight: float = 1.0,
48
- tkwargs: Optional[Dict[str, Any]] = None,
49
- fkwargs: Optional[Dict[str, Any]] = None,
50
- multistem: bool = True,
51
- ) -> None:
52
- if tkwargs is None:
53
- tkwargs = {}
54
- if fkwargs is None:
55
- fkwargs = {}
56
- time_module = nn.L1Loss(**tkwargs)
57
- freq_module = ReImL1Loss(**fkwargs)
58
- super().__init__(time_module, freq_module, time_weight, freq_weight, multistem)
59
-
60
-
61
- class TimeFreqL2Loss(TimeFreqWrapper):
62
- def __init__(
63
- self,
64
- time_weight: float = 1.0,
65
- freq_weight: float = 1.0,
66
- tkwargs: Optional[Dict[str, Any]] = None,
67
- fkwargs: Optional[Dict[str, Any]] = None,
68
- multistem: bool = True,
69
- ) -> None:
70
- if tkwargs is None:
71
- tkwargs = {}
72
- if fkwargs is None:
73
- fkwargs = {}
74
- time_module = nn.MSELoss(**tkwargs)
75
- freq_module = ReImL2Loss(**fkwargs)
76
- super().__init__(time_module, freq_module, time_weight, freq_weight, multistem)
77
-
78
-
79
- class TimeFreqSignalNoisePNormRatioLoss(TimeFreqWrapper):
80
- def __init__(
81
- self,
82
- time_weight: float = 1.0,
83
- freq_weight: float = 1.0,
84
- tkwargs: Optional[Dict[str, Any]] = None,
85
- fkwargs: Optional[Dict[str, Any]] = None,
86
- multistem: bool = True,
87
- ) -> None:
88
- if tkwargs is None:
89
- tkwargs = {}
90
- if fkwargs is None:
91
- fkwargs = {}
92
- time_module = SignalNoisePNormRatio(**tkwargs)
93
- freq_module = SignalNoisePNormRatio(**fkwargs)
94
- super().__init__(time_module, freq_module, time_weight, freq_weight, multistem)
 
1
+ from typing import Any, Dict, Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn.modules.loss import _Loss
6
+
7
+ from ._multistem import MultiStemWrapper
8
+ from ._complex import ReImL1Loss, ReImL2Loss, ReImLossWrapper
9
+ from .snr import SignalNoisePNormRatio
10
+
11
+
12
+ class TimeFreqWrapper(_Loss):
13
+ def __init__(
14
+ self,
15
+ time_module: _Loss,
16
+ freq_module: Optional[_Loss] = None,
17
+ time_weight: float = 1.0,
18
+ freq_weight: float = 1.0,
19
+ multistem: bool = True,
20
+ ) -> None:
21
+ super().__init__()
22
+
23
+ if freq_module is None:
24
+ freq_module = time_module
25
+
26
+ if multistem:
27
+ time_module = MultiStemWrapper(time_module, modality="audio")
28
+ freq_module = MultiStemWrapper(freq_module, modality="spectrogram")
29
+
30
+ self.time_module = time_module
31
+ self.freq_module = freq_module
32
+
33
+ self.time_weight = time_weight
34
+ self.freq_weight = freq_weight
35
+
36
+ def forward(self, preds: Any, target: Any) -> torch.Tensor:
37
+
38
+ return self.time_weight * self.time_module(
39
+ preds, target
40
+ ) + self.freq_weight * self.freq_module(preds, target)
41
+
42
+
43
+ class TimeFreqL1Loss(TimeFreqWrapper):
44
+ def __init__(
45
+ self,
46
+ time_weight: float = 1.0,
47
+ freq_weight: float = 1.0,
48
+ tkwargs: Optional[Dict[str, Any]] = None,
49
+ fkwargs: Optional[Dict[str, Any]] = None,
50
+ multistem: bool = True,
51
+ ) -> None:
52
+ if tkwargs is None:
53
+ tkwargs = {}
54
+ if fkwargs is None:
55
+ fkwargs = {}
56
+ time_module = nn.L1Loss(**tkwargs)
57
+ freq_module = ReImL1Loss(**fkwargs)
58
+ super().__init__(time_module, freq_module, time_weight, freq_weight, multistem)
59
+
60
+
61
+ class TimeFreqL2Loss(TimeFreqWrapper):
62
+ def __init__(
63
+ self,
64
+ time_weight: float = 1.0,
65
+ freq_weight: float = 1.0,
66
+ tkwargs: Optional[Dict[str, Any]] = None,
67
+ fkwargs: Optional[Dict[str, Any]] = None,
68
+ multistem: bool = True,
69
+ ) -> None:
70
+ if tkwargs is None:
71
+ tkwargs = {}
72
+ if fkwargs is None:
73
+ fkwargs = {}
74
+ time_module = nn.MSELoss(**tkwargs)
75
+ freq_module = ReImL2Loss(**fkwargs)
76
+ super().__init__(time_module, freq_module, time_weight, freq_weight, multistem)
77
+
78
+
79
+ class TimeFreqSignalNoisePNormRatioLoss(TimeFreqWrapper):
80
+ def __init__(
81
+ self,
82
+ time_weight: float = 1.0,
83
+ freq_weight: float = 1.0,
84
+ tkwargs: Optional[Dict[str, Any]] = None,
85
+ fkwargs: Optional[Dict[str, Any]] = None,
86
+ multistem: bool = True,
87
+ ) -> None:
88
+ if tkwargs is None:
89
+ tkwargs = {}
90
+ if fkwargs is None:
91
+ fkwargs = {}
92
+ time_module = SignalNoisePNormRatio(**tkwargs)
93
+ freq_module = SignalNoisePNormRatio(**fkwargs)
94
+ super().__init__(time_module, freq_module, time_weight, freq_weight, multistem)
mvsepless/models/bandit/core/loss/snr.py CHANGED
@@ -1,131 +1,131 @@
1
- import torch
2
- from torch.nn.modules.loss import _Loss
3
- from torch.nn import functional as F
4
-
5
-
6
- class SignalNoisePNormRatio(_Loss):
7
- def __init__(
8
- self,
9
- p: float = 1.0,
10
- scale_invariant: bool = False,
11
- zero_mean: bool = False,
12
- take_log: bool = True,
13
- reduction: str = "mean",
14
- EPS: float = 1e-3,
15
- ) -> None:
16
- assert reduction != "sum", NotImplementedError
17
- super().__init__(reduction=reduction)
18
- assert not zero_mean
19
-
20
- self.p = p
21
-
22
- self.EPS = EPS
23
- self.take_log = take_log
24
-
25
- self.scale_invariant = scale_invariant
26
-
27
- def forward(self, est_target: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
28
-
29
- target_ = target
30
- if self.scale_invariant:
31
- ndim = target.ndim
32
- dot = torch.sum(est_target * torch.conj(target), dim=-1, keepdim=True)
33
- s_target_energy = torch.sum(
34
- target * torch.conj(target), dim=-1, keepdim=True
35
- )
36
-
37
- if ndim > 2:
38
- dot = torch.sum(dot, dim=list(range(1, ndim)), keepdim=True)
39
- s_target_energy = torch.sum(
40
- s_target_energy, dim=list(range(1, ndim)), keepdim=True
41
- )
42
-
43
- target_scaler = (dot + 1e-8) / (s_target_energy + 1e-8)
44
- target = target_ * target_scaler
45
-
46
- if torch.is_complex(est_target):
47
- est_target = torch.view_as_real(est_target)
48
- target = torch.view_as_real(target)
49
-
50
- batch_size = est_target.shape[0]
51
- est_target = est_target.reshape(batch_size, -1)
52
- target = target.reshape(batch_size, -1)
53
-
54
- if self.p == 1:
55
- e_error = torch.abs(est_target - target).mean(dim=-1)
56
- e_target = torch.abs(target).mean(dim=-1)
57
- elif self.p == 2:
58
- e_error = torch.square(est_target - target).mean(dim=-1)
59
- e_target = torch.square(target).mean(dim=-1)
60
- else:
61
- raise NotImplementedError
62
-
63
- if self.take_log:
64
- loss = 10 * (
65
- torch.log10(e_error + self.EPS) - torch.log10(e_target + self.EPS)
66
- )
67
- else:
68
- loss = (e_error + self.EPS) / (e_target + self.EPS)
69
-
70
- if self.reduction == "mean":
71
- loss = loss.mean()
72
- elif self.reduction == "sum":
73
- loss = loss.sum()
74
-
75
- return loss
76
-
77
-
78
- class MultichannelSingleSrcNegSDR(_Loss):
79
- def __init__(
80
- self,
81
- sdr_type: str,
82
- p: float = 2.0,
83
- zero_mean: bool = True,
84
- take_log: bool = True,
85
- reduction: str = "mean",
86
- EPS: float = 1e-8,
87
- ) -> None:
88
- assert reduction != "sum", NotImplementedError
89
- super().__init__(reduction=reduction)
90
-
91
- assert sdr_type in ["snr", "sisdr", "sdsdr"]
92
- self.sdr_type = sdr_type
93
- self.zero_mean = zero_mean
94
- self.take_log = take_log
95
- self.EPS = 1e-8
96
-
97
- self.p = p
98
-
99
- def forward(self, est_target: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
100
- if target.size() != est_target.size() or target.ndim != 3:
101
- raise TypeError(
102
- f"Inputs must be of shape [batch, time], got {target.size()} and {est_target.size()} instead"
103
- )
104
- if self.zero_mean:
105
- mean_source = torch.mean(target, dim=[1, 2], keepdim=True)
106
- mean_estimate = torch.mean(est_target, dim=[1, 2], keepdim=True)
107
- target = target - mean_source
108
- est_target = est_target - mean_estimate
109
- if self.sdr_type in ["sisdr", "sdsdr"]:
110
- dot = torch.sum(est_target * target, dim=[1, 2], keepdim=True)
111
- s_target_energy = torch.sum(target**2, dim=[1, 2], keepdim=True) + self.EPS
112
- scaled_target = dot * target / s_target_energy
113
- else:
114
- scaled_target = target
115
- if self.sdr_type in ["sdsdr", "snr"]:
116
- e_noise = est_target - target
117
- else:
118
- e_noise = est_target - scaled_target
119
-
120
- if self.p == 2.0:
121
- losses = torch.sum(scaled_target**2, dim=[1, 2]) / (
122
- torch.sum(e_noise**2, dim=[1, 2]) + self.EPS
123
- )
124
- else:
125
- losses = torch.norm(scaled_target, p=self.p, dim=[1, 2]) / (
126
- torch.linalg.vector_norm(e_noise, p=self.p, dim=[1, 2]) + self.EPS
127
- )
128
- if self.take_log:
129
- losses = 10 * torch.log10(losses + self.EPS)
130
- losses = losses.mean() if self.reduction == "mean" else losses
131
- return -losses
 
1
+ import torch
2
+ from torch.nn.modules.loss import _Loss
3
+ from torch.nn import functional as F
4
+
5
+
6
+ class SignalNoisePNormRatio(_Loss):
7
+ def __init__(
8
+ self,
9
+ p: float = 1.0,
10
+ scale_invariant: bool = False,
11
+ zero_mean: bool = False,
12
+ take_log: bool = True,
13
+ reduction: str = "mean",
14
+ EPS: float = 1e-3,
15
+ ) -> None:
16
+ assert reduction != "sum", NotImplementedError
17
+ super().__init__(reduction=reduction)
18
+ assert not zero_mean
19
+
20
+ self.p = p
21
+
22
+ self.EPS = EPS
23
+ self.take_log = take_log
24
+
25
+ self.scale_invariant = scale_invariant
26
+
27
+ def forward(self, est_target: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
28
+
29
+ target_ = target
30
+ if self.scale_invariant:
31
+ ndim = target.ndim
32
+ dot = torch.sum(est_target * torch.conj(target), dim=-1, keepdim=True)
33
+ s_target_energy = torch.sum(
34
+ target * torch.conj(target), dim=-1, keepdim=True
35
+ )
36
+
37
+ if ndim > 2:
38
+ dot = torch.sum(dot, dim=list(range(1, ndim)), keepdim=True)
39
+ s_target_energy = torch.sum(
40
+ s_target_energy, dim=list(range(1, ndim)), keepdim=True
41
+ )
42
+
43
+ target_scaler = (dot + 1e-8) / (s_target_energy + 1e-8)
44
+ target = target_ * target_scaler
45
+
46
+ if torch.is_complex(est_target):
47
+ est_target = torch.view_as_real(est_target)
48
+ target = torch.view_as_real(target)
49
+
50
+ batch_size = est_target.shape[0]
51
+ est_target = est_target.reshape(batch_size, -1)
52
+ target = target.reshape(batch_size, -1)
53
+
54
+ if self.p == 1:
55
+ e_error = torch.abs(est_target - target).mean(dim=-1)
56
+ e_target = torch.abs(target).mean(dim=-1)
57
+ elif self.p == 2:
58
+ e_error = torch.square(est_target - target).mean(dim=-1)
59
+ e_target = torch.square(target).mean(dim=-1)
60
+ else:
61
+ raise NotImplementedError
62
+
63
+ if self.take_log:
64
+ loss = 10 * (
65
+ torch.log10(e_error + self.EPS) - torch.log10(e_target + self.EPS)
66
+ )
67
+ else:
68
+ loss = (e_error + self.EPS) / (e_target + self.EPS)
69
+
70
+ if self.reduction == "mean":
71
+ loss = loss.mean()
72
+ elif self.reduction == "sum":
73
+ loss = loss.sum()
74
+
75
+ return loss
76
+
77
+
78
+ class MultichannelSingleSrcNegSDR(_Loss):
79
+ def __init__(
80
+ self,
81
+ sdr_type: str,
82
+ p: float = 2.0,
83
+ zero_mean: bool = True,
84
+ take_log: bool = True,
85
+ reduction: str = "mean",
86
+ EPS: float = 1e-8,
87
+ ) -> None:
88
+ assert reduction != "sum", NotImplementedError
89
+ super().__init__(reduction=reduction)
90
+
91
+ assert sdr_type in ["snr", "sisdr", "sdsdr"]
92
+ self.sdr_type = sdr_type
93
+ self.zero_mean = zero_mean
94
+ self.take_log = take_log
95
+ self.EPS = 1e-8
96
+
97
+ self.p = p
98
+
99
+ def forward(self, est_target: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
100
+ if target.size() != est_target.size() or target.ndim != 3:
101
+ raise TypeError(
102
+ f"Inputs must be of shape [batch, time], got {target.size()} and {est_target.size()} instead"
103
+ )
104
+ if self.zero_mean:
105
+ mean_source = torch.mean(target, dim=[1, 2], keepdim=True)
106
+ mean_estimate = torch.mean(est_target, dim=[1, 2], keepdim=True)
107
+ target = target - mean_source
108
+ est_target = est_target - mean_estimate
109
+ if self.sdr_type in ["sisdr", "sdsdr"]:
110
+ dot = torch.sum(est_target * target, dim=[1, 2], keepdim=True)
111
+ s_target_energy = torch.sum(target**2, dim=[1, 2], keepdim=True) + self.EPS
112
+ scaled_target = dot * target / s_target_energy
113
+ else:
114
+ scaled_target = target
115
+ if self.sdr_type in ["sdsdr", "snr"]:
116
+ e_noise = est_target - target
117
+ else:
118
+ e_noise = est_target - scaled_target
119
+
120
+ if self.p == 2.0:
121
+ losses = torch.sum(scaled_target**2, dim=[1, 2]) / (
122
+ torch.sum(e_noise**2, dim=[1, 2]) + self.EPS
123
+ )
124
+ else:
125
+ losses = torch.norm(scaled_target, p=self.p, dim=[1, 2]) / (
126
+ torch.linalg.vector_norm(e_noise, p=self.p, dim=[1, 2]) + self.EPS
127
+ )
128
+ if self.take_log:
129
+ losses = 10 * torch.log10(losses + self.EPS)
130
+ losses = losses.mean() if self.reduction == "mean" else losses
131
+ return -losses
mvsepless/models/bandit/core/metrics/__init__.py CHANGED
@@ -1,7 +1,7 @@
1
- from .snr import (
2
- ChunkMedianScaleInvariantSignalDistortionRatio,
3
- ChunkMedianScaleInvariantSignalNoiseRatio,
4
- ChunkMedianSignalDistortionRatio,
5
- ChunkMedianSignalNoiseRatio,
6
- SafeSignalDistortionRatio,
7
- )
 
1
+ from .snr import (
2
+ ChunkMedianScaleInvariantSignalDistortionRatio,
3
+ ChunkMedianScaleInvariantSignalNoiseRatio,
4
+ ChunkMedianSignalDistortionRatio,
5
+ ChunkMedianSignalNoiseRatio,
6
+ SafeSignalDistortionRatio,
7
+ )
mvsepless/models/bandit/core/metrics/_squim.py CHANGED
@@ -1,350 +1,350 @@
1
- from dataclasses import dataclass
2
-
3
- from torchaudio._internal import load_state_dict_from_url
4
-
5
- import math
6
- from typing import List, Optional, Tuple
7
-
8
- import torch
9
- import torch.nn as nn
10
- import torch.nn.functional as F
11
-
12
-
13
- def transform_wb_pesq_range(x: float) -> float:
14
- return 0.999 + (4.999 - 0.999) / (1 + math.exp(-1.3669 * x + 3.8224))
15
-
16
-
17
- PESQRange: Tuple[float, float] = (
18
- 1.0,
19
- transform_wb_pesq_range(4.5),
20
- )
21
-
22
-
23
- class RangeSigmoid(nn.Module):
24
- def __init__(self, val_range: Tuple[float, float] = (0.0, 1.0)) -> None:
25
- super(RangeSigmoid, self).__init__()
26
- assert isinstance(val_range, tuple) and len(val_range) == 2
27
- self.val_range: Tuple[float, float] = val_range
28
- self.sigmoid: nn.modules.Module = nn.Sigmoid()
29
-
30
- def forward(self, x: torch.Tensor) -> torch.Tensor:
31
- out = (
32
- self.sigmoid(x) * (self.val_range[1] - self.val_range[0])
33
- + self.val_range[0]
34
- )
35
- return out
36
-
37
-
38
- class Encoder(nn.Module):
39
-
40
- def __init__(self, feat_dim: int = 512, win_len: int = 32) -> None:
41
- super(Encoder, self).__init__()
42
-
43
- self.conv1d = nn.Conv1d(1, feat_dim, win_len, stride=win_len // 2, bias=False)
44
-
45
- def forward(self, x: torch.Tensor) -> torch.Tensor:
46
- out = x.unsqueeze(dim=1)
47
- out = F.relu(self.conv1d(out))
48
- return out
49
-
50
-
51
- class SingleRNN(nn.Module):
52
- def __init__(
53
- self, rnn_type: str, input_size: int, hidden_size: int, dropout: float = 0.0
54
- ) -> None:
55
- super(SingleRNN, self).__init__()
56
-
57
- self.rnn_type = rnn_type
58
- self.input_size = input_size
59
- self.hidden_size = hidden_size
60
-
61
- self.rnn: nn.modules.Module = getattr(nn, rnn_type)(
62
- input_size,
63
- hidden_size,
64
- 1,
65
- dropout=dropout,
66
- batch_first=True,
67
- bidirectional=True,
68
- )
69
-
70
- self.proj = nn.Linear(hidden_size * 2, input_size)
71
-
72
- def forward(self, x: torch.Tensor) -> torch.Tensor:
73
- out, _ = self.rnn(x)
74
- out = self.proj(out)
75
- return out
76
-
77
-
78
- class DPRNN(nn.Module):
79
-
80
- def __init__(
81
- self,
82
- feat_dim: int = 64,
83
- hidden_dim: int = 128,
84
- num_blocks: int = 6,
85
- rnn_type: str = "LSTM",
86
- d_model: int = 256,
87
- chunk_size: int = 100,
88
- chunk_stride: int = 50,
89
- ) -> None:
90
- super(DPRNN, self).__init__()
91
-
92
- self.num_blocks = num_blocks
93
-
94
- self.row_rnn = nn.ModuleList([])
95
- self.col_rnn = nn.ModuleList([])
96
- self.row_norm = nn.ModuleList([])
97
- self.col_norm = nn.ModuleList([])
98
- for _ in range(num_blocks):
99
- self.row_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
100
- self.col_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
101
- self.row_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
102
- self.col_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
103
- self.conv = nn.Sequential(
104
- nn.Conv2d(feat_dim, d_model, 1),
105
- nn.PReLU(),
106
- )
107
- self.chunk_size = chunk_size
108
- self.chunk_stride = chunk_stride
109
-
110
- def pad_chunk(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
111
- seq_len = x.shape[-1]
112
-
113
- rest = (
114
- self.chunk_size
115
- - (self.chunk_stride + seq_len % self.chunk_size) % self.chunk_size
116
- )
117
- out = F.pad(x, [self.chunk_stride, rest + self.chunk_stride])
118
-
119
- return out, rest
120
-
121
- def chunking(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
122
- out, rest = self.pad_chunk(x)
123
- batch_size, feat_dim, seq_len = out.shape
124
-
125
- segments1 = (
126
- out[:, :, : -self.chunk_stride]
127
- .contiguous()
128
- .view(batch_size, feat_dim, -1, self.chunk_size)
129
- )
130
- segments2 = (
131
- out[:, :, self.chunk_stride :]
132
- .contiguous()
133
- .view(batch_size, feat_dim, -1, self.chunk_size)
134
- )
135
- out = torch.cat([segments1, segments2], dim=3)
136
- out = (
137
- out.view(batch_size, feat_dim, -1, self.chunk_size)
138
- .transpose(2, 3)
139
- .contiguous()
140
- )
141
-
142
- return out, rest
143
-
144
- def merging(self, x: torch.Tensor, rest: int) -> torch.Tensor:
145
- batch_size, dim, _, _ = x.shape
146
- out = (
147
- x.transpose(2, 3)
148
- .contiguous()
149
- .view(batch_size, dim, -1, self.chunk_size * 2)
150
- )
151
- out1 = (
152
- out[:, :, :, : self.chunk_size]
153
- .contiguous()
154
- .view(batch_size, dim, -1)[:, :, self.chunk_stride :]
155
- )
156
- out2 = (
157
- out[:, :, :, self.chunk_size :]
158
- .contiguous()
159
- .view(batch_size, dim, -1)[:, :, : -self.chunk_stride]
160
- )
161
- out = out1 + out2
162
- if rest > 0:
163
- out = out[:, :, :-rest]
164
- out = out.contiguous()
165
- return out
166
-
167
- def forward(self, x: torch.Tensor) -> torch.Tensor:
168
- x, rest = self.chunking(x)
169
- batch_size, _, dim1, dim2 = x.shape
170
- out = x
171
- for row_rnn, row_norm, col_rnn, col_norm in zip(
172
- self.row_rnn, self.row_norm, self.col_rnn, self.col_norm
173
- ):
174
- row_in = (
175
- out.permute(0, 3, 2, 1)
176
- .contiguous()
177
- .view(batch_size * dim2, dim1, -1)
178
- .contiguous()
179
- )
180
- row_out = row_rnn(row_in)
181
- row_out = (
182
- row_out.view(batch_size, dim2, dim1, -1)
183
- .permute(0, 3, 2, 1)
184
- .contiguous()
185
- )
186
- row_out = row_norm(row_out)
187
- out = out + row_out
188
-
189
- col_in = (
190
- out.permute(0, 2, 3, 1)
191
- .contiguous()
192
- .view(batch_size * dim1, dim2, -1)
193
- .contiguous()
194
- )
195
- col_out = col_rnn(col_in)
196
- col_out = (
197
- col_out.view(batch_size, dim1, dim2, -1)
198
- .permute(0, 3, 1, 2)
199
- .contiguous()
200
- )
201
- col_out = col_norm(col_out)
202
- out = out + col_out
203
- out = self.conv(out)
204
- out = self.merging(out, rest)
205
- out = out.transpose(1, 2).contiguous()
206
- return out
207
-
208
-
209
- class AutoPool(nn.Module):
210
- def __init__(self, pool_dim: int = 1) -> None:
211
- super(AutoPool, self).__init__()
212
- self.pool_dim: int = pool_dim
213
- self.softmax: nn.modules.Module = nn.Softmax(dim=pool_dim)
214
- self.register_parameter("alpha", nn.Parameter(torch.ones(1)))
215
-
216
- def forward(self, x: torch.Tensor) -> torch.Tensor:
217
- weight = self.softmax(torch.mul(x, self.alpha))
218
- out = torch.sum(torch.mul(x, weight), dim=self.pool_dim)
219
- return out
220
-
221
-
222
- class SquimObjective(nn.Module):
223
-
224
- def __init__(
225
- self,
226
- encoder: nn.Module,
227
- dprnn: nn.Module,
228
- branches: nn.ModuleList,
229
- ):
230
- super(SquimObjective, self).__init__()
231
- self.encoder = encoder
232
- self.dprnn = dprnn
233
- self.branches = branches
234
-
235
- def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
236
- if x.ndim != 2:
237
- raise ValueError(
238
- f"The input must be a 2D Tensor. Found dimension {x.ndim}."
239
- )
240
- x = x / (torch.mean(x**2, dim=1, keepdim=True) ** 0.5 * 20)
241
- out = self.encoder(x)
242
- out = self.dprnn(out)
243
- scores = []
244
- for branch in self.branches:
245
- scores.append(branch(out).squeeze(dim=1))
246
- return scores
247
-
248
-
249
- def _create_branch(d_model: int, nhead: int, metric: str) -> nn.modules.Module:
250
- layer1 = nn.TransformerEncoderLayer(
251
- d_model, nhead, d_model * 4, dropout=0.0, batch_first=True
252
- )
253
- layer2 = AutoPool()
254
- if metric == "stoi":
255
- layer3 = nn.Sequential(
256
- nn.Linear(d_model, d_model),
257
- nn.PReLU(),
258
- nn.Linear(d_model, 1),
259
- RangeSigmoid(),
260
- )
261
- elif metric == "pesq":
262
- layer3 = nn.Sequential(
263
- nn.Linear(d_model, d_model),
264
- nn.PReLU(),
265
- nn.Linear(d_model, 1),
266
- RangeSigmoid(val_range=PESQRange),
267
- )
268
- else:
269
- layer3: nn.modules.Module = nn.Sequential(
270
- nn.Linear(d_model, d_model), nn.PReLU(), nn.Linear(d_model, 1)
271
- )
272
- return nn.Sequential(layer1, layer2, layer3)
273
-
274
-
275
- def squim_objective_model(
276
- feat_dim: int,
277
- win_len: int,
278
- d_model: int,
279
- nhead: int,
280
- hidden_dim: int,
281
- num_blocks: int,
282
- rnn_type: str,
283
- chunk_size: int,
284
- chunk_stride: Optional[int] = None,
285
- ) -> SquimObjective:
286
- if chunk_stride is None:
287
- chunk_stride = chunk_size // 2
288
- encoder = Encoder(feat_dim, win_len)
289
- dprnn = DPRNN(
290
- feat_dim, hidden_dim, num_blocks, rnn_type, d_model, chunk_size, chunk_stride
291
- )
292
- branches = nn.ModuleList(
293
- [
294
- _create_branch(d_model, nhead, "stoi"),
295
- _create_branch(d_model, nhead, "pesq"),
296
- _create_branch(d_model, nhead, "sisdr"),
297
- ]
298
- )
299
- return SquimObjective(encoder, dprnn, branches)
300
-
301
-
302
- def squim_objective_base() -> SquimObjective:
303
- return squim_objective_model(
304
- feat_dim=256,
305
- win_len=64,
306
- d_model=256,
307
- nhead=4,
308
- hidden_dim=256,
309
- num_blocks=2,
310
- rnn_type="LSTM",
311
- chunk_size=71,
312
- )
313
-
314
-
315
- @dataclass
316
- class SquimObjectiveBundle:
317
-
318
- _path: str
319
- _sample_rate: float
320
-
321
- def _get_state_dict(self, dl_kwargs):
322
- url = f"https://download.pytorch.org/torchaudio/models/{self._path}"
323
- dl_kwargs = {} if dl_kwargs is None else dl_kwargs
324
- state_dict = load_state_dict_from_url(url, **dl_kwargs)
325
- return state_dict
326
-
327
- def get_model(self, *, dl_kwargs=None) -> SquimObjective:
328
- model = squim_objective_base()
329
- model.load_state_dict(self._get_state_dict(dl_kwargs))
330
- model.eval()
331
- return model
332
-
333
- @property
334
- def sample_rate(self):
335
- return self._sample_rate
336
-
337
-
338
- SQUIM_OBJECTIVE = SquimObjectiveBundle(
339
- "squim_objective_dns2020.pth",
340
- _sample_rate=16000,
341
- )
342
- SQUIM_OBJECTIVE.__doc__ = """SquimObjective pipeline trained using approach described in
343
- :cite:`kumar2023torchaudio` on the *DNS 2020 Dataset* :cite:`reddy2020interspeech`.
344
-
345
- The underlying model is constructed by :py:func:`torchaudio.models.squim_objective_base`.
346
- The weights are under `Creative Commons Attribution 4.0 International License
347
- <https://github.com/microsoft/DNS-Challenge/blob/interspeech2020/master/LICENSE>`__.
348
-
349
- Please refer to :py:class:`SquimObjectiveBundle` for usage instructions.
350
- """
 
1
+ from dataclasses import dataclass
2
+
3
+ from torchaudio._internal import load_state_dict_from_url
4
+
5
+ import math
6
+ from typing import List, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ def transform_wb_pesq_range(x: float) -> float:
14
+ return 0.999 + (4.999 - 0.999) / (1 + math.exp(-1.3669 * x + 3.8224))
15
+
16
+
17
+ PESQRange: Tuple[float, float] = (
18
+ 1.0,
19
+ transform_wb_pesq_range(4.5),
20
+ )
21
+
22
+
23
+ class RangeSigmoid(nn.Module):
24
+ def __init__(self, val_range: Tuple[float, float] = (0.0, 1.0)) -> None:
25
+ super(RangeSigmoid, self).__init__()
26
+ assert isinstance(val_range, tuple) and len(val_range) == 2
27
+ self.val_range: Tuple[float, float] = val_range
28
+ self.sigmoid: nn.modules.Module = nn.Sigmoid()
29
+
30
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
31
+ out = (
32
+ self.sigmoid(x) * (self.val_range[1] - self.val_range[0])
33
+ + self.val_range[0]
34
+ )
35
+ return out
36
+
37
+
38
+ class Encoder(nn.Module):
39
+
40
+ def __init__(self, feat_dim: int = 512, win_len: int = 32) -> None:
41
+ super(Encoder, self).__init__()
42
+
43
+ self.conv1d = nn.Conv1d(1, feat_dim, win_len, stride=win_len // 2, bias=False)
44
+
45
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
46
+ out = x.unsqueeze(dim=1)
47
+ out = F.relu(self.conv1d(out))
48
+ return out
49
+
50
+
51
+ class SingleRNN(nn.Module):
52
+ def __init__(
53
+ self, rnn_type: str, input_size: int, hidden_size: int, dropout: float = 0.0
54
+ ) -> None:
55
+ super(SingleRNN, self).__init__()
56
+
57
+ self.rnn_type = rnn_type
58
+ self.input_size = input_size
59
+ self.hidden_size = hidden_size
60
+
61
+ self.rnn: nn.modules.Module = getattr(nn, rnn_type)(
62
+ input_size,
63
+ hidden_size,
64
+ 1,
65
+ dropout=dropout,
66
+ batch_first=True,
67
+ bidirectional=True,
68
+ )
69
+
70
+ self.proj = nn.Linear(hidden_size * 2, input_size)
71
+
72
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
73
+ out, _ = self.rnn(x)
74
+ out = self.proj(out)
75
+ return out
76
+
77
+
78
+ class DPRNN(nn.Module):
79
+
80
+ def __init__(
81
+ self,
82
+ feat_dim: int = 64,
83
+ hidden_dim: int = 128,
84
+ num_blocks: int = 6,
85
+ rnn_type: str = "LSTM",
86
+ d_model: int = 256,
87
+ chunk_size: int = 100,
88
+ chunk_stride: int = 50,
89
+ ) -> None:
90
+ super(DPRNN, self).__init__()
91
+
92
+ self.num_blocks = num_blocks
93
+
94
+ self.row_rnn = nn.ModuleList([])
95
+ self.col_rnn = nn.ModuleList([])
96
+ self.row_norm = nn.ModuleList([])
97
+ self.col_norm = nn.ModuleList([])
98
+ for _ in range(num_blocks):
99
+ self.row_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
100
+ self.col_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
101
+ self.row_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
102
+ self.col_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
103
+ self.conv = nn.Sequential(
104
+ nn.Conv2d(feat_dim, d_model, 1),
105
+ nn.PReLU(),
106
+ )
107
+ self.chunk_size = chunk_size
108
+ self.chunk_stride = chunk_stride
109
+
110
+ def pad_chunk(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
111
+ seq_len = x.shape[-1]
112
+
113
+ rest = (
114
+ self.chunk_size
115
+ - (self.chunk_stride + seq_len % self.chunk_size) % self.chunk_size
116
+ )
117
+ out = F.pad(x, [self.chunk_stride, rest + self.chunk_stride])
118
+
119
+ return out, rest
120
+
121
+ def chunking(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
122
+ out, rest = self.pad_chunk(x)
123
+ batch_size, feat_dim, seq_len = out.shape
124
+
125
+ segments1 = (
126
+ out[:, :, : -self.chunk_stride]
127
+ .contiguous()
128
+ .view(batch_size, feat_dim, -1, self.chunk_size)
129
+ )
130
+ segments2 = (
131
+ out[:, :, self.chunk_stride :]
132
+ .contiguous()
133
+ .view(batch_size, feat_dim, -1, self.chunk_size)
134
+ )
135
+ out = torch.cat([segments1, segments2], dim=3)
136
+ out = (
137
+ out.view(batch_size, feat_dim, -1, self.chunk_size)
138
+ .transpose(2, 3)
139
+ .contiguous()
140
+ )
141
+
142
+ return out, rest
143
+
144
+ def merging(self, x: torch.Tensor, rest: int) -> torch.Tensor:
145
+ batch_size, dim, _, _ = x.shape
146
+ out = (
147
+ x.transpose(2, 3)
148
+ .contiguous()
149
+ .view(batch_size, dim, -1, self.chunk_size * 2)
150
+ )
151
+ out1 = (
152
+ out[:, :, :, : self.chunk_size]
153
+ .contiguous()
154
+ .view(batch_size, dim, -1)[:, :, self.chunk_stride :]
155
+ )
156
+ out2 = (
157
+ out[:, :, :, self.chunk_size :]
158
+ .contiguous()
159
+ .view(batch_size, dim, -1)[:, :, : -self.chunk_stride]
160
+ )
161
+ out = out1 + out2
162
+ if rest > 0:
163
+ out = out[:, :, :-rest]
164
+ out = out.contiguous()
165
+ return out
166
+
167
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
168
+ x, rest = self.chunking(x)
169
+ batch_size, _, dim1, dim2 = x.shape
170
+ out = x
171
+ for row_rnn, row_norm, col_rnn, col_norm in zip(
172
+ self.row_rnn, self.row_norm, self.col_rnn, self.col_norm
173
+ ):
174
+ row_in = (
175
+ out.permute(0, 3, 2, 1)
176
+ .contiguous()
177
+ .view(batch_size * dim2, dim1, -1)
178
+ .contiguous()
179
+ )
180
+ row_out = row_rnn(row_in)
181
+ row_out = (
182
+ row_out.view(batch_size, dim2, dim1, -1)
183
+ .permute(0, 3, 2, 1)
184
+ .contiguous()
185
+ )
186
+ row_out = row_norm(row_out)
187
+ out = out + row_out
188
+
189
+ col_in = (
190
+ out.permute(0, 2, 3, 1)
191
+ .contiguous()
192
+ .view(batch_size * dim1, dim2, -1)
193
+ .contiguous()
194
+ )
195
+ col_out = col_rnn(col_in)
196
+ col_out = (
197
+ col_out.view(batch_size, dim1, dim2, -1)
198
+ .permute(0, 3, 1, 2)
199
+ .contiguous()
200
+ )
201
+ col_out = col_norm(col_out)
202
+ out = out + col_out
203
+ out = self.conv(out)
204
+ out = self.merging(out, rest)
205
+ out = out.transpose(1, 2).contiguous()
206
+ return out
207
+
208
+
209
+ class AutoPool(nn.Module):
210
+ def __init__(self, pool_dim: int = 1) -> None:
211
+ super(AutoPool, self).__init__()
212
+ self.pool_dim: int = pool_dim
213
+ self.softmax: nn.modules.Module = nn.Softmax(dim=pool_dim)
214
+ self.register_parameter("alpha", nn.Parameter(torch.ones(1)))
215
+
216
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
217
+ weight = self.softmax(torch.mul(x, self.alpha))
218
+ out = torch.sum(torch.mul(x, weight), dim=self.pool_dim)
219
+ return out
220
+
221
+
222
+ class SquimObjective(nn.Module):
223
+
224
+ def __init__(
225
+ self,
226
+ encoder: nn.Module,
227
+ dprnn: nn.Module,
228
+ branches: nn.ModuleList,
229
+ ):
230
+ super(SquimObjective, self).__init__()
231
+ self.encoder = encoder
232
+ self.dprnn = dprnn
233
+ self.branches = branches
234
+
235
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
236
+ if x.ndim != 2:
237
+ raise ValueError(
238
+ f"The input must be a 2D Tensor. Found dimension {x.ndim}."
239
+ )
240
+ x = x / (torch.mean(x**2, dim=1, keepdim=True) ** 0.5 * 20)
241
+ out = self.encoder(x)
242
+ out = self.dprnn(out)
243
+ scores = []
244
+ for branch in self.branches:
245
+ scores.append(branch(out).squeeze(dim=1))
246
+ return scores
247
+
248
+
249
+ def _create_branch(d_model: int, nhead: int, metric: str) -> nn.modules.Module:
250
+ layer1 = nn.TransformerEncoderLayer(
251
+ d_model, nhead, d_model * 4, dropout=0.0, batch_first=True
252
+ )
253
+ layer2 = AutoPool()
254
+ if metric == "stoi":
255
+ layer3 = nn.Sequential(
256
+ nn.Linear(d_model, d_model),
257
+ nn.PReLU(),
258
+ nn.Linear(d_model, 1),
259
+ RangeSigmoid(),
260
+ )
261
+ elif metric == "pesq":
262
+ layer3 = nn.Sequential(
263
+ nn.Linear(d_model, d_model),
264
+ nn.PReLU(),
265
+ nn.Linear(d_model, 1),
266
+ RangeSigmoid(val_range=PESQRange),
267
+ )
268
+ else:
269
+ layer3: nn.modules.Module = nn.Sequential(
270
+ nn.Linear(d_model, d_model), nn.PReLU(), nn.Linear(d_model, 1)
271
+ )
272
+ return nn.Sequential(layer1, layer2, layer3)
273
+
274
+
275
+ def squim_objective_model(
276
+ feat_dim: int,
277
+ win_len: int,
278
+ d_model: int,
279
+ nhead: int,
280
+ hidden_dim: int,
281
+ num_blocks: int,
282
+ rnn_type: str,
283
+ chunk_size: int,
284
+ chunk_stride: Optional[int] = None,
285
+ ) -> SquimObjective:
286
+ if chunk_stride is None:
287
+ chunk_stride = chunk_size // 2
288
+ encoder = Encoder(feat_dim, win_len)
289
+ dprnn = DPRNN(
290
+ feat_dim, hidden_dim, num_blocks, rnn_type, d_model, chunk_size, chunk_stride
291
+ )
292
+ branches = nn.ModuleList(
293
+ [
294
+ _create_branch(d_model, nhead, "stoi"),
295
+ _create_branch(d_model, nhead, "pesq"),
296
+ _create_branch(d_model, nhead, "sisdr"),
297
+ ]
298
+ )
299
+ return SquimObjective(encoder, dprnn, branches)
300
+
301
+
302
+ def squim_objective_base() -> SquimObjective:
303
+ return squim_objective_model(
304
+ feat_dim=256,
305
+ win_len=64,
306
+ d_model=256,
307
+ nhead=4,
308
+ hidden_dim=256,
309
+ num_blocks=2,
310
+ rnn_type="LSTM",
311
+ chunk_size=71,
312
+ )
313
+
314
+
315
+ @dataclass
316
+ class SquimObjectiveBundle:
317
+
318
+ _path: str
319
+ _sample_rate: float
320
+
321
+ def _get_state_dict(self, dl_kwargs):
322
+ url = f"https://download.pytorch.org/torchaudio/models/{self._path}"
323
+ dl_kwargs = {} if dl_kwargs is None else dl_kwargs
324
+ state_dict = load_state_dict_from_url(url, **dl_kwargs)
325
+ return state_dict
326
+
327
+ def get_model(self, *, dl_kwargs=None) -> SquimObjective:
328
+ model = squim_objective_base()
329
+ model.load_state_dict(self._get_state_dict(dl_kwargs))
330
+ model.eval()
331
+ return model
332
+
333
+ @property
334
+ def sample_rate(self):
335
+ return self._sample_rate
336
+
337
+
338
+ SQUIM_OBJECTIVE = SquimObjectiveBundle(
339
+ "squim_objective_dns2020.pth",
340
+ _sample_rate=16000,
341
+ )
342
+ SQUIM_OBJECTIVE.__doc__ = """SquimObjective pipeline trained using approach described in
343
+ :cite:`kumar2023torchaudio` on the *DNS 2020 Dataset* :cite:`reddy2020interspeech`.
344
+
345
+ The underlying model is constructed by :py:func:`torchaudio.models.squim_objective_base`.
346
+ The weights are under `Creative Commons Attribution 4.0 International License
347
+ <https://github.com/microsoft/DNS-Challenge/blob/interspeech2020/master/LICENSE>`__.
348
+
349
+ Please refer to :py:class:`SquimObjectiveBundle` for usage instructions.
350
+ """
mvsepless/models/bandit/core/metrics/snr.py CHANGED
@@ -1,124 +1,124 @@
1
- from typing import Any, Callable
2
-
3
- import numpy as np
4
- import torch
5
- import torchmetrics as tm
6
- from torch._C import _LinAlgError
7
- from torchmetrics import functional as tmF
8
-
9
-
10
- class SafeSignalDistortionRatio(tm.SignalDistortionRatio):
11
- def __init__(self, **kwargs) -> None:
12
- super().__init__(**kwargs)
13
-
14
- def update(self, *args, **kwargs) -> Any:
15
- try:
16
- super().update(*args, **kwargs)
17
- except:
18
- pass
19
-
20
- def compute(self) -> Any:
21
- if self.total == 0:
22
- return torch.tensor(torch.nan)
23
- return super().compute()
24
-
25
-
26
- class BaseChunkMedianSignalRatio(tm.Metric):
27
- def __init__(
28
- self,
29
- func: Callable,
30
- window_size: int,
31
- hop_size: int = None,
32
- zero_mean: bool = False,
33
- ) -> None:
34
- super().__init__()
35
-
36
- self.func = func
37
- self.window_size = window_size
38
- if hop_size is None:
39
- hop_size = window_size
40
- self.hop_size = hop_size
41
-
42
- self.add_state("sum_snr", default=torch.tensor(0.0), dist_reduce_fx="sum")
43
- self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
44
-
45
- def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
46
-
47
- n_samples = target.shape[-1]
48
-
49
- n_chunks = int(np.ceil((n_samples - self.window_size) / self.hop_size) + 1)
50
-
51
- snr_chunk = []
52
-
53
- for i in range(n_chunks):
54
- start = i * self.hop_size
55
-
56
- if n_samples - start < self.window_size:
57
- continue
58
-
59
- end = start + self.window_size
60
-
61
- try:
62
- chunk_snr = self.func(preds[..., start:end], target[..., start:end])
63
-
64
- if torch.all(torch.isfinite(chunk_snr)):
65
- snr_chunk.append(chunk_snr)
66
- except _LinAlgError:
67
- pass
68
-
69
- snr_chunk = torch.stack(snr_chunk, dim=-1)
70
- snr_batch, _ = torch.nanmedian(snr_chunk, dim=-1)
71
-
72
- self.sum_snr += snr_batch.sum()
73
- self.total += snr_batch.numel()
74
-
75
- def compute(self) -> Any:
76
- return self.sum_snr / self.total
77
-
78
-
79
- class ChunkMedianSignalNoiseRatio(BaseChunkMedianSignalRatio):
80
- def __init__(
81
- self, window_size: int, hop_size: int = None, zero_mean: bool = False
82
- ) -> None:
83
- super().__init__(
84
- func=tmF.signal_noise_ratio,
85
- window_size=window_size,
86
- hop_size=hop_size,
87
- zero_mean=zero_mean,
88
- )
89
-
90
-
91
- class ChunkMedianScaleInvariantSignalNoiseRatio(BaseChunkMedianSignalRatio):
92
- def __init__(
93
- self, window_size: int, hop_size: int = None, zero_mean: bool = False
94
- ) -> None:
95
- super().__init__(
96
- func=tmF.scale_invariant_signal_noise_ratio,
97
- window_size=window_size,
98
- hop_size=hop_size,
99
- zero_mean=zero_mean,
100
- )
101
-
102
-
103
- class ChunkMedianSignalDistortionRatio(BaseChunkMedianSignalRatio):
104
- def __init__(
105
- self, window_size: int, hop_size: int = None, zero_mean: bool = False
106
- ) -> None:
107
- super().__init__(
108
- func=tmF.signal_distortion_ratio,
109
- window_size=window_size,
110
- hop_size=hop_size,
111
- zero_mean=zero_mean,
112
- )
113
-
114
-
115
- class ChunkMedianScaleInvariantSignalDistortionRatio(BaseChunkMedianSignalRatio):
116
- def __init__(
117
- self, window_size: int, hop_size: int = None, zero_mean: bool = False
118
- ) -> None:
119
- super().__init__(
120
- func=tmF.scale_invariant_signal_distortion_ratio,
121
- window_size=window_size,
122
- hop_size=hop_size,
123
- zero_mean=zero_mean,
124
- )
 
1
+ from typing import Any, Callable
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torchmetrics as tm
6
+ from torch._C import _LinAlgError
7
+ from torchmetrics import functional as tmF
8
+
9
+
10
+ class SafeSignalDistortionRatio(tm.SignalDistortionRatio):
11
+ def __init__(self, **kwargs) -> None:
12
+ super().__init__(**kwargs)
13
+
14
+ def update(self, *args, **kwargs) -> Any:
15
+ try:
16
+ super().update(*args, **kwargs)
17
+ except:
18
+ pass
19
+
20
+ def compute(self) -> Any:
21
+ if self.total == 0:
22
+ return torch.tensor(torch.nan)
23
+ return super().compute()
24
+
25
+
26
+ class BaseChunkMedianSignalRatio(tm.Metric):
27
+ def __init__(
28
+ self,
29
+ func: Callable,
30
+ window_size: int,
31
+ hop_size: int = None,
32
+ zero_mean: bool = False,
33
+ ) -> None:
34
+ super().__init__()
35
+
36
+ self.func = func
37
+ self.window_size = window_size
38
+ if hop_size is None:
39
+ hop_size = window_size
40
+ self.hop_size = hop_size
41
+
42
+ self.add_state("sum_snr", default=torch.tensor(0.0), dist_reduce_fx="sum")
43
+ self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
44
+
45
+ def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
46
+
47
+ n_samples = target.shape[-1]
48
+
49
+ n_chunks = int(np.ceil((n_samples - self.window_size) / self.hop_size) + 1)
50
+
51
+ snr_chunk = []
52
+
53
+ for i in range(n_chunks):
54
+ start = i * self.hop_size
55
+
56
+ if n_samples - start < self.window_size:
57
+ continue
58
+
59
+ end = start + self.window_size
60
+
61
+ try:
62
+ chunk_snr = self.func(preds[..., start:end], target[..., start:end])
63
+
64
+ if torch.all(torch.isfinite(chunk_snr)):
65
+ snr_chunk.append(chunk_snr)
66
+ except _LinAlgError:
67
+ pass
68
+
69
+ snr_chunk = torch.stack(snr_chunk, dim=-1)
70
+ snr_batch, _ = torch.nanmedian(snr_chunk, dim=-1)
71
+
72
+ self.sum_snr += snr_batch.sum()
73
+ self.total += snr_batch.numel()
74
+
75
+ def compute(self) -> Any:
76
+ return self.sum_snr / self.total
77
+
78
+
79
+ class ChunkMedianSignalNoiseRatio(BaseChunkMedianSignalRatio):
80
+ def __init__(
81
+ self, window_size: int, hop_size: int = None, zero_mean: bool = False
82
+ ) -> None:
83
+ super().__init__(
84
+ func=tmF.signal_noise_ratio,
85
+ window_size=window_size,
86
+ hop_size=hop_size,
87
+ zero_mean=zero_mean,
88
+ )
89
+
90
+
91
+ class ChunkMedianScaleInvariantSignalNoiseRatio(BaseChunkMedianSignalRatio):
92
+ def __init__(
93
+ self, window_size: int, hop_size: int = None, zero_mean: bool = False
94
+ ) -> None:
95
+ super().__init__(
96
+ func=tmF.scale_invariant_signal_noise_ratio,
97
+ window_size=window_size,
98
+ hop_size=hop_size,
99
+ zero_mean=zero_mean,
100
+ )
101
+
102
+
103
+ class ChunkMedianSignalDistortionRatio(BaseChunkMedianSignalRatio):
104
+ def __init__(
105
+ self, window_size: int, hop_size: int = None, zero_mean: bool = False
106
+ ) -> None:
107
+ super().__init__(
108
+ func=tmF.signal_distortion_ratio,
109
+ window_size=window_size,
110
+ hop_size=hop_size,
111
+ zero_mean=zero_mean,
112
+ )
113
+
114
+
115
+ class ChunkMedianScaleInvariantSignalDistortionRatio(BaseChunkMedianSignalRatio):
116
+ def __init__(
117
+ self, window_size: int, hop_size: int = None, zero_mean: bool = False
118
+ ) -> None:
119
+ super().__init__(
120
+ func=tmF.scale_invariant_signal_distortion_ratio,
121
+ window_size=window_size,
122
+ hop_size=hop_size,
123
+ zero_mean=zero_mean,
124
+ )
mvsepless/models/bandit/core/model/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
- from .bsrnn.wrapper import (
2
- MultiMaskMultiSourceBandSplitRNNSimple,
3
- )
 
1
+ from .bsrnn.wrapper import (
2
+ MultiMaskMultiSourceBandSplitRNNSimple,
3
+ )
mvsepless/models/bandit/core/model/_spectral.py CHANGED
@@ -1,54 +1,54 @@
1
- from typing import Dict, Optional
2
-
3
- import torch
4
- import torchaudio as ta
5
- from torch import nn
6
-
7
-
8
- class _SpectralComponent(nn.Module):
9
- def __init__(
10
- self,
11
- n_fft: int = 2048,
12
- win_length: Optional[int] = 2048,
13
- hop_length: int = 512,
14
- window_fn: str = "hann_window",
15
- wkwargs: Optional[Dict] = None,
16
- power: Optional[int] = None,
17
- center: bool = True,
18
- normalized: bool = True,
19
- pad_mode: str = "constant",
20
- onesided: bool = True,
21
- **kwargs,
22
- ) -> None:
23
- super().__init__()
24
-
25
- assert power is None
26
-
27
- window_fn = torch.__dict__[window_fn]
28
-
29
- self.stft = ta.transforms.Spectrogram(
30
- n_fft=n_fft,
31
- win_length=win_length,
32
- hop_length=hop_length,
33
- pad_mode=pad_mode,
34
- pad=0,
35
- window_fn=window_fn,
36
- wkwargs=wkwargs,
37
- power=power,
38
- normalized=normalized,
39
- center=center,
40
- onesided=onesided,
41
- )
42
-
43
- self.istft = ta.transforms.InverseSpectrogram(
44
- n_fft=n_fft,
45
- win_length=win_length,
46
- hop_length=hop_length,
47
- pad_mode=pad_mode,
48
- pad=0,
49
- window_fn=window_fn,
50
- wkwargs=wkwargs,
51
- normalized=normalized,
52
- center=center,
53
- onesided=onesided,
54
- )
 
1
+ from typing import Dict, Optional
2
+
3
+ import torch
4
+ import torchaudio as ta
5
+ from torch import nn
6
+
7
+
8
+ class _SpectralComponent(nn.Module):
9
+ def __init__(
10
+ self,
11
+ n_fft: int = 2048,
12
+ win_length: Optional[int] = 2048,
13
+ hop_length: int = 512,
14
+ window_fn: str = "hann_window",
15
+ wkwargs: Optional[Dict] = None,
16
+ power: Optional[int] = None,
17
+ center: bool = True,
18
+ normalized: bool = True,
19
+ pad_mode: str = "constant",
20
+ onesided: bool = True,
21
+ **kwargs,
22
+ ) -> None:
23
+ super().__init__()
24
+
25
+ assert power is None
26
+
27
+ window_fn = torch.__dict__[window_fn]
28
+
29
+ self.stft = ta.transforms.Spectrogram(
30
+ n_fft=n_fft,
31
+ win_length=win_length,
32
+ hop_length=hop_length,
33
+ pad_mode=pad_mode,
34
+ pad=0,
35
+ window_fn=window_fn,
36
+ wkwargs=wkwargs,
37
+ power=power,
38
+ normalized=normalized,
39
+ center=center,
40
+ onesided=onesided,
41
+ )
42
+
43
+ self.istft = ta.transforms.InverseSpectrogram(
44
+ n_fft=n_fft,
45
+ win_length=win_length,
46
+ hop_length=hop_length,
47
+ pad_mode=pad_mode,
48
+ pad=0,
49
+ window_fn=window_fn,
50
+ wkwargs=wkwargs,
51
+ normalized=normalized,
52
+ center=center,
53
+ onesided=onesided,
54
+ )
mvsepless/models/bandit/core/model/bsrnn/__init__.py CHANGED
@@ -1,23 +1,23 @@
1
- from abc import ABC
2
- from typing import Iterable, Mapping, Union
3
-
4
- from torch import nn
5
-
6
- from .bandsplit import BandSplitModule
7
- from .tfmodel import (
8
- SeqBandModellingModule,
9
- TransformerTimeFreqModule,
10
- )
11
-
12
-
13
- class BandsplitCoreBase(nn.Module, ABC):
14
- band_split: nn.Module
15
- tf_model: nn.Module
16
- mask_estim: Union[nn.Module, Mapping[str, nn.Module], Iterable[nn.Module]]
17
-
18
- def __init__(self) -> None:
19
- super().__init__()
20
-
21
- @staticmethod
22
- def mask(x, m):
23
- return x * m
 
1
+ from abc import ABC
2
+ from typing import Iterable, Mapping, Union
3
+
4
+ from torch import nn
5
+
6
+ from .bandsplit import BandSplitModule
7
+ from .tfmodel import (
8
+ SeqBandModellingModule,
9
+ TransformerTimeFreqModule,
10
+ )
11
+
12
+
13
+ class BandsplitCoreBase(nn.Module, ABC):
14
+ band_split: nn.Module
15
+ tf_model: nn.Module
16
+ mask_estim: Union[nn.Module, Mapping[str, nn.Module], Iterable[nn.Module]]
17
+
18
+ def __init__(self) -> None:
19
+ super().__init__()
20
+
21
+ @staticmethod
22
+ def mask(x, m):
23
+ return x * m
mvsepless/models/bandit/core/model/bsrnn/bandsplit.py CHANGED
@@ -1,119 +1,119 @@
1
- from typing import List, Tuple
2
-
3
- import torch
4
- from torch import nn
5
-
6
- from .utils import (
7
- band_widths_from_specs,
8
- check_no_gap,
9
- check_no_overlap,
10
- check_nonzero_bandwidth,
11
- )
12
-
13
-
14
- class NormFC(nn.Module):
15
- def __init__(
16
- self,
17
- emb_dim: int,
18
- bandwidth: int,
19
- in_channel: int,
20
- normalize_channel_independently: bool = False,
21
- treat_channel_as_feature: bool = True,
22
- ) -> None:
23
- super().__init__()
24
-
25
- self.treat_channel_as_feature = treat_channel_as_feature
26
-
27
- if normalize_channel_independently:
28
- raise NotImplementedError
29
-
30
- reim = 2
31
-
32
- self.norm = nn.LayerNorm(in_channel * bandwidth * reim)
33
-
34
- fc_in = bandwidth * reim
35
-
36
- if treat_channel_as_feature:
37
- fc_in *= in_channel
38
- else:
39
- assert emb_dim % in_channel == 0
40
- emb_dim = emb_dim // in_channel
41
-
42
- self.fc = nn.Linear(fc_in, emb_dim)
43
-
44
- def forward(self, xb):
45
-
46
- batch, n_time, in_chan, ribw = xb.shape
47
- xb = self.norm(xb.reshape(batch, n_time, in_chan * ribw))
48
-
49
- if not self.treat_channel_as_feature:
50
- xb = xb.reshape(batch, n_time, in_chan, ribw)
51
-
52
- zb = self.fc(xb)
53
-
54
- if not self.treat_channel_as_feature:
55
- batch, n_time, in_chan, emb_dim_per_chan = zb.shape
56
- zb = zb.reshape((batch, n_time, in_chan * emb_dim_per_chan))
57
-
58
- return zb
59
-
60
-
61
- class BandSplitModule(nn.Module):
62
- def __init__(
63
- self,
64
- band_specs: List[Tuple[float, float]],
65
- emb_dim: int,
66
- in_channel: int,
67
- require_no_overlap: bool = False,
68
- require_no_gap: bool = True,
69
- normalize_channel_independently: bool = False,
70
- treat_channel_as_feature: bool = True,
71
- ) -> None:
72
- super().__init__()
73
-
74
- check_nonzero_bandwidth(band_specs)
75
-
76
- if require_no_gap:
77
- check_no_gap(band_specs)
78
-
79
- if require_no_overlap:
80
- check_no_overlap(band_specs)
81
-
82
- self.band_specs = band_specs
83
- self.band_widths = band_widths_from_specs(band_specs)
84
- self.n_bands = len(band_specs)
85
- self.emb_dim = emb_dim
86
-
87
- self.norm_fc_modules = nn.ModuleList(
88
- [ # type: ignore
89
- (
90
- NormFC(
91
- emb_dim=emb_dim,
92
- bandwidth=bw,
93
- in_channel=in_channel,
94
- normalize_channel_independently=normalize_channel_independently,
95
- treat_channel_as_feature=treat_channel_as_feature,
96
- )
97
- )
98
- for bw in self.band_widths
99
- ]
100
- )
101
-
102
- def forward(self, x: torch.Tensor):
103
-
104
- batch, in_chan, _, n_time = x.shape
105
-
106
- z = torch.zeros(
107
- size=(batch, self.n_bands, n_time, self.emb_dim), device=x.device
108
- )
109
-
110
- xr = torch.view_as_real(x)
111
- xr = torch.permute(xr, (0, 3, 1, 4, 2))
112
- batch, n_time, in_chan, reim, band_width = xr.shape
113
- for i, nfm in enumerate(self.norm_fc_modules):
114
- fstart, fend = self.band_specs[i]
115
- xb = xr[..., fstart:fend]
116
- xb = torch.reshape(xb, (batch, n_time, in_chan, -1))
117
- z[:, i, :, :] = nfm(xb.contiguous())
118
-
119
- return z
 
1
+ from typing import List, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from .utils import (
7
+ band_widths_from_specs,
8
+ check_no_gap,
9
+ check_no_overlap,
10
+ check_nonzero_bandwidth,
11
+ )
12
+
13
+
14
+ class NormFC(nn.Module):
15
+ def __init__(
16
+ self,
17
+ emb_dim: int,
18
+ bandwidth: int,
19
+ in_channel: int,
20
+ normalize_channel_independently: bool = False,
21
+ treat_channel_as_feature: bool = True,
22
+ ) -> None:
23
+ super().__init__()
24
+
25
+ self.treat_channel_as_feature = treat_channel_as_feature
26
+
27
+ if normalize_channel_independently:
28
+ raise NotImplementedError
29
+
30
+ reim = 2
31
+
32
+ self.norm = nn.LayerNorm(in_channel * bandwidth * reim)
33
+
34
+ fc_in = bandwidth * reim
35
+
36
+ if treat_channel_as_feature:
37
+ fc_in *= in_channel
38
+ else:
39
+ assert emb_dim % in_channel == 0
40
+ emb_dim = emb_dim // in_channel
41
+
42
+ self.fc = nn.Linear(fc_in, emb_dim)
43
+
44
+ def forward(self, xb):
45
+
46
+ batch, n_time, in_chan, ribw = xb.shape
47
+ xb = self.norm(xb.reshape(batch, n_time, in_chan * ribw))
48
+
49
+ if not self.treat_channel_as_feature:
50
+ xb = xb.reshape(batch, n_time, in_chan, ribw)
51
+
52
+ zb = self.fc(xb)
53
+
54
+ if not self.treat_channel_as_feature:
55
+ batch, n_time, in_chan, emb_dim_per_chan = zb.shape
56
+ zb = zb.reshape((batch, n_time, in_chan * emb_dim_per_chan))
57
+
58
+ return zb
59
+
60
+
61
+ class BandSplitModule(nn.Module):
62
+ def __init__(
63
+ self,
64
+ band_specs: List[Tuple[float, float]],
65
+ emb_dim: int,
66
+ in_channel: int,
67
+ require_no_overlap: bool = False,
68
+ require_no_gap: bool = True,
69
+ normalize_channel_independently: bool = False,
70
+ treat_channel_as_feature: bool = True,
71
+ ) -> None:
72
+ super().__init__()
73
+
74
+ check_nonzero_bandwidth(band_specs)
75
+
76
+ if require_no_gap:
77
+ check_no_gap(band_specs)
78
+
79
+ if require_no_overlap:
80
+ check_no_overlap(band_specs)
81
+
82
+ self.band_specs = band_specs
83
+ self.band_widths = band_widths_from_specs(band_specs)
84
+ self.n_bands = len(band_specs)
85
+ self.emb_dim = emb_dim
86
+
87
+ self.norm_fc_modules = nn.ModuleList(
88
+ [ # type: ignore
89
+ (
90
+ NormFC(
91
+ emb_dim=emb_dim,
92
+ bandwidth=bw,
93
+ in_channel=in_channel,
94
+ normalize_channel_independently=normalize_channel_independently,
95
+ treat_channel_as_feature=treat_channel_as_feature,
96
+ )
97
+ )
98
+ for bw in self.band_widths
99
+ ]
100
+ )
101
+
102
+ def forward(self, x: torch.Tensor):
103
+
104
+ batch, in_chan, _, n_time = x.shape
105
+
106
+ z = torch.zeros(
107
+ size=(batch, self.n_bands, n_time, self.emb_dim), device=x.device
108
+ )
109
+
110
+ xr = torch.view_as_real(x)
111
+ xr = torch.permute(xr, (0, 3, 1, 4, 2))
112
+ batch, n_time, in_chan, reim, band_width = xr.shape
113
+ for i, nfm in enumerate(self.norm_fc_modules):
114
+ fstart, fend = self.band_specs[i]
115
+ xb = xr[..., fstart:fend]
116
+ xb = torch.reshape(xb, (batch, n_time, in_chan, -1))
117
+ z[:, i, :, :] = nfm(xb.contiguous())
118
+
119
+ return z
mvsepless/models/bandit/core/model/bsrnn/core.py CHANGED
@@ -1,619 +1,619 @@
1
- from typing import Dict, List, Optional, Tuple
2
-
3
- import torch
4
- from torch import nn
5
- from torch.nn import functional as F
6
-
7
- from . import BandsplitCoreBase
8
- from .bandsplit import BandSplitModule
9
- from .maskestim import (
10
- MaskEstimationModule,
11
- OverlappingMaskEstimationModule,
12
- )
13
- from .tfmodel import (
14
- ConvolutionalTimeFreqModule,
15
- SeqBandModellingModule,
16
- TransformerTimeFreqModule,
17
- )
18
-
19
-
20
- class MultiMaskBandSplitCoreBase(BandsplitCoreBase):
21
- def __init__(self) -> None:
22
- super().__init__()
23
-
24
- def forward(self, x, cond=None, compute_residual: bool = True):
25
- batch, in_chan, n_freq, n_time = x.shape
26
- x = torch.reshape(x, (-1, 1, n_freq, n_time))
27
-
28
- z = self.band_split(x)
29
-
30
- q = self.tf_model(z)
31
-
32
- out = {}
33
-
34
- for stem, mem in self.mask_estim.items():
35
- m = mem(q, cond=cond)
36
-
37
- s = self.mask(x, m)
38
- s = torch.reshape(s, (batch, in_chan, n_freq, n_time))
39
- out[stem] = s
40
-
41
- return {"spectrogram": out}
42
-
43
- def instantiate_mask_estim(
44
- self,
45
- in_channel: int,
46
- stems: List[str],
47
- band_specs: List[Tuple[float, float]],
48
- emb_dim: int,
49
- mlp_dim: int,
50
- cond_dim: int,
51
- hidden_activation: str,
52
- hidden_activation_kwargs: Optional[Dict] = None,
53
- complex_mask: bool = True,
54
- overlapping_band: bool = False,
55
- freq_weights: Optional[List[torch.Tensor]] = None,
56
- n_freq: Optional[int] = None,
57
- use_freq_weights: bool = True,
58
- mult_add_mask: bool = False,
59
- ):
60
- if hidden_activation_kwargs is None:
61
- hidden_activation_kwargs = {}
62
-
63
- if "mne:+" in stems:
64
- stems = [s for s in stems if s != "mne:+"]
65
-
66
- if overlapping_band:
67
- assert freq_weights is not None
68
- assert n_freq is not None
69
-
70
- if mult_add_mask:
71
-
72
- self.mask_estim = nn.ModuleDict(
73
- {
74
- stem: MultAddMaskEstimationModule(
75
- band_specs=band_specs,
76
- freq_weights=freq_weights,
77
- n_freq=n_freq,
78
- emb_dim=emb_dim,
79
- mlp_dim=mlp_dim,
80
- in_channel=in_channel,
81
- hidden_activation=hidden_activation,
82
- hidden_activation_kwargs=hidden_activation_kwargs,
83
- complex_mask=complex_mask,
84
- use_freq_weights=use_freq_weights,
85
- )
86
- for stem in stems
87
- }
88
- )
89
- else:
90
- self.mask_estim = nn.ModuleDict(
91
- {
92
- stem: OverlappingMaskEstimationModule(
93
- band_specs=band_specs,
94
- freq_weights=freq_weights,
95
- n_freq=n_freq,
96
- emb_dim=emb_dim,
97
- mlp_dim=mlp_dim,
98
- in_channel=in_channel,
99
- hidden_activation=hidden_activation,
100
- hidden_activation_kwargs=hidden_activation_kwargs,
101
- complex_mask=complex_mask,
102
- use_freq_weights=use_freq_weights,
103
- )
104
- for stem in stems
105
- }
106
- )
107
- else:
108
- self.mask_estim = nn.ModuleDict(
109
- {
110
- stem: MaskEstimationModule(
111
- band_specs=band_specs,
112
- emb_dim=emb_dim,
113
- mlp_dim=mlp_dim,
114
- cond_dim=cond_dim,
115
- in_channel=in_channel,
116
- hidden_activation=hidden_activation,
117
- hidden_activation_kwargs=hidden_activation_kwargs,
118
- complex_mask=complex_mask,
119
- )
120
- for stem in stems
121
- }
122
- )
123
-
124
- def instantiate_bandsplit(
125
- self,
126
- in_channel: int,
127
- band_specs: List[Tuple[float, float]],
128
- require_no_overlap: bool = False,
129
- require_no_gap: bool = True,
130
- normalize_channel_independently: bool = False,
131
- treat_channel_as_feature: bool = True,
132
- emb_dim: int = 128,
133
- ):
134
- self.band_split = BandSplitModule(
135
- in_channel=in_channel,
136
- band_specs=band_specs,
137
- require_no_overlap=require_no_overlap,
138
- require_no_gap=require_no_gap,
139
- normalize_channel_independently=normalize_channel_independently,
140
- treat_channel_as_feature=treat_channel_as_feature,
141
- emb_dim=emb_dim,
142
- )
143
-
144
-
145
- class SingleMaskBandsplitCoreBase(BandsplitCoreBase):
146
- def __init__(self, **kwargs) -> None:
147
- super().__init__()
148
-
149
- def forward(self, x):
150
- z = self.band_split(x)
151
- q = self.tf_model(z)
152
- m = self.mask_estim(q)
153
-
154
- s = self.mask(x, m)
155
-
156
- return s
157
-
158
-
159
- class SingleMaskBandsplitCoreRNN(
160
- SingleMaskBandsplitCoreBase,
161
- ):
162
- def __init__(
163
- self,
164
- in_channel: int,
165
- band_specs: List[Tuple[float, float]],
166
- require_no_overlap: bool = False,
167
- require_no_gap: bool = True,
168
- normalize_channel_independently: bool = False,
169
- treat_channel_as_feature: bool = True,
170
- n_sqm_modules: int = 12,
171
- emb_dim: int = 128,
172
- rnn_dim: int = 256,
173
- bidirectional: bool = True,
174
- rnn_type: str = "LSTM",
175
- mlp_dim: int = 512,
176
- hidden_activation: str = "Tanh",
177
- hidden_activation_kwargs: Optional[Dict] = None,
178
- complex_mask: bool = True,
179
- ) -> None:
180
- super().__init__()
181
- self.band_split = BandSplitModule(
182
- in_channel=in_channel,
183
- band_specs=band_specs,
184
- require_no_overlap=require_no_overlap,
185
- require_no_gap=require_no_gap,
186
- normalize_channel_independently=normalize_channel_independently,
187
- treat_channel_as_feature=treat_channel_as_feature,
188
- emb_dim=emb_dim,
189
- )
190
- self.tf_model = SeqBandModellingModule(
191
- n_modules=n_sqm_modules,
192
- emb_dim=emb_dim,
193
- rnn_dim=rnn_dim,
194
- bidirectional=bidirectional,
195
- rnn_type=rnn_type,
196
- )
197
- self.mask_estim = MaskEstimationModule(
198
- in_channel=in_channel,
199
- band_specs=band_specs,
200
- emb_dim=emb_dim,
201
- mlp_dim=mlp_dim,
202
- hidden_activation=hidden_activation,
203
- hidden_activation_kwargs=hidden_activation_kwargs,
204
- complex_mask=complex_mask,
205
- )
206
-
207
-
208
- class SingleMaskBandsplitCoreTransformer(
209
- SingleMaskBandsplitCoreBase,
210
- ):
211
- def __init__(
212
- self,
213
- in_channel: int,
214
- band_specs: List[Tuple[float, float]],
215
- require_no_overlap: bool = False,
216
- require_no_gap: bool = True,
217
- normalize_channel_independently: bool = False,
218
- treat_channel_as_feature: bool = True,
219
- n_sqm_modules: int = 12,
220
- emb_dim: int = 128,
221
- rnn_dim: int = 256,
222
- bidirectional: bool = True,
223
- tf_dropout: float = 0.0,
224
- mlp_dim: int = 512,
225
- hidden_activation: str = "Tanh",
226
- hidden_activation_kwargs: Optional[Dict] = None,
227
- complex_mask: bool = True,
228
- ) -> None:
229
- super().__init__()
230
- self.band_split = BandSplitModule(
231
- in_channel=in_channel,
232
- band_specs=band_specs,
233
- require_no_overlap=require_no_overlap,
234
- require_no_gap=require_no_gap,
235
- normalize_channel_independently=normalize_channel_independently,
236
- treat_channel_as_feature=treat_channel_as_feature,
237
- emb_dim=emb_dim,
238
- )
239
- self.tf_model = TransformerTimeFreqModule(
240
- n_modules=n_sqm_modules,
241
- emb_dim=emb_dim,
242
- rnn_dim=rnn_dim,
243
- bidirectional=bidirectional,
244
- dropout=tf_dropout,
245
- )
246
- self.mask_estim = MaskEstimationModule(
247
- in_channel=in_channel,
248
- band_specs=band_specs,
249
- emb_dim=emb_dim,
250
- mlp_dim=mlp_dim,
251
- hidden_activation=hidden_activation,
252
- hidden_activation_kwargs=hidden_activation_kwargs,
253
- complex_mask=complex_mask,
254
- )
255
-
256
-
257
- class MultiSourceMultiMaskBandSplitCoreRNN(MultiMaskBandSplitCoreBase):
258
- def __init__(
259
- self,
260
- in_channel: int,
261
- stems: List[str],
262
- band_specs: List[Tuple[float, float]],
263
- require_no_overlap: bool = False,
264
- require_no_gap: bool = True,
265
- normalize_channel_independently: bool = False,
266
- treat_channel_as_feature: bool = True,
267
- n_sqm_modules: int = 12,
268
- emb_dim: int = 128,
269
- rnn_dim: int = 256,
270
- bidirectional: bool = True,
271
- rnn_type: str = "LSTM",
272
- mlp_dim: int = 512,
273
- cond_dim: int = 0,
274
- hidden_activation: str = "Tanh",
275
- hidden_activation_kwargs: Optional[Dict] = None,
276
- complex_mask: bool = True,
277
- overlapping_band: bool = False,
278
- freq_weights: Optional[List[torch.Tensor]] = None,
279
- n_freq: Optional[int] = None,
280
- use_freq_weights: bool = True,
281
- mult_add_mask: bool = False,
282
- ) -> None:
283
-
284
- super().__init__()
285
- self.instantiate_bandsplit(
286
- in_channel=in_channel,
287
- band_specs=band_specs,
288
- require_no_overlap=require_no_overlap,
289
- require_no_gap=require_no_gap,
290
- normalize_channel_independently=normalize_channel_independently,
291
- treat_channel_as_feature=treat_channel_as_feature,
292
- emb_dim=emb_dim,
293
- )
294
-
295
- self.tf_model = SeqBandModellingModule(
296
- n_modules=n_sqm_modules,
297
- emb_dim=emb_dim,
298
- rnn_dim=rnn_dim,
299
- bidirectional=bidirectional,
300
- rnn_type=rnn_type,
301
- )
302
-
303
- self.mult_add_mask = mult_add_mask
304
-
305
- self.instantiate_mask_estim(
306
- in_channel=in_channel,
307
- stems=stems,
308
- band_specs=band_specs,
309
- emb_dim=emb_dim,
310
- mlp_dim=mlp_dim,
311
- cond_dim=cond_dim,
312
- hidden_activation=hidden_activation,
313
- hidden_activation_kwargs=hidden_activation_kwargs,
314
- complex_mask=complex_mask,
315
- overlapping_band=overlapping_band,
316
- freq_weights=freq_weights,
317
- n_freq=n_freq,
318
- use_freq_weights=use_freq_weights,
319
- mult_add_mask=mult_add_mask,
320
- )
321
-
322
- @staticmethod
323
- def _mult_add_mask(x, m):
324
-
325
- assert m.ndim == 5
326
-
327
- mm = m[..., 0]
328
- am = m[..., 1]
329
-
330
- return x * mm + am
331
-
332
- def mask(self, x, m):
333
- if self.mult_add_mask:
334
-
335
- return self._mult_add_mask(x, m)
336
- else:
337
- return super().mask(x, m)
338
-
339
-
340
- class MultiSourceMultiMaskBandSplitCoreTransformer(
341
- MultiMaskBandSplitCoreBase,
342
- ):
343
- def __init__(
344
- self,
345
- in_channel: int,
346
- stems: List[str],
347
- band_specs: List[Tuple[float, float]],
348
- require_no_overlap: bool = False,
349
- require_no_gap: bool = True,
350
- normalize_channel_independently: bool = False,
351
- treat_channel_as_feature: bool = True,
352
- n_sqm_modules: int = 12,
353
- emb_dim: int = 128,
354
- rnn_dim: int = 256,
355
- bidirectional: bool = True,
356
- tf_dropout: float = 0.0,
357
- mlp_dim: int = 512,
358
- hidden_activation: str = "Tanh",
359
- hidden_activation_kwargs: Optional[Dict] = None,
360
- complex_mask: bool = True,
361
- overlapping_band: bool = False,
362
- freq_weights: Optional[List[torch.Tensor]] = None,
363
- n_freq: Optional[int] = None,
364
- use_freq_weights: bool = True,
365
- rnn_type: str = "LSTM",
366
- cond_dim: int = 0,
367
- mult_add_mask: bool = False,
368
- ) -> None:
369
- super().__init__()
370
- self.instantiate_bandsplit(
371
- in_channel=in_channel,
372
- band_specs=band_specs,
373
- require_no_overlap=require_no_overlap,
374
- require_no_gap=require_no_gap,
375
- normalize_channel_independently=normalize_channel_independently,
376
- treat_channel_as_feature=treat_channel_as_feature,
377
- emb_dim=emb_dim,
378
- )
379
- self.tf_model = TransformerTimeFreqModule(
380
- n_modules=n_sqm_modules,
381
- emb_dim=emb_dim,
382
- rnn_dim=rnn_dim,
383
- bidirectional=bidirectional,
384
- dropout=tf_dropout,
385
- )
386
-
387
- self.instantiate_mask_estim(
388
- in_channel=in_channel,
389
- stems=stems,
390
- band_specs=band_specs,
391
- emb_dim=emb_dim,
392
- mlp_dim=mlp_dim,
393
- cond_dim=cond_dim,
394
- hidden_activation=hidden_activation,
395
- hidden_activation_kwargs=hidden_activation_kwargs,
396
- complex_mask=complex_mask,
397
- overlapping_band=overlapping_band,
398
- freq_weights=freq_weights,
399
- n_freq=n_freq,
400
- use_freq_weights=use_freq_weights,
401
- mult_add_mask=mult_add_mask,
402
- )
403
-
404
-
405
- class MultiSourceMultiMaskBandSplitCoreConv(
406
- MultiMaskBandSplitCoreBase,
407
- ):
408
- def __init__(
409
- self,
410
- in_channel: int,
411
- stems: List[str],
412
- band_specs: List[Tuple[float, float]],
413
- require_no_overlap: bool = False,
414
- require_no_gap: bool = True,
415
- normalize_channel_independently: bool = False,
416
- treat_channel_as_feature: bool = True,
417
- n_sqm_modules: int = 12,
418
- emb_dim: int = 128,
419
- rnn_dim: int = 256,
420
- bidirectional: bool = True,
421
- tf_dropout: float = 0.0,
422
- mlp_dim: int = 512,
423
- hidden_activation: str = "Tanh",
424
- hidden_activation_kwargs: Optional[Dict] = None,
425
- complex_mask: bool = True,
426
- overlapping_band: bool = False,
427
- freq_weights: Optional[List[torch.Tensor]] = None,
428
- n_freq: Optional[int] = None,
429
- use_freq_weights: bool = True,
430
- rnn_type: str = "LSTM",
431
- cond_dim: int = 0,
432
- mult_add_mask: bool = False,
433
- ) -> None:
434
- super().__init__()
435
- self.instantiate_bandsplit(
436
- in_channel=in_channel,
437
- band_specs=band_specs,
438
- require_no_overlap=require_no_overlap,
439
- require_no_gap=require_no_gap,
440
- normalize_channel_independently=normalize_channel_independently,
441
- treat_channel_as_feature=treat_channel_as_feature,
442
- emb_dim=emb_dim,
443
- )
444
- self.tf_model = ConvolutionalTimeFreqModule(
445
- n_modules=n_sqm_modules,
446
- emb_dim=emb_dim,
447
- rnn_dim=rnn_dim,
448
- bidirectional=bidirectional,
449
- dropout=tf_dropout,
450
- )
451
-
452
- self.instantiate_mask_estim(
453
- in_channel=in_channel,
454
- stems=stems,
455
- band_specs=band_specs,
456
- emb_dim=emb_dim,
457
- mlp_dim=mlp_dim,
458
- cond_dim=cond_dim,
459
- hidden_activation=hidden_activation,
460
- hidden_activation_kwargs=hidden_activation_kwargs,
461
- complex_mask=complex_mask,
462
- overlapping_band=overlapping_band,
463
- freq_weights=freq_weights,
464
- n_freq=n_freq,
465
- use_freq_weights=use_freq_weights,
466
- mult_add_mask=mult_add_mask,
467
- )
468
-
469
-
470
- class PatchingMaskBandsplitCoreBase(MultiMaskBandSplitCoreBase):
471
- def __init__(self) -> None:
472
- super().__init__()
473
-
474
- def mask(self, x, m):
475
-
476
- _, n_channel, kernel_freq, kernel_time, n_freq, n_time = m.shape
477
- padding = ((kernel_freq - 1) // 2, (kernel_time - 1) // 2)
478
-
479
- xf = F.unfold(
480
- x,
481
- kernel_size=(kernel_freq, kernel_time),
482
- padding=padding,
483
- stride=(1, 1),
484
- )
485
-
486
- xf = xf.view(
487
- -1,
488
- n_channel,
489
- kernel_freq,
490
- kernel_time,
491
- n_freq,
492
- n_time,
493
- )
494
-
495
- sf = xf * m
496
-
497
- sf = sf.view(
498
- -1,
499
- n_channel * kernel_freq * kernel_time,
500
- n_freq * n_time,
501
- )
502
-
503
- s = F.fold(
504
- sf,
505
- output_size=(n_freq, n_time),
506
- kernel_size=(kernel_freq, kernel_time),
507
- padding=padding,
508
- stride=(1, 1),
509
- ).view(
510
- -1,
511
- n_channel,
512
- n_freq,
513
- n_time,
514
- )
515
-
516
- return s
517
-
518
- def old_mask(self, x, m):
519
-
520
- s = torch.zeros_like(x)
521
-
522
- _, n_channel, n_freq, n_time = x.shape
523
- kernel_freq, kernel_time, _, _, _, _ = m.shape
524
-
525
- kernel_freq_half = (kernel_freq - 1) // 2
526
- kernel_time_half = (kernel_time - 1) // 2
527
-
528
- for ifreq in range(kernel_freq):
529
- for itime in range(kernel_time):
530
- df, dt = kernel_freq_half - ifreq, kernel_time_half - itime
531
- x = x.roll(shifts=(df, dt), dims=(2, 3))
532
-
533
- fslice = slice(max(0, df), min(n_freq, n_freq + df))
534
- tslice = slice(max(0, dt), min(n_time, n_time + dt))
535
-
536
- s[:, :, fslice, tslice] += (
537
- x[:, :, fslice, tslice] * m[ifreq, itime, :, :, fslice, tslice]
538
- )
539
-
540
- return s
541
-
542
-
543
- class MultiSourceMultiPatchingMaskBandSplitCoreRNN(PatchingMaskBandsplitCoreBase):
544
- def __init__(
545
- self,
546
- in_channel: int,
547
- stems: List[str],
548
- band_specs: List[Tuple[float, float]],
549
- mask_kernel_freq: int,
550
- mask_kernel_time: int,
551
- conv_kernel_freq: int,
552
- conv_kernel_time: int,
553
- kernel_norm_mlp_version: int,
554
- require_no_overlap: bool = False,
555
- require_no_gap: bool = True,
556
- normalize_channel_independently: bool = False,
557
- treat_channel_as_feature: bool = True,
558
- n_sqm_modules: int = 12,
559
- emb_dim: int = 128,
560
- rnn_dim: int = 256,
561
- bidirectional: bool = True,
562
- rnn_type: str = "LSTM",
563
- mlp_dim: int = 512,
564
- hidden_activation: str = "Tanh",
565
- hidden_activation_kwargs: Optional[Dict] = None,
566
- complex_mask: bool = True,
567
- overlapping_band: bool = False,
568
- freq_weights: Optional[List[torch.Tensor]] = None,
569
- n_freq: Optional[int] = None,
570
- ) -> None:
571
-
572
- super().__init__()
573
- self.band_split = BandSplitModule(
574
- in_channel=in_channel,
575
- band_specs=band_specs,
576
- require_no_overlap=require_no_overlap,
577
- require_no_gap=require_no_gap,
578
- normalize_channel_independently=normalize_channel_independently,
579
- treat_channel_as_feature=treat_channel_as_feature,
580
- emb_dim=emb_dim,
581
- )
582
-
583
- self.tf_model = SeqBandModellingModule(
584
- n_modules=n_sqm_modules,
585
- emb_dim=emb_dim,
586
- rnn_dim=rnn_dim,
587
- bidirectional=bidirectional,
588
- rnn_type=rnn_type,
589
- )
590
-
591
- if hidden_activation_kwargs is None:
592
- hidden_activation_kwargs = {}
593
-
594
- if overlapping_band:
595
- assert freq_weights is not None
596
- assert n_freq is not None
597
- self.mask_estim = nn.ModuleDict(
598
- {
599
- stem: PatchingMaskEstimationModule(
600
- band_specs=band_specs,
601
- freq_weights=freq_weights,
602
- n_freq=n_freq,
603
- emb_dim=emb_dim,
604
- mlp_dim=mlp_dim,
605
- in_channel=in_channel,
606
- hidden_activation=hidden_activation,
607
- hidden_activation_kwargs=hidden_activation_kwargs,
608
- complex_mask=complex_mask,
609
- mask_kernel_freq=mask_kernel_freq,
610
- mask_kernel_time=mask_kernel_time,
611
- conv_kernel_freq=conv_kernel_freq,
612
- conv_kernel_time=conv_kernel_time,
613
- kernel_norm_mlp_version=kernel_norm_mlp_version,
614
- )
615
- for stem in stems
616
- }
617
- )
618
- else:
619
- raise NotImplementedError
 
1
+ from typing import Dict, List, Optional, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ from . import BandsplitCoreBase
8
+ from .bandsplit import BandSplitModule
9
+ from .maskestim import (
10
+ MaskEstimationModule,
11
+ OverlappingMaskEstimationModule,
12
+ )
13
+ from .tfmodel import (
14
+ ConvolutionalTimeFreqModule,
15
+ SeqBandModellingModule,
16
+ TransformerTimeFreqModule,
17
+ )
18
+
19
+
20
+ class MultiMaskBandSplitCoreBase(BandsplitCoreBase):
21
+ def __init__(self) -> None:
22
+ super().__init__()
23
+
24
+ def forward(self, x, cond=None, compute_residual: bool = True):
25
+ batch, in_chan, n_freq, n_time = x.shape
26
+ x = torch.reshape(x, (-1, 1, n_freq, n_time))
27
+
28
+ z = self.band_split(x)
29
+
30
+ q = self.tf_model(z)
31
+
32
+ out = {}
33
+
34
+ for stem, mem in self.mask_estim.items():
35
+ m = mem(q, cond=cond)
36
+
37
+ s = self.mask(x, m)
38
+ s = torch.reshape(s, (batch, in_chan, n_freq, n_time))
39
+ out[stem] = s
40
+
41
+ return {"spectrogram": out}
42
+
43
+ def instantiate_mask_estim(
44
+ self,
45
+ in_channel: int,
46
+ stems: List[str],
47
+ band_specs: List[Tuple[float, float]],
48
+ emb_dim: int,
49
+ mlp_dim: int,
50
+ cond_dim: int,
51
+ hidden_activation: str,
52
+ hidden_activation_kwargs: Optional[Dict] = None,
53
+ complex_mask: bool = True,
54
+ overlapping_band: bool = False,
55
+ freq_weights: Optional[List[torch.Tensor]] = None,
56
+ n_freq: Optional[int] = None,
57
+ use_freq_weights: bool = True,
58
+ mult_add_mask: bool = False,
59
+ ):
60
+ if hidden_activation_kwargs is None:
61
+ hidden_activation_kwargs = {}
62
+
63
+ if "mne:+" in stems:
64
+ stems = [s for s in stems if s != "mne:+"]
65
+
66
+ if overlapping_band:
67
+ assert freq_weights is not None
68
+ assert n_freq is not None
69
+
70
+ if mult_add_mask:
71
+
72
+ self.mask_estim = nn.ModuleDict(
73
+ {
74
+ stem: MultAddMaskEstimationModule(
75
+ band_specs=band_specs,
76
+ freq_weights=freq_weights,
77
+ n_freq=n_freq,
78
+ emb_dim=emb_dim,
79
+ mlp_dim=mlp_dim,
80
+ in_channel=in_channel,
81
+ hidden_activation=hidden_activation,
82
+ hidden_activation_kwargs=hidden_activation_kwargs,
83
+ complex_mask=complex_mask,
84
+ use_freq_weights=use_freq_weights,
85
+ )
86
+ for stem in stems
87
+ }
88
+ )
89
+ else:
90
+ self.mask_estim = nn.ModuleDict(
91
+ {
92
+ stem: OverlappingMaskEstimationModule(
93
+ band_specs=band_specs,
94
+ freq_weights=freq_weights,
95
+ n_freq=n_freq,
96
+ emb_dim=emb_dim,
97
+ mlp_dim=mlp_dim,
98
+ in_channel=in_channel,
99
+ hidden_activation=hidden_activation,
100
+ hidden_activation_kwargs=hidden_activation_kwargs,
101
+ complex_mask=complex_mask,
102
+ use_freq_weights=use_freq_weights,
103
+ )
104
+ for stem in stems
105
+ }
106
+ )
107
+ else:
108
+ self.mask_estim = nn.ModuleDict(
109
+ {
110
+ stem: MaskEstimationModule(
111
+ band_specs=band_specs,
112
+ emb_dim=emb_dim,
113
+ mlp_dim=mlp_dim,
114
+ cond_dim=cond_dim,
115
+ in_channel=in_channel,
116
+ hidden_activation=hidden_activation,
117
+ hidden_activation_kwargs=hidden_activation_kwargs,
118
+ complex_mask=complex_mask,
119
+ )
120
+ for stem in stems
121
+ }
122
+ )
123
+
124
+ def instantiate_bandsplit(
125
+ self,
126
+ in_channel: int,
127
+ band_specs: List[Tuple[float, float]],
128
+ require_no_overlap: bool = False,
129
+ require_no_gap: bool = True,
130
+ normalize_channel_independently: bool = False,
131
+ treat_channel_as_feature: bool = True,
132
+ emb_dim: int = 128,
133
+ ):
134
+ self.band_split = BandSplitModule(
135
+ in_channel=in_channel,
136
+ band_specs=band_specs,
137
+ require_no_overlap=require_no_overlap,
138
+ require_no_gap=require_no_gap,
139
+ normalize_channel_independently=normalize_channel_independently,
140
+ treat_channel_as_feature=treat_channel_as_feature,
141
+ emb_dim=emb_dim,
142
+ )
143
+
144
+
145
+ class SingleMaskBandsplitCoreBase(BandsplitCoreBase):
146
+ def __init__(self, **kwargs) -> None:
147
+ super().__init__()
148
+
149
+ def forward(self, x):
150
+ z = self.band_split(x)
151
+ q = self.tf_model(z)
152
+ m = self.mask_estim(q)
153
+
154
+ s = self.mask(x, m)
155
+
156
+ return s
157
+
158
+
159
+ class SingleMaskBandsplitCoreRNN(
160
+ SingleMaskBandsplitCoreBase,
161
+ ):
162
+ def __init__(
163
+ self,
164
+ in_channel: int,
165
+ band_specs: List[Tuple[float, float]],
166
+ require_no_overlap: bool = False,
167
+ require_no_gap: bool = True,
168
+ normalize_channel_independently: bool = False,
169
+ treat_channel_as_feature: bool = True,
170
+ n_sqm_modules: int = 12,
171
+ emb_dim: int = 128,
172
+ rnn_dim: int = 256,
173
+ bidirectional: bool = True,
174
+ rnn_type: str = "LSTM",
175
+ mlp_dim: int = 512,
176
+ hidden_activation: str = "Tanh",
177
+ hidden_activation_kwargs: Optional[Dict] = None,
178
+ complex_mask: bool = True,
179
+ ) -> None:
180
+ super().__init__()
181
+ self.band_split = BandSplitModule(
182
+ in_channel=in_channel,
183
+ band_specs=band_specs,
184
+ require_no_overlap=require_no_overlap,
185
+ require_no_gap=require_no_gap,
186
+ normalize_channel_independently=normalize_channel_independently,
187
+ treat_channel_as_feature=treat_channel_as_feature,
188
+ emb_dim=emb_dim,
189
+ )
190
+ self.tf_model = SeqBandModellingModule(
191
+ n_modules=n_sqm_modules,
192
+ emb_dim=emb_dim,
193
+ rnn_dim=rnn_dim,
194
+ bidirectional=bidirectional,
195
+ rnn_type=rnn_type,
196
+ )
197
+ self.mask_estim = MaskEstimationModule(
198
+ in_channel=in_channel,
199
+ band_specs=band_specs,
200
+ emb_dim=emb_dim,
201
+ mlp_dim=mlp_dim,
202
+ hidden_activation=hidden_activation,
203
+ hidden_activation_kwargs=hidden_activation_kwargs,
204
+ complex_mask=complex_mask,
205
+ )
206
+
207
+
208
+ class SingleMaskBandsplitCoreTransformer(
209
+ SingleMaskBandsplitCoreBase,
210
+ ):
211
+ def __init__(
212
+ self,
213
+ in_channel: int,
214
+ band_specs: List[Tuple[float, float]],
215
+ require_no_overlap: bool = False,
216
+ require_no_gap: bool = True,
217
+ normalize_channel_independently: bool = False,
218
+ treat_channel_as_feature: bool = True,
219
+ n_sqm_modules: int = 12,
220
+ emb_dim: int = 128,
221
+ rnn_dim: int = 256,
222
+ bidirectional: bool = True,
223
+ tf_dropout: float = 0.0,
224
+ mlp_dim: int = 512,
225
+ hidden_activation: str = "Tanh",
226
+ hidden_activation_kwargs: Optional[Dict] = None,
227
+ complex_mask: bool = True,
228
+ ) -> None:
229
+ super().__init__()
230
+ self.band_split = BandSplitModule(
231
+ in_channel=in_channel,
232
+ band_specs=band_specs,
233
+ require_no_overlap=require_no_overlap,
234
+ require_no_gap=require_no_gap,
235
+ normalize_channel_independently=normalize_channel_independently,
236
+ treat_channel_as_feature=treat_channel_as_feature,
237
+ emb_dim=emb_dim,
238
+ )
239
+ self.tf_model = TransformerTimeFreqModule(
240
+ n_modules=n_sqm_modules,
241
+ emb_dim=emb_dim,
242
+ rnn_dim=rnn_dim,
243
+ bidirectional=bidirectional,
244
+ dropout=tf_dropout,
245
+ )
246
+ self.mask_estim = MaskEstimationModule(
247
+ in_channel=in_channel,
248
+ band_specs=band_specs,
249
+ emb_dim=emb_dim,
250
+ mlp_dim=mlp_dim,
251
+ hidden_activation=hidden_activation,
252
+ hidden_activation_kwargs=hidden_activation_kwargs,
253
+ complex_mask=complex_mask,
254
+ )
255
+
256
+
257
+ class MultiSourceMultiMaskBandSplitCoreRNN(MultiMaskBandSplitCoreBase):
258
+ def __init__(
259
+ self,
260
+ in_channel: int,
261
+ stems: List[str],
262
+ band_specs: List[Tuple[float, float]],
263
+ require_no_overlap: bool = False,
264
+ require_no_gap: bool = True,
265
+ normalize_channel_independently: bool = False,
266
+ treat_channel_as_feature: bool = True,
267
+ n_sqm_modules: int = 12,
268
+ emb_dim: int = 128,
269
+ rnn_dim: int = 256,
270
+ bidirectional: bool = True,
271
+ rnn_type: str = "LSTM",
272
+ mlp_dim: int = 512,
273
+ cond_dim: int = 0,
274
+ hidden_activation: str = "Tanh",
275
+ hidden_activation_kwargs: Optional[Dict] = None,
276
+ complex_mask: bool = True,
277
+ overlapping_band: bool = False,
278
+ freq_weights: Optional[List[torch.Tensor]] = None,
279
+ n_freq: Optional[int] = None,
280
+ use_freq_weights: bool = True,
281
+ mult_add_mask: bool = False,
282
+ ) -> None:
283
+
284
+ super().__init__()
285
+ self.instantiate_bandsplit(
286
+ in_channel=in_channel,
287
+ band_specs=band_specs,
288
+ require_no_overlap=require_no_overlap,
289
+ require_no_gap=require_no_gap,
290
+ normalize_channel_independently=normalize_channel_independently,
291
+ treat_channel_as_feature=treat_channel_as_feature,
292
+ emb_dim=emb_dim,
293
+ )
294
+
295
+ self.tf_model = SeqBandModellingModule(
296
+ n_modules=n_sqm_modules,
297
+ emb_dim=emb_dim,
298
+ rnn_dim=rnn_dim,
299
+ bidirectional=bidirectional,
300
+ rnn_type=rnn_type,
301
+ )
302
+
303
+ self.mult_add_mask = mult_add_mask
304
+
305
+ self.instantiate_mask_estim(
306
+ in_channel=in_channel,
307
+ stems=stems,
308
+ band_specs=band_specs,
309
+ emb_dim=emb_dim,
310
+ mlp_dim=mlp_dim,
311
+ cond_dim=cond_dim,
312
+ hidden_activation=hidden_activation,
313
+ hidden_activation_kwargs=hidden_activation_kwargs,
314
+ complex_mask=complex_mask,
315
+ overlapping_band=overlapping_band,
316
+ freq_weights=freq_weights,
317
+ n_freq=n_freq,
318
+ use_freq_weights=use_freq_weights,
319
+ mult_add_mask=mult_add_mask,
320
+ )
321
+
322
+ @staticmethod
323
+ def _mult_add_mask(x, m):
324
+
325
+ assert m.ndim == 5
326
+
327
+ mm = m[..., 0]
328
+ am = m[..., 1]
329
+
330
+ return x * mm + am
331
+
332
+ def mask(self, x, m):
333
+ if self.mult_add_mask:
334
+
335
+ return self._mult_add_mask(x, m)
336
+ else:
337
+ return super().mask(x, m)
338
+
339
+
340
+ class MultiSourceMultiMaskBandSplitCoreTransformer(
341
+ MultiMaskBandSplitCoreBase,
342
+ ):
343
+ def __init__(
344
+ self,
345
+ in_channel: int,
346
+ stems: List[str],
347
+ band_specs: List[Tuple[float, float]],
348
+ require_no_overlap: bool = False,
349
+ require_no_gap: bool = True,
350
+ normalize_channel_independently: bool = False,
351
+ treat_channel_as_feature: bool = True,
352
+ n_sqm_modules: int = 12,
353
+ emb_dim: int = 128,
354
+ rnn_dim: int = 256,
355
+ bidirectional: bool = True,
356
+ tf_dropout: float = 0.0,
357
+ mlp_dim: int = 512,
358
+ hidden_activation: str = "Tanh",
359
+ hidden_activation_kwargs: Optional[Dict] = None,
360
+ complex_mask: bool = True,
361
+ overlapping_band: bool = False,
362
+ freq_weights: Optional[List[torch.Tensor]] = None,
363
+ n_freq: Optional[int] = None,
364
+ use_freq_weights: bool = True,
365
+ rnn_type: str = "LSTM",
366
+ cond_dim: int = 0,
367
+ mult_add_mask: bool = False,
368
+ ) -> None:
369
+ super().__init__()
370
+ self.instantiate_bandsplit(
371
+ in_channel=in_channel,
372
+ band_specs=band_specs,
373
+ require_no_overlap=require_no_overlap,
374
+ require_no_gap=require_no_gap,
375
+ normalize_channel_independently=normalize_channel_independently,
376
+ treat_channel_as_feature=treat_channel_as_feature,
377
+ emb_dim=emb_dim,
378
+ )
379
+ self.tf_model = TransformerTimeFreqModule(
380
+ n_modules=n_sqm_modules,
381
+ emb_dim=emb_dim,
382
+ rnn_dim=rnn_dim,
383
+ bidirectional=bidirectional,
384
+ dropout=tf_dropout,
385
+ )
386
+
387
+ self.instantiate_mask_estim(
388
+ in_channel=in_channel,
389
+ stems=stems,
390
+ band_specs=band_specs,
391
+ emb_dim=emb_dim,
392
+ mlp_dim=mlp_dim,
393
+ cond_dim=cond_dim,
394
+ hidden_activation=hidden_activation,
395
+ hidden_activation_kwargs=hidden_activation_kwargs,
396
+ complex_mask=complex_mask,
397
+ overlapping_band=overlapping_band,
398
+ freq_weights=freq_weights,
399
+ n_freq=n_freq,
400
+ use_freq_weights=use_freq_weights,
401
+ mult_add_mask=mult_add_mask,
402
+ )
403
+
404
+
405
+ class MultiSourceMultiMaskBandSplitCoreConv(
406
+ MultiMaskBandSplitCoreBase,
407
+ ):
408
+ def __init__(
409
+ self,
410
+ in_channel: int,
411
+ stems: List[str],
412
+ band_specs: List[Tuple[float, float]],
413
+ require_no_overlap: bool = False,
414
+ require_no_gap: bool = True,
415
+ normalize_channel_independently: bool = False,
416
+ treat_channel_as_feature: bool = True,
417
+ n_sqm_modules: int = 12,
418
+ emb_dim: int = 128,
419
+ rnn_dim: int = 256,
420
+ bidirectional: bool = True,
421
+ tf_dropout: float = 0.0,
422
+ mlp_dim: int = 512,
423
+ hidden_activation: str = "Tanh",
424
+ hidden_activation_kwargs: Optional[Dict] = None,
425
+ complex_mask: bool = True,
426
+ overlapping_band: bool = False,
427
+ freq_weights: Optional[List[torch.Tensor]] = None,
428
+ n_freq: Optional[int] = None,
429
+ use_freq_weights: bool = True,
430
+ rnn_type: str = "LSTM",
431
+ cond_dim: int = 0,
432
+ mult_add_mask: bool = False,
433
+ ) -> None:
434
+ super().__init__()
435
+ self.instantiate_bandsplit(
436
+ in_channel=in_channel,
437
+ band_specs=band_specs,
438
+ require_no_overlap=require_no_overlap,
439
+ require_no_gap=require_no_gap,
440
+ normalize_channel_independently=normalize_channel_independently,
441
+ treat_channel_as_feature=treat_channel_as_feature,
442
+ emb_dim=emb_dim,
443
+ )
444
+ self.tf_model = ConvolutionalTimeFreqModule(
445
+ n_modules=n_sqm_modules,
446
+ emb_dim=emb_dim,
447
+ rnn_dim=rnn_dim,
448
+ bidirectional=bidirectional,
449
+ dropout=tf_dropout,
450
+ )
451
+
452
+ self.instantiate_mask_estim(
453
+ in_channel=in_channel,
454
+ stems=stems,
455
+ band_specs=band_specs,
456
+ emb_dim=emb_dim,
457
+ mlp_dim=mlp_dim,
458
+ cond_dim=cond_dim,
459
+ hidden_activation=hidden_activation,
460
+ hidden_activation_kwargs=hidden_activation_kwargs,
461
+ complex_mask=complex_mask,
462
+ overlapping_band=overlapping_band,
463
+ freq_weights=freq_weights,
464
+ n_freq=n_freq,
465
+ use_freq_weights=use_freq_weights,
466
+ mult_add_mask=mult_add_mask,
467
+ )
468
+
469
+
470
+ class PatchingMaskBandsplitCoreBase(MultiMaskBandSplitCoreBase):
471
+ def __init__(self) -> None:
472
+ super().__init__()
473
+
474
+ def mask(self, x, m):
475
+
476
+ _, n_channel, kernel_freq, kernel_time, n_freq, n_time = m.shape
477
+ padding = ((kernel_freq - 1) // 2, (kernel_time - 1) // 2)
478
+
479
+ xf = F.unfold(
480
+ x,
481
+ kernel_size=(kernel_freq, kernel_time),
482
+ padding=padding,
483
+ stride=(1, 1),
484
+ )
485
+
486
+ xf = xf.view(
487
+ -1,
488
+ n_channel,
489
+ kernel_freq,
490
+ kernel_time,
491
+ n_freq,
492
+ n_time,
493
+ )
494
+
495
+ sf = xf * m
496
+
497
+ sf = sf.view(
498
+ -1,
499
+ n_channel * kernel_freq * kernel_time,
500
+ n_freq * n_time,
501
+ )
502
+
503
+ s = F.fold(
504
+ sf,
505
+ output_size=(n_freq, n_time),
506
+ kernel_size=(kernel_freq, kernel_time),
507
+ padding=padding,
508
+ stride=(1, 1),
509
+ ).view(
510
+ -1,
511
+ n_channel,
512
+ n_freq,
513
+ n_time,
514
+ )
515
+
516
+ return s
517
+
518
+ def old_mask(self, x, m):
519
+
520
+ s = torch.zeros_like(x)
521
+
522
+ _, n_channel, n_freq, n_time = x.shape
523
+ kernel_freq, kernel_time, _, _, _, _ = m.shape
524
+
525
+ kernel_freq_half = (kernel_freq - 1) // 2
526
+ kernel_time_half = (kernel_time - 1) // 2
527
+
528
+ for ifreq in range(kernel_freq):
529
+ for itime in range(kernel_time):
530
+ df, dt = kernel_freq_half - ifreq, kernel_time_half - itime
531
+ x = x.roll(shifts=(df, dt), dims=(2, 3))
532
+
533
+ fslice = slice(max(0, df), min(n_freq, n_freq + df))
534
+ tslice = slice(max(0, dt), min(n_time, n_time + dt))
535
+
536
+ s[:, :, fslice, tslice] += (
537
+ x[:, :, fslice, tslice] * m[ifreq, itime, :, :, fslice, tslice]
538
+ )
539
+
540
+ return s
541
+
542
+
543
+ class MultiSourceMultiPatchingMaskBandSplitCoreRNN(PatchingMaskBandsplitCoreBase):
544
+ def __init__(
545
+ self,
546
+ in_channel: int,
547
+ stems: List[str],
548
+ band_specs: List[Tuple[float, float]],
549
+ mask_kernel_freq: int,
550
+ mask_kernel_time: int,
551
+ conv_kernel_freq: int,
552
+ conv_kernel_time: int,
553
+ kernel_norm_mlp_version: int,
554
+ require_no_overlap: bool = False,
555
+ require_no_gap: bool = True,
556
+ normalize_channel_independently: bool = False,
557
+ treat_channel_as_feature: bool = True,
558
+ n_sqm_modules: int = 12,
559
+ emb_dim: int = 128,
560
+ rnn_dim: int = 256,
561
+ bidirectional: bool = True,
562
+ rnn_type: str = "LSTM",
563
+ mlp_dim: int = 512,
564
+ hidden_activation: str = "Tanh",
565
+ hidden_activation_kwargs: Optional[Dict] = None,
566
+ complex_mask: bool = True,
567
+ overlapping_band: bool = False,
568
+ freq_weights: Optional[List[torch.Tensor]] = None,
569
+ n_freq: Optional[int] = None,
570
+ ) -> None:
571
+
572
+ super().__init__()
573
+ self.band_split = BandSplitModule(
574
+ in_channel=in_channel,
575
+ band_specs=band_specs,
576
+ require_no_overlap=require_no_overlap,
577
+ require_no_gap=require_no_gap,
578
+ normalize_channel_independently=normalize_channel_independently,
579
+ treat_channel_as_feature=treat_channel_as_feature,
580
+ emb_dim=emb_dim,
581
+ )
582
+
583
+ self.tf_model = SeqBandModellingModule(
584
+ n_modules=n_sqm_modules,
585
+ emb_dim=emb_dim,
586
+ rnn_dim=rnn_dim,
587
+ bidirectional=bidirectional,
588
+ rnn_type=rnn_type,
589
+ )
590
+
591
+ if hidden_activation_kwargs is None:
592
+ hidden_activation_kwargs = {}
593
+
594
+ if overlapping_band:
595
+ assert freq_weights is not None
596
+ assert n_freq is not None
597
+ self.mask_estim = nn.ModuleDict(
598
+ {
599
+ stem: PatchingMaskEstimationModule(
600
+ band_specs=band_specs,
601
+ freq_weights=freq_weights,
602
+ n_freq=n_freq,
603
+ emb_dim=emb_dim,
604
+ mlp_dim=mlp_dim,
605
+ in_channel=in_channel,
606
+ hidden_activation=hidden_activation,
607
+ hidden_activation_kwargs=hidden_activation_kwargs,
608
+ complex_mask=complex_mask,
609
+ mask_kernel_freq=mask_kernel_freq,
610
+ mask_kernel_time=mask_kernel_time,
611
+ conv_kernel_freq=conv_kernel_freq,
612
+ conv_kernel_time=conv_kernel_time,
613
+ kernel_norm_mlp_version=kernel_norm_mlp_version,
614
+ )
615
+ for stem in stems
616
+ }
617
+ )
618
+ else:
619
+ raise NotImplementedError
mvsepless/models/bandit/core/model/bsrnn/maskestim.py CHANGED
@@ -1,327 +1,327 @@
1
- import warnings
2
- from typing import Dict, List, Optional, Tuple, Type
3
-
4
- import torch
5
- from torch import nn
6
- from torch.nn.modules import activation
7
-
8
- from .utils import (
9
- band_widths_from_specs,
10
- check_no_gap,
11
- check_no_overlap,
12
- check_nonzero_bandwidth,
13
- )
14
-
15
-
16
- class BaseNormMLP(nn.Module):
17
- def __init__(
18
- self,
19
- emb_dim: int,
20
- mlp_dim: int,
21
- bandwidth: int,
22
- in_channel: Optional[int],
23
- hidden_activation: str = "Tanh",
24
- hidden_activation_kwargs=None,
25
- complex_mask: bool = True,
26
- ):
27
-
28
- super().__init__()
29
- if hidden_activation_kwargs is None:
30
- hidden_activation_kwargs = {}
31
- self.hidden_activation_kwargs = hidden_activation_kwargs
32
- self.norm = nn.LayerNorm(emb_dim)
33
- self.hidden = torch.jit.script(
34
- nn.Sequential(
35
- nn.Linear(in_features=emb_dim, out_features=mlp_dim),
36
- activation.__dict__[hidden_activation](**self.hidden_activation_kwargs),
37
- )
38
- )
39
-
40
- self.bandwidth = bandwidth
41
- self.in_channel = in_channel
42
-
43
- self.complex_mask = complex_mask
44
- self.reim = 2 if complex_mask else 1
45
- self.glu_mult = 2
46
-
47
-
48
- class NormMLP(BaseNormMLP):
49
- def __init__(
50
- self,
51
- emb_dim: int,
52
- mlp_dim: int,
53
- bandwidth: int,
54
- in_channel: Optional[int],
55
- hidden_activation: str = "Tanh",
56
- hidden_activation_kwargs=None,
57
- complex_mask: bool = True,
58
- ) -> None:
59
- super().__init__(
60
- emb_dim=emb_dim,
61
- mlp_dim=mlp_dim,
62
- bandwidth=bandwidth,
63
- in_channel=in_channel,
64
- hidden_activation=hidden_activation,
65
- hidden_activation_kwargs=hidden_activation_kwargs,
66
- complex_mask=complex_mask,
67
- )
68
-
69
- self.output = torch.jit.script(
70
- nn.Sequential(
71
- nn.Linear(
72
- in_features=mlp_dim,
73
- out_features=bandwidth * in_channel * self.reim * 2,
74
- ),
75
- nn.GLU(dim=-1),
76
- )
77
- )
78
-
79
- def reshape_output(self, mb):
80
- batch, n_time, _ = mb.shape
81
- if self.complex_mask:
82
- mb = mb.reshape(
83
- batch, n_time, self.in_channel, self.bandwidth, self.reim
84
- ).contiguous()
85
- mb = torch.view_as_complex(mb)
86
- else:
87
- mb = mb.reshape(batch, n_time, self.in_channel, self.bandwidth)
88
-
89
- mb = torch.permute(mb, (0, 2, 3, 1))
90
-
91
- return mb
92
-
93
- def forward(self, qb):
94
-
95
- qb = self.norm(qb)
96
-
97
- qb = self.hidden(qb)
98
- mb = self.output(qb)
99
- mb = self.reshape_output(mb)
100
-
101
- return mb
102
-
103
-
104
- class MultAddNormMLP(NormMLP):
105
- def __init__(
106
- self,
107
- emb_dim: int,
108
- mlp_dim: int,
109
- bandwidth: int,
110
- in_channel: "int | None",
111
- hidden_activation: str = "Tanh",
112
- hidden_activation_kwargs=None,
113
- complex_mask: bool = True,
114
- ) -> None:
115
- super().__init__(
116
- emb_dim,
117
- mlp_dim,
118
- bandwidth,
119
- in_channel,
120
- hidden_activation,
121
- hidden_activation_kwargs,
122
- complex_mask,
123
- )
124
-
125
- self.output2 = torch.jit.script(
126
- nn.Sequential(
127
- nn.Linear(
128
- in_features=mlp_dim,
129
- out_features=bandwidth * in_channel * self.reim * 2,
130
- ),
131
- nn.GLU(dim=-1),
132
- )
133
- )
134
-
135
- def forward(self, qb):
136
-
137
- qb = self.norm(qb)
138
- qb = self.hidden(qb)
139
- mmb = self.output(qb)
140
- mmb = self.reshape_output(mmb)
141
- amb = self.output2(qb)
142
- amb = self.reshape_output(amb)
143
-
144
- return mmb, amb
145
-
146
-
147
- class MaskEstimationModuleSuperBase(nn.Module):
148
- pass
149
-
150
-
151
- class MaskEstimationModuleBase(MaskEstimationModuleSuperBase):
152
- def __init__(
153
- self,
154
- band_specs: List[Tuple[float, float]],
155
- emb_dim: int,
156
- mlp_dim: int,
157
- in_channel: Optional[int],
158
- hidden_activation: str = "Tanh",
159
- hidden_activation_kwargs: Dict = None,
160
- complex_mask: bool = True,
161
- norm_mlp_cls: Type[nn.Module] = NormMLP,
162
- norm_mlp_kwargs: Dict = None,
163
- ) -> None:
164
- super().__init__()
165
-
166
- self.band_widths = band_widths_from_specs(band_specs)
167
- self.n_bands = len(band_specs)
168
-
169
- if hidden_activation_kwargs is None:
170
- hidden_activation_kwargs = {}
171
-
172
- if norm_mlp_kwargs is None:
173
- norm_mlp_kwargs = {}
174
-
175
- self.norm_mlp = nn.ModuleList(
176
- [
177
- (
178
- norm_mlp_cls(
179
- bandwidth=self.band_widths[b],
180
- emb_dim=emb_dim,
181
- mlp_dim=mlp_dim,
182
- in_channel=in_channel,
183
- hidden_activation=hidden_activation,
184
- hidden_activation_kwargs=hidden_activation_kwargs,
185
- complex_mask=complex_mask,
186
- **norm_mlp_kwargs,
187
- )
188
- )
189
- for b in range(self.n_bands)
190
- ]
191
- )
192
-
193
- def compute_masks(self, q):
194
- batch, n_bands, n_time, emb_dim = q.shape
195
-
196
- masks = []
197
-
198
- for b, nmlp in enumerate(self.norm_mlp):
199
- qb = q[:, b, :, :]
200
- mb = nmlp(qb)
201
- masks.append(mb)
202
-
203
- return masks
204
-
205
-
206
- class OverlappingMaskEstimationModule(MaskEstimationModuleBase):
207
- def __init__(
208
- self,
209
- in_channel: int,
210
- band_specs: List[Tuple[float, float]],
211
- freq_weights: List[torch.Tensor],
212
- n_freq: int,
213
- emb_dim: int,
214
- mlp_dim: int,
215
- cond_dim: int = 0,
216
- hidden_activation: str = "Tanh",
217
- hidden_activation_kwargs: Dict = None,
218
- complex_mask: bool = True,
219
- norm_mlp_cls: Type[nn.Module] = NormMLP,
220
- norm_mlp_kwargs: Dict = None,
221
- use_freq_weights: bool = True,
222
- ) -> None:
223
- check_nonzero_bandwidth(band_specs)
224
- check_no_gap(band_specs)
225
-
226
- super().__init__(
227
- band_specs=band_specs,
228
- emb_dim=emb_dim + cond_dim,
229
- mlp_dim=mlp_dim,
230
- in_channel=in_channel,
231
- hidden_activation=hidden_activation,
232
- hidden_activation_kwargs=hidden_activation_kwargs,
233
- complex_mask=complex_mask,
234
- norm_mlp_cls=norm_mlp_cls,
235
- norm_mlp_kwargs=norm_mlp_kwargs,
236
- )
237
-
238
- self.n_freq = n_freq
239
- self.band_specs = band_specs
240
- self.in_channel = in_channel
241
-
242
- if freq_weights is not None:
243
- for i, fw in enumerate(freq_weights):
244
- self.register_buffer(f"freq_weights/{i}", fw)
245
-
246
- self.use_freq_weights = use_freq_weights
247
- else:
248
- self.use_freq_weights = False
249
-
250
- self.cond_dim = cond_dim
251
-
252
- def forward(self, q, cond=None):
253
-
254
- batch, n_bands, n_time, emb_dim = q.shape
255
-
256
- if cond is not None:
257
- print(cond)
258
- if cond.ndim == 2:
259
- cond = cond[:, None, None, :].expand(-1, n_bands, n_time, -1)
260
- elif cond.ndim == 3:
261
- assert cond.shape[1] == n_time
262
- else:
263
- raise ValueError(f"Invalid cond shape: {cond.shape}")
264
-
265
- q = torch.cat([q, cond], dim=-1)
266
- elif self.cond_dim > 0:
267
- cond = torch.ones(
268
- (batch, n_bands, n_time, self.cond_dim),
269
- device=q.device,
270
- dtype=q.dtype,
271
- )
272
- q = torch.cat([q, cond], dim=-1)
273
- else:
274
- pass
275
-
276
- mask_list = self.compute_masks(q)
277
-
278
- masks = torch.zeros(
279
- (batch, self.in_channel, self.n_freq, n_time),
280
- device=q.device,
281
- dtype=mask_list[0].dtype,
282
- )
283
-
284
- for im, mask in enumerate(mask_list):
285
- fstart, fend = self.band_specs[im]
286
- if self.use_freq_weights:
287
- fw = self.get_buffer(f"freq_weights/{im}")[:, None]
288
- mask = mask * fw
289
- masks[:, :, fstart:fend, :] += mask
290
-
291
- return masks
292
-
293
-
294
- class MaskEstimationModule(OverlappingMaskEstimationModule):
295
- def __init__(
296
- self,
297
- band_specs: List[Tuple[float, float]],
298
- emb_dim: int,
299
- mlp_dim: int,
300
- in_channel: Optional[int],
301
- hidden_activation: str = "Tanh",
302
- hidden_activation_kwargs: Dict = None,
303
- complex_mask: bool = True,
304
- **kwargs,
305
- ) -> None:
306
- check_nonzero_bandwidth(band_specs)
307
- check_no_gap(band_specs)
308
- check_no_overlap(band_specs)
309
- super().__init__(
310
- in_channel=in_channel,
311
- band_specs=band_specs,
312
- freq_weights=None,
313
- n_freq=None,
314
- emb_dim=emb_dim,
315
- mlp_dim=mlp_dim,
316
- hidden_activation=hidden_activation,
317
- hidden_activation_kwargs=hidden_activation_kwargs,
318
- complex_mask=complex_mask,
319
- )
320
-
321
- def forward(self, q, cond=None):
322
-
323
- masks = self.compute_masks(q)
324
-
325
- masks = torch.concat(masks, dim=2)
326
-
327
- return masks
 
1
+ import warnings
2
+ from typing import Dict, List, Optional, Tuple, Type
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn.modules import activation
7
+
8
+ from .utils import (
9
+ band_widths_from_specs,
10
+ check_no_gap,
11
+ check_no_overlap,
12
+ check_nonzero_bandwidth,
13
+ )
14
+
15
+
16
+ class BaseNormMLP(nn.Module):
17
+ def __init__(
18
+ self,
19
+ emb_dim: int,
20
+ mlp_dim: int,
21
+ bandwidth: int,
22
+ in_channel: Optional[int],
23
+ hidden_activation: str = "Tanh",
24
+ hidden_activation_kwargs=None,
25
+ complex_mask: bool = True,
26
+ ):
27
+
28
+ super().__init__()
29
+ if hidden_activation_kwargs is None:
30
+ hidden_activation_kwargs = {}
31
+ self.hidden_activation_kwargs = hidden_activation_kwargs
32
+ self.norm = nn.LayerNorm(emb_dim)
33
+ self.hidden = torch.jit.script(
34
+ nn.Sequential(
35
+ nn.Linear(in_features=emb_dim, out_features=mlp_dim),
36
+ activation.__dict__[hidden_activation](**self.hidden_activation_kwargs),
37
+ )
38
+ )
39
+
40
+ self.bandwidth = bandwidth
41
+ self.in_channel = in_channel
42
+
43
+ self.complex_mask = complex_mask
44
+ self.reim = 2 if complex_mask else 1
45
+ self.glu_mult = 2
46
+
47
+
48
+ class NormMLP(BaseNormMLP):
49
+ def __init__(
50
+ self,
51
+ emb_dim: int,
52
+ mlp_dim: int,
53
+ bandwidth: int,
54
+ in_channel: Optional[int],
55
+ hidden_activation: str = "Tanh",
56
+ hidden_activation_kwargs=None,
57
+ complex_mask: bool = True,
58
+ ) -> None:
59
+ super().__init__(
60
+ emb_dim=emb_dim,
61
+ mlp_dim=mlp_dim,
62
+ bandwidth=bandwidth,
63
+ in_channel=in_channel,
64
+ hidden_activation=hidden_activation,
65
+ hidden_activation_kwargs=hidden_activation_kwargs,
66
+ complex_mask=complex_mask,
67
+ )
68
+
69
+ self.output = torch.jit.script(
70
+ nn.Sequential(
71
+ nn.Linear(
72
+ in_features=mlp_dim,
73
+ out_features=bandwidth * in_channel * self.reim * 2,
74
+ ),
75
+ nn.GLU(dim=-1),
76
+ )
77
+ )
78
+
79
+ def reshape_output(self, mb):
80
+ batch, n_time, _ = mb.shape
81
+ if self.complex_mask:
82
+ mb = mb.reshape(
83
+ batch, n_time, self.in_channel, self.bandwidth, self.reim
84
+ ).contiguous()
85
+ mb = torch.view_as_complex(mb)
86
+ else:
87
+ mb = mb.reshape(batch, n_time, self.in_channel, self.bandwidth)
88
+
89
+ mb = torch.permute(mb, (0, 2, 3, 1))
90
+
91
+ return mb
92
+
93
+ def forward(self, qb):
94
+
95
+ qb = self.norm(qb)
96
+
97
+ qb = self.hidden(qb)
98
+ mb = self.output(qb)
99
+ mb = self.reshape_output(mb)
100
+
101
+ return mb
102
+
103
+
104
+ class MultAddNormMLP(NormMLP):
105
+ def __init__(
106
+ self,
107
+ emb_dim: int,
108
+ mlp_dim: int,
109
+ bandwidth: int,
110
+ in_channel: "int | None",
111
+ hidden_activation: str = "Tanh",
112
+ hidden_activation_kwargs=None,
113
+ complex_mask: bool = True,
114
+ ) -> None:
115
+ super().__init__(
116
+ emb_dim,
117
+ mlp_dim,
118
+ bandwidth,
119
+ in_channel,
120
+ hidden_activation,
121
+ hidden_activation_kwargs,
122
+ complex_mask,
123
+ )
124
+
125
+ self.output2 = torch.jit.script(
126
+ nn.Sequential(
127
+ nn.Linear(
128
+ in_features=mlp_dim,
129
+ out_features=bandwidth * in_channel * self.reim * 2,
130
+ ),
131
+ nn.GLU(dim=-1),
132
+ )
133
+ )
134
+
135
+ def forward(self, qb):
136
+
137
+ qb = self.norm(qb)
138
+ qb = self.hidden(qb)
139
+ mmb = self.output(qb)
140
+ mmb = self.reshape_output(mmb)
141
+ amb = self.output2(qb)
142
+ amb = self.reshape_output(amb)
143
+
144
+ return mmb, amb
145
+
146
+
147
+ class MaskEstimationModuleSuperBase(nn.Module):
148
+ pass
149
+
150
+
151
+ class MaskEstimationModuleBase(MaskEstimationModuleSuperBase):
152
+ def __init__(
153
+ self,
154
+ band_specs: List[Tuple[float, float]],
155
+ emb_dim: int,
156
+ mlp_dim: int,
157
+ in_channel: Optional[int],
158
+ hidden_activation: str = "Tanh",
159
+ hidden_activation_kwargs: Dict = None,
160
+ complex_mask: bool = True,
161
+ norm_mlp_cls: Type[nn.Module] = NormMLP,
162
+ norm_mlp_kwargs: Dict = None,
163
+ ) -> None:
164
+ super().__init__()
165
+
166
+ self.band_widths = band_widths_from_specs(band_specs)
167
+ self.n_bands = len(band_specs)
168
+
169
+ if hidden_activation_kwargs is None:
170
+ hidden_activation_kwargs = {}
171
+
172
+ if norm_mlp_kwargs is None:
173
+ norm_mlp_kwargs = {}
174
+
175
+ self.norm_mlp = nn.ModuleList(
176
+ [
177
+ (
178
+ norm_mlp_cls(
179
+ bandwidth=self.band_widths[b],
180
+ emb_dim=emb_dim,
181
+ mlp_dim=mlp_dim,
182
+ in_channel=in_channel,
183
+ hidden_activation=hidden_activation,
184
+ hidden_activation_kwargs=hidden_activation_kwargs,
185
+ complex_mask=complex_mask,
186
+ **norm_mlp_kwargs,
187
+ )
188
+ )
189
+ for b in range(self.n_bands)
190
+ ]
191
+ )
192
+
193
+ def compute_masks(self, q):
194
+ batch, n_bands, n_time, emb_dim = q.shape
195
+
196
+ masks = []
197
+
198
+ for b, nmlp in enumerate(self.norm_mlp):
199
+ qb = q[:, b, :, :]
200
+ mb = nmlp(qb)
201
+ masks.append(mb)
202
+
203
+ return masks
204
+
205
+
206
+ class OverlappingMaskEstimationModule(MaskEstimationModuleBase):
207
+ def __init__(
208
+ self,
209
+ in_channel: int,
210
+ band_specs: List[Tuple[float, float]],
211
+ freq_weights: List[torch.Tensor],
212
+ n_freq: int,
213
+ emb_dim: int,
214
+ mlp_dim: int,
215
+ cond_dim: int = 0,
216
+ hidden_activation: str = "Tanh",
217
+ hidden_activation_kwargs: Dict = None,
218
+ complex_mask: bool = True,
219
+ norm_mlp_cls: Type[nn.Module] = NormMLP,
220
+ norm_mlp_kwargs: Dict = None,
221
+ use_freq_weights: bool = True,
222
+ ) -> None:
223
+ check_nonzero_bandwidth(band_specs)
224
+ check_no_gap(band_specs)
225
+
226
+ super().__init__(
227
+ band_specs=band_specs,
228
+ emb_dim=emb_dim + cond_dim,
229
+ mlp_dim=mlp_dim,
230
+ in_channel=in_channel,
231
+ hidden_activation=hidden_activation,
232
+ hidden_activation_kwargs=hidden_activation_kwargs,
233
+ complex_mask=complex_mask,
234
+ norm_mlp_cls=norm_mlp_cls,
235
+ norm_mlp_kwargs=norm_mlp_kwargs,
236
+ )
237
+
238
+ self.n_freq = n_freq
239
+ self.band_specs = band_specs
240
+ self.in_channel = in_channel
241
+
242
+ if freq_weights is not None:
243
+ for i, fw in enumerate(freq_weights):
244
+ self.register_buffer(f"freq_weights/{i}", fw)
245
+
246
+ self.use_freq_weights = use_freq_weights
247
+ else:
248
+ self.use_freq_weights = False
249
+
250
+ self.cond_dim = cond_dim
251
+
252
+ def forward(self, q, cond=None):
253
+
254
+ batch, n_bands, n_time, emb_dim = q.shape
255
+
256
+ if cond is not None:
257
+ print(cond)
258
+ if cond.ndim == 2:
259
+ cond = cond[:, None, None, :].expand(-1, n_bands, n_time, -1)
260
+ elif cond.ndim == 3:
261
+ assert cond.shape[1] == n_time
262
+ else:
263
+ raise ValueError(f"Invalid cond shape: {cond.shape}")
264
+
265
+ q = torch.cat([q, cond], dim=-1)
266
+ elif self.cond_dim > 0:
267
+ cond = torch.ones(
268
+ (batch, n_bands, n_time, self.cond_dim),
269
+ device=q.device,
270
+ dtype=q.dtype,
271
+ )
272
+ q = torch.cat([q, cond], dim=-1)
273
+ else:
274
+ pass
275
+
276
+ mask_list = self.compute_masks(q)
277
+
278
+ masks = torch.zeros(
279
+ (batch, self.in_channel, self.n_freq, n_time),
280
+ device=q.device,
281
+ dtype=mask_list[0].dtype,
282
+ )
283
+
284
+ for im, mask in enumerate(mask_list):
285
+ fstart, fend = self.band_specs[im]
286
+ if self.use_freq_weights:
287
+ fw = self.get_buffer(f"freq_weights/{im}")[:, None]
288
+ mask = mask * fw
289
+ masks[:, :, fstart:fend, :] += mask
290
+
291
+ return masks
292
+
293
+
294
+ class MaskEstimationModule(OverlappingMaskEstimationModule):
295
+ def __init__(
296
+ self,
297
+ band_specs: List[Tuple[float, float]],
298
+ emb_dim: int,
299
+ mlp_dim: int,
300
+ in_channel: Optional[int],
301
+ hidden_activation: str = "Tanh",
302
+ hidden_activation_kwargs: Dict = None,
303
+ complex_mask: bool = True,
304
+ **kwargs,
305
+ ) -> None:
306
+ check_nonzero_bandwidth(band_specs)
307
+ check_no_gap(band_specs)
308
+ check_no_overlap(band_specs)
309
+ super().__init__(
310
+ in_channel=in_channel,
311
+ band_specs=band_specs,
312
+ freq_weights=None,
313
+ n_freq=None,
314
+ emb_dim=emb_dim,
315
+ mlp_dim=mlp_dim,
316
+ hidden_activation=hidden_activation,
317
+ hidden_activation_kwargs=hidden_activation_kwargs,
318
+ complex_mask=complex_mask,
319
+ )
320
+
321
+ def forward(self, q, cond=None):
322
+
323
+ masks = self.compute_masks(q)
324
+
325
+ masks = torch.concat(masks, dim=2)
326
+
327
+ return masks
mvsepless/models/bandit/core/model/bsrnn/tfmodel.py CHANGED
@@ -1,287 +1,287 @@
1
- import warnings
2
-
3
- import torch
4
- from torch import nn
5
- from torch.nn import functional as F
6
- from torch.nn.modules import rnn
7
-
8
- import torch.backends.cuda
9
-
10
-
11
- class TimeFrequencyModellingModule(nn.Module):
12
- def __init__(self) -> None:
13
- super().__init__()
14
-
15
-
16
- class ResidualRNN(nn.Module):
17
- def __init__(
18
- self,
19
- emb_dim: int,
20
- rnn_dim: int,
21
- bidirectional: bool = True,
22
- rnn_type: str = "LSTM",
23
- use_batch_trick: bool = True,
24
- use_layer_norm: bool = True,
25
- ) -> None:
26
- super().__init__()
27
-
28
- self.use_layer_norm = use_layer_norm
29
- if use_layer_norm:
30
- self.norm = nn.LayerNorm(emb_dim)
31
- else:
32
- self.norm = nn.GroupNorm(num_groups=emb_dim, num_channels=emb_dim)
33
-
34
- self.rnn = rnn.__dict__[rnn_type](
35
- input_size=emb_dim,
36
- hidden_size=rnn_dim,
37
- num_layers=1,
38
- batch_first=True,
39
- bidirectional=bidirectional,
40
- )
41
-
42
- self.fc = nn.Linear(
43
- in_features=rnn_dim * (2 if bidirectional else 1), out_features=emb_dim
44
- )
45
-
46
- self.use_batch_trick = use_batch_trick
47
- if not self.use_batch_trick:
48
- warnings.warn("NOT USING BATCH TRICK IS EXTREMELY SLOW!!")
49
-
50
- def forward(self, z):
51
-
52
- z0 = torch.clone(z)
53
-
54
- if self.use_layer_norm:
55
- z = self.norm(z)
56
- else:
57
- z = torch.permute(z, (0, 3, 1, 2))
58
-
59
- z = self.norm(z)
60
-
61
- z = torch.permute(z, (0, 2, 3, 1))
62
-
63
- batch, n_uncrossed, n_across, emb_dim = z.shape
64
-
65
- if self.use_batch_trick:
66
- z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
67
-
68
- z = self.rnn(z.contiguous())[0]
69
-
70
- z = torch.reshape(z, (batch, n_uncrossed, n_across, -1))
71
- else:
72
- zlist = []
73
- for i in range(n_uncrossed):
74
- zi = self.rnn(z[:, i, :, :])[0]
75
- zlist.append(zi)
76
-
77
- z = torch.stack(zlist, dim=1)
78
-
79
- z = self.fc(z)
80
-
81
- z = z + z0
82
-
83
- return z
84
-
85
-
86
- class SeqBandModellingModule(TimeFrequencyModellingModule):
87
- def __init__(
88
- self,
89
- n_modules: int = 12,
90
- emb_dim: int = 128,
91
- rnn_dim: int = 256,
92
- bidirectional: bool = True,
93
- rnn_type: str = "LSTM",
94
- parallel_mode=False,
95
- ) -> None:
96
- super().__init__()
97
- self.seqband = nn.ModuleList([])
98
-
99
- if parallel_mode:
100
- for _ in range(n_modules):
101
- self.seqband.append(
102
- nn.ModuleList(
103
- [
104
- ResidualRNN(
105
- emb_dim=emb_dim,
106
- rnn_dim=rnn_dim,
107
- bidirectional=bidirectional,
108
- rnn_type=rnn_type,
109
- ),
110
- ResidualRNN(
111
- emb_dim=emb_dim,
112
- rnn_dim=rnn_dim,
113
- bidirectional=bidirectional,
114
- rnn_type=rnn_type,
115
- ),
116
- ]
117
- )
118
- )
119
- else:
120
-
121
- for _ in range(2 * n_modules):
122
- self.seqband.append(
123
- ResidualRNN(
124
- emb_dim=emb_dim,
125
- rnn_dim=rnn_dim,
126
- bidirectional=bidirectional,
127
- rnn_type=rnn_type,
128
- )
129
- )
130
-
131
- self.parallel_mode = parallel_mode
132
-
133
- def forward(self, z):
134
-
135
- if self.parallel_mode:
136
- for sbm_pair in self.seqband:
137
- sbm_t, sbm_f = sbm_pair[0], sbm_pair[1]
138
- zt = sbm_t(z)
139
- zf = sbm_f(z.transpose(1, 2))
140
- z = zt + zf.transpose(1, 2)
141
- else:
142
- for sbm in self.seqband:
143
- z = sbm(z)
144
- z = z.transpose(1, 2)
145
-
146
- q = z
147
- return q
148
-
149
-
150
- class ResidualTransformer(nn.Module):
151
- def __init__(
152
- self,
153
- emb_dim: int = 128,
154
- rnn_dim: int = 256,
155
- bidirectional: bool = True,
156
- dropout: float = 0.0,
157
- ) -> None:
158
- super().__init__()
159
-
160
- self.tf = nn.TransformerEncoderLayer(
161
- d_model=emb_dim, nhead=4, dim_feedforward=rnn_dim, batch_first=True
162
- )
163
-
164
- self.is_causal = not bidirectional
165
- self.dropout = dropout
166
-
167
- def forward(self, z):
168
- batch, n_uncrossed, n_across, emb_dim = z.shape
169
- z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
170
- z = self.tf(z, is_causal=self.is_causal)
171
- z = torch.reshape(z, (batch, n_uncrossed, n_across, emb_dim))
172
-
173
- return z
174
-
175
-
176
- class TransformerTimeFreqModule(TimeFrequencyModellingModule):
177
- def __init__(
178
- self,
179
- n_modules: int = 12,
180
- emb_dim: int = 128,
181
- rnn_dim: int = 256,
182
- bidirectional: bool = True,
183
- dropout: float = 0.0,
184
- ) -> None:
185
- super().__init__()
186
- self.norm = nn.LayerNorm(emb_dim)
187
- self.seqband = nn.ModuleList([])
188
-
189
- for _ in range(2 * n_modules):
190
- self.seqband.append(
191
- ResidualTransformer(
192
- emb_dim=emb_dim,
193
- rnn_dim=rnn_dim,
194
- bidirectional=bidirectional,
195
- dropout=dropout,
196
- )
197
- )
198
-
199
- def forward(self, z):
200
- z = self.norm(z)
201
-
202
- for sbm in self.seqband:
203
- z = sbm(z)
204
- z = z.transpose(1, 2)
205
-
206
- q = z
207
- return q
208
-
209
-
210
- class ResidualConvolution(nn.Module):
211
- def __init__(
212
- self,
213
- emb_dim: int = 128,
214
- rnn_dim: int = 256,
215
- bidirectional: bool = True,
216
- dropout: float = 0.0,
217
- ) -> None:
218
- super().__init__()
219
- self.norm = nn.InstanceNorm2d(emb_dim, affine=True)
220
-
221
- self.conv = nn.Sequential(
222
- nn.Conv2d(
223
- in_channels=emb_dim,
224
- out_channels=rnn_dim,
225
- kernel_size=(3, 3),
226
- padding="same",
227
- stride=(1, 1),
228
- ),
229
- nn.Tanhshrink(),
230
- )
231
-
232
- self.is_causal = not bidirectional
233
- self.dropout = dropout
234
-
235
- self.fc = nn.Conv2d(
236
- in_channels=rnn_dim,
237
- out_channels=emb_dim,
238
- kernel_size=(1, 1),
239
- padding="same",
240
- stride=(1, 1),
241
- )
242
-
243
- def forward(self, z):
244
-
245
- z0 = torch.clone(z)
246
-
247
- z = self.norm(z)
248
- z = self.conv(z)
249
- z = self.fc(z)
250
- z = z + z0
251
-
252
- return z
253
-
254
-
255
- class ConvolutionalTimeFreqModule(TimeFrequencyModellingModule):
256
- def __init__(
257
- self,
258
- n_modules: int = 12,
259
- emb_dim: int = 128,
260
- rnn_dim: int = 256,
261
- bidirectional: bool = True,
262
- dropout: float = 0.0,
263
- ) -> None:
264
- super().__init__()
265
- self.seqband = torch.jit.script(
266
- nn.Sequential(
267
- *[
268
- ResidualConvolution(
269
- emb_dim=emb_dim,
270
- rnn_dim=rnn_dim,
271
- bidirectional=bidirectional,
272
- dropout=dropout,
273
- )
274
- for _ in range(2 * n_modules)
275
- ]
276
- )
277
- )
278
-
279
- def forward(self, z):
280
-
281
- z = torch.permute(z, (0, 3, 1, 2))
282
-
283
- z = self.seqband(z)
284
-
285
- z = torch.permute(z, (0, 2, 3, 1))
286
-
287
- return z
 
1
+ import warnings
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from torch.nn.modules import rnn
7
+
8
+ import torch.backends.cuda
9
+
10
+
11
+ class TimeFrequencyModellingModule(nn.Module):
12
+ def __init__(self) -> None:
13
+ super().__init__()
14
+
15
+
16
+ class ResidualRNN(nn.Module):
17
+ def __init__(
18
+ self,
19
+ emb_dim: int,
20
+ rnn_dim: int,
21
+ bidirectional: bool = True,
22
+ rnn_type: str = "LSTM",
23
+ use_batch_trick: bool = True,
24
+ use_layer_norm: bool = True,
25
+ ) -> None:
26
+ super().__init__()
27
+
28
+ self.use_layer_norm = use_layer_norm
29
+ if use_layer_norm:
30
+ self.norm = nn.LayerNorm(emb_dim)
31
+ else:
32
+ self.norm = nn.GroupNorm(num_groups=emb_dim, num_channels=emb_dim)
33
+
34
+ self.rnn = rnn.__dict__[rnn_type](
35
+ input_size=emb_dim,
36
+ hidden_size=rnn_dim,
37
+ num_layers=1,
38
+ batch_first=True,
39
+ bidirectional=bidirectional,
40
+ )
41
+
42
+ self.fc = nn.Linear(
43
+ in_features=rnn_dim * (2 if bidirectional else 1), out_features=emb_dim
44
+ )
45
+
46
+ self.use_batch_trick = use_batch_trick
47
+ if not self.use_batch_trick:
48
+ warnings.warn("NOT USING BATCH TRICK IS EXTREMELY SLOW!!")
49
+
50
+ def forward(self, z):
51
+
52
+ z0 = torch.clone(z)
53
+
54
+ if self.use_layer_norm:
55
+ z = self.norm(z)
56
+ else:
57
+ z = torch.permute(z, (0, 3, 1, 2))
58
+
59
+ z = self.norm(z)
60
+
61
+ z = torch.permute(z, (0, 2, 3, 1))
62
+
63
+ batch, n_uncrossed, n_across, emb_dim = z.shape
64
+
65
+ if self.use_batch_trick:
66
+ z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
67
+
68
+ z = self.rnn(z.contiguous())[0]
69
+
70
+ z = torch.reshape(z, (batch, n_uncrossed, n_across, -1))
71
+ else:
72
+ zlist = []
73
+ for i in range(n_uncrossed):
74
+ zi = self.rnn(z[:, i, :, :])[0]
75
+ zlist.append(zi)
76
+
77
+ z = torch.stack(zlist, dim=1)
78
+
79
+ z = self.fc(z)
80
+
81
+ z = z + z0
82
+
83
+ return z
84
+
85
+
86
+ class SeqBandModellingModule(TimeFrequencyModellingModule):
87
+ def __init__(
88
+ self,
89
+ n_modules: int = 12,
90
+ emb_dim: int = 128,
91
+ rnn_dim: int = 256,
92
+ bidirectional: bool = True,
93
+ rnn_type: str = "LSTM",
94
+ parallel_mode=False,
95
+ ) -> None:
96
+ super().__init__()
97
+ self.seqband = nn.ModuleList([])
98
+
99
+ if parallel_mode:
100
+ for _ in range(n_modules):
101
+ self.seqband.append(
102
+ nn.ModuleList(
103
+ [
104
+ ResidualRNN(
105
+ emb_dim=emb_dim,
106
+ rnn_dim=rnn_dim,
107
+ bidirectional=bidirectional,
108
+ rnn_type=rnn_type,
109
+ ),
110
+ ResidualRNN(
111
+ emb_dim=emb_dim,
112
+ rnn_dim=rnn_dim,
113
+ bidirectional=bidirectional,
114
+ rnn_type=rnn_type,
115
+ ),
116
+ ]
117
+ )
118
+ )
119
+ else:
120
+
121
+ for _ in range(2 * n_modules):
122
+ self.seqband.append(
123
+ ResidualRNN(
124
+ emb_dim=emb_dim,
125
+ rnn_dim=rnn_dim,
126
+ bidirectional=bidirectional,
127
+ rnn_type=rnn_type,
128
+ )
129
+ )
130
+
131
+ self.parallel_mode = parallel_mode
132
+
133
+ def forward(self, z):
134
+
135
+ if self.parallel_mode:
136
+ for sbm_pair in self.seqband:
137
+ sbm_t, sbm_f = sbm_pair[0], sbm_pair[1]
138
+ zt = sbm_t(z)
139
+ zf = sbm_f(z.transpose(1, 2))
140
+ z = zt + zf.transpose(1, 2)
141
+ else:
142
+ for sbm in self.seqband:
143
+ z = sbm(z)
144
+ z = z.transpose(1, 2)
145
+
146
+ q = z
147
+ return q
148
+
149
+
150
+ class ResidualTransformer(nn.Module):
151
+ def __init__(
152
+ self,
153
+ emb_dim: int = 128,
154
+ rnn_dim: int = 256,
155
+ bidirectional: bool = True,
156
+ dropout: float = 0.0,
157
+ ) -> None:
158
+ super().__init__()
159
+
160
+ self.tf = nn.TransformerEncoderLayer(
161
+ d_model=emb_dim, nhead=4, dim_feedforward=rnn_dim, batch_first=True
162
+ )
163
+
164
+ self.is_causal = not bidirectional
165
+ self.dropout = dropout
166
+
167
+ def forward(self, z):
168
+ batch, n_uncrossed, n_across, emb_dim = z.shape
169
+ z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
170
+ z = self.tf(z, is_causal=self.is_causal)
171
+ z = torch.reshape(z, (batch, n_uncrossed, n_across, emb_dim))
172
+
173
+ return z
174
+
175
+
176
+ class TransformerTimeFreqModule(TimeFrequencyModellingModule):
177
+ def __init__(
178
+ self,
179
+ n_modules: int = 12,
180
+ emb_dim: int = 128,
181
+ rnn_dim: int = 256,
182
+ bidirectional: bool = True,
183
+ dropout: float = 0.0,
184
+ ) -> None:
185
+ super().__init__()
186
+ self.norm = nn.LayerNorm(emb_dim)
187
+ self.seqband = nn.ModuleList([])
188
+
189
+ for _ in range(2 * n_modules):
190
+ self.seqband.append(
191
+ ResidualTransformer(
192
+ emb_dim=emb_dim,
193
+ rnn_dim=rnn_dim,
194
+ bidirectional=bidirectional,
195
+ dropout=dropout,
196
+ )
197
+ )
198
+
199
+ def forward(self, z):
200
+ z = self.norm(z)
201
+
202
+ for sbm in self.seqband:
203
+ z = sbm(z)
204
+ z = z.transpose(1, 2)
205
+
206
+ q = z
207
+ return q
208
+
209
+
210
+ class ResidualConvolution(nn.Module):
211
+ def __init__(
212
+ self,
213
+ emb_dim: int = 128,
214
+ rnn_dim: int = 256,
215
+ bidirectional: bool = True,
216
+ dropout: float = 0.0,
217
+ ) -> None:
218
+ super().__init__()
219
+ self.norm = nn.InstanceNorm2d(emb_dim, affine=True)
220
+
221
+ self.conv = nn.Sequential(
222
+ nn.Conv2d(
223
+ in_channels=emb_dim,
224
+ out_channels=rnn_dim,
225
+ kernel_size=(3, 3),
226
+ padding="same",
227
+ stride=(1, 1),
228
+ ),
229
+ nn.Tanhshrink(),
230
+ )
231
+
232
+ self.is_causal = not bidirectional
233
+ self.dropout = dropout
234
+
235
+ self.fc = nn.Conv2d(
236
+ in_channels=rnn_dim,
237
+ out_channels=emb_dim,
238
+ kernel_size=(1, 1),
239
+ padding="same",
240
+ stride=(1, 1),
241
+ )
242
+
243
+ def forward(self, z):
244
+
245
+ z0 = torch.clone(z)
246
+
247
+ z = self.norm(z)
248
+ z = self.conv(z)
249
+ z = self.fc(z)
250
+ z = z + z0
251
+
252
+ return z
253
+
254
+
255
+ class ConvolutionalTimeFreqModule(TimeFrequencyModellingModule):
256
+ def __init__(
257
+ self,
258
+ n_modules: int = 12,
259
+ emb_dim: int = 128,
260
+ rnn_dim: int = 256,
261
+ bidirectional: bool = True,
262
+ dropout: float = 0.0,
263
+ ) -> None:
264
+ super().__init__()
265
+ self.seqband = torch.jit.script(
266
+ nn.Sequential(
267
+ *[
268
+ ResidualConvolution(
269
+ emb_dim=emb_dim,
270
+ rnn_dim=rnn_dim,
271
+ bidirectional=bidirectional,
272
+ dropout=dropout,
273
+ )
274
+ for _ in range(2 * n_modules)
275
+ ]
276
+ )
277
+ )
278
+
279
+ def forward(self, z):
280
+
281
+ z = torch.permute(z, (0, 3, 1, 2))
282
+
283
+ z = self.seqband(z)
284
+
285
+ z = torch.permute(z, (0, 2, 3, 1))
286
+
287
+ return z
mvsepless/models/bandit/core/model/bsrnn/utils.py CHANGED
@@ -1,518 +1,518 @@
1
- import os
2
- from abc import abstractmethod
3
- from typing import Any, Callable
4
-
5
- import numpy as np
6
- import torch
7
- from librosa import hz_to_midi, midi_to_hz
8
- from torch import Tensor
9
- from torchaudio import functional as taF
10
- from spafe.fbanks import bark_fbanks
11
- from spafe.utils.converters import erb2hz, hz2bark, hz2erb
12
- from torchaudio.functional.functional import _create_triangular_filterbank
13
-
14
-
15
- def band_widths_from_specs(band_specs):
16
- return [e - i for i, e in band_specs]
17
-
18
-
19
- def check_nonzero_bandwidth(band_specs):
20
- for fstart, fend in band_specs:
21
- if fend - fstart <= 0:
22
- raise ValueError("Bands cannot be zero-width")
23
-
24
-
25
- def check_no_overlap(band_specs):
26
- fend_prev = -1
27
- for fstart_curr, fend_curr in band_specs:
28
- if fstart_curr <= fend_prev:
29
- raise ValueError("Bands cannot overlap")
30
-
31
-
32
- def check_no_gap(band_specs):
33
- fstart, _ = band_specs[0]
34
- assert fstart == 0
35
-
36
- fend_prev = -1
37
- for fstart_curr, fend_curr in band_specs:
38
- if fstart_curr - fend_prev > 1:
39
- raise ValueError("Bands cannot leave gap")
40
- fend_prev = fend_curr
41
-
42
-
43
- class BandsplitSpecification:
44
- def __init__(self, nfft: int, fs: int) -> None:
45
- self.fs = fs
46
- self.nfft = nfft
47
- self.nyquist = fs / 2
48
- self.max_index = nfft // 2 + 1
49
-
50
- self.split500 = self.hertz_to_index(500)
51
- self.split1k = self.hertz_to_index(1000)
52
- self.split2k = self.hertz_to_index(2000)
53
- self.split4k = self.hertz_to_index(4000)
54
- self.split8k = self.hertz_to_index(8000)
55
- self.split16k = self.hertz_to_index(16000)
56
- self.split20k = self.hertz_to_index(20000)
57
-
58
- self.above20k = [(self.split20k, self.max_index)]
59
- self.above16k = [(self.split16k, self.split20k)] + self.above20k
60
-
61
- def index_to_hertz(self, index: int):
62
- return index * self.fs / self.nfft
63
-
64
- def hertz_to_index(self, hz: float, round: bool = True):
65
- index = hz * self.nfft / self.fs
66
-
67
- if round:
68
- index = int(np.round(index))
69
-
70
- return index
71
-
72
- def get_band_specs_with_bandwidth(self, start_index, end_index, bandwidth_hz):
73
- band_specs = []
74
- lower = start_index
75
-
76
- while lower < end_index:
77
- upper = int(np.floor(lower + self.hertz_to_index(bandwidth_hz)))
78
- upper = min(upper, end_index)
79
-
80
- band_specs.append((lower, upper))
81
- lower = upper
82
-
83
- return band_specs
84
-
85
- @abstractmethod
86
- def get_band_specs(self):
87
- raise NotImplementedError
88
-
89
-
90
- class VocalBandsplitSpecification(BandsplitSpecification):
91
- def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
92
- super().__init__(nfft=nfft, fs=fs)
93
-
94
- self.version = version
95
-
96
- def get_band_specs(self):
97
- return getattr(self, f"version{self.version}")()
98
-
99
- @property
100
- def version1(self):
101
- return self.get_band_specs_with_bandwidth(
102
- start_index=0, end_index=self.max_index, bandwidth_hz=1000
103
- )
104
-
105
- def version2(self):
106
- below16k = self.get_band_specs_with_bandwidth(
107
- start_index=0, end_index=self.split16k, bandwidth_hz=1000
108
- )
109
- below20k = self.get_band_specs_with_bandwidth(
110
- start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
111
- )
112
-
113
- return below16k + below20k + self.above20k
114
-
115
- def version3(self):
116
- below8k = self.get_band_specs_with_bandwidth(
117
- start_index=0, end_index=self.split8k, bandwidth_hz=1000
118
- )
119
- below16k = self.get_band_specs_with_bandwidth(
120
- start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
121
- )
122
-
123
- return below8k + below16k + self.above16k
124
-
125
- def version4(self):
126
- below1k = self.get_band_specs_with_bandwidth(
127
- start_index=0, end_index=self.split1k, bandwidth_hz=100
128
- )
129
- below8k = self.get_band_specs_with_bandwidth(
130
- start_index=self.split1k, end_index=self.split8k, bandwidth_hz=1000
131
- )
132
- below16k = self.get_band_specs_with_bandwidth(
133
- start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
134
- )
135
-
136
- return below1k + below8k + below16k + self.above16k
137
-
138
- def version5(self):
139
- below1k = self.get_band_specs_with_bandwidth(
140
- start_index=0, end_index=self.split1k, bandwidth_hz=100
141
- )
142
- below16k = self.get_band_specs_with_bandwidth(
143
- start_index=self.split1k, end_index=self.split16k, bandwidth_hz=1000
144
- )
145
- below20k = self.get_band_specs_with_bandwidth(
146
- start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
147
- )
148
- return below1k + below16k + below20k + self.above20k
149
-
150
- def version6(self):
151
- below1k = self.get_band_specs_with_bandwidth(
152
- start_index=0, end_index=self.split1k, bandwidth_hz=100
153
- )
154
- below4k = self.get_band_specs_with_bandwidth(
155
- start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
156
- )
157
- below8k = self.get_band_specs_with_bandwidth(
158
- start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
159
- )
160
- below16k = self.get_band_specs_with_bandwidth(
161
- start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
162
- )
163
- return below1k + below4k + below8k + below16k + self.above16k
164
-
165
- def version7(self):
166
- below1k = self.get_band_specs_with_bandwidth(
167
- start_index=0, end_index=self.split1k, bandwidth_hz=100
168
- )
169
- below4k = self.get_band_specs_with_bandwidth(
170
- start_index=self.split1k, end_index=self.split4k, bandwidth_hz=250
171
- )
172
- below8k = self.get_band_specs_with_bandwidth(
173
- start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
174
- )
175
- below16k = self.get_band_specs_with_bandwidth(
176
- start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
177
- )
178
- below20k = self.get_band_specs_with_bandwidth(
179
- start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
180
- )
181
- return below1k + below4k + below8k + below16k + below20k + self.above20k
182
-
183
-
184
- class OtherBandsplitSpecification(VocalBandsplitSpecification):
185
- def __init__(self, nfft: int, fs: int) -> None:
186
- super().__init__(nfft=nfft, fs=fs, version="7")
187
-
188
-
189
- class BassBandsplitSpecification(BandsplitSpecification):
190
- def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
191
- super().__init__(nfft=nfft, fs=fs)
192
-
193
- def get_band_specs(self):
194
- below500 = self.get_band_specs_with_bandwidth(
195
- start_index=0, end_index=self.split500, bandwidth_hz=50
196
- )
197
- below1k = self.get_band_specs_with_bandwidth(
198
- start_index=self.split500, end_index=self.split1k, bandwidth_hz=100
199
- )
200
- below4k = self.get_band_specs_with_bandwidth(
201
- start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
202
- )
203
- below8k = self.get_band_specs_with_bandwidth(
204
- start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
205
- )
206
- below16k = self.get_band_specs_with_bandwidth(
207
- start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
208
- )
209
- above16k = [(self.split16k, self.max_index)]
210
-
211
- return below500 + below1k + below4k + below8k + below16k + above16k
212
-
213
-
214
- class DrumBandsplitSpecification(BandsplitSpecification):
215
- def __init__(self, nfft: int, fs: int) -> None:
216
- super().__init__(nfft=nfft, fs=fs)
217
-
218
- def get_band_specs(self):
219
- below1k = self.get_band_specs_with_bandwidth(
220
- start_index=0, end_index=self.split1k, bandwidth_hz=50
221
- )
222
- below2k = self.get_band_specs_with_bandwidth(
223
- start_index=self.split1k, end_index=self.split2k, bandwidth_hz=100
224
- )
225
- below4k = self.get_band_specs_with_bandwidth(
226
- start_index=self.split2k, end_index=self.split4k, bandwidth_hz=250
227
- )
228
- below8k = self.get_band_specs_with_bandwidth(
229
- start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
230
- )
231
- below16k = self.get_band_specs_with_bandwidth(
232
- start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
233
- )
234
- above16k = [(self.split16k, self.max_index)]
235
-
236
- return below1k + below2k + below4k + below8k + below16k + above16k
237
-
238
-
239
- class PerceptualBandsplitSpecification(BandsplitSpecification):
240
- def __init__(
241
- self,
242
- nfft: int,
243
- fs: int,
244
- fbank_fn: Callable[[int, int, float, float, int], torch.Tensor],
245
- n_bands: int,
246
- f_min: float = 0.0,
247
- f_max: float = None,
248
- ) -> None:
249
- super().__init__(nfft=nfft, fs=fs)
250
- self.n_bands = n_bands
251
- if f_max is None:
252
- f_max = fs / 2
253
-
254
- self.filterbank = fbank_fn(n_bands, fs, f_min, f_max, self.max_index)
255
-
256
- weight_per_bin = torch.sum(self.filterbank, dim=0, keepdim=True)
257
- normalized_mel_fb = self.filterbank / weight_per_bin
258
-
259
- freq_weights = []
260
- band_specs = []
261
- for i in range(self.n_bands):
262
- active_bins = torch.nonzero(self.filterbank[i, :]).squeeze().tolist()
263
- if isinstance(active_bins, int):
264
- active_bins = (active_bins, active_bins)
265
- if len(active_bins) == 0:
266
- continue
267
- start_index = active_bins[0]
268
- end_index = active_bins[-1] + 1
269
- band_specs.append((start_index, end_index))
270
- freq_weights.append(normalized_mel_fb[i, start_index:end_index])
271
-
272
- self.freq_weights = freq_weights
273
- self.band_specs = band_specs
274
-
275
- def get_band_specs(self):
276
- return self.band_specs
277
-
278
- def get_freq_weights(self):
279
- return self.freq_weights
280
-
281
- def save_to_file(self, dir_path: str) -> None:
282
-
283
- os.makedirs(dir_path, exist_ok=True)
284
-
285
- import pickle
286
-
287
- with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f:
288
- pickle.dump(
289
- {
290
- "band_specs": self.band_specs,
291
- "freq_weights": self.freq_weights,
292
- "filterbank": self.filterbank,
293
- },
294
- f,
295
- )
296
-
297
-
298
- def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs):
299
- fb = taF.melscale_fbanks(
300
- n_mels=n_bands,
301
- sample_rate=fs,
302
- f_min=f_min,
303
- f_max=f_max,
304
- n_freqs=n_freqs,
305
- ).T
306
-
307
- fb[0, 0] = 1.0
308
-
309
- return fb
310
-
311
-
312
- class MelBandsplitSpecification(PerceptualBandsplitSpecification):
313
- def __init__(
314
- self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
315
- ) -> None:
316
- super().__init__(
317
- fbank_fn=mel_filterbank,
318
- nfft=nfft,
319
- fs=fs,
320
- n_bands=n_bands,
321
- f_min=f_min,
322
- f_max=f_max,
323
- )
324
-
325
-
326
- def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs, scale="constant"):
327
-
328
- nfft = 2 * (n_freqs - 1)
329
- df = fs / nfft
330
- f_max = f_max or fs / 2
331
- f_min = f_min or 0
332
- f_min = fs / nfft
333
-
334
- n_octaves = np.log2(f_max / f_min)
335
- n_octaves_per_band = n_octaves / n_bands
336
- bandwidth_mult = np.power(2.0, n_octaves_per_band)
337
-
338
- low_midi = max(0, hz_to_midi(f_min))
339
- high_midi = hz_to_midi(f_max)
340
- midi_points = np.linspace(low_midi, high_midi, n_bands)
341
- hz_pts = midi_to_hz(midi_points)
342
-
343
- low_pts = hz_pts / bandwidth_mult
344
- high_pts = hz_pts * bandwidth_mult
345
-
346
- low_bins = np.floor(low_pts / df).astype(int)
347
- high_bins = np.ceil(high_pts / df).astype(int)
348
-
349
- fb = np.zeros((n_bands, n_freqs))
350
-
351
- for i in range(n_bands):
352
- fb[i, low_bins[i] : high_bins[i] + 1] = 1.0
353
-
354
- fb[0, : low_bins[0]] = 1.0
355
- fb[-1, high_bins[-1] + 1 :] = 1.0
356
-
357
- return torch.as_tensor(fb)
358
-
359
-
360
- class MusicalBandsplitSpecification(PerceptualBandsplitSpecification):
361
- def __init__(
362
- self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
363
- ) -> None:
364
- super().__init__(
365
- fbank_fn=musical_filterbank,
366
- nfft=nfft,
367
- fs=fs,
368
- n_bands=n_bands,
369
- f_min=f_min,
370
- f_max=f_max,
371
- )
372
-
373
-
374
- def bark_filterbank(n_bands, fs, f_min, f_max, n_freqs):
375
- nfft = 2 * (n_freqs - 1)
376
- fb, _ = bark_fbanks.bark_filter_banks(
377
- nfilts=n_bands,
378
- nfft=nfft,
379
- fs=fs,
380
- low_freq=f_min,
381
- high_freq=f_max,
382
- scale="constant",
383
- )
384
-
385
- return torch.as_tensor(fb)
386
-
387
-
388
- class BarkBandsplitSpecification(PerceptualBandsplitSpecification):
389
- def __init__(
390
- self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
391
- ) -> None:
392
- super().__init__(
393
- fbank_fn=bark_filterbank,
394
- nfft=nfft,
395
- fs=fs,
396
- n_bands=n_bands,
397
- f_min=f_min,
398
- f_max=f_max,
399
- )
400
-
401
-
402
- def triangular_bark_filterbank(n_bands, fs, f_min, f_max, n_freqs):
403
-
404
- all_freqs = torch.linspace(0, fs // 2, n_freqs)
405
-
406
- m_min = hz2bark(f_min)
407
- m_max = hz2bark(f_max)
408
-
409
- m_pts = torch.linspace(m_min, m_max, n_bands + 2)
410
- f_pts = 600 * torch.sinh(m_pts / 6)
411
-
412
- fb = _create_triangular_filterbank(all_freqs, f_pts)
413
-
414
- fb = fb.T
415
-
416
- first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
417
- first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
418
-
419
- fb[first_active_band, :first_active_bin] = 1.0
420
-
421
- return fb
422
-
423
-
424
- class TriangularBarkBandsplitSpecification(PerceptualBandsplitSpecification):
425
- def __init__(
426
- self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
427
- ) -> None:
428
- super().__init__(
429
- fbank_fn=triangular_bark_filterbank,
430
- nfft=nfft,
431
- fs=fs,
432
- n_bands=n_bands,
433
- f_min=f_min,
434
- f_max=f_max,
435
- )
436
-
437
-
438
- def minibark_filterbank(n_bands, fs, f_min, f_max, n_freqs):
439
- fb = bark_filterbank(n_bands, fs, f_min, f_max, n_freqs)
440
-
441
- fb[fb < np.sqrt(0.5)] = 0.0
442
-
443
- return fb
444
-
445
-
446
- class MiniBarkBandsplitSpecification(PerceptualBandsplitSpecification):
447
- def __init__(
448
- self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
449
- ) -> None:
450
- super().__init__(
451
- fbank_fn=minibark_filterbank,
452
- nfft=nfft,
453
- fs=fs,
454
- n_bands=n_bands,
455
- f_min=f_min,
456
- f_max=f_max,
457
- )
458
-
459
-
460
- def erb_filterbank(
461
- n_bands: int,
462
- fs: int,
463
- f_min: float,
464
- f_max: float,
465
- n_freqs: int,
466
- ) -> Tensor:
467
- A = (1000 * np.log(10)) / (24.7 * 4.37)
468
- all_freqs = torch.linspace(0, fs // 2, n_freqs)
469
-
470
- m_min = hz2erb(f_min)
471
- m_max = hz2erb(f_max)
472
-
473
- m_pts = torch.linspace(m_min, m_max, n_bands + 2)
474
- f_pts = (torch.pow(10, (m_pts / A)) - 1) / 0.00437
475
-
476
- fb = _create_triangular_filterbank(all_freqs, f_pts)
477
-
478
- fb = fb.T
479
-
480
- first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
481
- first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
482
-
483
- fb[first_active_band, :first_active_bin] = 1.0
484
-
485
- return fb
486
-
487
-
488
- class EquivalentRectangularBandsplitSpecification(PerceptualBandsplitSpecification):
489
- def __init__(
490
- self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
491
- ) -> None:
492
- super().__init__(
493
- fbank_fn=erb_filterbank,
494
- nfft=nfft,
495
- fs=fs,
496
- n_bands=n_bands,
497
- f_min=f_min,
498
- f_max=f_max,
499
- )
500
-
501
-
502
- if __name__ == "__main__":
503
- import pandas as pd
504
-
505
- band_defs = []
506
-
507
- for bands in [VocalBandsplitSpecification]:
508
- band_name = bands.__name__.replace("BandsplitSpecification", "")
509
-
510
- mbs = bands(nfft=2048, fs=44100).get_band_specs()
511
-
512
- for i, (f_min, f_max) in enumerate(mbs):
513
- band_defs.append(
514
- {"band": band_name, "band_index": i, "f_min": f_min, "f_max": f_max}
515
- )
516
-
517
- df = pd.DataFrame(band_defs)
518
- df.to_csv("vox7bands.csv", index=False)
 
1
+ import os
2
+ from abc import abstractmethod
3
+ from typing import Any, Callable
4
+
5
+ import numpy as np
6
+ import torch
7
+ from librosa import hz_to_midi, midi_to_hz
8
+ from torch import Tensor
9
+ from torchaudio import functional as taF
10
+ from spafe.fbanks import bark_fbanks
11
+ from spafe.utils.converters import erb2hz, hz2bark, hz2erb
12
+ from torchaudio.functional.functional import _create_triangular_filterbank
13
+
14
+
15
+ def band_widths_from_specs(band_specs):
16
+ return [e - i for i, e in band_specs]
17
+
18
+
19
+ def check_nonzero_bandwidth(band_specs):
20
+ for fstart, fend in band_specs:
21
+ if fend - fstart <= 0:
22
+ raise ValueError("Bands cannot be zero-width")
23
+
24
+
25
+ def check_no_overlap(band_specs):
26
+ fend_prev = -1
27
+ for fstart_curr, fend_curr in band_specs:
28
+ if fstart_curr <= fend_prev:
29
+ raise ValueError("Bands cannot overlap")
30
+
31
+
32
+ def check_no_gap(band_specs):
33
+ fstart, _ = band_specs[0]
34
+ assert fstart == 0
35
+
36
+ fend_prev = -1
37
+ for fstart_curr, fend_curr in band_specs:
38
+ if fstart_curr - fend_prev > 1:
39
+ raise ValueError("Bands cannot leave gap")
40
+ fend_prev = fend_curr
41
+
42
+
43
+ class BandsplitSpecification:
44
+ def __init__(self, nfft: int, fs: int) -> None:
45
+ self.fs = fs
46
+ self.nfft = nfft
47
+ self.nyquist = fs / 2
48
+ self.max_index = nfft // 2 + 1
49
+
50
+ self.split500 = self.hertz_to_index(500)
51
+ self.split1k = self.hertz_to_index(1000)
52
+ self.split2k = self.hertz_to_index(2000)
53
+ self.split4k = self.hertz_to_index(4000)
54
+ self.split8k = self.hertz_to_index(8000)
55
+ self.split16k = self.hertz_to_index(16000)
56
+ self.split20k = self.hertz_to_index(20000)
57
+
58
+ self.above20k = [(self.split20k, self.max_index)]
59
+ self.above16k = [(self.split16k, self.split20k)] + self.above20k
60
+
61
+ def index_to_hertz(self, index: int):
62
+ return index * self.fs / self.nfft
63
+
64
+ def hertz_to_index(self, hz: float, round: bool = True):
65
+ index = hz * self.nfft / self.fs
66
+
67
+ if round:
68
+ index = int(np.round(index))
69
+
70
+ return index
71
+
72
+ def get_band_specs_with_bandwidth(self, start_index, end_index, bandwidth_hz):
73
+ band_specs = []
74
+ lower = start_index
75
+
76
+ while lower < end_index:
77
+ upper = int(np.floor(lower + self.hertz_to_index(bandwidth_hz)))
78
+ upper = min(upper, end_index)
79
+
80
+ band_specs.append((lower, upper))
81
+ lower = upper
82
+
83
+ return band_specs
84
+
85
+ @abstractmethod
86
+ def get_band_specs(self):
87
+ raise NotImplementedError
88
+
89
+
90
+ class VocalBandsplitSpecification(BandsplitSpecification):
91
+ def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
92
+ super().__init__(nfft=nfft, fs=fs)
93
+
94
+ self.version = version
95
+
96
+ def get_band_specs(self):
97
+ return getattr(self, f"version{self.version}")()
98
+
99
+ @property
100
+ def version1(self):
101
+ return self.get_band_specs_with_bandwidth(
102
+ start_index=0, end_index=self.max_index, bandwidth_hz=1000
103
+ )
104
+
105
+ def version2(self):
106
+ below16k = self.get_band_specs_with_bandwidth(
107
+ start_index=0, end_index=self.split16k, bandwidth_hz=1000
108
+ )
109
+ below20k = self.get_band_specs_with_bandwidth(
110
+ start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
111
+ )
112
+
113
+ return below16k + below20k + self.above20k
114
+
115
+ def version3(self):
116
+ below8k = self.get_band_specs_with_bandwidth(
117
+ start_index=0, end_index=self.split8k, bandwidth_hz=1000
118
+ )
119
+ below16k = self.get_band_specs_with_bandwidth(
120
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
121
+ )
122
+
123
+ return below8k + below16k + self.above16k
124
+
125
+ def version4(self):
126
+ below1k = self.get_band_specs_with_bandwidth(
127
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
128
+ )
129
+ below8k = self.get_band_specs_with_bandwidth(
130
+ start_index=self.split1k, end_index=self.split8k, bandwidth_hz=1000
131
+ )
132
+ below16k = self.get_band_specs_with_bandwidth(
133
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
134
+ )
135
+
136
+ return below1k + below8k + below16k + self.above16k
137
+
138
+ def version5(self):
139
+ below1k = self.get_band_specs_with_bandwidth(
140
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
141
+ )
142
+ below16k = self.get_band_specs_with_bandwidth(
143
+ start_index=self.split1k, end_index=self.split16k, bandwidth_hz=1000
144
+ )
145
+ below20k = self.get_band_specs_with_bandwidth(
146
+ start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
147
+ )
148
+ return below1k + below16k + below20k + self.above20k
149
+
150
+ def version6(self):
151
+ below1k = self.get_band_specs_with_bandwidth(
152
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
153
+ )
154
+ below4k = self.get_band_specs_with_bandwidth(
155
+ start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
156
+ )
157
+ below8k = self.get_band_specs_with_bandwidth(
158
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
159
+ )
160
+ below16k = self.get_band_specs_with_bandwidth(
161
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
162
+ )
163
+ return below1k + below4k + below8k + below16k + self.above16k
164
+
165
+ def version7(self):
166
+ below1k = self.get_band_specs_with_bandwidth(
167
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
168
+ )
169
+ below4k = self.get_band_specs_with_bandwidth(
170
+ start_index=self.split1k, end_index=self.split4k, bandwidth_hz=250
171
+ )
172
+ below8k = self.get_band_specs_with_bandwidth(
173
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
174
+ )
175
+ below16k = self.get_band_specs_with_bandwidth(
176
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
177
+ )
178
+ below20k = self.get_band_specs_with_bandwidth(
179
+ start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
180
+ )
181
+ return below1k + below4k + below8k + below16k + below20k + self.above20k
182
+
183
+
184
+ class OtherBandsplitSpecification(VocalBandsplitSpecification):
185
+ def __init__(self, nfft: int, fs: int) -> None:
186
+ super().__init__(nfft=nfft, fs=fs, version="7")
187
+
188
+
189
+ class BassBandsplitSpecification(BandsplitSpecification):
190
+ def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
191
+ super().__init__(nfft=nfft, fs=fs)
192
+
193
+ def get_band_specs(self):
194
+ below500 = self.get_band_specs_with_bandwidth(
195
+ start_index=0, end_index=self.split500, bandwidth_hz=50
196
+ )
197
+ below1k = self.get_band_specs_with_bandwidth(
198
+ start_index=self.split500, end_index=self.split1k, bandwidth_hz=100
199
+ )
200
+ below4k = self.get_band_specs_with_bandwidth(
201
+ start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
202
+ )
203
+ below8k = self.get_band_specs_with_bandwidth(
204
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
205
+ )
206
+ below16k = self.get_band_specs_with_bandwidth(
207
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
208
+ )
209
+ above16k = [(self.split16k, self.max_index)]
210
+
211
+ return below500 + below1k + below4k + below8k + below16k + above16k
212
+
213
+
214
+ class DrumBandsplitSpecification(BandsplitSpecification):
215
+ def __init__(self, nfft: int, fs: int) -> None:
216
+ super().__init__(nfft=nfft, fs=fs)
217
+
218
+ def get_band_specs(self):
219
+ below1k = self.get_band_specs_with_bandwidth(
220
+ start_index=0, end_index=self.split1k, bandwidth_hz=50
221
+ )
222
+ below2k = self.get_band_specs_with_bandwidth(
223
+ start_index=self.split1k, end_index=self.split2k, bandwidth_hz=100
224
+ )
225
+ below4k = self.get_band_specs_with_bandwidth(
226
+ start_index=self.split2k, end_index=self.split4k, bandwidth_hz=250
227
+ )
228
+ below8k = self.get_band_specs_with_bandwidth(
229
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
230
+ )
231
+ below16k = self.get_band_specs_with_bandwidth(
232
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
233
+ )
234
+ above16k = [(self.split16k, self.max_index)]
235
+
236
+ return below1k + below2k + below4k + below8k + below16k + above16k
237
+
238
+
239
+ class PerceptualBandsplitSpecification(BandsplitSpecification):
240
+ def __init__(
241
+ self,
242
+ nfft: int,
243
+ fs: int,
244
+ fbank_fn: Callable[[int, int, float, float, int], torch.Tensor],
245
+ n_bands: int,
246
+ f_min: float = 0.0,
247
+ f_max: float = None,
248
+ ) -> None:
249
+ super().__init__(nfft=nfft, fs=fs)
250
+ self.n_bands = n_bands
251
+ if f_max is None:
252
+ f_max = fs / 2
253
+
254
+ self.filterbank = fbank_fn(n_bands, fs, f_min, f_max, self.max_index)
255
+
256
+ weight_per_bin = torch.sum(self.filterbank, dim=0, keepdim=True)
257
+ normalized_mel_fb = self.filterbank / weight_per_bin
258
+
259
+ freq_weights = []
260
+ band_specs = []
261
+ for i in range(self.n_bands):
262
+ active_bins = torch.nonzero(self.filterbank[i, :]).squeeze().tolist()
263
+ if isinstance(active_bins, int):
264
+ active_bins = (active_bins, active_bins)
265
+ if len(active_bins) == 0:
266
+ continue
267
+ start_index = active_bins[0]
268
+ end_index = active_bins[-1] + 1
269
+ band_specs.append((start_index, end_index))
270
+ freq_weights.append(normalized_mel_fb[i, start_index:end_index])
271
+
272
+ self.freq_weights = freq_weights
273
+ self.band_specs = band_specs
274
+
275
+ def get_band_specs(self):
276
+ return self.band_specs
277
+
278
+ def get_freq_weights(self):
279
+ return self.freq_weights
280
+
281
+ def save_to_file(self, dir_path: str) -> None:
282
+
283
+ os.makedirs(dir_path, exist_ok=True)
284
+
285
+ import pickle
286
+
287
+ with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f:
288
+ pickle.dump(
289
+ {
290
+ "band_specs": self.band_specs,
291
+ "freq_weights": self.freq_weights,
292
+ "filterbank": self.filterbank,
293
+ },
294
+ f,
295
+ )
296
+
297
+
298
+ def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs):
299
+ fb = taF.melscale_fbanks(
300
+ n_mels=n_bands,
301
+ sample_rate=fs,
302
+ f_min=f_min,
303
+ f_max=f_max,
304
+ n_freqs=n_freqs,
305
+ ).T
306
+
307
+ fb[0, 0] = 1.0
308
+
309
+ return fb
310
+
311
+
312
+ class MelBandsplitSpecification(PerceptualBandsplitSpecification):
313
+ def __init__(
314
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
315
+ ) -> None:
316
+ super().__init__(
317
+ fbank_fn=mel_filterbank,
318
+ nfft=nfft,
319
+ fs=fs,
320
+ n_bands=n_bands,
321
+ f_min=f_min,
322
+ f_max=f_max,
323
+ )
324
+
325
+
326
+ def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs, scale="constant"):
327
+
328
+ nfft = 2 * (n_freqs - 1)
329
+ df = fs / nfft
330
+ f_max = f_max or fs / 2
331
+ f_min = f_min or 0
332
+ f_min = fs / nfft
333
+
334
+ n_octaves = np.log2(f_max / f_min)
335
+ n_octaves_per_band = n_octaves / n_bands
336
+ bandwidth_mult = np.power(2.0, n_octaves_per_band)
337
+
338
+ low_midi = max(0, hz_to_midi(f_min))
339
+ high_midi = hz_to_midi(f_max)
340
+ midi_points = np.linspace(low_midi, high_midi, n_bands)
341
+ hz_pts = midi_to_hz(midi_points)
342
+
343
+ low_pts = hz_pts / bandwidth_mult
344
+ high_pts = hz_pts * bandwidth_mult
345
+
346
+ low_bins = np.floor(low_pts / df).astype(int)
347
+ high_bins = np.ceil(high_pts / df).astype(int)
348
+
349
+ fb = np.zeros((n_bands, n_freqs))
350
+
351
+ for i in range(n_bands):
352
+ fb[i, low_bins[i] : high_bins[i] + 1] = 1.0
353
+
354
+ fb[0, : low_bins[0]] = 1.0
355
+ fb[-1, high_bins[-1] + 1 :] = 1.0
356
+
357
+ return torch.as_tensor(fb)
358
+
359
+
360
+ class MusicalBandsplitSpecification(PerceptualBandsplitSpecification):
361
+ def __init__(
362
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
363
+ ) -> None:
364
+ super().__init__(
365
+ fbank_fn=musical_filterbank,
366
+ nfft=nfft,
367
+ fs=fs,
368
+ n_bands=n_bands,
369
+ f_min=f_min,
370
+ f_max=f_max,
371
+ )
372
+
373
+
374
+ def bark_filterbank(n_bands, fs, f_min, f_max, n_freqs):
375
+ nfft = 2 * (n_freqs - 1)
376
+ fb, _ = bark_fbanks.bark_filter_banks(
377
+ nfilts=n_bands,
378
+ nfft=nfft,
379
+ fs=fs,
380
+ low_freq=f_min,
381
+ high_freq=f_max,
382
+ scale="constant",
383
+ )
384
+
385
+ return torch.as_tensor(fb)
386
+
387
+
388
+ class BarkBandsplitSpecification(PerceptualBandsplitSpecification):
389
+ def __init__(
390
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
391
+ ) -> None:
392
+ super().__init__(
393
+ fbank_fn=bark_filterbank,
394
+ nfft=nfft,
395
+ fs=fs,
396
+ n_bands=n_bands,
397
+ f_min=f_min,
398
+ f_max=f_max,
399
+ )
400
+
401
+
402
+ def triangular_bark_filterbank(n_bands, fs, f_min, f_max, n_freqs):
403
+
404
+ all_freqs = torch.linspace(0, fs // 2, n_freqs)
405
+
406
+ m_min = hz2bark(f_min)
407
+ m_max = hz2bark(f_max)
408
+
409
+ m_pts = torch.linspace(m_min, m_max, n_bands + 2)
410
+ f_pts = 600 * torch.sinh(m_pts / 6)
411
+
412
+ fb = _create_triangular_filterbank(all_freqs, f_pts)
413
+
414
+ fb = fb.T
415
+
416
+ first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
417
+ first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
418
+
419
+ fb[first_active_band, :first_active_bin] = 1.0
420
+
421
+ return fb
422
+
423
+
424
+ class TriangularBarkBandsplitSpecification(PerceptualBandsplitSpecification):
425
+ def __init__(
426
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
427
+ ) -> None:
428
+ super().__init__(
429
+ fbank_fn=triangular_bark_filterbank,
430
+ nfft=nfft,
431
+ fs=fs,
432
+ n_bands=n_bands,
433
+ f_min=f_min,
434
+ f_max=f_max,
435
+ )
436
+
437
+
438
+ def minibark_filterbank(n_bands, fs, f_min, f_max, n_freqs):
439
+ fb = bark_filterbank(n_bands, fs, f_min, f_max, n_freqs)
440
+
441
+ fb[fb < np.sqrt(0.5)] = 0.0
442
+
443
+ return fb
444
+
445
+
446
+ class MiniBarkBandsplitSpecification(PerceptualBandsplitSpecification):
447
+ def __init__(
448
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
449
+ ) -> None:
450
+ super().__init__(
451
+ fbank_fn=minibark_filterbank,
452
+ nfft=nfft,
453
+ fs=fs,
454
+ n_bands=n_bands,
455
+ f_min=f_min,
456
+ f_max=f_max,
457
+ )
458
+
459
+
460
+ def erb_filterbank(
461
+ n_bands: int,
462
+ fs: int,
463
+ f_min: float,
464
+ f_max: float,
465
+ n_freqs: int,
466
+ ) -> Tensor:
467
+ A = (1000 * np.log(10)) / (24.7 * 4.37)
468
+ all_freqs = torch.linspace(0, fs // 2, n_freqs)
469
+
470
+ m_min = hz2erb(f_min)
471
+ m_max = hz2erb(f_max)
472
+
473
+ m_pts = torch.linspace(m_min, m_max, n_bands + 2)
474
+ f_pts = (torch.pow(10, (m_pts / A)) - 1) / 0.00437
475
+
476
+ fb = _create_triangular_filterbank(all_freqs, f_pts)
477
+
478
+ fb = fb.T
479
+
480
+ first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
481
+ first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
482
+
483
+ fb[first_active_band, :first_active_bin] = 1.0
484
+
485
+ return fb
486
+
487
+
488
+ class EquivalentRectangularBandsplitSpecification(PerceptualBandsplitSpecification):
489
+ def __init__(
490
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
491
+ ) -> None:
492
+ super().__init__(
493
+ fbank_fn=erb_filterbank,
494
+ nfft=nfft,
495
+ fs=fs,
496
+ n_bands=n_bands,
497
+ f_min=f_min,
498
+ f_max=f_max,
499
+ )
500
+
501
+
502
+ if __name__ == "__main__":
503
+ import pandas as pd
504
+
505
+ band_defs = []
506
+
507
+ for bands in [VocalBandsplitSpecification]:
508
+ band_name = bands.__name__.replace("BandsplitSpecification", "")
509
+
510
+ mbs = bands(nfft=2048, fs=44100).get_band_specs()
511
+
512
+ for i, (f_min, f_max) in enumerate(mbs):
513
+ band_defs.append(
514
+ {"band": band_name, "band_index": i, "f_min": f_min, "f_max": f_max}
515
+ )
516
+
517
+ df = pd.DataFrame(band_defs)
518
+ df.to_csv("vox7bands.csv", index=False)
mvsepless/models/bandit/core/model/bsrnn/wrapper.py CHANGED
@@ -1,828 +1,828 @@
1
- from pprint import pprint
2
- from typing import Dict, List, Optional, Tuple, Union
3
-
4
- import torch
5
- from torch import nn
6
-
7
- from .._spectral import _SpectralComponent
8
- from .utils import (
9
- BarkBandsplitSpecification,
10
- BassBandsplitSpecification,
11
- DrumBandsplitSpecification,
12
- EquivalentRectangularBandsplitSpecification,
13
- MelBandsplitSpecification,
14
- MusicalBandsplitSpecification,
15
- OtherBandsplitSpecification,
16
- TriangularBarkBandsplitSpecification,
17
- VocalBandsplitSpecification,
18
- )
19
- from .core import (
20
- MultiSourceMultiMaskBandSplitCoreConv,
21
- MultiSourceMultiMaskBandSplitCoreRNN,
22
- MultiSourceMultiMaskBandSplitCoreTransformer,
23
- MultiSourceMultiPatchingMaskBandSplitCoreRNN,
24
- SingleMaskBandsplitCoreRNN,
25
- SingleMaskBandsplitCoreTransformer,
26
- )
27
-
28
- import pytorch_lightning as pl
29
-
30
-
31
- def get_band_specs(band_specs, n_fft, fs, n_bands=None):
32
- if band_specs in ["dnr:speech", "dnr:vox7", "musdb:vocals", "musdb:vox7"]:
33
- bsm = VocalBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs()
34
- freq_weights = None
35
- overlapping_band = False
36
- elif "tribark" in band_specs:
37
- assert n_bands is not None
38
- specs = TriangularBarkBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
39
- bsm = specs.get_band_specs()
40
- freq_weights = specs.get_freq_weights()
41
- overlapping_band = True
42
- elif "bark" in band_specs:
43
- assert n_bands is not None
44
- specs = BarkBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
45
- bsm = specs.get_band_specs()
46
- freq_weights = specs.get_freq_weights()
47
- overlapping_band = True
48
- elif "erb" in band_specs:
49
- assert n_bands is not None
50
- specs = EquivalentRectangularBandsplitSpecification(
51
- nfft=n_fft, fs=fs, n_bands=n_bands
52
- )
53
- bsm = specs.get_band_specs()
54
- freq_weights = specs.get_freq_weights()
55
- overlapping_band = True
56
- elif "musical" in band_specs:
57
- assert n_bands is not None
58
- specs = MusicalBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
59
- bsm = specs.get_band_specs()
60
- freq_weights = specs.get_freq_weights()
61
- overlapping_band = True
62
- elif band_specs == "dnr:mel" or "mel" in band_specs:
63
- assert n_bands is not None
64
- specs = MelBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
65
- bsm = specs.get_band_specs()
66
- freq_weights = specs.get_freq_weights()
67
- overlapping_band = True
68
- else:
69
- raise NameError
70
-
71
- return bsm, freq_weights, overlapping_band
72
-
73
-
74
- def get_band_specs_map(band_specs_map, n_fft, fs, n_bands=None):
75
- if band_specs_map == "musdb:all":
76
- bsm = {
77
- "vocals": VocalBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
78
- "drums": DrumBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
79
- "bass": BassBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
80
- "other": OtherBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
81
- }
82
- freq_weights = None
83
- overlapping_band = False
84
- elif band_specs_map == "dnr:vox7":
85
- bsm_, freq_weights, overlapping_band = get_band_specs(
86
- "dnr:speech", n_fft, fs, n_bands
87
- )
88
- bsm = {"speech": bsm_, "music": bsm_, "effects": bsm_}
89
- elif "dnr:vox7:" in band_specs_map:
90
- stem = band_specs_map.split(":")[-1]
91
- bsm_, freq_weights, overlapping_band = get_band_specs(
92
- "dnr:speech", n_fft, fs, n_bands
93
- )
94
- bsm = {stem: bsm_}
95
- else:
96
- raise NameError
97
-
98
- return bsm, freq_weights, overlapping_band
99
-
100
-
101
- class BandSplitWrapperBase(pl.LightningModule):
102
- bsrnn: nn.Module
103
-
104
- def __init__(self, **kwargs):
105
- super().__init__()
106
-
107
-
108
- class SingleMaskMultiSourceBandSplitBase(BandSplitWrapperBase, _SpectralComponent):
109
- def __init__(
110
- self,
111
- band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
112
- fs: int = 44100,
113
- n_fft: int = 2048,
114
- win_length: Optional[int] = 2048,
115
- hop_length: int = 512,
116
- window_fn: str = "hann_window",
117
- wkwargs: Optional[Dict] = None,
118
- power: Optional[int] = None,
119
- center: bool = True,
120
- normalized: bool = True,
121
- pad_mode: str = "constant",
122
- onesided: bool = True,
123
- n_bands: int = None,
124
- ) -> None:
125
- super().__init__(
126
- n_fft=n_fft,
127
- win_length=win_length,
128
- hop_length=hop_length,
129
- window_fn=window_fn,
130
- wkwargs=wkwargs,
131
- power=power,
132
- center=center,
133
- normalized=normalized,
134
- pad_mode=pad_mode,
135
- onesided=onesided,
136
- )
137
-
138
- if isinstance(band_specs_map, str):
139
- self.band_specs_map, self.freq_weights, self.overlapping_band = (
140
- get_band_specs_map(band_specs_map, n_fft, fs, n_bands=n_bands)
141
- )
142
-
143
- self.stems = list(self.band_specs_map.keys())
144
-
145
- def forward(self, batch):
146
- audio = batch["audio"]
147
-
148
- with torch.no_grad():
149
- batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in audio}
150
-
151
- X = batch["spectrogram"]["mixture"]
152
- length = batch["audio"]["mixture"].shape[-1]
153
-
154
- output = {"spectrogram": {}, "audio": {}}
155
-
156
- for stem, bsrnn in self.bsrnn.items():
157
- S = bsrnn(X)
158
- s = self.istft(S, length)
159
- output["spectrogram"][stem] = S
160
- output["audio"][stem] = s
161
-
162
- return batch, output
163
-
164
-
165
- class MultiMaskMultiSourceBandSplitBase(BandSplitWrapperBase, _SpectralComponent):
166
- def __init__(
167
- self,
168
- stems: List[str],
169
- band_specs: Union[str, List[Tuple[float, float]]],
170
- fs: int = 44100,
171
- n_fft: int = 2048,
172
- win_length: Optional[int] = 2048,
173
- hop_length: int = 512,
174
- window_fn: str = "hann_window",
175
- wkwargs: Optional[Dict] = None,
176
- power: Optional[int] = None,
177
- center: bool = True,
178
- normalized: bool = True,
179
- pad_mode: str = "constant",
180
- onesided: bool = True,
181
- n_bands: int = None,
182
- ) -> None:
183
- super().__init__(
184
- n_fft=n_fft,
185
- win_length=win_length,
186
- hop_length=hop_length,
187
- window_fn=window_fn,
188
- wkwargs=wkwargs,
189
- power=power,
190
- center=center,
191
- normalized=normalized,
192
- pad_mode=pad_mode,
193
- onesided=onesided,
194
- )
195
-
196
- if isinstance(band_specs, str):
197
- self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs(
198
- band_specs, n_fft, fs, n_bands
199
- )
200
-
201
- self.stems = stems
202
-
203
- def forward(self, batch):
204
- audio = batch["audio"]
205
- cond = batch.get("condition", None)
206
- with torch.no_grad():
207
- batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in audio}
208
-
209
- X = batch["spectrogram"]["mixture"]
210
- length = batch["audio"]["mixture"].shape[-1]
211
-
212
- output = self.bsrnn(X, cond=cond)
213
- output["audio"] = {}
214
-
215
- for stem, S in output["spectrogram"].items():
216
- s = self.istft(S, length)
217
- output["audio"][stem] = s
218
-
219
- return batch, output
220
-
221
-
222
- class MultiMaskMultiSourceBandSplitBaseSimple(BandSplitWrapperBase, _SpectralComponent):
223
- def __init__(
224
- self,
225
- stems: List[str],
226
- band_specs: Union[str, List[Tuple[float, float]]],
227
- fs: int = 44100,
228
- n_fft: int = 2048,
229
- win_length: Optional[int] = 2048,
230
- hop_length: int = 512,
231
- window_fn: str = "hann_window",
232
- wkwargs: Optional[Dict] = None,
233
- power: Optional[int] = None,
234
- center: bool = True,
235
- normalized: bool = True,
236
- pad_mode: str = "constant",
237
- onesided: bool = True,
238
- n_bands: int = None,
239
- ) -> None:
240
- super().__init__(
241
- n_fft=n_fft,
242
- win_length=win_length,
243
- hop_length=hop_length,
244
- window_fn=window_fn,
245
- wkwargs=wkwargs,
246
- power=power,
247
- center=center,
248
- normalized=normalized,
249
- pad_mode=pad_mode,
250
- onesided=onesided,
251
- )
252
-
253
- if isinstance(band_specs, str):
254
- self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs(
255
- band_specs, n_fft, fs, n_bands
256
- )
257
-
258
- self.stems = stems
259
-
260
- def forward(self, batch):
261
- with torch.no_grad():
262
- X = self.stft(batch)
263
- length = batch.shape[-1]
264
- output = self.bsrnn(X, cond=None)
265
- res = []
266
- for stem, S in output["spectrogram"].items():
267
- s = self.istft(S, length)
268
- res.append(s)
269
- res = torch.stack(res, dim=1)
270
- return res
271
-
272
-
273
- class SingleMaskMultiSourceBandSplitRNN(SingleMaskMultiSourceBandSplitBase):
274
- def __init__(
275
- self,
276
- in_channel: int,
277
- band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
278
- fs: int = 44100,
279
- require_no_overlap: bool = False,
280
- require_no_gap: bool = True,
281
- normalize_channel_independently: bool = False,
282
- treat_channel_as_feature: bool = True,
283
- n_sqm_modules: int = 12,
284
- emb_dim: int = 128,
285
- rnn_dim: int = 256,
286
- bidirectional: bool = True,
287
- rnn_type: str = "LSTM",
288
- mlp_dim: int = 512,
289
- hidden_activation: str = "Tanh",
290
- hidden_activation_kwargs: Optional[Dict] = None,
291
- complex_mask: bool = True,
292
- n_fft: int = 2048,
293
- win_length: Optional[int] = 2048,
294
- hop_length: int = 512,
295
- window_fn: str = "hann_window",
296
- wkwargs: Optional[Dict] = None,
297
- power: Optional[int] = None,
298
- center: bool = True,
299
- normalized: bool = True,
300
- pad_mode: str = "constant",
301
- onesided: bool = True,
302
- ) -> None:
303
- super().__init__(
304
- band_specs_map=band_specs_map,
305
- fs=fs,
306
- n_fft=n_fft,
307
- win_length=win_length,
308
- hop_length=hop_length,
309
- window_fn=window_fn,
310
- wkwargs=wkwargs,
311
- power=power,
312
- center=center,
313
- normalized=normalized,
314
- pad_mode=pad_mode,
315
- onesided=onesided,
316
- )
317
-
318
- self.bsrnn = nn.ModuleDict(
319
- {
320
- src: SingleMaskBandsplitCoreRNN(
321
- band_specs=specs,
322
- in_channel=in_channel,
323
- require_no_overlap=require_no_overlap,
324
- require_no_gap=require_no_gap,
325
- normalize_channel_independently=normalize_channel_independently,
326
- treat_channel_as_feature=treat_channel_as_feature,
327
- n_sqm_modules=n_sqm_modules,
328
- emb_dim=emb_dim,
329
- rnn_dim=rnn_dim,
330
- bidirectional=bidirectional,
331
- rnn_type=rnn_type,
332
- mlp_dim=mlp_dim,
333
- hidden_activation=hidden_activation,
334
- hidden_activation_kwargs=hidden_activation_kwargs,
335
- complex_mask=complex_mask,
336
- )
337
- for src, specs in self.band_specs_map.items()
338
- }
339
- )
340
-
341
-
342
- class SingleMaskMultiSourceBandSplitTransformer(SingleMaskMultiSourceBandSplitBase):
343
- def __init__(
344
- self,
345
- in_channel: int,
346
- band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
347
- fs: int = 44100,
348
- require_no_overlap: bool = False,
349
- require_no_gap: bool = True,
350
- normalize_channel_independently: bool = False,
351
- treat_channel_as_feature: bool = True,
352
- n_sqm_modules: int = 12,
353
- emb_dim: int = 128,
354
- rnn_dim: int = 256,
355
- bidirectional: bool = True,
356
- tf_dropout: float = 0.0,
357
- mlp_dim: int = 512,
358
- hidden_activation: str = "Tanh",
359
- hidden_activation_kwargs: Optional[Dict] = None,
360
- complex_mask: bool = True,
361
- n_fft: int = 2048,
362
- win_length: Optional[int] = 2048,
363
- hop_length: int = 512,
364
- window_fn: str = "hann_window",
365
- wkwargs: Optional[Dict] = None,
366
- power: Optional[int] = None,
367
- center: bool = True,
368
- normalized: bool = True,
369
- pad_mode: str = "constant",
370
- onesided: bool = True,
371
- ) -> None:
372
- super().__init__(
373
- band_specs_map=band_specs_map,
374
- fs=fs,
375
- n_fft=n_fft,
376
- win_length=win_length,
377
- hop_length=hop_length,
378
- window_fn=window_fn,
379
- wkwargs=wkwargs,
380
- power=power,
381
- center=center,
382
- normalized=normalized,
383
- pad_mode=pad_mode,
384
- onesided=onesided,
385
- )
386
-
387
- self.bsrnn = nn.ModuleDict(
388
- {
389
- src: SingleMaskBandsplitCoreTransformer(
390
- band_specs=specs,
391
- in_channel=in_channel,
392
- require_no_overlap=require_no_overlap,
393
- require_no_gap=require_no_gap,
394
- normalize_channel_independently=normalize_channel_independently,
395
- treat_channel_as_feature=treat_channel_as_feature,
396
- n_sqm_modules=n_sqm_modules,
397
- emb_dim=emb_dim,
398
- rnn_dim=rnn_dim,
399
- bidirectional=bidirectional,
400
- tf_dropout=tf_dropout,
401
- mlp_dim=mlp_dim,
402
- hidden_activation=hidden_activation,
403
- hidden_activation_kwargs=hidden_activation_kwargs,
404
- complex_mask=complex_mask,
405
- )
406
- for src, specs in self.band_specs_map.items()
407
- }
408
- )
409
-
410
-
411
- class MultiMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase):
412
- def __init__(
413
- self,
414
- in_channel: int,
415
- stems: List[str],
416
- band_specs: Union[str, List[Tuple[float, float]]],
417
- fs: int = 44100,
418
- require_no_overlap: bool = False,
419
- require_no_gap: bool = True,
420
- normalize_channel_independently: bool = False,
421
- treat_channel_as_feature: bool = True,
422
- n_sqm_modules: int = 12,
423
- emb_dim: int = 128,
424
- rnn_dim: int = 256,
425
- cond_dim: int = 0,
426
- bidirectional: bool = True,
427
- rnn_type: str = "LSTM",
428
- mlp_dim: int = 512,
429
- hidden_activation: str = "Tanh",
430
- hidden_activation_kwargs: Optional[Dict] = None,
431
- complex_mask: bool = True,
432
- n_fft: int = 2048,
433
- win_length: Optional[int] = 2048,
434
- hop_length: int = 512,
435
- window_fn: str = "hann_window",
436
- wkwargs: Optional[Dict] = None,
437
- power: Optional[int] = None,
438
- center: bool = True,
439
- normalized: bool = True,
440
- pad_mode: str = "constant",
441
- onesided: bool = True,
442
- n_bands: int = None,
443
- use_freq_weights: bool = True,
444
- normalize_input: bool = False,
445
- mult_add_mask: bool = False,
446
- freeze_encoder: bool = False,
447
- ) -> None:
448
- super().__init__(
449
- stems=stems,
450
- band_specs=band_specs,
451
- fs=fs,
452
- n_fft=n_fft,
453
- win_length=win_length,
454
- hop_length=hop_length,
455
- window_fn=window_fn,
456
- wkwargs=wkwargs,
457
- power=power,
458
- center=center,
459
- normalized=normalized,
460
- pad_mode=pad_mode,
461
- onesided=onesided,
462
- n_bands=n_bands,
463
- )
464
-
465
- self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN(
466
- stems=stems,
467
- band_specs=self.band_specs,
468
- in_channel=in_channel,
469
- require_no_overlap=require_no_overlap,
470
- require_no_gap=require_no_gap,
471
- normalize_channel_independently=normalize_channel_independently,
472
- treat_channel_as_feature=treat_channel_as_feature,
473
- n_sqm_modules=n_sqm_modules,
474
- emb_dim=emb_dim,
475
- rnn_dim=rnn_dim,
476
- bidirectional=bidirectional,
477
- rnn_type=rnn_type,
478
- mlp_dim=mlp_dim,
479
- cond_dim=cond_dim,
480
- hidden_activation=hidden_activation,
481
- hidden_activation_kwargs=hidden_activation_kwargs,
482
- complex_mask=complex_mask,
483
- overlapping_band=self.overlapping_band,
484
- freq_weights=self.freq_weights,
485
- n_freq=n_fft // 2 + 1,
486
- use_freq_weights=use_freq_weights,
487
- mult_add_mask=mult_add_mask,
488
- )
489
-
490
- self.normalize_input = normalize_input
491
- self.cond_dim = cond_dim
492
-
493
- if freeze_encoder:
494
- for param in self.bsrnn.band_split.parameters():
495
- param.requires_grad = False
496
-
497
- for param in self.bsrnn.tf_model.parameters():
498
- param.requires_grad = False
499
-
500
-
501
- class MultiMaskMultiSourceBandSplitRNNSimple(MultiMaskMultiSourceBandSplitBaseSimple):
502
- def __init__(
503
- self,
504
- in_channel: int,
505
- stems: List[str],
506
- band_specs: Union[str, List[Tuple[float, float]]],
507
- fs: int = 44100,
508
- require_no_overlap: bool = False,
509
- require_no_gap: bool = True,
510
- normalize_channel_independently: bool = False,
511
- treat_channel_as_feature: bool = True,
512
- n_sqm_modules: int = 12,
513
- emb_dim: int = 128,
514
- rnn_dim: int = 256,
515
- cond_dim: int = 0,
516
- bidirectional: bool = True,
517
- rnn_type: str = "LSTM",
518
- mlp_dim: int = 512,
519
- hidden_activation: str = "Tanh",
520
- hidden_activation_kwargs: Optional[Dict] = None,
521
- complex_mask: bool = True,
522
- n_fft: int = 2048,
523
- win_length: Optional[int] = 2048,
524
- hop_length: int = 512,
525
- window_fn: str = "hann_window",
526
- wkwargs: Optional[Dict] = None,
527
- power: Optional[int] = None,
528
- center: bool = True,
529
- normalized: bool = True,
530
- pad_mode: str = "constant",
531
- onesided: bool = True,
532
- n_bands: int = None,
533
- use_freq_weights: bool = True,
534
- normalize_input: bool = False,
535
- mult_add_mask: bool = False,
536
- freeze_encoder: bool = False,
537
- ) -> None:
538
- super().__init__(
539
- stems=stems,
540
- band_specs=band_specs,
541
- fs=fs,
542
- n_fft=n_fft,
543
- win_length=win_length,
544
- hop_length=hop_length,
545
- window_fn=window_fn,
546
- wkwargs=wkwargs,
547
- power=power,
548
- center=center,
549
- normalized=normalized,
550
- pad_mode=pad_mode,
551
- onesided=onesided,
552
- n_bands=n_bands,
553
- )
554
-
555
- self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN(
556
- stems=stems,
557
- band_specs=self.band_specs,
558
- in_channel=in_channel,
559
- require_no_overlap=require_no_overlap,
560
- require_no_gap=require_no_gap,
561
- normalize_channel_independently=normalize_channel_independently,
562
- treat_channel_as_feature=treat_channel_as_feature,
563
- n_sqm_modules=n_sqm_modules,
564
- emb_dim=emb_dim,
565
- rnn_dim=rnn_dim,
566
- bidirectional=bidirectional,
567
- rnn_type=rnn_type,
568
- mlp_dim=mlp_dim,
569
- cond_dim=cond_dim,
570
- hidden_activation=hidden_activation,
571
- hidden_activation_kwargs=hidden_activation_kwargs,
572
- complex_mask=complex_mask,
573
- overlapping_band=self.overlapping_band,
574
- freq_weights=self.freq_weights,
575
- n_freq=n_fft // 2 + 1,
576
- use_freq_weights=use_freq_weights,
577
- mult_add_mask=mult_add_mask,
578
- )
579
-
580
- self.normalize_input = normalize_input
581
- self.cond_dim = cond_dim
582
-
583
- if freeze_encoder:
584
- for param in self.bsrnn.band_split.parameters():
585
- param.requires_grad = False
586
-
587
- for param in self.bsrnn.tf_model.parameters():
588
- param.requires_grad = False
589
-
590
-
591
- class MultiMaskMultiSourceBandSplitTransformer(MultiMaskMultiSourceBandSplitBase):
592
- def __init__(
593
- self,
594
- in_channel: int,
595
- stems: List[str],
596
- band_specs: Union[str, List[Tuple[float, float]]],
597
- fs: int = 44100,
598
- require_no_overlap: bool = False,
599
- require_no_gap: bool = True,
600
- normalize_channel_independently: bool = False,
601
- treat_channel_as_feature: bool = True,
602
- n_sqm_modules: int = 12,
603
- emb_dim: int = 128,
604
- rnn_dim: int = 256,
605
- cond_dim: int = 0,
606
- bidirectional: bool = True,
607
- rnn_type: str = "LSTM",
608
- mlp_dim: int = 512,
609
- hidden_activation: str = "Tanh",
610
- hidden_activation_kwargs: Optional[Dict] = None,
611
- complex_mask: bool = True,
612
- n_fft: int = 2048,
613
- win_length: Optional[int] = 2048,
614
- hop_length: int = 512,
615
- window_fn: str = "hann_window",
616
- wkwargs: Optional[Dict] = None,
617
- power: Optional[int] = None,
618
- center: bool = True,
619
- normalized: bool = True,
620
- pad_mode: str = "constant",
621
- onesided: bool = True,
622
- n_bands: int = None,
623
- use_freq_weights: bool = True,
624
- normalize_input: bool = False,
625
- mult_add_mask: bool = False,
626
- ) -> None:
627
- super().__init__(
628
- stems=stems,
629
- band_specs=band_specs,
630
- fs=fs,
631
- n_fft=n_fft,
632
- win_length=win_length,
633
- hop_length=hop_length,
634
- window_fn=window_fn,
635
- wkwargs=wkwargs,
636
- power=power,
637
- center=center,
638
- normalized=normalized,
639
- pad_mode=pad_mode,
640
- onesided=onesided,
641
- n_bands=n_bands,
642
- )
643
-
644
- self.bsrnn = MultiSourceMultiMaskBandSplitCoreTransformer(
645
- stems=stems,
646
- band_specs=self.band_specs,
647
- in_channel=in_channel,
648
- require_no_overlap=require_no_overlap,
649
- require_no_gap=require_no_gap,
650
- normalize_channel_independently=normalize_channel_independently,
651
- treat_channel_as_feature=treat_channel_as_feature,
652
- n_sqm_modules=n_sqm_modules,
653
- emb_dim=emb_dim,
654
- rnn_dim=rnn_dim,
655
- bidirectional=bidirectional,
656
- rnn_type=rnn_type,
657
- mlp_dim=mlp_dim,
658
- cond_dim=cond_dim,
659
- hidden_activation=hidden_activation,
660
- hidden_activation_kwargs=hidden_activation_kwargs,
661
- complex_mask=complex_mask,
662
- overlapping_band=self.overlapping_band,
663
- freq_weights=self.freq_weights,
664
- n_freq=n_fft // 2 + 1,
665
- use_freq_weights=use_freq_weights,
666
- mult_add_mask=mult_add_mask,
667
- )
668
-
669
-
670
- class MultiMaskMultiSourceBandSplitConv(MultiMaskMultiSourceBandSplitBase):
671
- def __init__(
672
- self,
673
- in_channel: int,
674
- stems: List[str],
675
- band_specs: Union[str, List[Tuple[float, float]]],
676
- fs: int = 44100,
677
- require_no_overlap: bool = False,
678
- require_no_gap: bool = True,
679
- normalize_channel_independently: bool = False,
680
- treat_channel_as_feature: bool = True,
681
- n_sqm_modules: int = 12,
682
- emb_dim: int = 128,
683
- rnn_dim: int = 256,
684
- cond_dim: int = 0,
685
- bidirectional: bool = True,
686
- rnn_type: str = "LSTM",
687
- mlp_dim: int = 512,
688
- hidden_activation: str = "Tanh",
689
- hidden_activation_kwargs: Optional[Dict] = None,
690
- complex_mask: bool = True,
691
- n_fft: int = 2048,
692
- win_length: Optional[int] = 2048,
693
- hop_length: int = 512,
694
- window_fn: str = "hann_window",
695
- wkwargs: Optional[Dict] = None,
696
- power: Optional[int] = None,
697
- center: bool = True,
698
- normalized: bool = True,
699
- pad_mode: str = "constant",
700
- onesided: bool = True,
701
- n_bands: int = None,
702
- use_freq_weights: bool = True,
703
- normalize_input: bool = False,
704
- mult_add_mask: bool = False,
705
- ) -> None:
706
- super().__init__(
707
- stems=stems,
708
- band_specs=band_specs,
709
- fs=fs,
710
- n_fft=n_fft,
711
- win_length=win_length,
712
- hop_length=hop_length,
713
- window_fn=window_fn,
714
- wkwargs=wkwargs,
715
- power=power,
716
- center=center,
717
- normalized=normalized,
718
- pad_mode=pad_mode,
719
- onesided=onesided,
720
- n_bands=n_bands,
721
- )
722
-
723
- self.bsrnn = MultiSourceMultiMaskBandSplitCoreConv(
724
- stems=stems,
725
- band_specs=self.band_specs,
726
- in_channel=in_channel,
727
- require_no_overlap=require_no_overlap,
728
- require_no_gap=require_no_gap,
729
- normalize_channel_independently=normalize_channel_independently,
730
- treat_channel_as_feature=treat_channel_as_feature,
731
- n_sqm_modules=n_sqm_modules,
732
- emb_dim=emb_dim,
733
- rnn_dim=rnn_dim,
734
- bidirectional=bidirectional,
735
- rnn_type=rnn_type,
736
- mlp_dim=mlp_dim,
737
- cond_dim=cond_dim,
738
- hidden_activation=hidden_activation,
739
- hidden_activation_kwargs=hidden_activation_kwargs,
740
- complex_mask=complex_mask,
741
- overlapping_band=self.overlapping_band,
742
- freq_weights=self.freq_weights,
743
- n_freq=n_fft // 2 + 1,
744
- use_freq_weights=use_freq_weights,
745
- mult_add_mask=mult_add_mask,
746
- )
747
-
748
-
749
- class PatchingMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase):
750
- def __init__(
751
- self,
752
- in_channel: int,
753
- stems: List[str],
754
- band_specs: Union[str, List[Tuple[float, float]]],
755
- kernel_norm_mlp_version: int = 1,
756
- mask_kernel_freq: int = 3,
757
- mask_kernel_time: int = 3,
758
- conv_kernel_freq: int = 1,
759
- conv_kernel_time: int = 1,
760
- fs: int = 44100,
761
- require_no_overlap: bool = False,
762
- require_no_gap: bool = True,
763
- normalize_channel_independently: bool = False,
764
- treat_channel_as_feature: bool = True,
765
- n_sqm_modules: int = 12,
766
- emb_dim: int = 128,
767
- rnn_dim: int = 256,
768
- bidirectional: bool = True,
769
- rnn_type: str = "LSTM",
770
- mlp_dim: int = 512,
771
- hidden_activation: str = "Tanh",
772
- hidden_activation_kwargs: Optional[Dict] = None,
773
- complex_mask: bool = True,
774
- n_fft: int = 2048,
775
- win_length: Optional[int] = 2048,
776
- hop_length: int = 512,
777
- window_fn: str = "hann_window",
778
- wkwargs: Optional[Dict] = None,
779
- power: Optional[int] = None,
780
- center: bool = True,
781
- normalized: bool = True,
782
- pad_mode: str = "constant",
783
- onesided: bool = True,
784
- n_bands: int = None,
785
- ) -> None:
786
- super().__init__(
787
- stems=stems,
788
- band_specs=band_specs,
789
- fs=fs,
790
- n_fft=n_fft,
791
- win_length=win_length,
792
- hop_length=hop_length,
793
- window_fn=window_fn,
794
- wkwargs=wkwargs,
795
- power=power,
796
- center=center,
797
- normalized=normalized,
798
- pad_mode=pad_mode,
799
- onesided=onesided,
800
- n_bands=n_bands,
801
- )
802
-
803
- self.bsrnn = MultiSourceMultiPatchingMaskBandSplitCoreRNN(
804
- stems=stems,
805
- band_specs=self.band_specs,
806
- in_channel=in_channel,
807
- require_no_overlap=require_no_overlap,
808
- require_no_gap=require_no_gap,
809
- normalize_channel_independently=normalize_channel_independently,
810
- treat_channel_as_feature=treat_channel_as_feature,
811
- n_sqm_modules=n_sqm_modules,
812
- emb_dim=emb_dim,
813
- rnn_dim=rnn_dim,
814
- bidirectional=bidirectional,
815
- rnn_type=rnn_type,
816
- mlp_dim=mlp_dim,
817
- hidden_activation=hidden_activation,
818
- hidden_activation_kwargs=hidden_activation_kwargs,
819
- complex_mask=complex_mask,
820
- overlapping_band=self.overlapping_band,
821
- freq_weights=self.freq_weights,
822
- n_freq=n_fft // 2 + 1,
823
- mask_kernel_freq=mask_kernel_freq,
824
- mask_kernel_time=mask_kernel_time,
825
- conv_kernel_freq=conv_kernel_freq,
826
- conv_kernel_time=conv_kernel_time,
827
- kernel_norm_mlp_version=kernel_norm_mlp_version,
828
- )
 
1
+ from pprint import pprint
2
+ from typing import Dict, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from .._spectral import _SpectralComponent
8
+ from .utils import (
9
+ BarkBandsplitSpecification,
10
+ BassBandsplitSpecification,
11
+ DrumBandsplitSpecification,
12
+ EquivalentRectangularBandsplitSpecification,
13
+ MelBandsplitSpecification,
14
+ MusicalBandsplitSpecification,
15
+ OtherBandsplitSpecification,
16
+ TriangularBarkBandsplitSpecification,
17
+ VocalBandsplitSpecification,
18
+ )
19
+ from .core import (
20
+ MultiSourceMultiMaskBandSplitCoreConv,
21
+ MultiSourceMultiMaskBandSplitCoreRNN,
22
+ MultiSourceMultiMaskBandSplitCoreTransformer,
23
+ MultiSourceMultiPatchingMaskBandSplitCoreRNN,
24
+ SingleMaskBandsplitCoreRNN,
25
+ SingleMaskBandsplitCoreTransformer,
26
+ )
27
+
28
+ import pytorch_lightning as pl
29
+
30
+
31
+ def get_band_specs(band_specs, n_fft, fs, n_bands=None):
32
+ if band_specs in ["dnr:speech", "dnr:vox7", "musdb:vocals", "musdb:vox7"]:
33
+ bsm = VocalBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs()
34
+ freq_weights = None
35
+ overlapping_band = False
36
+ elif "tribark" in band_specs:
37
+ assert n_bands is not None
38
+ specs = TriangularBarkBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
39
+ bsm = specs.get_band_specs()
40
+ freq_weights = specs.get_freq_weights()
41
+ overlapping_band = True
42
+ elif "bark" in band_specs:
43
+ assert n_bands is not None
44
+ specs = BarkBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
45
+ bsm = specs.get_band_specs()
46
+ freq_weights = specs.get_freq_weights()
47
+ overlapping_band = True
48
+ elif "erb" in band_specs:
49
+ assert n_bands is not None
50
+ specs = EquivalentRectangularBandsplitSpecification(
51
+ nfft=n_fft, fs=fs, n_bands=n_bands
52
+ )
53
+ bsm = specs.get_band_specs()
54
+ freq_weights = specs.get_freq_weights()
55
+ overlapping_band = True
56
+ elif "musical" in band_specs:
57
+ assert n_bands is not None
58
+ specs = MusicalBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
59
+ bsm = specs.get_band_specs()
60
+ freq_weights = specs.get_freq_weights()
61
+ overlapping_band = True
62
+ elif band_specs == "dnr:mel" or "mel" in band_specs:
63
+ assert n_bands is not None
64
+ specs = MelBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
65
+ bsm = specs.get_band_specs()
66
+ freq_weights = specs.get_freq_weights()
67
+ overlapping_band = True
68
+ else:
69
+ raise NameError
70
+
71
+ return bsm, freq_weights, overlapping_band
72
+
73
+
74
+ def get_band_specs_map(band_specs_map, n_fft, fs, n_bands=None):
75
+ if band_specs_map == "musdb:all":
76
+ bsm = {
77
+ "vocals": VocalBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
78
+ "drums": DrumBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
79
+ "bass": BassBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
80
+ "other": OtherBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
81
+ }
82
+ freq_weights = None
83
+ overlapping_band = False
84
+ elif band_specs_map == "dnr:vox7":
85
+ bsm_, freq_weights, overlapping_band = get_band_specs(
86
+ "dnr:speech", n_fft, fs, n_bands
87
+ )
88
+ bsm = {"speech": bsm_, "music": bsm_, "effects": bsm_}
89
+ elif "dnr:vox7:" in band_specs_map:
90
+ stem = band_specs_map.split(":")[-1]
91
+ bsm_, freq_weights, overlapping_band = get_band_specs(
92
+ "dnr:speech", n_fft, fs, n_bands
93
+ )
94
+ bsm = {stem: bsm_}
95
+ else:
96
+ raise NameError
97
+
98
+ return bsm, freq_weights, overlapping_band
99
+
100
+
101
+ class BandSplitWrapperBase(pl.LightningModule):
102
+ bsrnn: nn.Module
103
+
104
+ def __init__(self, **kwargs):
105
+ super().__init__()
106
+
107
+
108
+ class SingleMaskMultiSourceBandSplitBase(BandSplitWrapperBase, _SpectralComponent):
109
+ def __init__(
110
+ self,
111
+ band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
112
+ fs: int = 44100,
113
+ n_fft: int = 2048,
114
+ win_length: Optional[int] = 2048,
115
+ hop_length: int = 512,
116
+ window_fn: str = "hann_window",
117
+ wkwargs: Optional[Dict] = None,
118
+ power: Optional[int] = None,
119
+ center: bool = True,
120
+ normalized: bool = True,
121
+ pad_mode: str = "constant",
122
+ onesided: bool = True,
123
+ n_bands: int = None,
124
+ ) -> None:
125
+ super().__init__(
126
+ n_fft=n_fft,
127
+ win_length=win_length,
128
+ hop_length=hop_length,
129
+ window_fn=window_fn,
130
+ wkwargs=wkwargs,
131
+ power=power,
132
+ center=center,
133
+ normalized=normalized,
134
+ pad_mode=pad_mode,
135
+ onesided=onesided,
136
+ )
137
+
138
+ if isinstance(band_specs_map, str):
139
+ self.band_specs_map, self.freq_weights, self.overlapping_band = (
140
+ get_band_specs_map(band_specs_map, n_fft, fs, n_bands=n_bands)
141
+ )
142
+
143
+ self.stems = list(self.band_specs_map.keys())
144
+
145
+ def forward(self, batch):
146
+ audio = batch["audio"]
147
+
148
+ with torch.no_grad():
149
+ batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in audio}
150
+
151
+ X = batch["spectrogram"]["mixture"]
152
+ length = batch["audio"]["mixture"].shape[-1]
153
+
154
+ output = {"spectrogram": {}, "audio": {}}
155
+
156
+ for stem, bsrnn in self.bsrnn.items():
157
+ S = bsrnn(X)
158
+ s = self.istft(S, length)
159
+ output["spectrogram"][stem] = S
160
+ output["audio"][stem] = s
161
+
162
+ return batch, output
163
+
164
+
165
+ class MultiMaskMultiSourceBandSplitBase(BandSplitWrapperBase, _SpectralComponent):
166
+ def __init__(
167
+ self,
168
+ stems: List[str],
169
+ band_specs: Union[str, List[Tuple[float, float]]],
170
+ fs: int = 44100,
171
+ n_fft: int = 2048,
172
+ win_length: Optional[int] = 2048,
173
+ hop_length: int = 512,
174
+ window_fn: str = "hann_window",
175
+ wkwargs: Optional[Dict] = None,
176
+ power: Optional[int] = None,
177
+ center: bool = True,
178
+ normalized: bool = True,
179
+ pad_mode: str = "constant",
180
+ onesided: bool = True,
181
+ n_bands: int = None,
182
+ ) -> None:
183
+ super().__init__(
184
+ n_fft=n_fft,
185
+ win_length=win_length,
186
+ hop_length=hop_length,
187
+ window_fn=window_fn,
188
+ wkwargs=wkwargs,
189
+ power=power,
190
+ center=center,
191
+ normalized=normalized,
192
+ pad_mode=pad_mode,
193
+ onesided=onesided,
194
+ )
195
+
196
+ if isinstance(band_specs, str):
197
+ self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs(
198
+ band_specs, n_fft, fs, n_bands
199
+ )
200
+
201
+ self.stems = stems
202
+
203
+ def forward(self, batch):
204
+ audio = batch["audio"]
205
+ cond = batch.get("condition", None)
206
+ with torch.no_grad():
207
+ batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in audio}
208
+
209
+ X = batch["spectrogram"]["mixture"]
210
+ length = batch["audio"]["mixture"].shape[-1]
211
+
212
+ output = self.bsrnn(X, cond=cond)
213
+ output["audio"] = {}
214
+
215
+ for stem, S in output["spectrogram"].items():
216
+ s = self.istft(S, length)
217
+ output["audio"][stem] = s
218
+
219
+ return batch, output
220
+
221
+
222
+ class MultiMaskMultiSourceBandSplitBaseSimple(BandSplitWrapperBase, _SpectralComponent):
223
+ def __init__(
224
+ self,
225
+ stems: List[str],
226
+ band_specs: Union[str, List[Tuple[float, float]]],
227
+ fs: int = 44100,
228
+ n_fft: int = 2048,
229
+ win_length: Optional[int] = 2048,
230
+ hop_length: int = 512,
231
+ window_fn: str = "hann_window",
232
+ wkwargs: Optional[Dict] = None,
233
+ power: Optional[int] = None,
234
+ center: bool = True,
235
+ normalized: bool = True,
236
+ pad_mode: str = "constant",
237
+ onesided: bool = True,
238
+ n_bands: int = None,
239
+ ) -> None:
240
+ super().__init__(
241
+ n_fft=n_fft,
242
+ win_length=win_length,
243
+ hop_length=hop_length,
244
+ window_fn=window_fn,
245
+ wkwargs=wkwargs,
246
+ power=power,
247
+ center=center,
248
+ normalized=normalized,
249
+ pad_mode=pad_mode,
250
+ onesided=onesided,
251
+ )
252
+
253
+ if isinstance(band_specs, str):
254
+ self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs(
255
+ band_specs, n_fft, fs, n_bands
256
+ )
257
+
258
+ self.stems = stems
259
+
260
+ def forward(self, batch):
261
+ with torch.no_grad():
262
+ X = self.stft(batch)
263
+ length = batch.shape[-1]
264
+ output = self.bsrnn(X, cond=None)
265
+ res = []
266
+ for stem, S in output["spectrogram"].items():
267
+ s = self.istft(S, length)
268
+ res.append(s)
269
+ res = torch.stack(res, dim=1)
270
+ return res
271
+
272
+
273
+ class SingleMaskMultiSourceBandSplitRNN(SingleMaskMultiSourceBandSplitBase):
274
+ def __init__(
275
+ self,
276
+ in_channel: int,
277
+ band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
278
+ fs: int = 44100,
279
+ require_no_overlap: bool = False,
280
+ require_no_gap: bool = True,
281
+ normalize_channel_independently: bool = False,
282
+ treat_channel_as_feature: bool = True,
283
+ n_sqm_modules: int = 12,
284
+ emb_dim: int = 128,
285
+ rnn_dim: int = 256,
286
+ bidirectional: bool = True,
287
+ rnn_type: str = "LSTM",
288
+ mlp_dim: int = 512,
289
+ hidden_activation: str = "Tanh",
290
+ hidden_activation_kwargs: Optional[Dict] = None,
291
+ complex_mask: bool = True,
292
+ n_fft: int = 2048,
293
+ win_length: Optional[int] = 2048,
294
+ hop_length: int = 512,
295
+ window_fn: str = "hann_window",
296
+ wkwargs: Optional[Dict] = None,
297
+ power: Optional[int] = None,
298
+ center: bool = True,
299
+ normalized: bool = True,
300
+ pad_mode: str = "constant",
301
+ onesided: bool = True,
302
+ ) -> None:
303
+ super().__init__(
304
+ band_specs_map=band_specs_map,
305
+ fs=fs,
306
+ n_fft=n_fft,
307
+ win_length=win_length,
308
+ hop_length=hop_length,
309
+ window_fn=window_fn,
310
+ wkwargs=wkwargs,
311
+ power=power,
312
+ center=center,
313
+ normalized=normalized,
314
+ pad_mode=pad_mode,
315
+ onesided=onesided,
316
+ )
317
+
318
+ self.bsrnn = nn.ModuleDict(
319
+ {
320
+ src: SingleMaskBandsplitCoreRNN(
321
+ band_specs=specs,
322
+ in_channel=in_channel,
323
+ require_no_overlap=require_no_overlap,
324
+ require_no_gap=require_no_gap,
325
+ normalize_channel_independently=normalize_channel_independently,
326
+ treat_channel_as_feature=treat_channel_as_feature,
327
+ n_sqm_modules=n_sqm_modules,
328
+ emb_dim=emb_dim,
329
+ rnn_dim=rnn_dim,
330
+ bidirectional=bidirectional,
331
+ rnn_type=rnn_type,
332
+ mlp_dim=mlp_dim,
333
+ hidden_activation=hidden_activation,
334
+ hidden_activation_kwargs=hidden_activation_kwargs,
335
+ complex_mask=complex_mask,
336
+ )
337
+ for src, specs in self.band_specs_map.items()
338
+ }
339
+ )
340
+
341
+
342
+ class SingleMaskMultiSourceBandSplitTransformer(SingleMaskMultiSourceBandSplitBase):
343
+ def __init__(
344
+ self,
345
+ in_channel: int,
346
+ band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
347
+ fs: int = 44100,
348
+ require_no_overlap: bool = False,
349
+ require_no_gap: bool = True,
350
+ normalize_channel_independently: bool = False,
351
+ treat_channel_as_feature: bool = True,
352
+ n_sqm_modules: int = 12,
353
+ emb_dim: int = 128,
354
+ rnn_dim: int = 256,
355
+ bidirectional: bool = True,
356
+ tf_dropout: float = 0.0,
357
+ mlp_dim: int = 512,
358
+ hidden_activation: str = "Tanh",
359
+ hidden_activation_kwargs: Optional[Dict] = None,
360
+ complex_mask: bool = True,
361
+ n_fft: int = 2048,
362
+ win_length: Optional[int] = 2048,
363
+ hop_length: int = 512,
364
+ window_fn: str = "hann_window",
365
+ wkwargs: Optional[Dict] = None,
366
+ power: Optional[int] = None,
367
+ center: bool = True,
368
+ normalized: bool = True,
369
+ pad_mode: str = "constant",
370
+ onesided: bool = True,
371
+ ) -> None:
372
+ super().__init__(
373
+ band_specs_map=band_specs_map,
374
+ fs=fs,
375
+ n_fft=n_fft,
376
+ win_length=win_length,
377
+ hop_length=hop_length,
378
+ window_fn=window_fn,
379
+ wkwargs=wkwargs,
380
+ power=power,
381
+ center=center,
382
+ normalized=normalized,
383
+ pad_mode=pad_mode,
384
+ onesided=onesided,
385
+ )
386
+
387
+ self.bsrnn = nn.ModuleDict(
388
+ {
389
+ src: SingleMaskBandsplitCoreTransformer(
390
+ band_specs=specs,
391
+ in_channel=in_channel,
392
+ require_no_overlap=require_no_overlap,
393
+ require_no_gap=require_no_gap,
394
+ normalize_channel_independently=normalize_channel_independently,
395
+ treat_channel_as_feature=treat_channel_as_feature,
396
+ n_sqm_modules=n_sqm_modules,
397
+ emb_dim=emb_dim,
398
+ rnn_dim=rnn_dim,
399
+ bidirectional=bidirectional,
400
+ tf_dropout=tf_dropout,
401
+ mlp_dim=mlp_dim,
402
+ hidden_activation=hidden_activation,
403
+ hidden_activation_kwargs=hidden_activation_kwargs,
404
+ complex_mask=complex_mask,
405
+ )
406
+ for src, specs in self.band_specs_map.items()
407
+ }
408
+ )
409
+
410
+
411
+ class MultiMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase):
412
+ def __init__(
413
+ self,
414
+ in_channel: int,
415
+ stems: List[str],
416
+ band_specs: Union[str, List[Tuple[float, float]]],
417
+ fs: int = 44100,
418
+ require_no_overlap: bool = False,
419
+ require_no_gap: bool = True,
420
+ normalize_channel_independently: bool = False,
421
+ treat_channel_as_feature: bool = True,
422
+ n_sqm_modules: int = 12,
423
+ emb_dim: int = 128,
424
+ rnn_dim: int = 256,
425
+ cond_dim: int = 0,
426
+ bidirectional: bool = True,
427
+ rnn_type: str = "LSTM",
428
+ mlp_dim: int = 512,
429
+ hidden_activation: str = "Tanh",
430
+ hidden_activation_kwargs: Optional[Dict] = None,
431
+ complex_mask: bool = True,
432
+ n_fft: int = 2048,
433
+ win_length: Optional[int] = 2048,
434
+ hop_length: int = 512,
435
+ window_fn: str = "hann_window",
436
+ wkwargs: Optional[Dict] = None,
437
+ power: Optional[int] = None,
438
+ center: bool = True,
439
+ normalized: bool = True,
440
+ pad_mode: str = "constant",
441
+ onesided: bool = True,
442
+ n_bands: int = None,
443
+ use_freq_weights: bool = True,
444
+ normalize_input: bool = False,
445
+ mult_add_mask: bool = False,
446
+ freeze_encoder: bool = False,
447
+ ) -> None:
448
+ super().__init__(
449
+ stems=stems,
450
+ band_specs=band_specs,
451
+ fs=fs,
452
+ n_fft=n_fft,
453
+ win_length=win_length,
454
+ hop_length=hop_length,
455
+ window_fn=window_fn,
456
+ wkwargs=wkwargs,
457
+ power=power,
458
+ center=center,
459
+ normalized=normalized,
460
+ pad_mode=pad_mode,
461
+ onesided=onesided,
462
+ n_bands=n_bands,
463
+ )
464
+
465
+ self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN(
466
+ stems=stems,
467
+ band_specs=self.band_specs,
468
+ in_channel=in_channel,
469
+ require_no_overlap=require_no_overlap,
470
+ require_no_gap=require_no_gap,
471
+ normalize_channel_independently=normalize_channel_independently,
472
+ treat_channel_as_feature=treat_channel_as_feature,
473
+ n_sqm_modules=n_sqm_modules,
474
+ emb_dim=emb_dim,
475
+ rnn_dim=rnn_dim,
476
+ bidirectional=bidirectional,
477
+ rnn_type=rnn_type,
478
+ mlp_dim=mlp_dim,
479
+ cond_dim=cond_dim,
480
+ hidden_activation=hidden_activation,
481
+ hidden_activation_kwargs=hidden_activation_kwargs,
482
+ complex_mask=complex_mask,
483
+ overlapping_band=self.overlapping_band,
484
+ freq_weights=self.freq_weights,
485
+ n_freq=n_fft // 2 + 1,
486
+ use_freq_weights=use_freq_weights,
487
+ mult_add_mask=mult_add_mask,
488
+ )
489
+
490
+ self.normalize_input = normalize_input
491
+ self.cond_dim = cond_dim
492
+
493
+ if freeze_encoder:
494
+ for param in self.bsrnn.band_split.parameters():
495
+ param.requires_grad = False
496
+
497
+ for param in self.bsrnn.tf_model.parameters():
498
+ param.requires_grad = False
499
+
500
+
501
+ class MultiMaskMultiSourceBandSplitRNNSimple(MultiMaskMultiSourceBandSplitBaseSimple):
502
+ def __init__(
503
+ self,
504
+ in_channel: int,
505
+ stems: List[str],
506
+ band_specs: Union[str, List[Tuple[float, float]]],
507
+ fs: int = 44100,
508
+ require_no_overlap: bool = False,
509
+ require_no_gap: bool = True,
510
+ normalize_channel_independently: bool = False,
511
+ treat_channel_as_feature: bool = True,
512
+ n_sqm_modules: int = 12,
513
+ emb_dim: int = 128,
514
+ rnn_dim: int = 256,
515
+ cond_dim: int = 0,
516
+ bidirectional: bool = True,
517
+ rnn_type: str = "LSTM",
518
+ mlp_dim: int = 512,
519
+ hidden_activation: str = "Tanh",
520
+ hidden_activation_kwargs: Optional[Dict] = None,
521
+ complex_mask: bool = True,
522
+ n_fft: int = 2048,
523
+ win_length: Optional[int] = 2048,
524
+ hop_length: int = 512,
525
+ window_fn: str = "hann_window",
526
+ wkwargs: Optional[Dict] = None,
527
+ power: Optional[int] = None,
528
+ center: bool = True,
529
+ normalized: bool = True,
530
+ pad_mode: str = "constant",
531
+ onesided: bool = True,
532
+ n_bands: int = None,
533
+ use_freq_weights: bool = True,
534
+ normalize_input: bool = False,
535
+ mult_add_mask: bool = False,
536
+ freeze_encoder: bool = False,
537
+ ) -> None:
538
+ super().__init__(
539
+ stems=stems,
540
+ band_specs=band_specs,
541
+ fs=fs,
542
+ n_fft=n_fft,
543
+ win_length=win_length,
544
+ hop_length=hop_length,
545
+ window_fn=window_fn,
546
+ wkwargs=wkwargs,
547
+ power=power,
548
+ center=center,
549
+ normalized=normalized,
550
+ pad_mode=pad_mode,
551
+ onesided=onesided,
552
+ n_bands=n_bands,
553
+ )
554
+
555
+ self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN(
556
+ stems=stems,
557
+ band_specs=self.band_specs,
558
+ in_channel=in_channel,
559
+ require_no_overlap=require_no_overlap,
560
+ require_no_gap=require_no_gap,
561
+ normalize_channel_independently=normalize_channel_independently,
562
+ treat_channel_as_feature=treat_channel_as_feature,
563
+ n_sqm_modules=n_sqm_modules,
564
+ emb_dim=emb_dim,
565
+ rnn_dim=rnn_dim,
566
+ bidirectional=bidirectional,
567
+ rnn_type=rnn_type,
568
+ mlp_dim=mlp_dim,
569
+ cond_dim=cond_dim,
570
+ hidden_activation=hidden_activation,
571
+ hidden_activation_kwargs=hidden_activation_kwargs,
572
+ complex_mask=complex_mask,
573
+ overlapping_band=self.overlapping_band,
574
+ freq_weights=self.freq_weights,
575
+ n_freq=n_fft // 2 + 1,
576
+ use_freq_weights=use_freq_weights,
577
+ mult_add_mask=mult_add_mask,
578
+ )
579
+
580
+ self.normalize_input = normalize_input
581
+ self.cond_dim = cond_dim
582
+
583
+ if freeze_encoder:
584
+ for param in self.bsrnn.band_split.parameters():
585
+ param.requires_grad = False
586
+
587
+ for param in self.bsrnn.tf_model.parameters():
588
+ param.requires_grad = False
589
+
590
+
591
+ class MultiMaskMultiSourceBandSplitTransformer(MultiMaskMultiSourceBandSplitBase):
592
+ def __init__(
593
+ self,
594
+ in_channel: int,
595
+ stems: List[str],
596
+ band_specs: Union[str, List[Tuple[float, float]]],
597
+ fs: int = 44100,
598
+ require_no_overlap: bool = False,
599
+ require_no_gap: bool = True,
600
+ normalize_channel_independently: bool = False,
601
+ treat_channel_as_feature: bool = True,
602
+ n_sqm_modules: int = 12,
603
+ emb_dim: int = 128,
604
+ rnn_dim: int = 256,
605
+ cond_dim: int = 0,
606
+ bidirectional: bool = True,
607
+ rnn_type: str = "LSTM",
608
+ mlp_dim: int = 512,
609
+ hidden_activation: str = "Tanh",
610
+ hidden_activation_kwargs: Optional[Dict] = None,
611
+ complex_mask: bool = True,
612
+ n_fft: int = 2048,
613
+ win_length: Optional[int] = 2048,
614
+ hop_length: int = 512,
615
+ window_fn: str = "hann_window",
616
+ wkwargs: Optional[Dict] = None,
617
+ power: Optional[int] = None,
618
+ center: bool = True,
619
+ normalized: bool = True,
620
+ pad_mode: str = "constant",
621
+ onesided: bool = True,
622
+ n_bands: int = None,
623
+ use_freq_weights: bool = True,
624
+ normalize_input: bool = False,
625
+ mult_add_mask: bool = False,
626
+ ) -> None:
627
+ super().__init__(
628
+ stems=stems,
629
+ band_specs=band_specs,
630
+ fs=fs,
631
+ n_fft=n_fft,
632
+ win_length=win_length,
633
+ hop_length=hop_length,
634
+ window_fn=window_fn,
635
+ wkwargs=wkwargs,
636
+ power=power,
637
+ center=center,
638
+ normalized=normalized,
639
+ pad_mode=pad_mode,
640
+ onesided=onesided,
641
+ n_bands=n_bands,
642
+ )
643
+
644
+ self.bsrnn = MultiSourceMultiMaskBandSplitCoreTransformer(
645
+ stems=stems,
646
+ band_specs=self.band_specs,
647
+ in_channel=in_channel,
648
+ require_no_overlap=require_no_overlap,
649
+ require_no_gap=require_no_gap,
650
+ normalize_channel_independently=normalize_channel_independently,
651
+ treat_channel_as_feature=treat_channel_as_feature,
652
+ n_sqm_modules=n_sqm_modules,
653
+ emb_dim=emb_dim,
654
+ rnn_dim=rnn_dim,
655
+ bidirectional=bidirectional,
656
+ rnn_type=rnn_type,
657
+ mlp_dim=mlp_dim,
658
+ cond_dim=cond_dim,
659
+ hidden_activation=hidden_activation,
660
+ hidden_activation_kwargs=hidden_activation_kwargs,
661
+ complex_mask=complex_mask,
662
+ overlapping_band=self.overlapping_band,
663
+ freq_weights=self.freq_weights,
664
+ n_freq=n_fft // 2 + 1,
665
+ use_freq_weights=use_freq_weights,
666
+ mult_add_mask=mult_add_mask,
667
+ )
668
+
669
+
670
+ class MultiMaskMultiSourceBandSplitConv(MultiMaskMultiSourceBandSplitBase):
671
+ def __init__(
672
+ self,
673
+ in_channel: int,
674
+ stems: List[str],
675
+ band_specs: Union[str, List[Tuple[float, float]]],
676
+ fs: int = 44100,
677
+ require_no_overlap: bool = False,
678
+ require_no_gap: bool = True,
679
+ normalize_channel_independently: bool = False,
680
+ treat_channel_as_feature: bool = True,
681
+ n_sqm_modules: int = 12,
682
+ emb_dim: int = 128,
683
+ rnn_dim: int = 256,
684
+ cond_dim: int = 0,
685
+ bidirectional: bool = True,
686
+ rnn_type: str = "LSTM",
687
+ mlp_dim: int = 512,
688
+ hidden_activation: str = "Tanh",
689
+ hidden_activation_kwargs: Optional[Dict] = None,
690
+ complex_mask: bool = True,
691
+ n_fft: int = 2048,
692
+ win_length: Optional[int] = 2048,
693
+ hop_length: int = 512,
694
+ window_fn: str = "hann_window",
695
+ wkwargs: Optional[Dict] = None,
696
+ power: Optional[int] = None,
697
+ center: bool = True,
698
+ normalized: bool = True,
699
+ pad_mode: str = "constant",
700
+ onesided: bool = True,
701
+ n_bands: int = None,
702
+ use_freq_weights: bool = True,
703
+ normalize_input: bool = False,
704
+ mult_add_mask: bool = False,
705
+ ) -> None:
706
+ super().__init__(
707
+ stems=stems,
708
+ band_specs=band_specs,
709
+ fs=fs,
710
+ n_fft=n_fft,
711
+ win_length=win_length,
712
+ hop_length=hop_length,
713
+ window_fn=window_fn,
714
+ wkwargs=wkwargs,
715
+ power=power,
716
+ center=center,
717
+ normalized=normalized,
718
+ pad_mode=pad_mode,
719
+ onesided=onesided,
720
+ n_bands=n_bands,
721
+ )
722
+
723
+ self.bsrnn = MultiSourceMultiMaskBandSplitCoreConv(
724
+ stems=stems,
725
+ band_specs=self.band_specs,
726
+ in_channel=in_channel,
727
+ require_no_overlap=require_no_overlap,
728
+ require_no_gap=require_no_gap,
729
+ normalize_channel_independently=normalize_channel_independently,
730
+ treat_channel_as_feature=treat_channel_as_feature,
731
+ n_sqm_modules=n_sqm_modules,
732
+ emb_dim=emb_dim,
733
+ rnn_dim=rnn_dim,
734
+ bidirectional=bidirectional,
735
+ rnn_type=rnn_type,
736
+ mlp_dim=mlp_dim,
737
+ cond_dim=cond_dim,
738
+ hidden_activation=hidden_activation,
739
+ hidden_activation_kwargs=hidden_activation_kwargs,
740
+ complex_mask=complex_mask,
741
+ overlapping_band=self.overlapping_band,
742
+ freq_weights=self.freq_weights,
743
+ n_freq=n_fft // 2 + 1,
744
+ use_freq_weights=use_freq_weights,
745
+ mult_add_mask=mult_add_mask,
746
+ )
747
+
748
+
749
+ class PatchingMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase):
750
+ def __init__(
751
+ self,
752
+ in_channel: int,
753
+ stems: List[str],
754
+ band_specs: Union[str, List[Tuple[float, float]]],
755
+ kernel_norm_mlp_version: int = 1,
756
+ mask_kernel_freq: int = 3,
757
+ mask_kernel_time: int = 3,
758
+ conv_kernel_freq: int = 1,
759
+ conv_kernel_time: int = 1,
760
+ fs: int = 44100,
761
+ require_no_overlap: bool = False,
762
+ require_no_gap: bool = True,
763
+ normalize_channel_independently: bool = False,
764
+ treat_channel_as_feature: bool = True,
765
+ n_sqm_modules: int = 12,
766
+ emb_dim: int = 128,
767
+ rnn_dim: int = 256,
768
+ bidirectional: bool = True,
769
+ rnn_type: str = "LSTM",
770
+ mlp_dim: int = 512,
771
+ hidden_activation: str = "Tanh",
772
+ hidden_activation_kwargs: Optional[Dict] = None,
773
+ complex_mask: bool = True,
774
+ n_fft: int = 2048,
775
+ win_length: Optional[int] = 2048,
776
+ hop_length: int = 512,
777
+ window_fn: str = "hann_window",
778
+ wkwargs: Optional[Dict] = None,
779
+ power: Optional[int] = None,
780
+ center: bool = True,
781
+ normalized: bool = True,
782
+ pad_mode: str = "constant",
783
+ onesided: bool = True,
784
+ n_bands: int = None,
785
+ ) -> None:
786
+ super().__init__(
787
+ stems=stems,
788
+ band_specs=band_specs,
789
+ fs=fs,
790
+ n_fft=n_fft,
791
+ win_length=win_length,
792
+ hop_length=hop_length,
793
+ window_fn=window_fn,
794
+ wkwargs=wkwargs,
795
+ power=power,
796
+ center=center,
797
+ normalized=normalized,
798
+ pad_mode=pad_mode,
799
+ onesided=onesided,
800
+ n_bands=n_bands,
801
+ )
802
+
803
+ self.bsrnn = MultiSourceMultiPatchingMaskBandSplitCoreRNN(
804
+ stems=stems,
805
+ band_specs=self.band_specs,
806
+ in_channel=in_channel,
807
+ require_no_overlap=require_no_overlap,
808
+ require_no_gap=require_no_gap,
809
+ normalize_channel_independently=normalize_channel_independently,
810
+ treat_channel_as_feature=treat_channel_as_feature,
811
+ n_sqm_modules=n_sqm_modules,
812
+ emb_dim=emb_dim,
813
+ rnn_dim=rnn_dim,
814
+ bidirectional=bidirectional,
815
+ rnn_type=rnn_type,
816
+ mlp_dim=mlp_dim,
817
+ hidden_activation=hidden_activation,
818
+ hidden_activation_kwargs=hidden_activation_kwargs,
819
+ complex_mask=complex_mask,
820
+ overlapping_band=self.overlapping_band,
821
+ freq_weights=self.freq_weights,
822
+ n_freq=n_fft // 2 + 1,
823
+ mask_kernel_freq=mask_kernel_freq,
824
+ mask_kernel_time=mask_kernel_time,
825
+ conv_kernel_freq=conv_kernel_freq,
826
+ conv_kernel_time=conv_kernel_time,
827
+ kernel_norm_mlp_version=kernel_norm_mlp_version,
828
+ )
mvsepless/models/bandit/core/utils/audio.py CHANGED
@@ -1,324 +1,324 @@
1
- from collections import defaultdict
2
-
3
- from tqdm.auto import tqdm
4
- from typing import Callable, Dict, List, Optional, Tuple
5
-
6
- import numpy as np
7
- import torch
8
- from torch import nn
9
- from torch.nn import functional as F
10
-
11
-
12
- @torch.jit.script
13
- def merge(
14
- combined: torch.Tensor,
15
- original_batch_size: int,
16
- n_channel: int,
17
- n_chunks: int,
18
- chunk_size: int,
19
- ):
20
- combined = torch.reshape(
21
- combined, (original_batch_size, n_chunks, n_channel, chunk_size)
22
- )
23
- combined = torch.permute(combined, (0, 2, 3, 1)).reshape(
24
- original_batch_size * n_channel, chunk_size, n_chunks
25
- )
26
-
27
- return combined
28
-
29
-
30
- @torch.jit.script
31
- def unfold(
32
- padded_audio: torch.Tensor,
33
- original_batch_size: int,
34
- n_channel: int,
35
- chunk_size: int,
36
- hop_size: int,
37
- ) -> torch.Tensor:
38
-
39
- unfolded_input = F.unfold(
40
- padded_audio[:, :, None, :], kernel_size=(1, chunk_size), stride=(1, hop_size)
41
- )
42
-
43
- _, _, n_chunks = unfolded_input.shape
44
- unfolded_input = unfolded_input.view(
45
- original_batch_size, n_channel, chunk_size, n_chunks
46
- )
47
- unfolded_input = torch.permute(unfolded_input, (0, 3, 1, 2)).reshape(
48
- original_batch_size * n_chunks, n_channel, chunk_size
49
- )
50
-
51
- return unfolded_input
52
-
53
-
54
- @torch.jit.script
55
- def merge_chunks_all(
56
- combined: torch.Tensor,
57
- original_batch_size: int,
58
- n_channel: int,
59
- n_samples: int,
60
- n_padded_samples: int,
61
- n_chunks: int,
62
- chunk_size: int,
63
- hop_size: int,
64
- edge_frame_pad_sizes: Tuple[int, int],
65
- standard_window: torch.Tensor,
66
- first_window: torch.Tensor,
67
- last_window: torch.Tensor,
68
- ):
69
- combined = merge(combined, original_batch_size, n_channel, n_chunks, chunk_size)
70
-
71
- combined = combined * standard_window[:, None].to(combined.device)
72
-
73
- combined = F.fold(
74
- combined.to(torch.float32),
75
- output_size=(1, n_padded_samples),
76
- kernel_size=(1, chunk_size),
77
- stride=(1, hop_size),
78
- )
79
-
80
- combined = combined.view(original_batch_size, n_channel, n_padded_samples)
81
-
82
- pad_front, pad_back = edge_frame_pad_sizes
83
- combined = combined[..., pad_front:-pad_back]
84
-
85
- combined = combined[..., :n_samples]
86
-
87
- return combined
88
-
89
-
90
- def merge_chunks_edge(
91
- combined: torch.Tensor,
92
- original_batch_size: int,
93
- n_channel: int,
94
- n_samples: int,
95
- n_padded_samples: int,
96
- n_chunks: int,
97
- chunk_size: int,
98
- hop_size: int,
99
- edge_frame_pad_sizes: Tuple[int, int],
100
- standard_window: torch.Tensor,
101
- first_window: torch.Tensor,
102
- last_window: torch.Tensor,
103
- ):
104
- combined = merge(combined, original_batch_size, n_channel, n_chunks, chunk_size)
105
-
106
- combined[..., 0] = combined[..., 0] * first_window
107
- combined[..., -1] = combined[..., -1] * last_window
108
- combined[..., 1:-1] = combined[..., 1:-1] * standard_window[:, None]
109
-
110
- combined = F.fold(
111
- combined,
112
- output_size=(1, n_padded_samples),
113
- kernel_size=(1, chunk_size),
114
- stride=(1, hop_size),
115
- )
116
-
117
- combined = combined.view(original_batch_size, n_channel, n_padded_samples)
118
-
119
- combined = combined[..., :n_samples]
120
-
121
- return combined
122
-
123
-
124
- class BaseFader(nn.Module):
125
- def __init__(
126
- self,
127
- chunk_size_second: float,
128
- hop_size_second: float,
129
- fs: int,
130
- fade_edge_frames: bool,
131
- batch_size: int,
132
- ) -> None:
133
- super().__init__()
134
-
135
- self.chunk_size = int(chunk_size_second * fs)
136
- self.hop_size = int(hop_size_second * fs)
137
- self.overlap_size = self.chunk_size - self.hop_size
138
- self.fade_edge_frames = fade_edge_frames
139
- self.batch_size = batch_size
140
-
141
- def prepare(self, audio):
142
-
143
- if self.fade_edge_frames:
144
- audio = F.pad(audio, self.edge_frame_pad_sizes, mode="reflect")
145
-
146
- n_samples = audio.shape[-1]
147
- n_chunks = int(np.ceil((n_samples - self.chunk_size) / self.hop_size) + 1)
148
-
149
- padded_size = (n_chunks - 1) * self.hop_size + self.chunk_size
150
- pad_size = padded_size - n_samples
151
-
152
- padded_audio = F.pad(audio, (0, pad_size))
153
-
154
- return padded_audio, n_chunks
155
-
156
- def forward(
157
- self,
158
- audio: torch.Tensor,
159
- model_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]],
160
- ):
161
-
162
- original_dtype = audio.dtype
163
- original_device = audio.device
164
-
165
- audio = audio.to("cpu")
166
-
167
- original_batch_size, n_channel, n_samples = audio.shape
168
- padded_audio, n_chunks = self.prepare(audio)
169
- del audio
170
- n_padded_samples = padded_audio.shape[-1]
171
-
172
- if n_channel > 1:
173
- padded_audio = padded_audio.view(
174
- original_batch_size * n_channel, 1, n_padded_samples
175
- )
176
-
177
- unfolded_input = unfold(
178
- padded_audio, original_batch_size, n_channel, self.chunk_size, self.hop_size
179
- )
180
-
181
- n_total_chunks, n_channel, chunk_size = unfolded_input.shape
182
-
183
- n_batch = np.ceil(n_total_chunks / self.batch_size).astype(int)
184
-
185
- chunks_in = [
186
- unfolded_input[b * self.batch_size : (b + 1) * self.batch_size, ...].clone()
187
- for b in range(n_batch)
188
- ]
189
-
190
- all_chunks_out = defaultdict(
191
- lambda: torch.zeros_like(unfolded_input, device="cpu")
192
- )
193
-
194
- for b, cin in enumerate(chunks_in):
195
- if torch.allclose(cin, torch.tensor(0.0)):
196
- del cin
197
- continue
198
-
199
- chunks_out = model_fn(cin.to(original_device))
200
- del cin
201
- for s, c in chunks_out.items():
202
- all_chunks_out[s][
203
- b * self.batch_size : (b + 1) * self.batch_size, ...
204
- ] = c.cpu()
205
- del chunks_out
206
-
207
- del unfolded_input
208
- del padded_audio
209
-
210
- if self.fade_edge_frames:
211
- fn = merge_chunks_all
212
- else:
213
- fn = merge_chunks_edge
214
- outputs = {}
215
-
216
- torch.cuda.empty_cache()
217
-
218
- for s, c in all_chunks_out.items():
219
- combined: torch.Tensor = fn(
220
- c,
221
- original_batch_size,
222
- n_channel,
223
- n_samples,
224
- n_padded_samples,
225
- n_chunks,
226
- self.chunk_size,
227
- self.hop_size,
228
- self.edge_frame_pad_sizes,
229
- self.standard_window,
230
- self.__dict__.get("first_window", self.standard_window),
231
- self.__dict__.get("last_window", self.standard_window),
232
- )
233
-
234
- outputs[s] = combined.to(dtype=original_dtype, device=original_device)
235
-
236
- return {"audio": outputs}
237
-
238
-
239
- class LinearFader(BaseFader):
240
- def __init__(
241
- self,
242
- chunk_size_second: float,
243
- hop_size_second: float,
244
- fs: int,
245
- fade_edge_frames: bool = False,
246
- batch_size: int = 1,
247
- ) -> None:
248
-
249
- assert hop_size_second >= chunk_size_second / 2
250
-
251
- super().__init__(
252
- chunk_size_second=chunk_size_second,
253
- hop_size_second=hop_size_second,
254
- fs=fs,
255
- fade_edge_frames=fade_edge_frames,
256
- batch_size=batch_size,
257
- )
258
-
259
- in_fade = torch.linspace(0.0, 1.0, self.overlap_size + 1)[:-1]
260
- out_fade = torch.linspace(1.0, 0.0, self.overlap_size + 1)[1:]
261
- center_ones = torch.ones(self.chunk_size - 2 * self.overlap_size)
262
- inout_ones = torch.ones(self.overlap_size)
263
-
264
- self.register_buffer(
265
- "standard_window", torch.concat([in_fade, center_ones, out_fade])
266
- )
267
-
268
- self.fade_edge_frames = fade_edge_frames
269
- self.edge_frame_pad_size = (self.overlap_size, self.overlap_size)
270
-
271
- if not self.fade_edge_frames:
272
- self.first_window = nn.Parameter(
273
- torch.concat([inout_ones, center_ones, out_fade]), requires_grad=False
274
- )
275
- self.last_window = nn.Parameter(
276
- torch.concat([in_fade, center_ones, inout_ones]), requires_grad=False
277
- )
278
-
279
-
280
- class OverlapAddFader(BaseFader):
281
- def __init__(
282
- self,
283
- window_type: str,
284
- chunk_size_second: float,
285
- hop_size_second: float,
286
- fs: int,
287
- batch_size: int = 1,
288
- ) -> None:
289
- assert (chunk_size_second / hop_size_second) % 2 == 0
290
- assert int(chunk_size_second * fs) % 2 == 0
291
-
292
- super().__init__(
293
- chunk_size_second=chunk_size_second,
294
- hop_size_second=hop_size_second,
295
- fs=fs,
296
- fade_edge_frames=True,
297
- batch_size=batch_size,
298
- )
299
-
300
- self.hop_multiplier = self.chunk_size / (2 * self.hop_size)
301
-
302
- self.edge_frame_pad_sizes = (2 * self.overlap_size, 2 * self.overlap_size)
303
-
304
- self.register_buffer(
305
- "standard_window",
306
- torch.windows.__dict__[window_type](
307
- self.chunk_size,
308
- sym=False,
309
- )
310
- / self.hop_multiplier,
311
- )
312
-
313
-
314
- if __name__ == "__main__":
315
- import torchaudio as ta
316
-
317
- fs = 44100
318
- ola = OverlapAddFader("hann", 6.0, 1.0, fs, batch_size=16)
319
- audio_, _ = ta.load(
320
- "$DATA_ROOT/MUSDB18/HQ/canonical/test/BKS - Too " "Much/vocals.wav"
321
- )
322
- audio_ = audio_[None, ...]
323
- out = ola(audio_, lambda x: {"stem": x})["audio"]["stem"]
324
- print(torch.allclose(out, audio_))
 
1
+ from collections import defaultdict
2
+
3
+ from tqdm.auto import tqdm
4
+ from typing import Callable, Dict, List, Optional, Tuple
5
+
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+
11
+
12
+ @torch.jit.script
13
+ def merge(
14
+ combined: torch.Tensor,
15
+ original_batch_size: int,
16
+ n_channel: int,
17
+ n_chunks: int,
18
+ chunk_size: int,
19
+ ):
20
+ combined = torch.reshape(
21
+ combined, (original_batch_size, n_chunks, n_channel, chunk_size)
22
+ )
23
+ combined = torch.permute(combined, (0, 2, 3, 1)).reshape(
24
+ original_batch_size * n_channel, chunk_size, n_chunks
25
+ )
26
+
27
+ return combined
28
+
29
+
30
+ @torch.jit.script
31
+ def unfold(
32
+ padded_audio: torch.Tensor,
33
+ original_batch_size: int,
34
+ n_channel: int,
35
+ chunk_size: int,
36
+ hop_size: int,
37
+ ) -> torch.Tensor:
38
+
39
+ unfolded_input = F.unfold(
40
+ padded_audio[:, :, None, :], kernel_size=(1, chunk_size), stride=(1, hop_size)
41
+ )
42
+
43
+ _, _, n_chunks = unfolded_input.shape
44
+ unfolded_input = unfolded_input.view(
45
+ original_batch_size, n_channel, chunk_size, n_chunks
46
+ )
47
+ unfolded_input = torch.permute(unfolded_input, (0, 3, 1, 2)).reshape(
48
+ original_batch_size * n_chunks, n_channel, chunk_size
49
+ )
50
+
51
+ return unfolded_input
52
+
53
+
54
+ @torch.jit.script
55
+ def merge_chunks_all(
56
+ combined: torch.Tensor,
57
+ original_batch_size: int,
58
+ n_channel: int,
59
+ n_samples: int,
60
+ n_padded_samples: int,
61
+ n_chunks: int,
62
+ chunk_size: int,
63
+ hop_size: int,
64
+ edge_frame_pad_sizes: Tuple[int, int],
65
+ standard_window: torch.Tensor,
66
+ first_window: torch.Tensor,
67
+ last_window: torch.Tensor,
68
+ ):
69
+ combined = merge(combined, original_batch_size, n_channel, n_chunks, chunk_size)
70
+
71
+ combined = combined * standard_window[:, None].to(combined.device)
72
+
73
+ combined = F.fold(
74
+ combined.to(torch.float32),
75
+ output_size=(1, n_padded_samples),
76
+ kernel_size=(1, chunk_size),
77
+ stride=(1, hop_size),
78
+ )
79
+
80
+ combined = combined.view(original_batch_size, n_channel, n_padded_samples)
81
+
82
+ pad_front, pad_back = edge_frame_pad_sizes
83
+ combined = combined[..., pad_front:-pad_back]
84
+
85
+ combined = combined[..., :n_samples]
86
+
87
+ return combined
88
+
89
+
90
+ def merge_chunks_edge(
91
+ combined: torch.Tensor,
92
+ original_batch_size: int,
93
+ n_channel: int,
94
+ n_samples: int,
95
+ n_padded_samples: int,
96
+ n_chunks: int,
97
+ chunk_size: int,
98
+ hop_size: int,
99
+ edge_frame_pad_sizes: Tuple[int, int],
100
+ standard_window: torch.Tensor,
101
+ first_window: torch.Tensor,
102
+ last_window: torch.Tensor,
103
+ ):
104
+ combined = merge(combined, original_batch_size, n_channel, n_chunks, chunk_size)
105
+
106
+ combined[..., 0] = combined[..., 0] * first_window
107
+ combined[..., -1] = combined[..., -1] * last_window
108
+ combined[..., 1:-1] = combined[..., 1:-1] * standard_window[:, None]
109
+
110
+ combined = F.fold(
111
+ combined,
112
+ output_size=(1, n_padded_samples),
113
+ kernel_size=(1, chunk_size),
114
+ stride=(1, hop_size),
115
+ )
116
+
117
+ combined = combined.view(original_batch_size, n_channel, n_padded_samples)
118
+
119
+ combined = combined[..., :n_samples]
120
+
121
+ return combined
122
+
123
+
124
+ class BaseFader(nn.Module):
125
+ def __init__(
126
+ self,
127
+ chunk_size_second: float,
128
+ hop_size_second: float,
129
+ fs: int,
130
+ fade_edge_frames: bool,
131
+ batch_size: int,
132
+ ) -> None:
133
+ super().__init__()
134
+
135
+ self.chunk_size = int(chunk_size_second * fs)
136
+ self.hop_size = int(hop_size_second * fs)
137
+ self.overlap_size = self.chunk_size - self.hop_size
138
+ self.fade_edge_frames = fade_edge_frames
139
+ self.batch_size = batch_size
140
+
141
+ def prepare(self, audio):
142
+
143
+ if self.fade_edge_frames:
144
+ audio = F.pad(audio, self.edge_frame_pad_sizes, mode="reflect")
145
+
146
+ n_samples = audio.shape[-1]
147
+ n_chunks = int(np.ceil((n_samples - self.chunk_size) / self.hop_size) + 1)
148
+
149
+ padded_size = (n_chunks - 1) * self.hop_size + self.chunk_size
150
+ pad_size = padded_size - n_samples
151
+
152
+ padded_audio = F.pad(audio, (0, pad_size))
153
+
154
+ return padded_audio, n_chunks
155
+
156
+ def forward(
157
+ self,
158
+ audio: torch.Tensor,
159
+ model_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]],
160
+ ):
161
+
162
+ original_dtype = audio.dtype
163
+ original_device = audio.device
164
+
165
+ audio = audio.to("cpu")
166
+
167
+ original_batch_size, n_channel, n_samples = audio.shape
168
+ padded_audio, n_chunks = self.prepare(audio)
169
+ del audio
170
+ n_padded_samples = padded_audio.shape[-1]
171
+
172
+ if n_channel > 1:
173
+ padded_audio = padded_audio.view(
174
+ original_batch_size * n_channel, 1, n_padded_samples
175
+ )
176
+
177
+ unfolded_input = unfold(
178
+ padded_audio, original_batch_size, n_channel, self.chunk_size, self.hop_size
179
+ )
180
+
181
+ n_total_chunks, n_channel, chunk_size = unfolded_input.shape
182
+
183
+ n_batch = np.ceil(n_total_chunks / self.batch_size).astype(int)
184
+
185
+ chunks_in = [
186
+ unfolded_input[b * self.batch_size : (b + 1) * self.batch_size, ...].clone()
187
+ for b in range(n_batch)
188
+ ]
189
+
190
+ all_chunks_out = defaultdict(
191
+ lambda: torch.zeros_like(unfolded_input, device="cpu")
192
+ )
193
+
194
+ for b, cin in enumerate(chunks_in):
195
+ if torch.allclose(cin, torch.tensor(0.0)):
196
+ del cin
197
+ continue
198
+
199
+ chunks_out = model_fn(cin.to(original_device))
200
+ del cin
201
+ for s, c in chunks_out.items():
202
+ all_chunks_out[s][
203
+ b * self.batch_size : (b + 1) * self.batch_size, ...
204
+ ] = c.cpu()
205
+ del chunks_out
206
+
207
+ del unfolded_input
208
+ del padded_audio
209
+
210
+ if self.fade_edge_frames:
211
+ fn = merge_chunks_all
212
+ else:
213
+ fn = merge_chunks_edge
214
+ outputs = {}
215
+
216
+ torch.cuda.empty_cache()
217
+
218
+ for s, c in all_chunks_out.items():
219
+ combined: torch.Tensor = fn(
220
+ c,
221
+ original_batch_size,
222
+ n_channel,
223
+ n_samples,
224
+ n_padded_samples,
225
+ n_chunks,
226
+ self.chunk_size,
227
+ self.hop_size,
228
+ self.edge_frame_pad_sizes,
229
+ self.standard_window,
230
+ self.__dict__.get("first_window", self.standard_window),
231
+ self.__dict__.get("last_window", self.standard_window),
232
+ )
233
+
234
+ outputs[s] = combined.to(dtype=original_dtype, device=original_device)
235
+
236
+ return {"audio": outputs}
237
+
238
+
239
+ class LinearFader(BaseFader):
240
+ def __init__(
241
+ self,
242
+ chunk_size_second: float,
243
+ hop_size_second: float,
244
+ fs: int,
245
+ fade_edge_frames: bool = False,
246
+ batch_size: int = 1,
247
+ ) -> None:
248
+
249
+ assert hop_size_second >= chunk_size_second / 2
250
+
251
+ super().__init__(
252
+ chunk_size_second=chunk_size_second,
253
+ hop_size_second=hop_size_second,
254
+ fs=fs,
255
+ fade_edge_frames=fade_edge_frames,
256
+ batch_size=batch_size,
257
+ )
258
+
259
+ in_fade = torch.linspace(0.0, 1.0, self.overlap_size + 1)[:-1]
260
+ out_fade = torch.linspace(1.0, 0.0, self.overlap_size + 1)[1:]
261
+ center_ones = torch.ones(self.chunk_size - 2 * self.overlap_size)
262
+ inout_ones = torch.ones(self.overlap_size)
263
+
264
+ self.register_buffer(
265
+ "standard_window", torch.concat([in_fade, center_ones, out_fade])
266
+ )
267
+
268
+ self.fade_edge_frames = fade_edge_frames
269
+ self.edge_frame_pad_size = (self.overlap_size, self.overlap_size)
270
+
271
+ if not self.fade_edge_frames:
272
+ self.first_window = nn.Parameter(
273
+ torch.concat([inout_ones, center_ones, out_fade]), requires_grad=False
274
+ )
275
+ self.last_window = nn.Parameter(
276
+ torch.concat([in_fade, center_ones, inout_ones]), requires_grad=False
277
+ )
278
+
279
+
280
+ class OverlapAddFader(BaseFader):
281
+ def __init__(
282
+ self,
283
+ window_type: str,
284
+ chunk_size_second: float,
285
+ hop_size_second: float,
286
+ fs: int,
287
+ batch_size: int = 1,
288
+ ) -> None:
289
+ assert (chunk_size_second / hop_size_second) % 2 == 0
290
+ assert int(chunk_size_second * fs) % 2 == 0
291
+
292
+ super().__init__(
293
+ chunk_size_second=chunk_size_second,
294
+ hop_size_second=hop_size_second,
295
+ fs=fs,
296
+ fade_edge_frames=True,
297
+ batch_size=batch_size,
298
+ )
299
+
300
+ self.hop_multiplier = self.chunk_size / (2 * self.hop_size)
301
+
302
+ self.edge_frame_pad_sizes = (2 * self.overlap_size, 2 * self.overlap_size)
303
+
304
+ self.register_buffer(
305
+ "standard_window",
306
+ torch.windows.__dict__[window_type](
307
+ self.chunk_size,
308
+ sym=False,
309
+ )
310
+ / self.hop_multiplier,
311
+ )
312
+
313
+
314
+ if __name__ == "__main__":
315
+ import torchaudio as ta
316
+
317
+ fs = 44100
318
+ ola = OverlapAddFader("hann", 6.0, 1.0, fs, batch_size=16)
319
+ audio_, _ = ta.load(
320
+ "$DATA_ROOT/MUSDB18/HQ/canonical/test/BKS - Too " "Much/vocals.wav"
321
+ )
322
+ audio_ = audio_[None, ...]
323
+ out = ola(audio_, lambda x: {"stem": x})["audio"]["stem"]
324
+ print(torch.allclose(out, audio_))
mvsepless/models/bandit/model_from_config.py CHANGED
@@ -1,26 +1,26 @@
1
- import sys
2
- import os.path
3
- import torch
4
-
5
- import yaml
6
- from ml_collections import ConfigDict
7
-
8
- torch.set_float32_matmul_precision("medium")
9
-
10
-
11
- def get_model(
12
- config_path,
13
- weights_path,
14
- device,
15
- ):
16
- from .core.model import MultiMaskMultiSourceBandSplitRNNSimple
17
-
18
- f = open(config_path)
19
- config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
20
- f.close()
21
-
22
- model = MultiMaskMultiSourceBandSplitRNNSimple(**config.model)
23
- d = torch.load(code_path + "model_bandit_plus_dnr_sdr_11.47.chpt")
24
- model.load_state_dict(d)
25
- model.to(device)
26
- return model, config
 
1
+ import sys
2
+ import os.path
3
+ import torch
4
+
5
+ import yaml
6
+ from ml_collections import ConfigDict
7
+
8
+ torch.set_float32_matmul_precision("medium")
9
+
10
+
11
+ def get_model(
12
+ config_path,
13
+ weights_path,
14
+ device,
15
+ ):
16
+ from .core.model import MultiMaskMultiSourceBandSplitRNNSimple
17
+
18
+ f = open(config_path)
19
+ config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
20
+ f.close()
21
+
22
+ model = MultiMaskMultiSourceBandSplitRNNSimple(**config.model)
23
+ d = torch.load(code_path + "model_bandit_plus_dnr_sdr_11.47.chpt")
24
+ model.load_state_dict(d)
25
+ model.to(device)
26
+ return model, config
mvsepless/models/bandit_v2/bandit.py CHANGED
@@ -1,360 +1,360 @@
1
- from typing import Dict, List, Optional
2
-
3
- import torch
4
- import torchaudio as ta
5
- from torch import nn
6
- import pytorch_lightning as pl
7
-
8
- from .bandsplit import BandSplitModule
9
- from .maskestim import OverlappingMaskEstimationModule
10
- from .tfmodel import SeqBandModellingModule
11
- from .utils import MusicalBandsplitSpecification
12
-
13
-
14
- class BaseEndToEndModule(pl.LightningModule):
15
- def __init__(
16
- self,
17
- ) -> None:
18
- super().__init__()
19
-
20
-
21
- class BaseBandit(BaseEndToEndModule):
22
- def __init__(
23
- self,
24
- in_channels: int,
25
- fs: int,
26
- band_type: str = "musical",
27
- n_bands: int = 64,
28
- require_no_overlap: bool = False,
29
- require_no_gap: bool = True,
30
- normalize_channel_independently: bool = False,
31
- treat_channel_as_feature: bool = True,
32
- n_sqm_modules: int = 12,
33
- emb_dim: int = 128,
34
- rnn_dim: int = 256,
35
- bidirectional: bool = True,
36
- rnn_type: str = "LSTM",
37
- n_fft: int = 2048,
38
- win_length: Optional[int] = 2048,
39
- hop_length: int = 512,
40
- window_fn: str = "hann_window",
41
- wkwargs: Optional[Dict] = None,
42
- power: Optional[int] = None,
43
- center: bool = True,
44
- normalized: bool = True,
45
- pad_mode: str = "constant",
46
- onesided: bool = True,
47
- ):
48
- super().__init__()
49
-
50
- self.in_channels = in_channels
51
-
52
- self.instantitate_spectral(
53
- n_fft=n_fft,
54
- win_length=win_length,
55
- hop_length=hop_length,
56
- window_fn=window_fn,
57
- wkwargs=wkwargs,
58
- power=power,
59
- normalized=normalized,
60
- center=center,
61
- pad_mode=pad_mode,
62
- onesided=onesided,
63
- )
64
-
65
- self.instantiate_bandsplit(
66
- in_channels=in_channels,
67
- band_type=band_type,
68
- n_bands=n_bands,
69
- require_no_overlap=require_no_overlap,
70
- require_no_gap=require_no_gap,
71
- normalize_channel_independently=normalize_channel_independently,
72
- treat_channel_as_feature=treat_channel_as_feature,
73
- emb_dim=emb_dim,
74
- n_fft=n_fft,
75
- fs=fs,
76
- )
77
-
78
- self.instantiate_tf_modelling(
79
- n_sqm_modules=n_sqm_modules,
80
- emb_dim=emb_dim,
81
- rnn_dim=rnn_dim,
82
- bidirectional=bidirectional,
83
- rnn_type=rnn_type,
84
- )
85
-
86
- def instantitate_spectral(
87
- self,
88
- n_fft: int = 2048,
89
- win_length: Optional[int] = 2048,
90
- hop_length: int = 512,
91
- window_fn: str = "hann_window",
92
- wkwargs: Optional[Dict] = None,
93
- power: Optional[int] = None,
94
- normalized: bool = True,
95
- center: bool = True,
96
- pad_mode: str = "constant",
97
- onesided: bool = True,
98
- ):
99
- assert power is None
100
-
101
- window_fn = torch.__dict__[window_fn]
102
-
103
- self.stft = ta.transforms.Spectrogram(
104
- n_fft=n_fft,
105
- win_length=win_length,
106
- hop_length=hop_length,
107
- pad_mode=pad_mode,
108
- pad=0,
109
- window_fn=window_fn,
110
- wkwargs=wkwargs,
111
- power=power,
112
- normalized=normalized,
113
- center=center,
114
- onesided=onesided,
115
- )
116
-
117
- self.istft = ta.transforms.InverseSpectrogram(
118
- n_fft=n_fft,
119
- win_length=win_length,
120
- hop_length=hop_length,
121
- pad_mode=pad_mode,
122
- pad=0,
123
- window_fn=window_fn,
124
- wkwargs=wkwargs,
125
- normalized=normalized,
126
- center=center,
127
- onesided=onesided,
128
- )
129
-
130
- def instantiate_bandsplit(
131
- self,
132
- in_channels: int,
133
- band_type: str = "musical",
134
- n_bands: int = 64,
135
- require_no_overlap: bool = False,
136
- require_no_gap: bool = True,
137
- normalize_channel_independently: bool = False,
138
- treat_channel_as_feature: bool = True,
139
- emb_dim: int = 128,
140
- n_fft: int = 2048,
141
- fs: int = 44100,
142
- ):
143
- assert band_type == "musical"
144
-
145
- self.band_specs = MusicalBandsplitSpecification(
146
- nfft=n_fft, fs=fs, n_bands=n_bands
147
- )
148
-
149
- self.band_split = BandSplitModule(
150
- in_channels=in_channels,
151
- band_specs=self.band_specs.get_band_specs(),
152
- require_no_overlap=require_no_overlap,
153
- require_no_gap=require_no_gap,
154
- normalize_channel_independently=normalize_channel_independently,
155
- treat_channel_as_feature=treat_channel_as_feature,
156
- emb_dim=emb_dim,
157
- )
158
-
159
- def instantiate_tf_modelling(
160
- self,
161
- n_sqm_modules: int = 12,
162
- emb_dim: int = 128,
163
- rnn_dim: int = 256,
164
- bidirectional: bool = True,
165
- rnn_type: str = "LSTM",
166
- ):
167
- try:
168
- self.tf_model = torch.compile(
169
- SeqBandModellingModule(
170
- n_modules=n_sqm_modules,
171
- emb_dim=emb_dim,
172
- rnn_dim=rnn_dim,
173
- bidirectional=bidirectional,
174
- rnn_type=rnn_type,
175
- ),
176
- disable=True,
177
- )
178
- except Exception as e:
179
- self.tf_model = SeqBandModellingModule(
180
- n_modules=n_sqm_modules,
181
- emb_dim=emb_dim,
182
- rnn_dim=rnn_dim,
183
- bidirectional=bidirectional,
184
- rnn_type=rnn_type,
185
- )
186
-
187
- def mask(self, x, m):
188
- return x * m
189
-
190
- def forward(self, batch, mode="train"):
191
- init_shape = batch.shape
192
- if not isinstance(batch, dict):
193
- mono = batch.view(-1, 1, batch.shape[-1])
194
- batch = {"mixture": {"audio": mono}}
195
-
196
- with torch.no_grad():
197
- mixture = batch["mixture"]["audio"]
198
-
199
- x = self.stft(mixture)
200
- batch["mixture"]["spectrogram"] = x
201
-
202
- if "sources" in batch.keys():
203
- for stem in batch["sources"].keys():
204
- s = batch["sources"][stem]["audio"]
205
- s = self.stft(s)
206
- batch["sources"][stem]["spectrogram"] = s
207
-
208
- batch = self.separate(batch)
209
-
210
- if 1:
211
- b = []
212
- for s in self.stems:
213
- r = batch["estimates"][s]["audio"].view(
214
- -1, init_shape[1], init_shape[2]
215
- )
216
- b.append(r)
217
- batch = torch.stack(b, dim=1)
218
- return batch
219
-
220
- def encode(self, batch):
221
- x = batch["mixture"]["spectrogram"]
222
- length = batch["mixture"]["audio"].shape[-1]
223
-
224
- z = self.band_split(x)
225
- q = self.tf_model(z)
226
-
227
- return x, q, length
228
-
229
- def separate(self, batch):
230
- raise NotImplementedError
231
-
232
-
233
- class Bandit(BaseBandit):
234
- def __init__(
235
- self,
236
- in_channels: int,
237
- stems: List[str],
238
- band_type: str = "musical",
239
- n_bands: int = 64,
240
- require_no_overlap: bool = False,
241
- require_no_gap: bool = True,
242
- normalize_channel_independently: bool = False,
243
- treat_channel_as_feature: bool = True,
244
- n_sqm_modules: int = 12,
245
- emb_dim: int = 128,
246
- rnn_dim: int = 256,
247
- bidirectional: bool = True,
248
- rnn_type: str = "LSTM",
249
- mlp_dim: int = 512,
250
- hidden_activation: str = "Tanh",
251
- hidden_activation_kwargs: Dict | None = None,
252
- complex_mask: bool = True,
253
- use_freq_weights: bool = True,
254
- n_fft: int = 2048,
255
- win_length: int | None = 2048,
256
- hop_length: int = 512,
257
- window_fn: str = "hann_window",
258
- wkwargs: Dict | None = None,
259
- power: int | None = None,
260
- center: bool = True,
261
- normalized: bool = True,
262
- pad_mode: str = "constant",
263
- onesided: bool = True,
264
- fs: int = 44100,
265
- stft_precisions="32",
266
- bandsplit_precisions="bf16",
267
- tf_model_precisions="bf16",
268
- mask_estim_precisions="bf16",
269
- ):
270
- super().__init__(
271
- in_channels=in_channels,
272
- band_type=band_type,
273
- n_bands=n_bands,
274
- require_no_overlap=require_no_overlap,
275
- require_no_gap=require_no_gap,
276
- normalize_channel_independently=normalize_channel_independently,
277
- treat_channel_as_feature=treat_channel_as_feature,
278
- n_sqm_modules=n_sqm_modules,
279
- emb_dim=emb_dim,
280
- rnn_dim=rnn_dim,
281
- bidirectional=bidirectional,
282
- rnn_type=rnn_type,
283
- n_fft=n_fft,
284
- win_length=win_length,
285
- hop_length=hop_length,
286
- window_fn=window_fn,
287
- wkwargs=wkwargs,
288
- power=power,
289
- center=center,
290
- normalized=normalized,
291
- pad_mode=pad_mode,
292
- onesided=onesided,
293
- fs=fs,
294
- )
295
-
296
- self.stems = stems
297
-
298
- self.instantiate_mask_estim(
299
- in_channels=in_channels,
300
- stems=stems,
301
- emb_dim=emb_dim,
302
- mlp_dim=mlp_dim,
303
- hidden_activation=hidden_activation,
304
- hidden_activation_kwargs=hidden_activation_kwargs,
305
- complex_mask=complex_mask,
306
- n_freq=n_fft // 2 + 1,
307
- use_freq_weights=use_freq_weights,
308
- )
309
-
310
- def instantiate_mask_estim(
311
- self,
312
- in_channels: int,
313
- stems: List[str],
314
- emb_dim: int,
315
- mlp_dim: int,
316
- hidden_activation: str,
317
- hidden_activation_kwargs: Optional[Dict] = None,
318
- complex_mask: bool = True,
319
- n_freq: Optional[int] = None,
320
- use_freq_weights: bool = False,
321
- ):
322
- if hidden_activation_kwargs is None:
323
- hidden_activation_kwargs = {}
324
-
325
- assert n_freq is not None
326
-
327
- self.mask_estim = nn.ModuleDict(
328
- {
329
- stem: OverlappingMaskEstimationModule(
330
- band_specs=self.band_specs.get_band_specs(),
331
- freq_weights=self.band_specs.get_freq_weights(),
332
- n_freq=n_freq,
333
- emb_dim=emb_dim,
334
- mlp_dim=mlp_dim,
335
- in_channels=in_channels,
336
- hidden_activation=hidden_activation,
337
- hidden_activation_kwargs=hidden_activation_kwargs,
338
- complex_mask=complex_mask,
339
- use_freq_weights=use_freq_weights,
340
- )
341
- for stem in stems
342
- }
343
- )
344
-
345
- def separate(self, batch):
346
- batch["estimates"] = {}
347
-
348
- x, q, length = self.encode(batch)
349
-
350
- for stem, mem in self.mask_estim.items():
351
- m = mem(q)
352
-
353
- s = self.mask(x, m.to(x.dtype))
354
- s = torch.reshape(s, x.shape)
355
- batch["estimates"][stem] = {
356
- "audio": self.istft(s, length),
357
- "spectrogram": s,
358
- }
359
-
360
- return batch
 
1
+ from typing import Dict, List, Optional
2
+
3
+ import torch
4
+ import torchaudio as ta
5
+ from torch import nn
6
+ import pytorch_lightning as pl
7
+
8
+ from .bandsplit import BandSplitModule
9
+ from .maskestim import OverlappingMaskEstimationModule
10
+ from .tfmodel import SeqBandModellingModule
11
+ from .utils import MusicalBandsplitSpecification
12
+
13
+
14
+ class BaseEndToEndModule(pl.LightningModule):
15
+ def __init__(
16
+ self,
17
+ ) -> None:
18
+ super().__init__()
19
+
20
+
21
+ class BaseBandit(BaseEndToEndModule):
22
+ def __init__(
23
+ self,
24
+ in_channels: int,
25
+ fs: int,
26
+ band_type: str = "musical",
27
+ n_bands: int = 64,
28
+ require_no_overlap: bool = False,
29
+ require_no_gap: bool = True,
30
+ normalize_channel_independently: bool = False,
31
+ treat_channel_as_feature: bool = True,
32
+ n_sqm_modules: int = 12,
33
+ emb_dim: int = 128,
34
+ rnn_dim: int = 256,
35
+ bidirectional: bool = True,
36
+ rnn_type: str = "LSTM",
37
+ n_fft: int = 2048,
38
+ win_length: Optional[int] = 2048,
39
+ hop_length: int = 512,
40
+ window_fn: str = "hann_window",
41
+ wkwargs: Optional[Dict] = None,
42
+ power: Optional[int] = None,
43
+ center: bool = True,
44
+ normalized: bool = True,
45
+ pad_mode: str = "constant",
46
+ onesided: bool = True,
47
+ ):
48
+ super().__init__()
49
+
50
+ self.in_channels = in_channels
51
+
52
+ self.instantitate_spectral(
53
+ n_fft=n_fft,
54
+ win_length=win_length,
55
+ hop_length=hop_length,
56
+ window_fn=window_fn,
57
+ wkwargs=wkwargs,
58
+ power=power,
59
+ normalized=normalized,
60
+ center=center,
61
+ pad_mode=pad_mode,
62
+ onesided=onesided,
63
+ )
64
+
65
+ self.instantiate_bandsplit(
66
+ in_channels=in_channels,
67
+ band_type=band_type,
68
+ n_bands=n_bands,
69
+ require_no_overlap=require_no_overlap,
70
+ require_no_gap=require_no_gap,
71
+ normalize_channel_independently=normalize_channel_independently,
72
+ treat_channel_as_feature=treat_channel_as_feature,
73
+ emb_dim=emb_dim,
74
+ n_fft=n_fft,
75
+ fs=fs,
76
+ )
77
+
78
+ self.instantiate_tf_modelling(
79
+ n_sqm_modules=n_sqm_modules,
80
+ emb_dim=emb_dim,
81
+ rnn_dim=rnn_dim,
82
+ bidirectional=bidirectional,
83
+ rnn_type=rnn_type,
84
+ )
85
+
86
+ def instantitate_spectral(
87
+ self,
88
+ n_fft: int = 2048,
89
+ win_length: Optional[int] = 2048,
90
+ hop_length: int = 512,
91
+ window_fn: str = "hann_window",
92
+ wkwargs: Optional[Dict] = None,
93
+ power: Optional[int] = None,
94
+ normalized: bool = True,
95
+ center: bool = True,
96
+ pad_mode: str = "constant",
97
+ onesided: bool = True,
98
+ ):
99
+ assert power is None
100
+
101
+ window_fn = torch.__dict__[window_fn]
102
+
103
+ self.stft = ta.transforms.Spectrogram(
104
+ n_fft=n_fft,
105
+ win_length=win_length,
106
+ hop_length=hop_length,
107
+ pad_mode=pad_mode,
108
+ pad=0,
109
+ window_fn=window_fn,
110
+ wkwargs=wkwargs,
111
+ power=power,
112
+ normalized=normalized,
113
+ center=center,
114
+ onesided=onesided,
115
+ )
116
+
117
+ self.istft = ta.transforms.InverseSpectrogram(
118
+ n_fft=n_fft,
119
+ win_length=win_length,
120
+ hop_length=hop_length,
121
+ pad_mode=pad_mode,
122
+ pad=0,
123
+ window_fn=window_fn,
124
+ wkwargs=wkwargs,
125
+ normalized=normalized,
126
+ center=center,
127
+ onesided=onesided,
128
+ )
129
+
130
+ def instantiate_bandsplit(
131
+ self,
132
+ in_channels: int,
133
+ band_type: str = "musical",
134
+ n_bands: int = 64,
135
+ require_no_overlap: bool = False,
136
+ require_no_gap: bool = True,
137
+ normalize_channel_independently: bool = False,
138
+ treat_channel_as_feature: bool = True,
139
+ emb_dim: int = 128,
140
+ n_fft: int = 2048,
141
+ fs: int = 44100,
142
+ ):
143
+ assert band_type == "musical"
144
+
145
+ self.band_specs = MusicalBandsplitSpecification(
146
+ nfft=n_fft, fs=fs, n_bands=n_bands
147
+ )
148
+
149
+ self.band_split = BandSplitModule(
150
+ in_channels=in_channels,
151
+ band_specs=self.band_specs.get_band_specs(),
152
+ require_no_overlap=require_no_overlap,
153
+ require_no_gap=require_no_gap,
154
+ normalize_channel_independently=normalize_channel_independently,
155
+ treat_channel_as_feature=treat_channel_as_feature,
156
+ emb_dim=emb_dim,
157
+ )
158
+
159
+ def instantiate_tf_modelling(
160
+ self,
161
+ n_sqm_modules: int = 12,
162
+ emb_dim: int = 128,
163
+ rnn_dim: int = 256,
164
+ bidirectional: bool = True,
165
+ rnn_type: str = "LSTM",
166
+ ):
167
+ try:
168
+ self.tf_model = torch.compile(
169
+ SeqBandModellingModule(
170
+ n_modules=n_sqm_modules,
171
+ emb_dim=emb_dim,
172
+ rnn_dim=rnn_dim,
173
+ bidirectional=bidirectional,
174
+ rnn_type=rnn_type,
175
+ ),
176
+ disable=True,
177
+ )
178
+ except Exception as e:
179
+ self.tf_model = SeqBandModellingModule(
180
+ n_modules=n_sqm_modules,
181
+ emb_dim=emb_dim,
182
+ rnn_dim=rnn_dim,
183
+ bidirectional=bidirectional,
184
+ rnn_type=rnn_type,
185
+ )
186
+
187
+ def mask(self, x, m):
188
+ return x * m
189
+
190
+ def forward(self, batch, mode="train"):
191
+ init_shape = batch.shape
192
+ if not isinstance(batch, dict):
193
+ mono = batch.view(-1, 1, batch.shape[-1])
194
+ batch = {"mixture": {"audio": mono}}
195
+
196
+ with torch.no_grad():
197
+ mixture = batch["mixture"]["audio"]
198
+
199
+ x = self.stft(mixture)
200
+ batch["mixture"]["spectrogram"] = x
201
+
202
+ if "sources" in batch.keys():
203
+ for stem in batch["sources"].keys():
204
+ s = batch["sources"][stem]["audio"]
205
+ s = self.stft(s)
206
+ batch["sources"][stem]["spectrogram"] = s
207
+
208
+ batch = self.separate(batch)
209
+
210
+ if 1:
211
+ b = []
212
+ for s in self.stems:
213
+ r = batch["estimates"][s]["audio"].view(
214
+ -1, init_shape[1], init_shape[2]
215
+ )
216
+ b.append(r)
217
+ batch = torch.stack(b, dim=1)
218
+ return batch
219
+
220
+ def encode(self, batch):
221
+ x = batch["mixture"]["spectrogram"]
222
+ length = batch["mixture"]["audio"].shape[-1]
223
+
224
+ z = self.band_split(x)
225
+ q = self.tf_model(z)
226
+
227
+ return x, q, length
228
+
229
+ def separate(self, batch):
230
+ raise NotImplementedError
231
+
232
+
233
+ class Bandit(BaseBandit):
234
+ def __init__(
235
+ self,
236
+ in_channels: int,
237
+ stems: List[str],
238
+ band_type: str = "musical",
239
+ n_bands: int = 64,
240
+ require_no_overlap: bool = False,
241
+ require_no_gap: bool = True,
242
+ normalize_channel_independently: bool = False,
243
+ treat_channel_as_feature: bool = True,
244
+ n_sqm_modules: int = 12,
245
+ emb_dim: int = 128,
246
+ rnn_dim: int = 256,
247
+ bidirectional: bool = True,
248
+ rnn_type: str = "LSTM",
249
+ mlp_dim: int = 512,
250
+ hidden_activation: str = "Tanh",
251
+ hidden_activation_kwargs: Dict | None = None,
252
+ complex_mask: bool = True,
253
+ use_freq_weights: bool = True,
254
+ n_fft: int = 2048,
255
+ win_length: int | None = 2048,
256
+ hop_length: int = 512,
257
+ window_fn: str = "hann_window",
258
+ wkwargs: Dict | None = None,
259
+ power: int | None = None,
260
+ center: bool = True,
261
+ normalized: bool = True,
262
+ pad_mode: str = "constant",
263
+ onesided: bool = True,
264
+ fs: int = 44100,
265
+ stft_precisions="32",
266
+ bandsplit_precisions="bf16",
267
+ tf_model_precisions="bf16",
268
+ mask_estim_precisions="bf16",
269
+ ):
270
+ super().__init__(
271
+ in_channels=in_channels,
272
+ band_type=band_type,
273
+ n_bands=n_bands,
274
+ require_no_overlap=require_no_overlap,
275
+ require_no_gap=require_no_gap,
276
+ normalize_channel_independently=normalize_channel_independently,
277
+ treat_channel_as_feature=treat_channel_as_feature,
278
+ n_sqm_modules=n_sqm_modules,
279
+ emb_dim=emb_dim,
280
+ rnn_dim=rnn_dim,
281
+ bidirectional=bidirectional,
282
+ rnn_type=rnn_type,
283
+ n_fft=n_fft,
284
+ win_length=win_length,
285
+ hop_length=hop_length,
286
+ window_fn=window_fn,
287
+ wkwargs=wkwargs,
288
+ power=power,
289
+ center=center,
290
+ normalized=normalized,
291
+ pad_mode=pad_mode,
292
+ onesided=onesided,
293
+ fs=fs,
294
+ )
295
+
296
+ self.stems = stems
297
+
298
+ self.instantiate_mask_estim(
299
+ in_channels=in_channels,
300
+ stems=stems,
301
+ emb_dim=emb_dim,
302
+ mlp_dim=mlp_dim,
303
+ hidden_activation=hidden_activation,
304
+ hidden_activation_kwargs=hidden_activation_kwargs,
305
+ complex_mask=complex_mask,
306
+ n_freq=n_fft // 2 + 1,
307
+ use_freq_weights=use_freq_weights,
308
+ )
309
+
310
+ def instantiate_mask_estim(
311
+ self,
312
+ in_channels: int,
313
+ stems: List[str],
314
+ emb_dim: int,
315
+ mlp_dim: int,
316
+ hidden_activation: str,
317
+ hidden_activation_kwargs: Optional[Dict] = None,
318
+ complex_mask: bool = True,
319
+ n_freq: Optional[int] = None,
320
+ use_freq_weights: bool = False,
321
+ ):
322
+ if hidden_activation_kwargs is None:
323
+ hidden_activation_kwargs = {}
324
+
325
+ assert n_freq is not None
326
+
327
+ self.mask_estim = nn.ModuleDict(
328
+ {
329
+ stem: OverlappingMaskEstimationModule(
330
+ band_specs=self.band_specs.get_band_specs(),
331
+ freq_weights=self.band_specs.get_freq_weights(),
332
+ n_freq=n_freq,
333
+ emb_dim=emb_dim,
334
+ mlp_dim=mlp_dim,
335
+ in_channels=in_channels,
336
+ hidden_activation=hidden_activation,
337
+ hidden_activation_kwargs=hidden_activation_kwargs,
338
+ complex_mask=complex_mask,
339
+ use_freq_weights=use_freq_weights,
340
+ )
341
+ for stem in stems
342
+ }
343
+ )
344
+
345
+ def separate(self, batch):
346
+ batch["estimates"] = {}
347
+
348
+ x, q, length = self.encode(batch)
349
+
350
+ for stem, mem in self.mask_estim.items():
351
+ m = mem(q)
352
+
353
+ s = self.mask(x, m.to(x.dtype))
354
+ s = torch.reshape(s, x.shape)
355
+ batch["estimates"][stem] = {
356
+ "audio": self.istft(s, length),
357
+ "spectrogram": s,
358
+ }
359
+
360
+ return batch
mvsepless/models/bandit_v2/bandsplit.py CHANGED
@@ -1,127 +1,127 @@
1
- from typing import List, Tuple
2
-
3
- import torch
4
- from torch import nn
5
- from torch.utils.checkpoint import checkpoint_sequential
6
-
7
- from .utils import (
8
- band_widths_from_specs,
9
- check_no_gap,
10
- check_no_overlap,
11
- check_nonzero_bandwidth,
12
- )
13
-
14
-
15
- class NormFC(nn.Module):
16
- def __init__(
17
- self,
18
- emb_dim: int,
19
- bandwidth: int,
20
- in_channels: int,
21
- normalize_channel_independently: bool = False,
22
- treat_channel_as_feature: bool = True,
23
- ) -> None:
24
- super().__init__()
25
-
26
- if not treat_channel_as_feature:
27
- raise NotImplementedError
28
-
29
- self.treat_channel_as_feature = treat_channel_as_feature
30
-
31
- if normalize_channel_independently:
32
- raise NotImplementedError
33
-
34
- reim = 2
35
-
36
- norm = nn.LayerNorm(in_channels * bandwidth * reim)
37
-
38
- fc_in = bandwidth * reim
39
-
40
- if treat_channel_as_feature:
41
- fc_in *= in_channels
42
- else:
43
- assert emb_dim % in_channels == 0
44
- emb_dim = emb_dim // in_channels
45
-
46
- fc = nn.Linear(fc_in, emb_dim)
47
-
48
- self.combined = nn.Sequential(norm, fc)
49
-
50
- def forward(self, xb):
51
- return checkpoint_sequential(self.combined, 1, xb, use_reentrant=False)
52
-
53
-
54
- class BandSplitModule(nn.Module):
55
- def __init__(
56
- self,
57
- band_specs: List[Tuple[float, float]],
58
- emb_dim: int,
59
- in_channels: int,
60
- require_no_overlap: bool = False,
61
- require_no_gap: bool = True,
62
- normalize_channel_independently: bool = False,
63
- treat_channel_as_feature: bool = True,
64
- ) -> None:
65
- super().__init__()
66
-
67
- check_nonzero_bandwidth(band_specs)
68
-
69
- if require_no_gap:
70
- check_no_gap(band_specs)
71
-
72
- if require_no_overlap:
73
- check_no_overlap(band_specs)
74
-
75
- self.band_specs = band_specs
76
- self.band_widths = band_widths_from_specs(band_specs)
77
- self.n_bands = len(band_specs)
78
- self.emb_dim = emb_dim
79
-
80
- try:
81
- self.norm_fc_modules = nn.ModuleList(
82
- [ # type: ignore
83
- torch.compile(
84
- NormFC(
85
- emb_dim=emb_dim,
86
- bandwidth=bw,
87
- in_channels=in_channels,
88
- normalize_channel_independently=normalize_channel_independently,
89
- treat_channel_as_feature=treat_channel_as_feature,
90
- ),
91
- disable=True,
92
- )
93
- for bw in self.band_widths
94
- ]
95
- )
96
- except Exception as e:
97
- self.norm_fc_modules = nn.ModuleList(
98
- [ # type: ignore
99
- NormFC(
100
- emb_dim=emb_dim,
101
- bandwidth=bw,
102
- in_channels=in_channels,
103
- normalize_channel_independently=normalize_channel_independently,
104
- treat_channel_as_feature=treat_channel_as_feature,
105
- )
106
- for bw in self.band_widths
107
- ]
108
- )
109
-
110
- def forward(self, x: torch.Tensor):
111
-
112
- batch, in_chan, band_width, n_time = x.shape
113
-
114
- z = torch.zeros(
115
- size=(batch, self.n_bands, n_time, self.emb_dim), device=x.device
116
- )
117
-
118
- x = torch.permute(x, (0, 3, 1, 2)).contiguous()
119
-
120
- for i, nfm in enumerate(self.norm_fc_modules):
121
- fstart, fend = self.band_specs[i]
122
- xb = x[:, :, :, fstart:fend]
123
- xb = torch.view_as_real(xb)
124
- xb = torch.reshape(xb, (batch, n_time, -1))
125
- z[:, i, :, :] = nfm(xb)
126
-
127
- return z
 
1
+ from typing import List, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.utils.checkpoint import checkpoint_sequential
6
+
7
+ from .utils import (
8
+ band_widths_from_specs,
9
+ check_no_gap,
10
+ check_no_overlap,
11
+ check_nonzero_bandwidth,
12
+ )
13
+
14
+
15
+ class NormFC(nn.Module):
16
+ def __init__(
17
+ self,
18
+ emb_dim: int,
19
+ bandwidth: int,
20
+ in_channels: int,
21
+ normalize_channel_independently: bool = False,
22
+ treat_channel_as_feature: bool = True,
23
+ ) -> None:
24
+ super().__init__()
25
+
26
+ if not treat_channel_as_feature:
27
+ raise NotImplementedError
28
+
29
+ self.treat_channel_as_feature = treat_channel_as_feature
30
+
31
+ if normalize_channel_independently:
32
+ raise NotImplementedError
33
+
34
+ reim = 2
35
+
36
+ norm = nn.LayerNorm(in_channels * bandwidth * reim)
37
+
38
+ fc_in = bandwidth * reim
39
+
40
+ if treat_channel_as_feature:
41
+ fc_in *= in_channels
42
+ else:
43
+ assert emb_dim % in_channels == 0
44
+ emb_dim = emb_dim // in_channels
45
+
46
+ fc = nn.Linear(fc_in, emb_dim)
47
+
48
+ self.combined = nn.Sequential(norm, fc)
49
+
50
+ def forward(self, xb):
51
+ return checkpoint_sequential(self.combined, 1, xb, use_reentrant=False)
52
+
53
+
54
+ class BandSplitModule(nn.Module):
55
+ def __init__(
56
+ self,
57
+ band_specs: List[Tuple[float, float]],
58
+ emb_dim: int,
59
+ in_channels: int,
60
+ require_no_overlap: bool = False,
61
+ require_no_gap: bool = True,
62
+ normalize_channel_independently: bool = False,
63
+ treat_channel_as_feature: bool = True,
64
+ ) -> None:
65
+ super().__init__()
66
+
67
+ check_nonzero_bandwidth(band_specs)
68
+
69
+ if require_no_gap:
70
+ check_no_gap(band_specs)
71
+
72
+ if require_no_overlap:
73
+ check_no_overlap(band_specs)
74
+
75
+ self.band_specs = band_specs
76
+ self.band_widths = band_widths_from_specs(band_specs)
77
+ self.n_bands = len(band_specs)
78
+ self.emb_dim = emb_dim
79
+
80
+ try:
81
+ self.norm_fc_modules = nn.ModuleList(
82
+ [ # type: ignore
83
+ torch.compile(
84
+ NormFC(
85
+ emb_dim=emb_dim,
86
+ bandwidth=bw,
87
+ in_channels=in_channels,
88
+ normalize_channel_independently=normalize_channel_independently,
89
+ treat_channel_as_feature=treat_channel_as_feature,
90
+ ),
91
+ disable=True,
92
+ )
93
+ for bw in self.band_widths
94
+ ]
95
+ )
96
+ except Exception as e:
97
+ self.norm_fc_modules = nn.ModuleList(
98
+ [ # type: ignore
99
+ NormFC(
100
+ emb_dim=emb_dim,
101
+ bandwidth=bw,
102
+ in_channels=in_channels,
103
+ normalize_channel_independently=normalize_channel_independently,
104
+ treat_channel_as_feature=treat_channel_as_feature,
105
+ )
106
+ for bw in self.band_widths
107
+ ]
108
+ )
109
+
110
+ def forward(self, x: torch.Tensor):
111
+
112
+ batch, in_chan, band_width, n_time = x.shape
113
+
114
+ z = torch.zeros(
115
+ size=(batch, self.n_bands, n_time, self.emb_dim), device=x.device
116
+ )
117
+
118
+ x = torch.permute(x, (0, 3, 1, 2)).contiguous()
119
+
120
+ for i, nfm in enumerate(self.norm_fc_modules):
121
+ fstart, fend = self.band_specs[i]
122
+ xb = x[:, :, :, fstart:fend]
123
+ xb = torch.view_as_real(xb)
124
+ xb = torch.reshape(xb, (batch, n_time, -1))
125
+ z[:, i, :, :] = nfm(xb)
126
+
127
+ return z
mvsepless/models/bandit_v2/film.py CHANGED
@@ -1,23 +1,23 @@
1
- from torch import nn
2
- import torch
3
-
4
-
5
- class FiLM(nn.Module):
6
- def __init__(self):
7
- super().__init__()
8
-
9
- def forward(self, x, gamma, beta):
10
- return gamma * x + beta
11
-
12
-
13
- class BTFBroadcastedFiLM(nn.Module):
14
- def __init__(self):
15
- super().__init__()
16
- self.film = FiLM()
17
-
18
- def forward(self, x, gamma, beta):
19
-
20
- gamma = gamma[None, None, None, :]
21
- beta = beta[None, None, None, :]
22
-
23
- return self.film(x, gamma, beta)
 
1
+ from torch import nn
2
+ import torch
3
+
4
+
5
+ class FiLM(nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+
9
+ def forward(self, x, gamma, beta):
10
+ return gamma * x + beta
11
+
12
+
13
+ class BTFBroadcastedFiLM(nn.Module):
14
+ def __init__(self):
15
+ super().__init__()
16
+ self.film = FiLM()
17
+
18
+ def forward(self, x, gamma, beta):
19
+
20
+ gamma = gamma[None, None, None, :]
21
+ beta = beta[None, None, None, :]
22
+
23
+ return self.film(x, gamma, beta)
mvsepless/models/bandit_v2/maskestim.py CHANGED
@@ -1,269 +1,269 @@
1
- from typing import Dict, List, Optional, Tuple, Type
2
-
3
- import torch
4
- from torch import nn
5
- from torch.nn.modules import activation
6
- from torch.utils.checkpoint import checkpoint_sequential
7
-
8
- from .utils import (
9
- band_widths_from_specs,
10
- check_no_gap,
11
- check_no_overlap,
12
- check_nonzero_bandwidth,
13
- )
14
-
15
-
16
- class BaseNormMLP(nn.Module):
17
- def __init__(
18
- self,
19
- emb_dim: int,
20
- mlp_dim: int,
21
- bandwidth: int,
22
- in_channels: Optional[int],
23
- hidden_activation: str = "Tanh",
24
- hidden_activation_kwargs=None,
25
- complex_mask: bool = True,
26
- ):
27
- super().__init__()
28
- if hidden_activation_kwargs is None:
29
- hidden_activation_kwargs = {}
30
- self.hidden_activation_kwargs = hidden_activation_kwargs
31
- self.norm = nn.LayerNorm(emb_dim)
32
- self.hidden = nn.Sequential(
33
- nn.Linear(in_features=emb_dim, out_features=mlp_dim),
34
- activation.__dict__[hidden_activation](**self.hidden_activation_kwargs),
35
- )
36
-
37
- self.bandwidth = bandwidth
38
- self.in_channels = in_channels
39
-
40
- self.complex_mask = complex_mask
41
- self.reim = 2 if complex_mask else 1
42
- self.glu_mult = 2
43
-
44
-
45
- class NormMLP(BaseNormMLP):
46
- def __init__(
47
- self,
48
- emb_dim: int,
49
- mlp_dim: int,
50
- bandwidth: int,
51
- in_channels: Optional[int],
52
- hidden_activation: str = "Tanh",
53
- hidden_activation_kwargs=None,
54
- complex_mask: bool = True,
55
- ) -> None:
56
- super().__init__(
57
- emb_dim=emb_dim,
58
- mlp_dim=mlp_dim,
59
- bandwidth=bandwidth,
60
- in_channels=in_channels,
61
- hidden_activation=hidden_activation,
62
- hidden_activation_kwargs=hidden_activation_kwargs,
63
- complex_mask=complex_mask,
64
- )
65
-
66
- self.output = nn.Sequential(
67
- nn.Linear(
68
- in_features=mlp_dim,
69
- out_features=bandwidth * in_channels * self.reim * 2,
70
- ),
71
- nn.GLU(dim=-1),
72
- )
73
-
74
- try:
75
- self.combined = torch.compile(
76
- nn.Sequential(self.norm, self.hidden, self.output), disable=True
77
- )
78
- except Exception as e:
79
- self.combined = nn.Sequential(self.norm, self.hidden, self.output)
80
-
81
- def reshape_output(self, mb):
82
- batch, n_time, _ = mb.shape
83
- if self.complex_mask:
84
- mb = mb.reshape(
85
- batch, n_time, self.in_channels, self.bandwidth, self.reim
86
- ).contiguous()
87
- mb = torch.view_as_complex(mb)
88
- else:
89
- mb = mb.reshape(batch, n_time, self.in_channels, self.bandwidth)
90
-
91
- mb = torch.permute(mb, (0, 2, 3, 1))
92
-
93
- return mb
94
-
95
- def forward(self, qb):
96
-
97
- mb = checkpoint_sequential(self.combined, 2, qb, use_reentrant=False)
98
- mb = self.reshape_output(mb)
99
-
100
- return mb
101
-
102
-
103
- class MaskEstimationModuleSuperBase(nn.Module):
104
- pass
105
-
106
-
107
- class MaskEstimationModuleBase(MaskEstimationModuleSuperBase):
108
- def __init__(
109
- self,
110
- band_specs: List[Tuple[float, float]],
111
- emb_dim: int,
112
- mlp_dim: int,
113
- in_channels: Optional[int],
114
- hidden_activation: str = "Tanh",
115
- hidden_activation_kwargs: Dict = None,
116
- complex_mask: bool = True,
117
- norm_mlp_cls: Type[nn.Module] = NormMLP,
118
- norm_mlp_kwargs: Dict = None,
119
- ) -> None:
120
- super().__init__()
121
-
122
- self.band_widths = band_widths_from_specs(band_specs)
123
- self.n_bands = len(band_specs)
124
-
125
- if hidden_activation_kwargs is None:
126
- hidden_activation_kwargs = {}
127
-
128
- if norm_mlp_kwargs is None:
129
- norm_mlp_kwargs = {}
130
-
131
- self.norm_mlp = nn.ModuleList(
132
- [
133
- norm_mlp_cls(
134
- bandwidth=self.band_widths[b],
135
- emb_dim=emb_dim,
136
- mlp_dim=mlp_dim,
137
- in_channels=in_channels,
138
- hidden_activation=hidden_activation,
139
- hidden_activation_kwargs=hidden_activation_kwargs,
140
- complex_mask=complex_mask,
141
- **norm_mlp_kwargs,
142
- )
143
- for b in range(self.n_bands)
144
- ]
145
- )
146
-
147
- def compute_masks(self, q):
148
- batch, n_bands, n_time, emb_dim = q.shape
149
-
150
- masks = []
151
-
152
- for b, nmlp in enumerate(self.norm_mlp):
153
- qb = q[:, b, :, :]
154
- mb = nmlp(qb)
155
- masks.append(mb)
156
-
157
- return masks
158
-
159
- def compute_mask(self, q, b):
160
- batch, n_bands, n_time, emb_dim = q.shape
161
- qb = q[:, b, :, :]
162
- mb = self.norm_mlp[b](qb)
163
- return mb
164
-
165
-
166
- class OverlappingMaskEstimationModule(MaskEstimationModuleBase):
167
- def __init__(
168
- self,
169
- in_channels: int,
170
- band_specs: List[Tuple[float, float]],
171
- freq_weights: List[torch.Tensor],
172
- n_freq: int,
173
- emb_dim: int,
174
- mlp_dim: int,
175
- cond_dim: int = 0,
176
- hidden_activation: str = "Tanh",
177
- hidden_activation_kwargs: Dict = None,
178
- complex_mask: bool = True,
179
- norm_mlp_cls: Type[nn.Module] = NormMLP,
180
- norm_mlp_kwargs: Dict = None,
181
- use_freq_weights: bool = False,
182
- ) -> None:
183
- check_nonzero_bandwidth(band_specs)
184
- check_no_gap(band_specs)
185
-
186
- if cond_dim > 0:
187
- raise NotImplementedError
188
-
189
- super().__init__(
190
- band_specs=band_specs,
191
- emb_dim=emb_dim + cond_dim,
192
- mlp_dim=mlp_dim,
193
- in_channels=in_channels,
194
- hidden_activation=hidden_activation,
195
- hidden_activation_kwargs=hidden_activation_kwargs,
196
- complex_mask=complex_mask,
197
- norm_mlp_cls=norm_mlp_cls,
198
- norm_mlp_kwargs=norm_mlp_kwargs,
199
- )
200
-
201
- self.n_freq = n_freq
202
- self.band_specs = band_specs
203
- self.in_channels = in_channels
204
-
205
- if freq_weights is not None and use_freq_weights:
206
- for i, fw in enumerate(freq_weights):
207
- self.register_buffer(f"freq_weights/{i}", fw)
208
-
209
- self.use_freq_weights = use_freq_weights
210
- else:
211
- self.use_freq_weights = False
212
-
213
- def forward(self, q):
214
-
215
- batch, n_bands, n_time, emb_dim = q.shape
216
-
217
- masks = torch.zeros(
218
- (batch, self.in_channels, self.n_freq, n_time),
219
- device=q.device,
220
- dtype=torch.complex64,
221
- )
222
-
223
- for im in range(n_bands):
224
- fstart, fend = self.band_specs[im]
225
-
226
- mask = self.compute_mask(q, im)
227
-
228
- if self.use_freq_weights:
229
- fw = self.get_buffer(f"freq_weights/{im}")[:, None]
230
- mask = mask * fw
231
- masks[:, :, fstart:fend, :] += mask
232
-
233
- return masks
234
-
235
-
236
- class MaskEstimationModule(OverlappingMaskEstimationModule):
237
- def __init__(
238
- self,
239
- band_specs: List[Tuple[float, float]],
240
- emb_dim: int,
241
- mlp_dim: int,
242
- in_channels: Optional[int],
243
- hidden_activation: str = "Tanh",
244
- hidden_activation_kwargs: Dict = None,
245
- complex_mask: bool = True,
246
- **kwargs,
247
- ) -> None:
248
- check_nonzero_bandwidth(band_specs)
249
- check_no_gap(band_specs)
250
- check_no_overlap(band_specs)
251
- super().__init__(
252
- in_channels=in_channels,
253
- band_specs=band_specs,
254
- freq_weights=None,
255
- n_freq=None,
256
- emb_dim=emb_dim,
257
- mlp_dim=mlp_dim,
258
- hidden_activation=hidden_activation,
259
- hidden_activation_kwargs=hidden_activation_kwargs,
260
- complex_mask=complex_mask,
261
- )
262
-
263
- def forward(self, q, cond=None):
264
-
265
- masks = self.compute_masks(q)
266
-
267
- masks = torch.concat(masks, dim=2)
268
-
269
- return masks
 
1
+ from typing import Dict, List, Optional, Tuple, Type
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn.modules import activation
6
+ from torch.utils.checkpoint import checkpoint_sequential
7
+
8
+ from .utils import (
9
+ band_widths_from_specs,
10
+ check_no_gap,
11
+ check_no_overlap,
12
+ check_nonzero_bandwidth,
13
+ )
14
+
15
+
16
+ class BaseNormMLP(nn.Module):
17
+ def __init__(
18
+ self,
19
+ emb_dim: int,
20
+ mlp_dim: int,
21
+ bandwidth: int,
22
+ in_channels: Optional[int],
23
+ hidden_activation: str = "Tanh",
24
+ hidden_activation_kwargs=None,
25
+ complex_mask: bool = True,
26
+ ):
27
+ super().__init__()
28
+ if hidden_activation_kwargs is None:
29
+ hidden_activation_kwargs = {}
30
+ self.hidden_activation_kwargs = hidden_activation_kwargs
31
+ self.norm = nn.LayerNorm(emb_dim)
32
+ self.hidden = nn.Sequential(
33
+ nn.Linear(in_features=emb_dim, out_features=mlp_dim),
34
+ activation.__dict__[hidden_activation](**self.hidden_activation_kwargs),
35
+ )
36
+
37
+ self.bandwidth = bandwidth
38
+ self.in_channels = in_channels
39
+
40
+ self.complex_mask = complex_mask
41
+ self.reim = 2 if complex_mask else 1
42
+ self.glu_mult = 2
43
+
44
+
45
+ class NormMLP(BaseNormMLP):
46
+ def __init__(
47
+ self,
48
+ emb_dim: int,
49
+ mlp_dim: int,
50
+ bandwidth: int,
51
+ in_channels: Optional[int],
52
+ hidden_activation: str = "Tanh",
53
+ hidden_activation_kwargs=None,
54
+ complex_mask: bool = True,
55
+ ) -> None:
56
+ super().__init__(
57
+ emb_dim=emb_dim,
58
+ mlp_dim=mlp_dim,
59
+ bandwidth=bandwidth,
60
+ in_channels=in_channels,
61
+ hidden_activation=hidden_activation,
62
+ hidden_activation_kwargs=hidden_activation_kwargs,
63
+ complex_mask=complex_mask,
64
+ )
65
+
66
+ self.output = nn.Sequential(
67
+ nn.Linear(
68
+ in_features=mlp_dim,
69
+ out_features=bandwidth * in_channels * self.reim * 2,
70
+ ),
71
+ nn.GLU(dim=-1),
72
+ )
73
+
74
+ try:
75
+ self.combined = torch.compile(
76
+ nn.Sequential(self.norm, self.hidden, self.output), disable=True
77
+ )
78
+ except Exception as e:
79
+ self.combined = nn.Sequential(self.norm, self.hidden, self.output)
80
+
81
+ def reshape_output(self, mb):
82
+ batch, n_time, _ = mb.shape
83
+ if self.complex_mask:
84
+ mb = mb.reshape(
85
+ batch, n_time, self.in_channels, self.bandwidth, self.reim
86
+ ).contiguous()
87
+ mb = torch.view_as_complex(mb)
88
+ else:
89
+ mb = mb.reshape(batch, n_time, self.in_channels, self.bandwidth)
90
+
91
+ mb = torch.permute(mb, (0, 2, 3, 1))
92
+
93
+ return mb
94
+
95
+ def forward(self, qb):
96
+
97
+ mb = checkpoint_sequential(self.combined, 2, qb, use_reentrant=False)
98
+ mb = self.reshape_output(mb)
99
+
100
+ return mb
101
+
102
+
103
+ class MaskEstimationModuleSuperBase(nn.Module):
104
+ pass
105
+
106
+
107
+ class MaskEstimationModuleBase(MaskEstimationModuleSuperBase):
108
+ def __init__(
109
+ self,
110
+ band_specs: List[Tuple[float, float]],
111
+ emb_dim: int,
112
+ mlp_dim: int,
113
+ in_channels: Optional[int],
114
+ hidden_activation: str = "Tanh",
115
+ hidden_activation_kwargs: Dict = None,
116
+ complex_mask: bool = True,
117
+ norm_mlp_cls: Type[nn.Module] = NormMLP,
118
+ norm_mlp_kwargs: Dict = None,
119
+ ) -> None:
120
+ super().__init__()
121
+
122
+ self.band_widths = band_widths_from_specs(band_specs)
123
+ self.n_bands = len(band_specs)
124
+
125
+ if hidden_activation_kwargs is None:
126
+ hidden_activation_kwargs = {}
127
+
128
+ if norm_mlp_kwargs is None:
129
+ norm_mlp_kwargs = {}
130
+
131
+ self.norm_mlp = nn.ModuleList(
132
+ [
133
+ norm_mlp_cls(
134
+ bandwidth=self.band_widths[b],
135
+ emb_dim=emb_dim,
136
+ mlp_dim=mlp_dim,
137
+ in_channels=in_channels,
138
+ hidden_activation=hidden_activation,
139
+ hidden_activation_kwargs=hidden_activation_kwargs,
140
+ complex_mask=complex_mask,
141
+ **norm_mlp_kwargs,
142
+ )
143
+ for b in range(self.n_bands)
144
+ ]
145
+ )
146
+
147
+ def compute_masks(self, q):
148
+ batch, n_bands, n_time, emb_dim = q.shape
149
+
150
+ masks = []
151
+
152
+ for b, nmlp in enumerate(self.norm_mlp):
153
+ qb = q[:, b, :, :]
154
+ mb = nmlp(qb)
155
+ masks.append(mb)
156
+
157
+ return masks
158
+
159
+ def compute_mask(self, q, b):
160
+ batch, n_bands, n_time, emb_dim = q.shape
161
+ qb = q[:, b, :, :]
162
+ mb = self.norm_mlp[b](qb)
163
+ return mb
164
+
165
+
166
+ class OverlappingMaskEstimationModule(MaskEstimationModuleBase):
167
+ def __init__(
168
+ self,
169
+ in_channels: int,
170
+ band_specs: List[Tuple[float, float]],
171
+ freq_weights: List[torch.Tensor],
172
+ n_freq: int,
173
+ emb_dim: int,
174
+ mlp_dim: int,
175
+ cond_dim: int = 0,
176
+ hidden_activation: str = "Tanh",
177
+ hidden_activation_kwargs: Dict = None,
178
+ complex_mask: bool = True,
179
+ norm_mlp_cls: Type[nn.Module] = NormMLP,
180
+ norm_mlp_kwargs: Dict = None,
181
+ use_freq_weights: bool = False,
182
+ ) -> None:
183
+ check_nonzero_bandwidth(band_specs)
184
+ check_no_gap(band_specs)
185
+
186
+ if cond_dim > 0:
187
+ raise NotImplementedError
188
+
189
+ super().__init__(
190
+ band_specs=band_specs,
191
+ emb_dim=emb_dim + cond_dim,
192
+ mlp_dim=mlp_dim,
193
+ in_channels=in_channels,
194
+ hidden_activation=hidden_activation,
195
+ hidden_activation_kwargs=hidden_activation_kwargs,
196
+ complex_mask=complex_mask,
197
+ norm_mlp_cls=norm_mlp_cls,
198
+ norm_mlp_kwargs=norm_mlp_kwargs,
199
+ )
200
+
201
+ self.n_freq = n_freq
202
+ self.band_specs = band_specs
203
+ self.in_channels = in_channels
204
+
205
+ if freq_weights is not None and use_freq_weights:
206
+ for i, fw in enumerate(freq_weights):
207
+ self.register_buffer(f"freq_weights/{i}", fw)
208
+
209
+ self.use_freq_weights = use_freq_weights
210
+ else:
211
+ self.use_freq_weights = False
212
+
213
+ def forward(self, q):
214
+
215
+ batch, n_bands, n_time, emb_dim = q.shape
216
+
217
+ masks = torch.zeros(
218
+ (batch, self.in_channels, self.n_freq, n_time),
219
+ device=q.device,
220
+ dtype=torch.complex64,
221
+ )
222
+
223
+ for im in range(n_bands):
224
+ fstart, fend = self.band_specs[im]
225
+
226
+ mask = self.compute_mask(q, im)
227
+
228
+ if self.use_freq_weights:
229
+ fw = self.get_buffer(f"freq_weights/{im}")[:, None]
230
+ mask = mask * fw
231
+ masks[:, :, fstart:fend, :] += mask
232
+
233
+ return masks
234
+
235
+
236
+ class MaskEstimationModule(OverlappingMaskEstimationModule):
237
+ def __init__(
238
+ self,
239
+ band_specs: List[Tuple[float, float]],
240
+ emb_dim: int,
241
+ mlp_dim: int,
242
+ in_channels: Optional[int],
243
+ hidden_activation: str = "Tanh",
244
+ hidden_activation_kwargs: Dict = None,
245
+ complex_mask: bool = True,
246
+ **kwargs,
247
+ ) -> None:
248
+ check_nonzero_bandwidth(band_specs)
249
+ check_no_gap(band_specs)
250
+ check_no_overlap(band_specs)
251
+ super().__init__(
252
+ in_channels=in_channels,
253
+ band_specs=band_specs,
254
+ freq_weights=None,
255
+ n_freq=None,
256
+ emb_dim=emb_dim,
257
+ mlp_dim=mlp_dim,
258
+ hidden_activation=hidden_activation,
259
+ hidden_activation_kwargs=hidden_activation_kwargs,
260
+ complex_mask=complex_mask,
261
+ )
262
+
263
+ def forward(self, q, cond=None):
264
+
265
+ masks = self.compute_masks(q)
266
+
267
+ masks = torch.concat(masks, dim=2)
268
+
269
+ return masks
mvsepless/models/bandit_v2/tfmodel.py CHANGED
@@ -1,141 +1,141 @@
1
- import warnings
2
-
3
- import torch
4
- import torch.backends.cuda
5
- from torch import nn
6
- from torch.nn.modules import rnn
7
- from torch.utils.checkpoint import checkpoint_sequential
8
-
9
-
10
- class TimeFrequencyModellingModule(nn.Module):
11
- def __init__(self) -> None:
12
- super().__init__()
13
-
14
-
15
- class ResidualRNN(nn.Module):
16
- def __init__(
17
- self,
18
- emb_dim: int,
19
- rnn_dim: int,
20
- bidirectional: bool = True,
21
- rnn_type: str = "LSTM",
22
- use_batch_trick: bool = True,
23
- use_layer_norm: bool = True,
24
- ) -> None:
25
- super().__init__()
26
-
27
- assert use_layer_norm
28
- assert use_batch_trick
29
-
30
- self.use_layer_norm = use_layer_norm
31
- self.norm = nn.LayerNorm(emb_dim)
32
- self.rnn = rnn.__dict__[rnn_type](
33
- input_size=emb_dim,
34
- hidden_size=rnn_dim,
35
- num_layers=1,
36
- batch_first=True,
37
- bidirectional=bidirectional,
38
- )
39
-
40
- self.fc = nn.Linear(
41
- in_features=rnn_dim * (2 if bidirectional else 1), out_features=emb_dim
42
- )
43
-
44
- self.use_batch_trick = use_batch_trick
45
- if not self.use_batch_trick:
46
- warnings.warn("NOT USING BATCH TRICK IS EXTREMELY SLOW!!")
47
-
48
- def forward(self, z):
49
-
50
- z0 = torch.clone(z)
51
- z = self.norm(z)
52
-
53
- batch, n_uncrossed, n_across, emb_dim = z.shape
54
- z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
55
- z = self.rnn(z)[0]
56
- z = torch.reshape(z, (batch, n_uncrossed, n_across, -1))
57
-
58
- z = self.fc(z)
59
-
60
- z = z + z0
61
-
62
- return z
63
-
64
-
65
- class Transpose(nn.Module):
66
- def __init__(self, dim0: int, dim1: int) -> None:
67
- super().__init__()
68
- self.dim0 = dim0
69
- self.dim1 = dim1
70
-
71
- def forward(self, z):
72
- return z.transpose(self.dim0, self.dim1)
73
-
74
-
75
- class SeqBandModellingModule(TimeFrequencyModellingModule):
76
- def __init__(
77
- self,
78
- n_modules: int = 12,
79
- emb_dim: int = 128,
80
- rnn_dim: int = 256,
81
- bidirectional: bool = True,
82
- rnn_type: str = "LSTM",
83
- parallel_mode=False,
84
- ) -> None:
85
- super().__init__()
86
-
87
- self.n_modules = n_modules
88
-
89
- if parallel_mode:
90
- self.seqband = nn.ModuleList([])
91
- for _ in range(n_modules):
92
- self.seqband.append(
93
- nn.ModuleList(
94
- [
95
- ResidualRNN(
96
- emb_dim=emb_dim,
97
- rnn_dim=rnn_dim,
98
- bidirectional=bidirectional,
99
- rnn_type=rnn_type,
100
- ),
101
- ResidualRNN(
102
- emb_dim=emb_dim,
103
- rnn_dim=rnn_dim,
104
- bidirectional=bidirectional,
105
- rnn_type=rnn_type,
106
- ),
107
- ]
108
- )
109
- )
110
- else:
111
- seqband = []
112
- for _ in range(2 * n_modules):
113
- seqband += [
114
- ResidualRNN(
115
- emb_dim=emb_dim,
116
- rnn_dim=rnn_dim,
117
- bidirectional=bidirectional,
118
- rnn_type=rnn_type,
119
- ),
120
- Transpose(1, 2),
121
- ]
122
-
123
- self.seqband = nn.Sequential(*seqband)
124
-
125
- self.parallel_mode = parallel_mode
126
-
127
- def forward(self, z):
128
-
129
- if self.parallel_mode:
130
- for sbm_pair in self.seqband:
131
- sbm_t, sbm_f = sbm_pair[0], sbm_pair[1]
132
- zt = sbm_t(z)
133
- zf = sbm_f(z.transpose(1, 2))
134
- z = zt + zf.transpose(1, 2)
135
- else:
136
- z = checkpoint_sequential(
137
- self.seqband, self.n_modules, z, use_reentrant=False
138
- )
139
-
140
- q = z
141
- return q
 
1
+ import warnings
2
+
3
+ import torch
4
+ import torch.backends.cuda
5
+ from torch import nn
6
+ from torch.nn.modules import rnn
7
+ from torch.utils.checkpoint import checkpoint_sequential
8
+
9
+
10
+ class TimeFrequencyModellingModule(nn.Module):
11
+ def __init__(self) -> None:
12
+ super().__init__()
13
+
14
+
15
+ class ResidualRNN(nn.Module):
16
+ def __init__(
17
+ self,
18
+ emb_dim: int,
19
+ rnn_dim: int,
20
+ bidirectional: bool = True,
21
+ rnn_type: str = "LSTM",
22
+ use_batch_trick: bool = True,
23
+ use_layer_norm: bool = True,
24
+ ) -> None:
25
+ super().__init__()
26
+
27
+ assert use_layer_norm
28
+ assert use_batch_trick
29
+
30
+ self.use_layer_norm = use_layer_norm
31
+ self.norm = nn.LayerNorm(emb_dim)
32
+ self.rnn = rnn.__dict__[rnn_type](
33
+ input_size=emb_dim,
34
+ hidden_size=rnn_dim,
35
+ num_layers=1,
36
+ batch_first=True,
37
+ bidirectional=bidirectional,
38
+ )
39
+
40
+ self.fc = nn.Linear(
41
+ in_features=rnn_dim * (2 if bidirectional else 1), out_features=emb_dim
42
+ )
43
+
44
+ self.use_batch_trick = use_batch_trick
45
+ if not self.use_batch_trick:
46
+ warnings.warn("NOT USING BATCH TRICK IS EXTREMELY SLOW!!")
47
+
48
+ def forward(self, z):
49
+
50
+ z0 = torch.clone(z)
51
+ z = self.norm(z)
52
+
53
+ batch, n_uncrossed, n_across, emb_dim = z.shape
54
+ z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
55
+ z = self.rnn(z)[0]
56
+ z = torch.reshape(z, (batch, n_uncrossed, n_across, -1))
57
+
58
+ z = self.fc(z)
59
+
60
+ z = z + z0
61
+
62
+ return z
63
+
64
+
65
+ class Transpose(nn.Module):
66
+ def __init__(self, dim0: int, dim1: int) -> None:
67
+ super().__init__()
68
+ self.dim0 = dim0
69
+ self.dim1 = dim1
70
+
71
+ def forward(self, z):
72
+ return z.transpose(self.dim0, self.dim1)
73
+
74
+
75
+ class SeqBandModellingModule(TimeFrequencyModellingModule):
76
+ def __init__(
77
+ self,
78
+ n_modules: int = 12,
79
+ emb_dim: int = 128,
80
+ rnn_dim: int = 256,
81
+ bidirectional: bool = True,
82
+ rnn_type: str = "LSTM",
83
+ parallel_mode=False,
84
+ ) -> None:
85
+ super().__init__()
86
+
87
+ self.n_modules = n_modules
88
+
89
+ if parallel_mode:
90
+ self.seqband = nn.ModuleList([])
91
+ for _ in range(n_modules):
92
+ self.seqband.append(
93
+ nn.ModuleList(
94
+ [
95
+ ResidualRNN(
96
+ emb_dim=emb_dim,
97
+ rnn_dim=rnn_dim,
98
+ bidirectional=bidirectional,
99
+ rnn_type=rnn_type,
100
+ ),
101
+ ResidualRNN(
102
+ emb_dim=emb_dim,
103
+ rnn_dim=rnn_dim,
104
+ bidirectional=bidirectional,
105
+ rnn_type=rnn_type,
106
+ ),
107
+ ]
108
+ )
109
+ )
110
+ else:
111
+ seqband = []
112
+ for _ in range(2 * n_modules):
113
+ seqband += [
114
+ ResidualRNN(
115
+ emb_dim=emb_dim,
116
+ rnn_dim=rnn_dim,
117
+ bidirectional=bidirectional,
118
+ rnn_type=rnn_type,
119
+ ),
120
+ Transpose(1, 2),
121
+ ]
122
+
123
+ self.seqband = nn.Sequential(*seqband)
124
+
125
+ self.parallel_mode = parallel_mode
126
+
127
+ def forward(self, z):
128
+
129
+ if self.parallel_mode:
130
+ for sbm_pair in self.seqband:
131
+ sbm_t, sbm_f = sbm_pair[0], sbm_pair[1]
132
+ zt = sbm_t(z)
133
+ zf = sbm_f(z.transpose(1, 2))
134
+ z = zt + zf.transpose(1, 2)
135
+ else:
136
+ z = checkpoint_sequential(
137
+ self.seqband, self.n_modules, z, use_reentrant=False
138
+ )
139
+
140
+ q = z
141
+ return q
mvsepless/models/bandit_v2/utils.py CHANGED
@@ -1,384 +1,384 @@
1
- import os
2
- from abc import abstractmethod
3
- from typing import Callable
4
-
5
- import numpy as np
6
- import torch
7
- from librosa import hz_to_midi, midi_to_hz
8
- from torchaudio import functional as taF
9
-
10
-
11
- def band_widths_from_specs(band_specs):
12
- return [e - i for i, e in band_specs]
13
-
14
-
15
- def check_nonzero_bandwidth(band_specs):
16
- for fstart, fend in band_specs:
17
- if fend - fstart <= 0:
18
- raise ValueError("Bands cannot be zero-width")
19
-
20
-
21
- def check_no_overlap(band_specs):
22
- fend_prev = -1
23
- for fstart_curr, fend_curr in band_specs:
24
- if fstart_curr <= fend_prev:
25
- raise ValueError("Bands cannot overlap")
26
-
27
-
28
- def check_no_gap(band_specs):
29
- fstart, _ = band_specs[0]
30
- assert fstart == 0
31
-
32
- fend_prev = -1
33
- for fstart_curr, fend_curr in band_specs:
34
- if fstart_curr - fend_prev > 1:
35
- raise ValueError("Bands cannot leave gap")
36
- fend_prev = fend_curr
37
-
38
-
39
- class BandsplitSpecification:
40
- def __init__(self, nfft: int, fs: int) -> None:
41
- self.fs = fs
42
- self.nfft = nfft
43
- self.nyquist = fs / 2
44
- self.max_index = nfft // 2 + 1
45
-
46
- self.split500 = self.hertz_to_index(500)
47
- self.split1k = self.hertz_to_index(1000)
48
- self.split2k = self.hertz_to_index(2000)
49
- self.split4k = self.hertz_to_index(4000)
50
- self.split8k = self.hertz_to_index(8000)
51
- self.split16k = self.hertz_to_index(16000)
52
- self.split20k = self.hertz_to_index(20000)
53
-
54
- self.above20k = [(self.split20k, self.max_index)]
55
- self.above16k = [(self.split16k, self.split20k)] + self.above20k
56
-
57
- def index_to_hertz(self, index: int):
58
- return index * self.fs / self.nfft
59
-
60
- def hertz_to_index(self, hz: float, round: bool = True):
61
- index = hz * self.nfft / self.fs
62
-
63
- if round:
64
- index = int(np.round(index))
65
-
66
- return index
67
-
68
- def get_band_specs_with_bandwidth(self, start_index, end_index, bandwidth_hz):
69
- band_specs = []
70
- lower = start_index
71
-
72
- while lower < end_index:
73
- upper = int(np.floor(lower + self.hertz_to_index(bandwidth_hz)))
74
- upper = min(upper, end_index)
75
-
76
- band_specs.append((lower, upper))
77
- lower = upper
78
-
79
- return band_specs
80
-
81
- @abstractmethod
82
- def get_band_specs(self):
83
- raise NotImplementedError
84
-
85
-
86
- class VocalBandsplitSpecification(BandsplitSpecification):
87
- def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
88
- super().__init__(nfft=nfft, fs=fs)
89
-
90
- self.version = version
91
-
92
- def get_band_specs(self):
93
- return getattr(self, f"version{self.version}")()
94
-
95
- @property
96
- def version1(self):
97
- return self.get_band_specs_with_bandwidth(
98
- start_index=0, end_index=self.max_index, bandwidth_hz=1000
99
- )
100
-
101
- def version2(self):
102
- below16k = self.get_band_specs_with_bandwidth(
103
- start_index=0, end_index=self.split16k, bandwidth_hz=1000
104
- )
105
- below20k = self.get_band_specs_with_bandwidth(
106
- start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
107
- )
108
-
109
- return below16k + below20k + self.above20k
110
-
111
- def version3(self):
112
- below8k = self.get_band_specs_with_bandwidth(
113
- start_index=0, end_index=self.split8k, bandwidth_hz=1000
114
- )
115
- below16k = self.get_band_specs_with_bandwidth(
116
- start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
117
- )
118
-
119
- return below8k + below16k + self.above16k
120
-
121
- def version4(self):
122
- below1k = self.get_band_specs_with_bandwidth(
123
- start_index=0, end_index=self.split1k, bandwidth_hz=100
124
- )
125
- below8k = self.get_band_specs_with_bandwidth(
126
- start_index=self.split1k, end_index=self.split8k, bandwidth_hz=1000
127
- )
128
- below16k = self.get_band_specs_with_bandwidth(
129
- start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
130
- )
131
-
132
- return below1k + below8k + below16k + self.above16k
133
-
134
- def version5(self):
135
- below1k = self.get_band_specs_with_bandwidth(
136
- start_index=0, end_index=self.split1k, bandwidth_hz=100
137
- )
138
- below16k = self.get_band_specs_with_bandwidth(
139
- start_index=self.split1k, end_index=self.split16k, bandwidth_hz=1000
140
- )
141
- below20k = self.get_band_specs_with_bandwidth(
142
- start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
143
- )
144
- return below1k + below16k + below20k + self.above20k
145
-
146
- def version6(self):
147
- below1k = self.get_band_specs_with_bandwidth(
148
- start_index=0, end_index=self.split1k, bandwidth_hz=100
149
- )
150
- below4k = self.get_band_specs_with_bandwidth(
151
- start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
152
- )
153
- below8k = self.get_band_specs_with_bandwidth(
154
- start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
155
- )
156
- below16k = self.get_band_specs_with_bandwidth(
157
- start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
158
- )
159
- return below1k + below4k + below8k + below16k + self.above16k
160
-
161
- def version7(self):
162
- below1k = self.get_band_specs_with_bandwidth(
163
- start_index=0, end_index=self.split1k, bandwidth_hz=100
164
- )
165
- below4k = self.get_band_specs_with_bandwidth(
166
- start_index=self.split1k, end_index=self.split4k, bandwidth_hz=250
167
- )
168
- below8k = self.get_band_specs_with_bandwidth(
169
- start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
170
- )
171
- below16k = self.get_band_specs_with_bandwidth(
172
- start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
173
- )
174
- below20k = self.get_band_specs_with_bandwidth(
175
- start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
176
- )
177
- return below1k + below4k + below8k + below16k + below20k + self.above20k
178
-
179
-
180
- class OtherBandsplitSpecification(VocalBandsplitSpecification):
181
- def __init__(self, nfft: int, fs: int) -> None:
182
- super().__init__(nfft=nfft, fs=fs, version="7")
183
-
184
-
185
- class BassBandsplitSpecification(BandsplitSpecification):
186
- def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
187
- super().__init__(nfft=nfft, fs=fs)
188
-
189
- def get_band_specs(self):
190
- below500 = self.get_band_specs_with_bandwidth(
191
- start_index=0, end_index=self.split500, bandwidth_hz=50
192
- )
193
- below1k = self.get_band_specs_with_bandwidth(
194
- start_index=self.split500, end_index=self.split1k, bandwidth_hz=100
195
- )
196
- below4k = self.get_band_specs_with_bandwidth(
197
- start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
198
- )
199
- below8k = self.get_band_specs_with_bandwidth(
200
- start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
201
- )
202
- below16k = self.get_band_specs_with_bandwidth(
203
- start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
204
- )
205
- above16k = [(self.split16k, self.max_index)]
206
-
207
- return below500 + below1k + below4k + below8k + below16k + above16k
208
-
209
-
210
- class DrumBandsplitSpecification(BandsplitSpecification):
211
- def __init__(self, nfft: int, fs: int) -> None:
212
- super().__init__(nfft=nfft, fs=fs)
213
-
214
- def get_band_specs(self):
215
- below1k = self.get_band_specs_with_bandwidth(
216
- start_index=0, end_index=self.split1k, bandwidth_hz=50
217
- )
218
- below2k = self.get_band_specs_with_bandwidth(
219
- start_index=self.split1k, end_index=self.split2k, bandwidth_hz=100
220
- )
221
- below4k = self.get_band_specs_with_bandwidth(
222
- start_index=self.split2k, end_index=self.split4k, bandwidth_hz=250
223
- )
224
- below8k = self.get_band_specs_with_bandwidth(
225
- start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
226
- )
227
- below16k = self.get_band_specs_with_bandwidth(
228
- start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
229
- )
230
- above16k = [(self.split16k, self.max_index)]
231
-
232
- return below1k + below2k + below4k + below8k + below16k + above16k
233
-
234
-
235
- class PerceptualBandsplitSpecification(BandsplitSpecification):
236
- def __init__(
237
- self,
238
- nfft: int,
239
- fs: int,
240
- fbank_fn: Callable[[int, int, float, float, int], torch.Tensor],
241
- n_bands: int,
242
- f_min: float = 0.0,
243
- f_max: float = None,
244
- ) -> None:
245
- super().__init__(nfft=nfft, fs=fs)
246
- self.n_bands = n_bands
247
- if f_max is None:
248
- f_max = fs / 2
249
-
250
- self.filterbank = fbank_fn(n_bands, fs, f_min, f_max, self.max_index)
251
-
252
- weight_per_bin = torch.sum(self.filterbank, dim=0, keepdim=True)
253
- normalized_mel_fb = self.filterbank / weight_per_bin
254
-
255
- freq_weights = []
256
- band_specs = []
257
- for i in range(self.n_bands):
258
- active_bins = torch.nonzero(self.filterbank[i, :]).squeeze().tolist()
259
- if isinstance(active_bins, int):
260
- active_bins = (active_bins, active_bins)
261
- if len(active_bins) == 0:
262
- continue
263
- start_index = active_bins[0]
264
- end_index = active_bins[-1] + 1
265
- band_specs.append((start_index, end_index))
266
- freq_weights.append(normalized_mel_fb[i, start_index:end_index])
267
-
268
- self.freq_weights = freq_weights
269
- self.band_specs = band_specs
270
-
271
- def get_band_specs(self):
272
- return self.band_specs
273
-
274
- def get_freq_weights(self):
275
- return self.freq_weights
276
-
277
- def save_to_file(self, dir_path: str) -> None:
278
- os.makedirs(dir_path, exist_ok=True)
279
-
280
- import pickle
281
-
282
- with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f:
283
- pickle.dump(
284
- {
285
- "band_specs": self.band_specs,
286
- "freq_weights": self.freq_weights,
287
- "filterbank": self.filterbank,
288
- },
289
- f,
290
- )
291
-
292
-
293
- def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs):
294
- fb = taF.melscale_fbanks(
295
- n_mels=n_bands,
296
- sample_rate=fs,
297
- f_min=f_min,
298
- f_max=f_max,
299
- n_freqs=n_freqs,
300
- ).T
301
-
302
- fb[0, 0] = 1.0
303
-
304
- return fb
305
-
306
-
307
- class MelBandsplitSpecification(PerceptualBandsplitSpecification):
308
- def __init__(
309
- self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
310
- ) -> None:
311
- super().__init__(
312
- fbank_fn=mel_filterbank,
313
- nfft=nfft,
314
- fs=fs,
315
- n_bands=n_bands,
316
- f_min=f_min,
317
- f_max=f_max,
318
- )
319
-
320
-
321
- def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs, scale="constant"):
322
- nfft = 2 * (n_freqs - 1)
323
- df = fs / nfft
324
- f_max = f_max or fs / 2
325
- f_min = f_min or 0
326
- f_min = fs / nfft
327
-
328
- n_octaves = np.log2(f_max / f_min)
329
- n_octaves_per_band = n_octaves / n_bands
330
- bandwidth_mult = np.power(2.0, n_octaves_per_band)
331
-
332
- low_midi = max(0, hz_to_midi(f_min))
333
- high_midi = hz_to_midi(f_max)
334
- midi_points = np.linspace(low_midi, high_midi, n_bands)
335
- hz_pts = midi_to_hz(midi_points)
336
-
337
- low_pts = hz_pts / bandwidth_mult
338
- high_pts = hz_pts * bandwidth_mult
339
-
340
- low_bins = np.floor(low_pts / df).astype(int)
341
- high_bins = np.ceil(high_pts / df).astype(int)
342
-
343
- fb = np.zeros((n_bands, n_freqs))
344
-
345
- for i in range(n_bands):
346
- fb[i, low_bins[i] : high_bins[i] + 1] = 1.0
347
-
348
- fb[0, : low_bins[0]] = 1.0
349
- fb[-1, high_bins[-1] + 1 :] = 1.0
350
-
351
- return torch.as_tensor(fb)
352
-
353
-
354
- class MusicalBandsplitSpecification(PerceptualBandsplitSpecification):
355
- def __init__(
356
- self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
357
- ) -> None:
358
- super().__init__(
359
- fbank_fn=musical_filterbank,
360
- nfft=nfft,
361
- fs=fs,
362
- n_bands=n_bands,
363
- f_min=f_min,
364
- f_max=f_max,
365
- )
366
-
367
-
368
- if __name__ == "__main__":
369
- import pandas as pd
370
-
371
- band_defs = []
372
-
373
- for bands in [VocalBandsplitSpecification]:
374
- band_name = bands.__name__.replace("BandsplitSpecification", "")
375
-
376
- mbs = bands(nfft=2048, fs=44100).get_band_specs()
377
-
378
- for i, (f_min, f_max) in enumerate(mbs):
379
- band_defs.append(
380
- {"band": band_name, "band_index": i, "f_min": f_min, "f_max": f_max}
381
- )
382
-
383
- df = pd.DataFrame(band_defs)
384
- df.to_csv("vox7bands.csv", index=False)
 
1
+ import os
2
+ from abc import abstractmethod
3
+ from typing import Callable
4
+
5
+ import numpy as np
6
+ import torch
7
+ from librosa import hz_to_midi, midi_to_hz
8
+ from torchaudio import functional as taF
9
+
10
+
11
+ def band_widths_from_specs(band_specs):
12
+ return [e - i for i, e in band_specs]
13
+
14
+
15
+ def check_nonzero_bandwidth(band_specs):
16
+ for fstart, fend in band_specs:
17
+ if fend - fstart <= 0:
18
+ raise ValueError("Bands cannot be zero-width")
19
+
20
+
21
+ def check_no_overlap(band_specs):
22
+ fend_prev = -1
23
+ for fstart_curr, fend_curr in band_specs:
24
+ if fstart_curr <= fend_prev:
25
+ raise ValueError("Bands cannot overlap")
26
+
27
+
28
+ def check_no_gap(band_specs):
29
+ fstart, _ = band_specs[0]
30
+ assert fstart == 0
31
+
32
+ fend_prev = -1
33
+ for fstart_curr, fend_curr in band_specs:
34
+ if fstart_curr - fend_prev > 1:
35
+ raise ValueError("Bands cannot leave gap")
36
+ fend_prev = fend_curr
37
+
38
+
39
+ class BandsplitSpecification:
40
+ def __init__(self, nfft: int, fs: int) -> None:
41
+ self.fs = fs
42
+ self.nfft = nfft
43
+ self.nyquist = fs / 2
44
+ self.max_index = nfft // 2 + 1
45
+
46
+ self.split500 = self.hertz_to_index(500)
47
+ self.split1k = self.hertz_to_index(1000)
48
+ self.split2k = self.hertz_to_index(2000)
49
+ self.split4k = self.hertz_to_index(4000)
50
+ self.split8k = self.hertz_to_index(8000)
51
+ self.split16k = self.hertz_to_index(16000)
52
+ self.split20k = self.hertz_to_index(20000)
53
+
54
+ self.above20k = [(self.split20k, self.max_index)]
55
+ self.above16k = [(self.split16k, self.split20k)] + self.above20k
56
+
57
+ def index_to_hertz(self, index: int):
58
+ return index * self.fs / self.nfft
59
+
60
+ def hertz_to_index(self, hz: float, round: bool = True):
61
+ index = hz * self.nfft / self.fs
62
+
63
+ if round:
64
+ index = int(np.round(index))
65
+
66
+ return index
67
+
68
+ def get_band_specs_with_bandwidth(self, start_index, end_index, bandwidth_hz):
69
+ band_specs = []
70
+ lower = start_index
71
+
72
+ while lower < end_index:
73
+ upper = int(np.floor(lower + self.hertz_to_index(bandwidth_hz)))
74
+ upper = min(upper, end_index)
75
+
76
+ band_specs.append((lower, upper))
77
+ lower = upper
78
+
79
+ return band_specs
80
+
81
+ @abstractmethod
82
+ def get_band_specs(self):
83
+ raise NotImplementedError
84
+
85
+
86
+ class VocalBandsplitSpecification(BandsplitSpecification):
87
+ def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
88
+ super().__init__(nfft=nfft, fs=fs)
89
+
90
+ self.version = version
91
+
92
+ def get_band_specs(self):
93
+ return getattr(self, f"version{self.version}")()
94
+
95
+ @property
96
+ def version1(self):
97
+ return self.get_band_specs_with_bandwidth(
98
+ start_index=0, end_index=self.max_index, bandwidth_hz=1000
99
+ )
100
+
101
+ def version2(self):
102
+ below16k = self.get_band_specs_with_bandwidth(
103
+ start_index=0, end_index=self.split16k, bandwidth_hz=1000
104
+ )
105
+ below20k = self.get_band_specs_with_bandwidth(
106
+ start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
107
+ )
108
+
109
+ return below16k + below20k + self.above20k
110
+
111
+ def version3(self):
112
+ below8k = self.get_band_specs_with_bandwidth(
113
+ start_index=0, end_index=self.split8k, bandwidth_hz=1000
114
+ )
115
+ below16k = self.get_band_specs_with_bandwidth(
116
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
117
+ )
118
+
119
+ return below8k + below16k + self.above16k
120
+
121
+ def version4(self):
122
+ below1k = self.get_band_specs_with_bandwidth(
123
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
124
+ )
125
+ below8k = self.get_band_specs_with_bandwidth(
126
+ start_index=self.split1k, end_index=self.split8k, bandwidth_hz=1000
127
+ )
128
+ below16k = self.get_band_specs_with_bandwidth(
129
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
130
+ )
131
+
132
+ return below1k + below8k + below16k + self.above16k
133
+
134
+ def version5(self):
135
+ below1k = self.get_band_specs_with_bandwidth(
136
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
137
+ )
138
+ below16k = self.get_band_specs_with_bandwidth(
139
+ start_index=self.split1k, end_index=self.split16k, bandwidth_hz=1000
140
+ )
141
+ below20k = self.get_band_specs_with_bandwidth(
142
+ start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
143
+ )
144
+ return below1k + below16k + below20k + self.above20k
145
+
146
+ def version6(self):
147
+ below1k = self.get_band_specs_with_bandwidth(
148
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
149
+ )
150
+ below4k = self.get_band_specs_with_bandwidth(
151
+ start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
152
+ )
153
+ below8k = self.get_band_specs_with_bandwidth(
154
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
155
+ )
156
+ below16k = self.get_band_specs_with_bandwidth(
157
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
158
+ )
159
+ return below1k + below4k + below8k + below16k + self.above16k
160
+
161
+ def version7(self):
162
+ below1k = self.get_band_specs_with_bandwidth(
163
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
164
+ )
165
+ below4k = self.get_band_specs_with_bandwidth(
166
+ start_index=self.split1k, end_index=self.split4k, bandwidth_hz=250
167
+ )
168
+ below8k = self.get_band_specs_with_bandwidth(
169
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
170
+ )
171
+ below16k = self.get_band_specs_with_bandwidth(
172
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
173
+ )
174
+ below20k = self.get_band_specs_with_bandwidth(
175
+ start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
176
+ )
177
+ return below1k + below4k + below8k + below16k + below20k + self.above20k
178
+
179
+
180
+ class OtherBandsplitSpecification(VocalBandsplitSpecification):
181
+ def __init__(self, nfft: int, fs: int) -> None:
182
+ super().__init__(nfft=nfft, fs=fs, version="7")
183
+
184
+
185
+ class BassBandsplitSpecification(BandsplitSpecification):
186
+ def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
187
+ super().__init__(nfft=nfft, fs=fs)
188
+
189
+ def get_band_specs(self):
190
+ below500 = self.get_band_specs_with_bandwidth(
191
+ start_index=0, end_index=self.split500, bandwidth_hz=50
192
+ )
193
+ below1k = self.get_band_specs_with_bandwidth(
194
+ start_index=self.split500, end_index=self.split1k, bandwidth_hz=100
195
+ )
196
+ below4k = self.get_band_specs_with_bandwidth(
197
+ start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
198
+ )
199
+ below8k = self.get_band_specs_with_bandwidth(
200
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
201
+ )
202
+ below16k = self.get_band_specs_with_bandwidth(
203
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
204
+ )
205
+ above16k = [(self.split16k, self.max_index)]
206
+
207
+ return below500 + below1k + below4k + below8k + below16k + above16k
208
+
209
+
210
+ class DrumBandsplitSpecification(BandsplitSpecification):
211
+ def __init__(self, nfft: int, fs: int) -> None:
212
+ super().__init__(nfft=nfft, fs=fs)
213
+
214
+ def get_band_specs(self):
215
+ below1k = self.get_band_specs_with_bandwidth(
216
+ start_index=0, end_index=self.split1k, bandwidth_hz=50
217
+ )
218
+ below2k = self.get_band_specs_with_bandwidth(
219
+ start_index=self.split1k, end_index=self.split2k, bandwidth_hz=100
220
+ )
221
+ below4k = self.get_band_specs_with_bandwidth(
222
+ start_index=self.split2k, end_index=self.split4k, bandwidth_hz=250
223
+ )
224
+ below8k = self.get_band_specs_with_bandwidth(
225
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
226
+ )
227
+ below16k = self.get_band_specs_with_bandwidth(
228
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
229
+ )
230
+ above16k = [(self.split16k, self.max_index)]
231
+
232
+ return below1k + below2k + below4k + below8k + below16k + above16k
233
+
234
+
235
+ class PerceptualBandsplitSpecification(BandsplitSpecification):
236
+ def __init__(
237
+ self,
238
+ nfft: int,
239
+ fs: int,
240
+ fbank_fn: Callable[[int, int, float, float, int], torch.Tensor],
241
+ n_bands: int,
242
+ f_min: float = 0.0,
243
+ f_max: float = None,
244
+ ) -> None:
245
+ super().__init__(nfft=nfft, fs=fs)
246
+ self.n_bands = n_bands
247
+ if f_max is None:
248
+ f_max = fs / 2
249
+
250
+ self.filterbank = fbank_fn(n_bands, fs, f_min, f_max, self.max_index)
251
+
252
+ weight_per_bin = torch.sum(self.filterbank, dim=0, keepdim=True)
253
+ normalized_mel_fb = self.filterbank / weight_per_bin
254
+
255
+ freq_weights = []
256
+ band_specs = []
257
+ for i in range(self.n_bands):
258
+ active_bins = torch.nonzero(self.filterbank[i, :]).squeeze().tolist()
259
+ if isinstance(active_bins, int):
260
+ active_bins = (active_bins, active_bins)
261
+ if len(active_bins) == 0:
262
+ continue
263
+ start_index = active_bins[0]
264
+ end_index = active_bins[-1] + 1
265
+ band_specs.append((start_index, end_index))
266
+ freq_weights.append(normalized_mel_fb[i, start_index:end_index])
267
+
268
+ self.freq_weights = freq_weights
269
+ self.band_specs = band_specs
270
+
271
+ def get_band_specs(self):
272
+ return self.band_specs
273
+
274
+ def get_freq_weights(self):
275
+ return self.freq_weights
276
+
277
+ def save_to_file(self, dir_path: str) -> None:
278
+ os.makedirs(dir_path, exist_ok=True)
279
+
280
+ import pickle
281
+
282
+ with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f:
283
+ pickle.dump(
284
+ {
285
+ "band_specs": self.band_specs,
286
+ "freq_weights": self.freq_weights,
287
+ "filterbank": self.filterbank,
288
+ },
289
+ f,
290
+ )
291
+
292
+
293
+ def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs):
294
+ fb = taF.melscale_fbanks(
295
+ n_mels=n_bands,
296
+ sample_rate=fs,
297
+ f_min=f_min,
298
+ f_max=f_max,
299
+ n_freqs=n_freqs,
300
+ ).T
301
+
302
+ fb[0, 0] = 1.0
303
+
304
+ return fb
305
+
306
+
307
+ class MelBandsplitSpecification(PerceptualBandsplitSpecification):
308
+ def __init__(
309
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
310
+ ) -> None:
311
+ super().__init__(
312
+ fbank_fn=mel_filterbank,
313
+ nfft=nfft,
314
+ fs=fs,
315
+ n_bands=n_bands,
316
+ f_min=f_min,
317
+ f_max=f_max,
318
+ )
319
+
320
+
321
+ def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs, scale="constant"):
322
+ nfft = 2 * (n_freqs - 1)
323
+ df = fs / nfft
324
+ f_max = f_max or fs / 2
325
+ f_min = f_min or 0
326
+ f_min = fs / nfft
327
+
328
+ n_octaves = np.log2(f_max / f_min)
329
+ n_octaves_per_band = n_octaves / n_bands
330
+ bandwidth_mult = np.power(2.0, n_octaves_per_band)
331
+
332
+ low_midi = max(0, hz_to_midi(f_min))
333
+ high_midi = hz_to_midi(f_max)
334
+ midi_points = np.linspace(low_midi, high_midi, n_bands)
335
+ hz_pts = midi_to_hz(midi_points)
336
+
337
+ low_pts = hz_pts / bandwidth_mult
338
+ high_pts = hz_pts * bandwidth_mult
339
+
340
+ low_bins = np.floor(low_pts / df).astype(int)
341
+ high_bins = np.ceil(high_pts / df).astype(int)
342
+
343
+ fb = np.zeros((n_bands, n_freqs))
344
+
345
+ for i in range(n_bands):
346
+ fb[i, low_bins[i] : high_bins[i] + 1] = 1.0
347
+
348
+ fb[0, : low_bins[0]] = 1.0
349
+ fb[-1, high_bins[-1] + 1 :] = 1.0
350
+
351
+ return torch.as_tensor(fb)
352
+
353
+
354
+ class MusicalBandsplitSpecification(PerceptualBandsplitSpecification):
355
+ def __init__(
356
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
357
+ ) -> None:
358
+ super().__init__(
359
+ fbank_fn=musical_filterbank,
360
+ nfft=nfft,
361
+ fs=fs,
362
+ n_bands=n_bands,
363
+ f_min=f_min,
364
+ f_max=f_max,
365
+ )
366
+
367
+
368
+ if __name__ == "__main__":
369
+ import pandas as pd
370
+
371
+ band_defs = []
372
+
373
+ for bands in [VocalBandsplitSpecification]:
374
+ band_name = bands.__name__.replace("BandsplitSpecification", "")
375
+
376
+ mbs = bands(nfft=2048, fs=44100).get_band_specs()
377
+
378
+ for i, (f_min, f_max) in enumerate(mbs):
379
+ band_defs.append(
380
+ {"band": band_name, "band_index": i, "f_min": f_min, "f_max": f_max}
381
+ )
382
+
383
+ df = pd.DataFrame(band_defs)
384
+ df.to_csv("vox7bands.csv", index=False)
mvsepless/models/bs_roformer/__init__.py CHANGED
@@ -1,14 +1,14 @@
1
- from .bs_roformer import BSRoformer
2
- from .bs_conformer import BSConformer
3
- from .bs_roformer_sw import BSRoformer_SW
4
- try:
5
- from neuralop.models import FNO1d
6
- from .bs_roformer_fno import BSRoformer_FNO
7
- except:
8
- pass
9
- from .bs_roformer_hyperace import BSRoformerHyperACE
10
- from .bs_roformer_hyperace2 import BSRoformerHyperACE_2
11
- from .bs_roformer_conditional import BSRoformer_Conditional
12
- from .bs_roformer_unwa_inst_large_2 import BSRoformer_2
13
- from .mel_band_roformer import MelBandRoformer
14
- from .mel_band_conformer import MelBandConformer
 
1
+ from .bs_roformer import BSRoformer
2
+ from .bs_conformer import BSConformer
3
+ from .bs_roformer_sw import BSRoformer_SW
4
+ try:
5
+ from neuralop.models import FNO1d
6
+ from .bs_roformer_fno import BSRoformer_FNO
7
+ except:
8
+ pass
9
+ from .bs_roformer_hyperace import BSRoformerHyperACE
10
+ from .bs_roformer_hyperace2 import BSRoformerHyperACE_2
11
+ from .bs_roformer_conditional import BSRoformer_Conditional
12
+ from .bs_roformer_unwa_inst_large_2 import BSRoformer_2
13
+ from .mel_band_roformer import MelBandRoformer
14
+ from .mel_band_conformer import MelBandConformer
mvsepless/models/bs_roformer/attend.py CHANGED
@@ -1,128 +1,128 @@
1
- from functools import wraps
2
- from packaging import version
3
- from collections import namedtuple
4
-
5
- import os
6
- import torch
7
- from torch import nn, einsum
8
- import torch.nn.functional as F
9
-
10
- from einops import rearrange, reduce
11
-
12
-
13
- FlashAttentionConfig = namedtuple(
14
- "FlashAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"]
15
- )
16
-
17
-
18
- def exists(val):
19
- return val is not None
20
-
21
-
22
- def default(v, d):
23
- return v if exists(v) else d
24
-
25
-
26
- def once(fn):
27
- called = False
28
-
29
- @wraps(fn)
30
- def inner(x):
31
- nonlocal called
32
- if called:
33
- return
34
- called = True
35
- return fn(x)
36
-
37
- return inner
38
-
39
-
40
- print_once = once(print)
41
-
42
-
43
- class Attend(nn.Module):
44
- def __init__(self, dropout=0.0, flash=False, scale=None):
45
- super().__init__()
46
- self.scale = scale
47
- self.dropout = dropout
48
- self.attn_dropout = nn.Dropout(dropout)
49
-
50
- self.flash = flash
51
- self.use_torch_2_sdpa = False
52
- self._config_checked = False
53
-
54
- # Проверяем версию PyTorch при первом вызове
55
- if flash and not self._config_checked:
56
- if version.parse(torch.__version__) >= version.parse("2.0.0"):
57
- print_once("PyTorch >= 2.0 detected, will use SDPA if available.")
58
- self.use_torch_2_sdpa = True
59
-
60
- # Настройки для PyTorch >= 2.0
61
- self.cpu_config = FlashAttentionConfig(True, True, True)
62
- self.cuda_config = None
63
-
64
- if torch.cuda.is_available():
65
- device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
66
- device_version = version.parse(
67
- f"{device_properties.major}.{device_properties.minor}"
68
- )
69
-
70
- if device_version >= version.parse("8.0"):
71
- if os.name == "nt":
72
- print_once(
73
- "Windows OS detected, using math or mem efficient attention if input tensor is on cuda"
74
- )
75
- self.cuda_config = FlashAttentionConfig(False, True, True)
76
- else:
77
- print_once(
78
- "GPU Compute Capability equal or above 8.0, using flash attention if input tensor is on cuda"
79
- )
80
- self.cuda_config = FlashAttentionConfig(True, False, False)
81
- else:
82
- print_once(
83
- "GPU Compute Capability below 8.0, using math or mem efficient attention if input tensor is on cuda"
84
- )
85
- self.cuda_config = FlashAttentionConfig(False, True, True)
86
- else:
87
- print_once("PyTorch < 2.0 detected, flash attention will use einsum fallback.")
88
- self.use_torch_2_sdpa = False
89
-
90
- self._config_checked = True
91
-
92
- def flash_attn_torch2(self, q, k, v):
93
- """SDPA для PyTorch >= 2.0"""
94
- if exists(self.scale):
95
- default_scale = q.shape[-1] ** -0.5
96
- q = q * (self.scale / default_scale)
97
-
98
- is_cuda = q.is_cuda
99
- config = self.cuda_config if is_cuda else self.cpu_config
100
-
101
- with torch.backends.cuda.sdp_kernel(**config._asdict()):
102
- out = F.scaled_dot_product_attention(
103
- q, k, v, dropout_p=self.dropout if self.training else 0.0
104
- )
105
-
106
- return out
107
-
108
- def forward(self, q, k, v):
109
- q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
110
-
111
- scale = default(self.scale, q.shape[-1] ** -0.5)
112
-
113
- if self.flash and self.use_torch_2_sdpa:
114
- try:
115
- return self.flash_attn_torch2(q, k, v)
116
- except Exception as e:
117
- print(f"Flash attention failed: {e}. Falling back to einsum.")
118
- self.use_torch_2_sdpa = False
119
-
120
- # Fallback для PyTorch < 2.0 или если flash отключен
121
- sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
122
-
123
- attn = sim.softmax(dim=-1)
124
- attn = self.attn_dropout(attn)
125
-
126
- out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
127
-
128
  return out
 
1
+ from functools import wraps
2
+ from packaging import version
3
+ from collections import namedtuple
4
+
5
+ import os
6
+ import torch
7
+ from torch import nn, einsum
8
+ import torch.nn.functional as F
9
+
10
+ from einops import rearrange, reduce
11
+
12
+
13
+ FlashAttentionConfig = namedtuple(
14
+ "FlashAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"]
15
+ )
16
+
17
+
18
+ def exists(val):
19
+ return val is not None
20
+
21
+
22
+ def default(v, d):
23
+ return v if exists(v) else d
24
+
25
+
26
+ def once(fn):
27
+ called = False
28
+
29
+ @wraps(fn)
30
+ def inner(x):
31
+ nonlocal called
32
+ if called:
33
+ return
34
+ called = True
35
+ return fn(x)
36
+
37
+ return inner
38
+
39
+
40
+ print_once = once(print)
41
+
42
+
43
+ class Attend(nn.Module):
44
+ def __init__(self, dropout=0.0, flash=False, scale=None):
45
+ super().__init__()
46
+ self.scale = scale
47
+ self.dropout = dropout
48
+ self.attn_dropout = nn.Dropout(dropout)
49
+
50
+ self.flash = flash
51
+ self.use_torch_2_sdpa = False
52
+ self._config_checked = False
53
+
54
+ # Проверяем версию PyTorch при первом вызове
55
+ if flash and not self._config_checked:
56
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
57
+ print_once("PyTorch >= 2.0 detected, will use SDPA if available.")
58
+ self.use_torch_2_sdpa = True
59
+
60
+ # Настройки для PyTorch >= 2.0
61
+ self.cpu_config = FlashAttentionConfig(True, True, True)
62
+ self.cuda_config = None
63
+
64
+ if torch.cuda.is_available():
65
+ device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
66
+ device_version = version.parse(
67
+ f"{device_properties.major}.{device_properties.minor}"
68
+ )
69
+
70
+ if device_version >= version.parse("8.0"):
71
+ if os.name == "nt":
72
+ print_once(
73
+ "Windows OS detected, using math or mem efficient attention if input tensor is on cuda"
74
+ )
75
+ self.cuda_config = FlashAttentionConfig(False, True, True)
76
+ else:
77
+ print_once(
78
+ "GPU Compute Capability equal or above 8.0, using flash attention if input tensor is on cuda"
79
+ )
80
+ self.cuda_config = FlashAttentionConfig(True, False, False)
81
+ else:
82
+ print_once(
83
+ "GPU Compute Capability below 8.0, using math or mem efficient attention if input tensor is on cuda"
84
+ )
85
+ self.cuda_config = FlashAttentionConfig(False, True, True)
86
+ else:
87
+ print_once("PyTorch < 2.0 detected, flash attention will use einsum fallback.")
88
+ self.use_torch_2_sdpa = False
89
+
90
+ self._config_checked = True
91
+
92
+ def flash_attn_torch2(self, q, k, v):
93
+ """SDPA для PyTorch >= 2.0"""
94
+ if exists(self.scale):
95
+ default_scale = q.shape[-1] ** -0.5
96
+ q = q * (self.scale / default_scale)
97
+
98
+ is_cuda = q.is_cuda
99
+ config = self.cuda_config if is_cuda else self.cpu_config
100
+
101
+ with torch.backends.cuda.sdp_kernel(**config._asdict()):
102
+ out = F.scaled_dot_product_attention(
103
+ q, k, v, dropout_p=self.dropout if self.training else 0.0
104
+ )
105
+
106
+ return out
107
+
108
+ def forward(self, q, k, v):
109
+ q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
110
+
111
+ scale = default(self.scale, q.shape[-1] ** -0.5)
112
+
113
+ if self.flash and self.use_torch_2_sdpa:
114
+ try:
115
+ return self.flash_attn_torch2(q, k, v)
116
+ except Exception as e:
117
+ print(f"Flash attention failed: {e}. Falling back to einsum.")
118
+ self.use_torch_2_sdpa = False
119
+
120
+ # Fallback для PyTorch < 2.0 или если flash отключен
121
+ sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
122
+
123
+ attn = sim.softmax(dim=-1)
124
+ attn = self.attn_dropout(attn)
125
+
126
+ out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
127
+
128
  return out
mvsepless/models/bs_roformer/attend_sage.py CHANGED
@@ -1,147 +1,147 @@
1
- from functools import wraps
2
- from packaging import version
3
- from collections import namedtuple
4
-
5
- import os
6
- import torch
7
- from torch import nn, einsum
8
- import torch.nn.functional as F
9
-
10
- from einops import rearrange, reduce
11
-
12
-
13
- def _print_once(msg):
14
- printed = False
15
-
16
- def inner():
17
- nonlocal printed
18
- if not printed:
19
- print(msg)
20
- printed = True
21
-
22
- return inner
23
-
24
-
25
- # Проверяем доступность SageAttention
26
- try:
27
- from sageattention import sageattn
28
- _has_sage_attention = True
29
- except ImportError:
30
- _has_sage_attention = False
31
- _print_sage_not_found = _print_once(
32
- "SageAttention not found. Will fall back to PyTorch SDPA (if available) or manual einsum."
33
- )
34
- _print_sage_not_found()
35
-
36
-
37
- def exists(val):
38
- return val is not None
39
-
40
-
41
- def default(v, d):
42
- return v if exists(v) else d
43
-
44
-
45
- class Attend(nn.Module):
46
- def __init__(self, dropout=0.0, flash=False, scale=None):
47
- super().__init__()
48
- self.scale = scale
49
- self.dropout = dropout
50
-
51
- self.use_sage = flash and _has_sage_attention
52
- self.use_pytorch_sdpa = False
53
- self._sdpa_checked = False
54
- self.flash = flash
55
-
56
- # Инициализируем сообщения
57
- self._init_messages = False
58
-
59
- if flash and not self.use_sage:
60
- if not self._sdpa_checked:
61
- if version.parse(torch.__version__) >= version.parse("2.0.0"):
62
- self.use_pytorch_sdpa = True
63
- self._sdpa_checked = True
64
-
65
- self.attn_dropout = nn.Dropout(dropout)
66
-
67
- def _print_init_messages(self):
68
- """Печатаем сообщения инициализации один раз"""
69
- if self._init_messages:
70
- return
71
-
72
- if self.flash:
73
- if self.use_sage:
74
- print_once = _print_once("Using SageAttention backend.")
75
- print_once()
76
- elif self.use_pytorch_sdpa:
77
- print_once = _print_once(
78
- "Using PyTorch SDPA backend (FlashAttention-2, Memory-Efficient, or Math)."
79
- )
80
- print_once()
81
- else:
82
- print_once = _print_once(
83
- "Flash attention requested but Pytorch < 2.0 and SageAttention not found. Falling back to einsum."
84
- )
85
- print_once()
86
-
87
- self._init_messages = True
88
-
89
- def forward(self, q, k, v):
90
- q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
91
-
92
- # Печатаем сообщения инициализации при первом вызове
93
- self._print_init_messages()
94
-
95
- # Пробуем SageAttention если доступен
96
- if self.use_sage and self.flash:
97
- try:
98
- # Исправленный вызов: убрали повторный try-except
99
- out = sageattn(q, k, v, tensor_layout="HND", is_causal=False)
100
- return out
101
- except Exception as e:
102
- print(f"SageAttention failed with error: {e}. Falling back.")
103
- self.use_sage = False
104
- if not self._sdpa_checked:
105
- if version.parse(torch.__version__) >= version.parse("2.0.0"):
106
- self.use_pytorch_sdpa = True
107
- print_once = _print_once(
108
- "Falling back to PyTorch SDPA."
109
- )
110
- print_once()
111
- else:
112
- print_once = _print_once("Falling back to einsum.")
113
- print_once()
114
- self._sdpa_checked = True
115
-
116
- # Пробуем PyTorch SDPA если доступен
117
- if self.use_pytorch_sdpa and self.flash:
118
- try:
119
- # Для PyTorch >= 2.0
120
- if version.parse(torch.__version__) >= version.parse("2.0.0"):
121
- with torch.backends.cuda.sdp_kernel(
122
- enable_flash=True, enable_math=True, enable_mem_efficient=True
123
- ):
124
- out = F.scaled_dot_product_attention(
125
- q,
126
- k,
127
- v,
128
- attn_mask=None,
129
- dropout_p=self.dropout if self.training else 0.0,
130
- is_causal=False,
131
- )
132
- return out
133
- except Exception as e:
134
- print(f"PyTorch SDPA failed with error: {e}. Falling back to einsum.")
135
- self.use_pytorch_sdpa = False
136
-
137
- # Fallback на einsum (работает в PyTorch 1.13+)
138
- scale = default(self.scale, q.shape[-1] ** -0.5)
139
-
140
- sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
141
-
142
- attn = sim.softmax(dim=-1)
143
- attn = self.attn_dropout(attn)
144
-
145
- out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
146
-
147
  return out
 
1
+ from functools import wraps
2
+ from packaging import version
3
+ from collections import namedtuple
4
+
5
+ import os
6
+ import torch
7
+ from torch import nn, einsum
8
+ import torch.nn.functional as F
9
+
10
+ from einops import rearrange, reduce
11
+
12
+
13
+ def _print_once(msg):
14
+ printed = False
15
+
16
+ def inner():
17
+ nonlocal printed
18
+ if not printed:
19
+ print(msg)
20
+ printed = True
21
+
22
+ return inner
23
+
24
+
25
+ # Проверяем доступность SageAttention
26
+ try:
27
+ from sageattention import sageattn
28
+ _has_sage_attention = True
29
+ except ImportError:
30
+ _has_sage_attention = False
31
+ _print_sage_not_found = _print_once(
32
+ "SageAttention not found. Will fall back to PyTorch SDPA (if available) or manual einsum."
33
+ )
34
+ _print_sage_not_found()
35
+
36
+
37
+ def exists(val):
38
+ return val is not None
39
+
40
+
41
+ def default(v, d):
42
+ return v if exists(v) else d
43
+
44
+
45
+ class Attend(nn.Module):
46
+ def __init__(self, dropout=0.0, flash=False, scale=None):
47
+ super().__init__()
48
+ self.scale = scale
49
+ self.dropout = dropout
50
+
51
+ self.use_sage = flash and _has_sage_attention
52
+ self.use_pytorch_sdpa = False
53
+ self._sdpa_checked = False
54
+ self.flash = flash
55
+
56
+ # Инициализируем сообщения
57
+ self._init_messages = False
58
+
59
+ if flash and not self.use_sage:
60
+ if not self._sdpa_checked:
61
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
62
+ self.use_pytorch_sdpa = True
63
+ self._sdpa_checked = True
64
+
65
+ self.attn_dropout = nn.Dropout(dropout)
66
+
67
+ def _print_init_messages(self):
68
+ """Печатаем сообщения инициализации один раз"""
69
+ if self._init_messages:
70
+ return
71
+
72
+ if self.flash:
73
+ if self.use_sage:
74
+ print_once = _print_once("Using SageAttention backend.")
75
+ print_once()
76
+ elif self.use_pytorch_sdpa:
77
+ print_once = _print_once(
78
+ "Using PyTorch SDPA backend (FlashAttention-2, Memory-Efficient, or Math)."
79
+ )
80
+ print_once()
81
+ else:
82
+ print_once = _print_once(
83
+ "Flash attention requested but Pytorch < 2.0 and SageAttention not found. Falling back to einsum."
84
+ )
85
+ print_once()
86
+
87
+ self._init_messages = True
88
+
89
+ def forward(self, q, k, v):
90
+ q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
91
+
92
+ # Печатаем сообщения инициализации при первом вызове
93
+ self._print_init_messages()
94
+
95
+ # Пробуем SageAttention если доступен
96
+ if self.use_sage and self.flash:
97
+ try:
98
+ # Исправленный вызов: убрали повторный try-except
99
+ out = sageattn(q, k, v, tensor_layout="HND", is_causal=False)
100
+ return out
101
+ except Exception as e:
102
+ print(f"SageAttention failed with error: {e}. Falling back.")
103
+ self.use_sage = False
104
+ if not self._sdpa_checked:
105
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
106
+ self.use_pytorch_sdpa = True
107
+ print_once = _print_once(
108
+ "Falling back to PyTorch SDPA."
109
+ )
110
+ print_once()
111
+ else:
112
+ print_once = _print_once("Falling back to einsum.")
113
+ print_once()
114
+ self._sdpa_checked = True
115
+
116
+ # Пробуем PyTorch SDPA если доступен
117
+ if self.use_pytorch_sdpa and self.flash:
118
+ try:
119
+ # Для PyTorch >= 2.0
120
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
121
+ with torch.backends.cuda.sdp_kernel(
122
+ enable_flash=True, enable_math=True, enable_mem_efficient=True
123
+ ):
124
+ out = F.scaled_dot_product_attention(
125
+ q,
126
+ k,
127
+ v,
128
+ attn_mask=None,
129
+ dropout_p=self.dropout if self.training else 0.0,
130
+ is_causal=False,
131
+ )
132
+ return out
133
+ except Exception as e:
134
+ print(f"PyTorch SDPA failed with error: {e}. Falling back to einsum.")
135
+ self.use_pytorch_sdpa = False
136
+
137
+ # Fallback на einsum (работает в PyTorch 1.13+)
138
+ scale = default(self.scale, q.shape[-1] ** -0.5)
139
+
140
+ sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
141
+
142
+ attn = sim.softmax(dim=-1)
143
+ attn = self.attn_dropout(attn)
144
+
145
+ out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
146
+
147
  return out
mvsepless/models/bs_roformer/attend_sw.py CHANGED
@@ -1,88 +1,88 @@
1
- import logging
2
- import os
3
-
4
- import torch
5
- import torch.nn.functional as F
6
- from packaging import version
7
- from torch import Tensor, einsum, nn
8
- from torch.nn.attention import SDPBackend, sdpa_kernel
9
-
10
- logger = logging.getLogger(__name__)
11
-
12
-
13
- class Attend(nn.Module):
14
- def __init__(self, dropout: float = 0.0, flash: bool = False, scale=None):
15
- super().__init__()
16
- self.scale = scale
17
- self.dropout = dropout
18
- self.attn_dropout = nn.Dropout(dropout)
19
-
20
- self.flash = flash
21
- assert not (
22
- flash and version.parse(torch.__version__) < version.parse("2.0.0")
23
- ), "expected pytorch >= 2.0.0 to use flash attention"
24
-
25
- self.cpu_backends = [
26
- SDPBackend.FLASH_ATTENTION,
27
- SDPBackend.EFFICIENT_ATTENTION,
28
- SDPBackend.MATH,
29
- ]
30
- self.cuda_backends: list | None = None
31
-
32
- if not torch.cuda.is_available() or not flash:
33
- return
34
-
35
- device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
36
- device_version = version.parse(
37
- f"{device_properties.major}.{device_properties.minor}"
38
- )
39
-
40
- if device_version >= version.parse("8.0"):
41
- if os.name == "nt":
42
- cuda_backends = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
43
- logger.info(f"windows detected, {cuda_backends=}")
44
- else:
45
- cuda_backends = [SDPBackend.FLASH_ATTENTION]
46
- logger.info(f"gpu compute capability >= 8.0, {cuda_backends=}")
47
- else:
48
- cuda_backends = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
49
- logger.info(f"gpu compute capability < 8.0, {cuda_backends=}")
50
-
51
- self.cuda_backends = cuda_backends
52
-
53
- def flash_attn(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
54
- _, _heads, _q_len, _, _k_len, is_cuda, _device = (
55
- *q.shape,
56
- k.shape[-2],
57
- q.is_cuda,
58
- q.device,
59
- ) # type: ignore
60
-
61
- if self.scale is not None:
62
- default_scale = q.shape[-1] ** -0.5
63
- q = q * (self.scale / default_scale)
64
-
65
- backends = self.cuda_backends if is_cuda else self.cpu_backends
66
- with sdpa_kernel(backends=backends): # type: ignore
67
- out = F.scaled_dot_product_attention(
68
- q, k, v, dropout_p=self.dropout if self.training else 0.0
69
- )
70
-
71
- return out
72
-
73
- def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
74
- _q_len, _k_len, _device = q.shape[-2], k.shape[-2], q.device
75
-
76
- scale = self.scale or q.shape[-1] ** -0.5
77
-
78
- if self.flash:
79
- return self.flash_attn(q, k, v)
80
-
81
- sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale
82
-
83
- attn = sim.softmax(dim=-1)
84
- attn = self.attn_dropout(attn)
85
-
86
- out = einsum("b h i j, b h j d -> b h i d", attn, v)
87
-
88
- return out
 
1
+ import logging
2
+ import os
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from packaging import version
7
+ from torch import Tensor, einsum, nn
8
+ from torch.nn.attention import SDPBackend, sdpa_kernel
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class Attend(nn.Module):
14
+ def __init__(self, dropout: float = 0.0, flash: bool = False, scale=None):
15
+ super().__init__()
16
+ self.scale = scale
17
+ self.dropout = dropout
18
+ self.attn_dropout = nn.Dropout(dropout)
19
+
20
+ self.flash = flash
21
+ assert not (
22
+ flash and version.parse(torch.__version__) < version.parse("2.0.0")
23
+ ), "expected pytorch >= 2.0.0 to use flash attention"
24
+
25
+ self.cpu_backends = [
26
+ SDPBackend.FLASH_ATTENTION,
27
+ SDPBackend.EFFICIENT_ATTENTION,
28
+ SDPBackend.MATH,
29
+ ]
30
+ self.cuda_backends: list | None = None
31
+
32
+ if not torch.cuda.is_available() or not flash:
33
+ return
34
+
35
+ device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
36
+ device_version = version.parse(
37
+ f"{device_properties.major}.{device_properties.minor}"
38
+ )
39
+
40
+ if device_version >= version.parse("8.0"):
41
+ if os.name == "nt":
42
+ cuda_backends = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
43
+ logger.info(f"windows detected, {cuda_backends=}")
44
+ else:
45
+ cuda_backends = [SDPBackend.FLASH_ATTENTION]
46
+ logger.info(f"gpu compute capability >= 8.0, {cuda_backends=}")
47
+ else:
48
+ cuda_backends = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
49
+ logger.info(f"gpu compute capability < 8.0, {cuda_backends=}")
50
+
51
+ self.cuda_backends = cuda_backends
52
+
53
+ def flash_attn(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
54
+ _, _heads, _q_len, _, _k_len, is_cuda, _device = (
55
+ *q.shape,
56
+ k.shape[-2],
57
+ q.is_cuda,
58
+ q.device,
59
+ ) # type: ignore
60
+
61
+ if self.scale is not None:
62
+ default_scale = q.shape[-1] ** -0.5
63
+ q = q * (self.scale / default_scale)
64
+
65
+ backends = self.cuda_backends if is_cuda else self.cpu_backends
66
+ with sdpa_kernel(backends=backends): # type: ignore
67
+ out = F.scaled_dot_product_attention(
68
+ q, k, v, dropout_p=self.dropout if self.training else 0.0
69
+ )
70
+
71
+ return out
72
+
73
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
74
+ _q_len, _k_len, _device = q.shape[-2], k.shape[-2], q.device
75
+
76
+ scale = self.scale or q.shape[-1] ** -0.5
77
+
78
+ if self.flash:
79
+ return self.flash_attn(q, k, v)
80
+
81
+ sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale
82
+
83
+ attn = sim.softmax(dim=-1)
84
+ attn = self.attn_dropout(attn)
85
+
86
+ out = einsum("b h i j, b h j d -> b h i d", attn, v)
87
+
88
+ return out
mvsepless/models/bs_roformer/bs_roformer.py CHANGED
@@ -1,696 +1,696 @@
1
- from functools import partial
2
-
3
- import torch
4
- from torch import nn, einsum, tensor, Tensor
5
- from torch.nn import Module, ModuleList
6
- import torch.nn.functional as F
7
-
8
- from .attend import Attend
9
-
10
- try:
11
- from .attend_sage import Attend as AttendSage
12
- except:
13
- pass
14
- from torch.utils.checkpoint import checkpoint
15
-
16
- from beartype.typing import Tuple, Optional, List, Callable
17
- from beartype import beartype
18
-
19
- from rotary_embedding_torch import RotaryEmbedding
20
-
21
- from einops import rearrange, pack, unpack
22
- from einops.layers.torch import Rearrange
23
-
24
-
25
- def exists(val):
26
- return val is not None
27
-
28
-
29
- def default(v, d):
30
- return v if exists(v) else d
31
-
32
-
33
- def pack_one(t, pattern):
34
- return pack([t], pattern)
35
-
36
-
37
- def unpack_one(t, ps, pattern):
38
- return unpack(t, ps, pattern)[0]
39
-
40
-
41
- def l2norm(t):
42
- return F.normalize(t, dim=-1, p=2)
43
-
44
-
45
- class RMSNorm(Module):
46
- def __init__(self, dim):
47
- super().__init__()
48
- self.scale = dim**0.5
49
- self.gamma = nn.Parameter(torch.ones(dim))
50
-
51
- def forward(self, x):
52
- return F.normalize(x, dim=-1) * self.scale * self.gamma
53
-
54
-
55
- class FeedForward(Module):
56
- def __init__(self, dim, mult=4, dropout=0.0):
57
- super().__init__()
58
- dim_inner = int(dim * mult)
59
- self.net = nn.Sequential(
60
- RMSNorm(dim),
61
- nn.Linear(dim, dim_inner),
62
- nn.GELU(),
63
- nn.Dropout(dropout),
64
- nn.Linear(dim_inner, dim),
65
- nn.Dropout(dropout),
66
- )
67
-
68
- def forward(self, x):
69
- return self.net(x)
70
-
71
-
72
- class Attention(Module):
73
- def __init__(
74
- self,
75
- dim,
76
- heads=8,
77
- dim_head=64,
78
- dropout=0.0,
79
- rotary_embed=None,
80
- flash=True,
81
- sage_attention=False,
82
- ):
83
- super().__init__()
84
- self.heads = heads
85
- self.scale = dim_head**-0.5
86
- dim_inner = heads * dim_head
87
-
88
- self.rotary_embed = rotary_embed
89
-
90
- if sage_attention:
91
- self.attend = AttendSage(flash=flash, dropout=dropout)
92
- else:
93
- self.attend = Attend(flash=flash, dropout=dropout)
94
-
95
- self.norm = RMSNorm(dim)
96
- self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
97
-
98
- self.to_gates = nn.Linear(dim, heads)
99
-
100
- self.to_out = nn.Sequential(
101
- nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout)
102
- )
103
-
104
- def forward(self, x):
105
- x = self.norm(x)
106
-
107
- q, k, v = rearrange(
108
- self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads
109
- )
110
-
111
- if exists(self.rotary_embed):
112
- q = self.rotary_embed.rotate_queries_or_keys(q)
113
- k = self.rotary_embed.rotate_queries_or_keys(k)
114
-
115
- out = self.attend(q, k, v)
116
-
117
- gates = self.to_gates(x)
118
- out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
119
-
120
- out = rearrange(out, "b h n d -> b n (h d)")
121
- return self.to_out(out)
122
-
123
-
124
- class LinearAttention(Module):
125
-
126
- @beartype
127
- def __init__(
128
- self,
129
- *,
130
- dim,
131
- dim_head=32,
132
- heads=8,
133
- scale=8,
134
- flash=True,
135
- dropout=0.0,
136
- sage_attention=False,
137
- ):
138
- super().__init__()
139
- dim_inner = dim_head * heads
140
- self.norm = RMSNorm(dim)
141
-
142
- self.to_qkv = nn.Sequential(
143
- nn.Linear(dim, dim_inner * 3, bias=False),
144
- Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads),
145
- )
146
-
147
- self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
148
-
149
- if sage_attention:
150
- self.attend = AttendSage(scale=scale, dropout=dropout, flash=flash)
151
- else:
152
- self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
153
-
154
- self.to_out = nn.Sequential(
155
- Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)
156
- )
157
-
158
- def forward(self, x):
159
- x = self.norm(x)
160
-
161
- q, k, v = self.to_qkv(x)
162
-
163
- q, k = map(l2norm, (q, k))
164
- q = q * self.temperature.exp()
165
-
166
- out = self.attend(q, k, v)
167
-
168
- return self.to_out(out)
169
-
170
-
171
- class Transformer(Module):
172
- def __init__(
173
- self,
174
- *,
175
- dim,
176
- depth,
177
- dim_head=64,
178
- heads=8,
179
- attn_dropout=0.0,
180
- ff_dropout=0.0,
181
- ff_mult=4,
182
- norm_output=True,
183
- rotary_embed=None,
184
- flash_attn=True,
185
- linear_attn=False,
186
- sage_attention=False,
187
- ):
188
- super().__init__()
189
- self.layers = ModuleList([])
190
-
191
- for _ in range(depth):
192
- if linear_attn:
193
- attn = LinearAttention(
194
- dim=dim,
195
- dim_head=dim_head,
196
- heads=heads,
197
- dropout=attn_dropout,
198
- flash=flash_attn,
199
- sage_attention=sage_attention,
200
- )
201
- else:
202
- attn = Attention(
203
- dim=dim,
204
- dim_head=dim_head,
205
- heads=heads,
206
- dropout=attn_dropout,
207
- rotary_embed=rotary_embed,
208
- flash=flash_attn,
209
- sage_attention=sage_attention,
210
- )
211
-
212
- self.layers.append(
213
- ModuleList(
214
- [attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]
215
- )
216
- )
217
-
218
- self.norm = RMSNorm(dim) if norm_output else nn.Identity()
219
-
220
- def forward(self, x):
221
-
222
- for attn, ff in self.layers:
223
- x = attn(x) + x
224
- x = ff(x) + x
225
-
226
- return self.norm(x)
227
-
228
-
229
- class BandSplit(Module):
230
- @beartype
231
- def __init__(self, dim, dim_inputs: Tuple[int, ...]):
232
- super().__init__()
233
- self.dim_inputs = dim_inputs
234
- self.to_features = ModuleList([])
235
-
236
- for dim_in in dim_inputs:
237
- net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
238
-
239
- self.to_features.append(net)
240
-
241
- def forward(self, x):
242
- x = x.split(self.dim_inputs, dim=-1)
243
-
244
- outs = []
245
- for split_input, to_feature in zip(x, self.to_features):
246
- split_output = to_feature(split_input)
247
- outs.append(split_output)
248
-
249
- return torch.stack(outs, dim=-2)
250
-
251
-
252
- def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
253
- dim_hidden = default(dim_hidden, dim_in)
254
-
255
- net = []
256
- dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
257
-
258
- for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
259
- is_last = ind == (len(dims) - 2)
260
-
261
- net.append(nn.Linear(layer_dim_in, layer_dim_out))
262
-
263
- if is_last:
264
- continue
265
-
266
- net.append(activation())
267
-
268
- return nn.Sequential(*net)
269
-
270
-
271
- class MaskEstimator(Module):
272
- @beartype
273
- def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
274
- super().__init__()
275
- self.dim_inputs = dim_inputs
276
- self.to_freqs = ModuleList([])
277
- dim_hidden = dim * mlp_expansion_factor
278
-
279
- for dim_in in dim_inputs:
280
- net = []
281
-
282
- mlp = nn.Sequential(
283
- MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1)
284
- )
285
-
286
- self.to_freqs.append(mlp)
287
-
288
- def forward(self, x):
289
- x = x.unbind(dim=-2)
290
-
291
- outs = []
292
-
293
- for band_features, mlp in zip(x, self.to_freqs):
294
- freq_out = mlp(band_features)
295
- outs.append(freq_out)
296
-
297
- return torch.cat(outs, dim=-1)
298
-
299
-
300
- DEFAULT_FREQS_PER_BANDS = (
301
- 2,
302
- 2,
303
- 2,
304
- 2,
305
- 2,
306
- 2,
307
- 2,
308
- 2,
309
- 2,
310
- 2,
311
- 2,
312
- 2,
313
- 2,
314
- 2,
315
- 2,
316
- 2,
317
- 2,
318
- 2,
319
- 2,
320
- 2,
321
- 2,
322
- 2,
323
- 2,
324
- 2,
325
- 4,
326
- 4,
327
- 4,
328
- 4,
329
- 4,
330
- 4,
331
- 4,
332
- 4,
333
- 4,
334
- 4,
335
- 4,
336
- 4,
337
- 12,
338
- 12,
339
- 12,
340
- 12,
341
- 12,
342
- 12,
343
- 12,
344
- 12,
345
- 24,
346
- 24,
347
- 24,
348
- 24,
349
- 24,
350
- 24,
351
- 24,
352
- 24,
353
- 48,
354
- 48,
355
- 48,
356
- 48,
357
- 48,
358
- 48,
359
- 48,
360
- 48,
361
- 128,
362
- 129,
363
- )
364
-
365
-
366
- class BSRoformer(Module):
367
-
368
- @beartype
369
- def __init__(
370
- self,
371
- dim,
372
- *,
373
- depth,
374
- stereo=False,
375
- num_stems=1,
376
- time_transformer_depth=2,
377
- freq_transformer_depth=2,
378
- linear_transformer_depth=0,
379
- freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
380
- dim_head=64,
381
- heads=8,
382
- attn_dropout=0.0,
383
- ff_dropout=0.0,
384
- flash_attn=True,
385
- dim_freqs_in=1025,
386
- stft_n_fft=2048,
387
- stft_hop_length=512,
388
- stft_win_length=2048,
389
- stft_normalized=False,
390
- stft_window_fn: Optional[Callable] = None,
391
- zero_dc=True,
392
- mask_estimator_depth=2,
393
- multi_stft_resolution_loss_weight=1.0,
394
- multi_stft_resolutions_window_sizes: Tuple[int, ...] = (
395
- 4096,
396
- 2048,
397
- 1024,
398
- 512,
399
- 256,
400
- ),
401
- multi_stft_hop_size=147,
402
- multi_stft_normalized=False,
403
- multi_stft_window_fn: Callable = torch.hann_window,
404
- mlp_expansion_factor=4,
405
- use_torch_checkpoint=False,
406
- skip_connection=False,
407
- sage_attention=False,
408
- ):
409
- super().__init__()
410
-
411
- self.stereo = stereo
412
- self.audio_channels = 2 if stereo else 1
413
- self.num_stems = num_stems
414
- self.use_torch_checkpoint = use_torch_checkpoint
415
- self.skip_connection = skip_connection
416
-
417
- self.layers = ModuleList([])
418
-
419
- if sage_attention:
420
- print("Use Sage Attention")
421
-
422
- transformer_kwargs = dict(
423
- dim=dim,
424
- heads=heads,
425
- dim_head=dim_head,
426
- attn_dropout=attn_dropout,
427
- ff_dropout=ff_dropout,
428
- flash_attn=flash_attn,
429
- norm_output=False,
430
- sage_attention=sage_attention,
431
- )
432
-
433
- time_rotary_embed = RotaryEmbedding(dim=dim_head)
434
- freq_rotary_embed = RotaryEmbedding(dim=dim_head)
435
-
436
- for _ in range(depth):
437
- tran_modules = []
438
- if linear_transformer_depth > 0:
439
- tran_modules.append(
440
- Transformer(
441
- depth=linear_transformer_depth,
442
- linear_attn=True,
443
- **transformer_kwargs,
444
- )
445
- )
446
- tran_modules.append(
447
- Transformer(
448
- depth=time_transformer_depth,
449
- rotary_embed=time_rotary_embed,
450
- **transformer_kwargs,
451
- )
452
- )
453
- tran_modules.append(
454
- Transformer(
455
- depth=freq_transformer_depth,
456
- rotary_embed=freq_rotary_embed,
457
- **transformer_kwargs,
458
- )
459
- )
460
- self.layers.append(nn.ModuleList(tran_modules))
461
-
462
- self.final_norm = RMSNorm(dim)
463
-
464
- self.stft_kwargs = dict(
465
- n_fft=stft_n_fft,
466
- hop_length=stft_hop_length,
467
- win_length=stft_win_length,
468
- normalized=stft_normalized,
469
- )
470
-
471
- self.stft_window_fn = partial(
472
- default(stft_window_fn, torch.hann_window), stft_win_length
473
- )
474
-
475
- freqs = torch.stft(
476
- torch.randn(1, 4096),
477
- **self.stft_kwargs,
478
- window=torch.ones(stft_win_length),
479
- return_complex=True,
480
- ).shape[1]
481
-
482
- assert len(freqs_per_bands) > 1
483
- assert (
484
- sum(freqs_per_bands) == freqs
485
- ), f"the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}"
486
-
487
- freqs_per_bands_with_complex = tuple(
488
- 2 * f * self.audio_channels for f in freqs_per_bands
489
- )
490
-
491
- self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
492
-
493
- self.mask_estimators = nn.ModuleList([])
494
-
495
- for _ in range(num_stems):
496
- mask_estimator = MaskEstimator(
497
- dim=dim,
498
- dim_inputs=freqs_per_bands_with_complex,
499
- depth=mask_estimator_depth,
500
- mlp_expansion_factor=mlp_expansion_factor,
501
- )
502
-
503
- self.mask_estimators.append(mask_estimator)
504
-
505
- self.zero_dc = zero_dc
506
-
507
- self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
508
- self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
509
- self.multi_stft_n_fft = stft_n_fft
510
- self.multi_stft_window_fn = multi_stft_window_fn
511
-
512
- self.multi_stft_kwargs = dict(
513
- hop_length=multi_stft_hop_size, normalized=multi_stft_normalized
514
- )
515
-
516
- def forward(self, raw_audio, target=None, return_loss_breakdown=False):
517
-
518
- device = raw_audio.device
519
-
520
- x_is_mps = True if device.type == "mps" else False
521
-
522
- if raw_audio.ndim == 2:
523
- raw_audio = rearrange(raw_audio, "b t -> b 1 t")
524
-
525
- channels = raw_audio.shape[1]
526
- assert (not self.stereo and channels == 1) or (
527
- self.stereo and channels == 2
528
- ), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
529
-
530
- raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
531
-
532
- stft_window = self.stft_window_fn(device=device)
533
-
534
- try:
535
- stft_repr = torch.stft(
536
- raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True
537
- )
538
- except:
539
- stft_repr = torch.stft(
540
- raw_audio.cpu() if x_is_mps else raw_audio,
541
- **self.stft_kwargs,
542
- window=stft_window.cpu() if x_is_mps else stft_window,
543
- return_complex=True,
544
- ).to(device)
545
- stft_repr = torch.view_as_real(stft_repr)
546
-
547
- stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
548
-
549
- stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c")
550
-
551
- x = rearrange(stft_repr, "b f t c -> b t (f c)")
552
-
553
- if self.use_torch_checkpoint:
554
- x = checkpoint(self.band_split, x, use_reentrant=False)
555
- else:
556
- x = self.band_split(x)
557
-
558
- store = [None] * len(self.layers)
559
- for i, transformer_block in enumerate(self.layers):
560
-
561
- if len(transformer_block) == 3:
562
- linear_transformer, time_transformer, freq_transformer = (
563
- transformer_block
564
- )
565
-
566
- x, ft_ps = pack([x], "b * d")
567
- if self.use_torch_checkpoint:
568
- x = checkpoint(linear_transformer, x, use_reentrant=False)
569
- else:
570
- x = linear_transformer(x)
571
- (x,) = unpack(x, ft_ps, "b * d")
572
- else:
573
- time_transformer, freq_transformer = transformer_block
574
-
575
- if self.skip_connection:
576
- for j in range(i):
577
- x = x + store[j]
578
-
579
- x = rearrange(x, "b t f d -> b f t d")
580
- x, ps = pack([x], "* t d")
581
-
582
- if self.use_torch_checkpoint:
583
- x = checkpoint(time_transformer, x, use_reentrant=False)
584
- else:
585
- x = time_transformer(x)
586
-
587
- (x,) = unpack(x, ps, "* t d")
588
- x = rearrange(x, "b f t d -> b t f d")
589
- x, ps = pack([x], "* f d")
590
-
591
- if self.use_torch_checkpoint:
592
- x = checkpoint(freq_transformer, x, use_reentrant=False)
593
- else:
594
- x = freq_transformer(x)
595
-
596
- (x,) = unpack(x, ps, "* f d")
597
-
598
- if self.skip_connection:
599
- store[i] = x
600
-
601
- x = self.final_norm(x)
602
-
603
- num_stems = len(self.mask_estimators)
604
-
605
- if self.use_torch_checkpoint:
606
- mask = torch.stack(
607
- [checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators],
608
- dim=1,
609
- )
610
- else:
611
- mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
612
- mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2)
613
-
614
- stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
615
-
616
- stft_repr = torch.view_as_complex(stft_repr)
617
- mask = torch.view_as_complex(mask)
618
-
619
- stft_repr = stft_repr * mask
620
-
621
- stft_repr = rearrange(
622
- stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels
623
- )
624
-
625
- if self.zero_dc:
626
- stft_repr = stft_repr.index_fill(1, tensor(0, device=device), 0.0)
627
-
628
- try:
629
- recon_audio = torch.istft(
630
- stft_repr,
631
- **self.stft_kwargs,
632
- window=stft_window,
633
- return_complex=False,
634
- length=raw_audio.shape[-1],
635
- )
636
- except:
637
- recon_audio = torch.istft(
638
- stft_repr.cpu() if x_is_mps else stft_repr,
639
- **self.stft_kwargs,
640
- window=stft_window.cpu() if x_is_mps else stft_window,
641
- return_complex=False,
642
- length=raw_audio.shape[-1],
643
- ).to(device)
644
-
645
- recon_audio = rearrange(
646
- recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=num_stems
647
- )
648
-
649
- if num_stems == 1:
650
- recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
651
-
652
- if not exists(target):
653
- return recon_audio
654
-
655
- if self.num_stems > 1:
656
- assert target.ndim == 4 and target.shape[1] == self.num_stems
657
-
658
- if target.ndim == 2:
659
- target = rearrange(target, "... t -> ... 1 t")
660
-
661
- target = target[..., : recon_audio.shape[-1]]
662
-
663
- loss = F.l1_loss(recon_audio, target)
664
-
665
- multi_stft_resolution_loss = 0.0
666
-
667
- for window_size in self.multi_stft_resolutions_window_sizes:
668
- res_stft_kwargs = dict(
669
- n_fft=max(window_size, self.multi_stft_n_fft),
670
- win_length=window_size,
671
- return_complex=True,
672
- window=self.multi_stft_window_fn(window_size, device=device),
673
- **self.multi_stft_kwargs,
674
- )
675
-
676
- recon_Y = torch.stft(
677
- rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs
678
- )
679
- target_Y = torch.stft(
680
- rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs
681
- )
682
-
683
- multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(
684
- recon_Y, target_Y
685
- )
686
-
687
- weighted_multi_resolution_loss = (
688
- multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
689
- )
690
-
691
- total_loss = loss + weighted_multi_resolution_loss
692
-
693
- if not return_loss_breakdown:
694
- return total_loss
695
-
696
- return total_loss, (loss, multi_stft_resolution_loss)
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ from torch import nn, einsum, tensor, Tensor
5
+ from torch.nn import Module, ModuleList
6
+ import torch.nn.functional as F
7
+
8
+ from .attend import Attend
9
+
10
+ try:
11
+ from .attend_sage import Attend as AttendSage
12
+ except:
13
+ pass
14
+ from torch.utils.checkpoint import checkpoint
15
+
16
+ from beartype.typing import Tuple, Optional, List, Callable
17
+ from beartype import beartype
18
+
19
+ from rotary_embedding_torch import RotaryEmbedding
20
+
21
+ from einops import rearrange, pack, unpack
22
+ from einops.layers.torch import Rearrange
23
+
24
+
25
+ def exists(val):
26
+ return val is not None
27
+
28
+
29
+ def default(v, d):
30
+ return v if exists(v) else d
31
+
32
+
33
+ def pack_one(t, pattern):
34
+ return pack([t], pattern)
35
+
36
+
37
+ def unpack_one(t, ps, pattern):
38
+ return unpack(t, ps, pattern)[0]
39
+
40
+
41
+ def l2norm(t):
42
+ return F.normalize(t, dim=-1, p=2)
43
+
44
+
45
+ class RMSNorm(Module):
46
+ def __init__(self, dim):
47
+ super().__init__()
48
+ self.scale = dim**0.5
49
+ self.gamma = nn.Parameter(torch.ones(dim))
50
+
51
+ def forward(self, x):
52
+ return F.normalize(x, dim=-1) * self.scale * self.gamma
53
+
54
+
55
+ class FeedForward(Module):
56
+ def __init__(self, dim, mult=4, dropout=0.0):
57
+ super().__init__()
58
+ dim_inner = int(dim * mult)
59
+ self.net = nn.Sequential(
60
+ RMSNorm(dim),
61
+ nn.Linear(dim, dim_inner),
62
+ nn.GELU(),
63
+ nn.Dropout(dropout),
64
+ nn.Linear(dim_inner, dim),
65
+ nn.Dropout(dropout),
66
+ )
67
+
68
+ def forward(self, x):
69
+ return self.net(x)
70
+
71
+
72
+ class Attention(Module):
73
+ def __init__(
74
+ self,
75
+ dim,
76
+ heads=8,
77
+ dim_head=64,
78
+ dropout=0.0,
79
+ rotary_embed=None,
80
+ flash=True,
81
+ sage_attention=False,
82
+ ):
83
+ super().__init__()
84
+ self.heads = heads
85
+ self.scale = dim_head**-0.5
86
+ dim_inner = heads * dim_head
87
+
88
+ self.rotary_embed = rotary_embed
89
+
90
+ if sage_attention:
91
+ self.attend = AttendSage(flash=flash, dropout=dropout)
92
+ else:
93
+ self.attend = Attend(flash=flash, dropout=dropout)
94
+
95
+ self.norm = RMSNorm(dim)
96
+ self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
97
+
98
+ self.to_gates = nn.Linear(dim, heads)
99
+
100
+ self.to_out = nn.Sequential(
101
+ nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout)
102
+ )
103
+
104
+ def forward(self, x):
105
+ x = self.norm(x)
106
+
107
+ q, k, v = rearrange(
108
+ self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads
109
+ )
110
+
111
+ if exists(self.rotary_embed):
112
+ q = self.rotary_embed.rotate_queries_or_keys(q)
113
+ k = self.rotary_embed.rotate_queries_or_keys(k)
114
+
115
+ out = self.attend(q, k, v)
116
+
117
+ gates = self.to_gates(x)
118
+ out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
119
+
120
+ out = rearrange(out, "b h n d -> b n (h d)")
121
+ return self.to_out(out)
122
+
123
+
124
+ class LinearAttention(Module):
125
+
126
+ @beartype
127
+ def __init__(
128
+ self,
129
+ *,
130
+ dim,
131
+ dim_head=32,
132
+ heads=8,
133
+ scale=8,
134
+ flash=True,
135
+ dropout=0.0,
136
+ sage_attention=False,
137
+ ):
138
+ super().__init__()
139
+ dim_inner = dim_head * heads
140
+ self.norm = RMSNorm(dim)
141
+
142
+ self.to_qkv = nn.Sequential(
143
+ nn.Linear(dim, dim_inner * 3, bias=False),
144
+ Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads),
145
+ )
146
+
147
+ self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
148
+
149
+ if sage_attention:
150
+ self.attend = AttendSage(scale=scale, dropout=dropout, flash=flash)
151
+ else:
152
+ self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
153
+
154
+ self.to_out = nn.Sequential(
155
+ Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)
156
+ )
157
+
158
+ def forward(self, x):
159
+ x = self.norm(x)
160
+
161
+ q, k, v = self.to_qkv(x)
162
+
163
+ q, k = map(l2norm, (q, k))
164
+ q = q * self.temperature.exp()
165
+
166
+ out = self.attend(q, k, v)
167
+
168
+ return self.to_out(out)
169
+
170
+
171
+ class Transformer(Module):
172
+ def __init__(
173
+ self,
174
+ *,
175
+ dim,
176
+ depth,
177
+ dim_head=64,
178
+ heads=8,
179
+ attn_dropout=0.0,
180
+ ff_dropout=0.0,
181
+ ff_mult=4,
182
+ norm_output=True,
183
+ rotary_embed=None,
184
+ flash_attn=True,
185
+ linear_attn=False,
186
+ sage_attention=False,
187
+ ):
188
+ super().__init__()
189
+ self.layers = ModuleList([])
190
+
191
+ for _ in range(depth):
192
+ if linear_attn:
193
+ attn = LinearAttention(
194
+ dim=dim,
195
+ dim_head=dim_head,
196
+ heads=heads,
197
+ dropout=attn_dropout,
198
+ flash=flash_attn,
199
+ sage_attention=sage_attention,
200
+ )
201
+ else:
202
+ attn = Attention(
203
+ dim=dim,
204
+ dim_head=dim_head,
205
+ heads=heads,
206
+ dropout=attn_dropout,
207
+ rotary_embed=rotary_embed,
208
+ flash=flash_attn,
209
+ sage_attention=sage_attention,
210
+ )
211
+
212
+ self.layers.append(
213
+ ModuleList(
214
+ [attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]
215
+ )
216
+ )
217
+
218
+ self.norm = RMSNorm(dim) if norm_output else nn.Identity()
219
+
220
+ def forward(self, x):
221
+
222
+ for attn, ff in self.layers:
223
+ x = attn(x) + x
224
+ x = ff(x) + x
225
+
226
+ return self.norm(x)
227
+
228
+
229
+ class BandSplit(Module):
230
+ @beartype
231
+ def __init__(self, dim, dim_inputs: Tuple[int, ...]):
232
+ super().__init__()
233
+ self.dim_inputs = dim_inputs
234
+ self.to_features = ModuleList([])
235
+
236
+ for dim_in in dim_inputs:
237
+ net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
238
+
239
+ self.to_features.append(net)
240
+
241
+ def forward(self, x):
242
+ x = x.split(self.dim_inputs, dim=-1)
243
+
244
+ outs = []
245
+ for split_input, to_feature in zip(x, self.to_features):
246
+ split_output = to_feature(split_input)
247
+ outs.append(split_output)
248
+
249
+ return torch.stack(outs, dim=-2)
250
+
251
+
252
+ def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
253
+ dim_hidden = default(dim_hidden, dim_in)
254
+
255
+ net = []
256
+ dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
257
+
258
+ for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
259
+ is_last = ind == (len(dims) - 2)
260
+
261
+ net.append(nn.Linear(layer_dim_in, layer_dim_out))
262
+
263
+ if is_last:
264
+ continue
265
+
266
+ net.append(activation())
267
+
268
+ return nn.Sequential(*net)
269
+
270
+
271
+ class MaskEstimator(Module):
272
+ @beartype
273
+ def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
274
+ super().__init__()
275
+ self.dim_inputs = dim_inputs
276
+ self.to_freqs = ModuleList([])
277
+ dim_hidden = dim * mlp_expansion_factor
278
+
279
+ for dim_in in dim_inputs:
280
+ net = []
281
+
282
+ mlp = nn.Sequential(
283
+ MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1)
284
+ )
285
+
286
+ self.to_freqs.append(mlp)
287
+
288
+ def forward(self, x):
289
+ x = x.unbind(dim=-2)
290
+
291
+ outs = []
292
+
293
+ for band_features, mlp in zip(x, self.to_freqs):
294
+ freq_out = mlp(band_features)
295
+ outs.append(freq_out)
296
+
297
+ return torch.cat(outs, dim=-1)
298
+
299
+
300
+ DEFAULT_FREQS_PER_BANDS = (
301
+ 2,
302
+ 2,
303
+ 2,
304
+ 2,
305
+ 2,
306
+ 2,
307
+ 2,
308
+ 2,
309
+ 2,
310
+ 2,
311
+ 2,
312
+ 2,
313
+ 2,
314
+ 2,
315
+ 2,
316
+ 2,
317
+ 2,
318
+ 2,
319
+ 2,
320
+ 2,
321
+ 2,
322
+ 2,
323
+ 2,
324
+ 2,
325
+ 4,
326
+ 4,
327
+ 4,
328
+ 4,
329
+ 4,
330
+ 4,
331
+ 4,
332
+ 4,
333
+ 4,
334
+ 4,
335
+ 4,
336
+ 4,
337
+ 12,
338
+ 12,
339
+ 12,
340
+ 12,
341
+ 12,
342
+ 12,
343
+ 12,
344
+ 12,
345
+ 24,
346
+ 24,
347
+ 24,
348
+ 24,
349
+ 24,
350
+ 24,
351
+ 24,
352
+ 24,
353
+ 48,
354
+ 48,
355
+ 48,
356
+ 48,
357
+ 48,
358
+ 48,
359
+ 48,
360
+ 48,
361
+ 128,
362
+ 129,
363
+ )
364
+
365
+
366
+ class BSRoformer(Module):
367
+
368
+ @beartype
369
+ def __init__(
370
+ self,
371
+ dim,
372
+ *,
373
+ depth,
374
+ stereo=False,
375
+ num_stems=1,
376
+ time_transformer_depth=2,
377
+ freq_transformer_depth=2,
378
+ linear_transformer_depth=0,
379
+ freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
380
+ dim_head=64,
381
+ heads=8,
382
+ attn_dropout=0.0,
383
+ ff_dropout=0.0,
384
+ flash_attn=True,
385
+ dim_freqs_in=1025,
386
+ stft_n_fft=2048,
387
+ stft_hop_length=512,
388
+ stft_win_length=2048,
389
+ stft_normalized=False,
390
+ stft_window_fn: Optional[Callable] = None,
391
+ zero_dc=True,
392
+ mask_estimator_depth=2,
393
+ multi_stft_resolution_loss_weight=1.0,
394
+ multi_stft_resolutions_window_sizes: Tuple[int, ...] = (
395
+ 4096,
396
+ 2048,
397
+ 1024,
398
+ 512,
399
+ 256,
400
+ ),
401
+ multi_stft_hop_size=147,
402
+ multi_stft_normalized=False,
403
+ multi_stft_window_fn: Callable = torch.hann_window,
404
+ mlp_expansion_factor=4,
405
+ use_torch_checkpoint=False,
406
+ skip_connection=False,
407
+ sage_attention=False,
408
+ ):
409
+ super().__init__()
410
+
411
+ self.stereo = stereo
412
+ self.audio_channels = 2 if stereo else 1
413
+ self.num_stems = num_stems
414
+ self.use_torch_checkpoint = use_torch_checkpoint
415
+ self.skip_connection = skip_connection
416
+
417
+ self.layers = ModuleList([])
418
+
419
+ if sage_attention:
420
+ print("Use Sage Attention")
421
+
422
+ transformer_kwargs = dict(
423
+ dim=dim,
424
+ heads=heads,
425
+ dim_head=dim_head,
426
+ attn_dropout=attn_dropout,
427
+ ff_dropout=ff_dropout,
428
+ flash_attn=flash_attn,
429
+ norm_output=False,
430
+ sage_attention=sage_attention,
431
+ )
432
+
433
+ time_rotary_embed = RotaryEmbedding(dim=dim_head)
434
+ freq_rotary_embed = RotaryEmbedding(dim=dim_head)
435
+
436
+ for _ in range(depth):
437
+ tran_modules = []
438
+ if linear_transformer_depth > 0:
439
+ tran_modules.append(
440
+ Transformer(
441
+ depth=linear_transformer_depth,
442
+ linear_attn=True,
443
+ **transformer_kwargs,
444
+ )
445
+ )
446
+ tran_modules.append(
447
+ Transformer(
448
+ depth=time_transformer_depth,
449
+ rotary_embed=time_rotary_embed,
450
+ **transformer_kwargs,
451
+ )
452
+ )
453
+ tran_modules.append(
454
+ Transformer(
455
+ depth=freq_transformer_depth,
456
+ rotary_embed=freq_rotary_embed,
457
+ **transformer_kwargs,
458
+ )
459
+ )
460
+ self.layers.append(nn.ModuleList(tran_modules))
461
+
462
+ self.final_norm = RMSNorm(dim)
463
+
464
+ self.stft_kwargs = dict(
465
+ n_fft=stft_n_fft,
466
+ hop_length=stft_hop_length,
467
+ win_length=stft_win_length,
468
+ normalized=stft_normalized,
469
+ )
470
+
471
+ self.stft_window_fn = partial(
472
+ default(stft_window_fn, torch.hann_window), stft_win_length
473
+ )
474
+
475
+ freqs = torch.stft(
476
+ torch.randn(1, 4096),
477
+ **self.stft_kwargs,
478
+ window=torch.ones(stft_win_length),
479
+ return_complex=True,
480
+ ).shape[1]
481
+
482
+ assert len(freqs_per_bands) > 1
483
+ assert (
484
+ sum(freqs_per_bands) == freqs
485
+ ), f"the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}"
486
+
487
+ freqs_per_bands_with_complex = tuple(
488
+ 2 * f * self.audio_channels for f in freqs_per_bands
489
+ )
490
+
491
+ self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
492
+
493
+ self.mask_estimators = nn.ModuleList([])
494
+
495
+ for _ in range(num_stems):
496
+ mask_estimator = MaskEstimator(
497
+ dim=dim,
498
+ dim_inputs=freqs_per_bands_with_complex,
499
+ depth=mask_estimator_depth,
500
+ mlp_expansion_factor=mlp_expansion_factor,
501
+ )
502
+
503
+ self.mask_estimators.append(mask_estimator)
504
+
505
+ self.zero_dc = zero_dc
506
+
507
+ self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
508
+ self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
509
+ self.multi_stft_n_fft = stft_n_fft
510
+ self.multi_stft_window_fn = multi_stft_window_fn
511
+
512
+ self.multi_stft_kwargs = dict(
513
+ hop_length=multi_stft_hop_size, normalized=multi_stft_normalized
514
+ )
515
+
516
+ def forward(self, raw_audio, target=None, return_loss_breakdown=False):
517
+
518
+ device = raw_audio.device
519
+
520
+ x_is_mps = True if device.type == "mps" else False
521
+
522
+ if raw_audio.ndim == 2:
523
+ raw_audio = rearrange(raw_audio, "b t -> b 1 t")
524
+
525
+ channels = raw_audio.shape[1]
526
+ assert (not self.stereo and channels == 1) or (
527
+ self.stereo and channels == 2
528
+ ), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
529
+
530
+ raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
531
+
532
+ stft_window = self.stft_window_fn(device=device)
533
+
534
+ try:
535
+ stft_repr = torch.stft(
536
+ raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True
537
+ )
538
+ except:
539
+ stft_repr = torch.stft(
540
+ raw_audio.cpu() if x_is_mps else raw_audio,
541
+ **self.stft_kwargs,
542
+ window=stft_window.cpu() if x_is_mps else stft_window,
543
+ return_complex=True,
544
+ ).to(device)
545
+ stft_repr = torch.view_as_real(stft_repr)
546
+
547
+ stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
548
+
549
+ stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c")
550
+
551
+ x = rearrange(stft_repr, "b f t c -> b t (f c)")
552
+
553
+ if self.use_torch_checkpoint:
554
+ x = checkpoint(self.band_split, x, use_reentrant=False)
555
+ else:
556
+ x = self.band_split(x)
557
+
558
+ store = [None] * len(self.layers)
559
+ for i, transformer_block in enumerate(self.layers):
560
+
561
+ if len(transformer_block) == 3:
562
+ linear_transformer, time_transformer, freq_transformer = (
563
+ transformer_block
564
+ )
565
+
566
+ x, ft_ps = pack([x], "b * d")
567
+ if self.use_torch_checkpoint:
568
+ x = checkpoint(linear_transformer, x, use_reentrant=False)
569
+ else:
570
+ x = linear_transformer(x)
571
+ (x,) = unpack(x, ft_ps, "b * d")
572
+ else:
573
+ time_transformer, freq_transformer = transformer_block
574
+
575
+ if self.skip_connection:
576
+ for j in range(i):
577
+ x = x + store[j]
578
+
579
+ x = rearrange(x, "b t f d -> b f t d")
580
+ x, ps = pack([x], "* t d")
581
+
582
+ if self.use_torch_checkpoint:
583
+ x = checkpoint(time_transformer, x, use_reentrant=False)
584
+ else:
585
+ x = time_transformer(x)
586
+
587
+ (x,) = unpack(x, ps, "* t d")
588
+ x = rearrange(x, "b f t d -> b t f d")
589
+ x, ps = pack([x], "* f d")
590
+
591
+ if self.use_torch_checkpoint:
592
+ x = checkpoint(freq_transformer, x, use_reentrant=False)
593
+ else:
594
+ x = freq_transformer(x)
595
+
596
+ (x,) = unpack(x, ps, "* f d")
597
+
598
+ if self.skip_connection:
599
+ store[i] = x
600
+
601
+ x = self.final_norm(x)
602
+
603
+ num_stems = len(self.mask_estimators)
604
+
605
+ if self.use_torch_checkpoint:
606
+ mask = torch.stack(
607
+ [checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators],
608
+ dim=1,
609
+ )
610
+ else:
611
+ mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
612
+ mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2)
613
+
614
+ stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
615
+
616
+ stft_repr = torch.view_as_complex(stft_repr)
617
+ mask = torch.view_as_complex(mask)
618
+
619
+ stft_repr = stft_repr * mask
620
+
621
+ stft_repr = rearrange(
622
+ stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels
623
+ )
624
+
625
+ if self.zero_dc:
626
+ stft_repr = stft_repr.index_fill(1, tensor(0, device=device), 0.0)
627
+
628
+ try:
629
+ recon_audio = torch.istft(
630
+ stft_repr,
631
+ **self.stft_kwargs,
632
+ window=stft_window,
633
+ return_complex=False,
634
+ length=raw_audio.shape[-1],
635
+ )
636
+ except:
637
+ recon_audio = torch.istft(
638
+ stft_repr.cpu() if x_is_mps else stft_repr,
639
+ **self.stft_kwargs,
640
+ window=stft_window.cpu() if x_is_mps else stft_window,
641
+ return_complex=False,
642
+ length=raw_audio.shape[-1],
643
+ ).to(device)
644
+
645
+ recon_audio = rearrange(
646
+ recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=num_stems
647
+ )
648
+
649
+ if num_stems == 1:
650
+ recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
651
+
652
+ if not exists(target):
653
+ return recon_audio
654
+
655
+ if self.num_stems > 1:
656
+ assert target.ndim == 4 and target.shape[1] == self.num_stems
657
+
658
+ if target.ndim == 2:
659
+ target = rearrange(target, "... t -> ... 1 t")
660
+
661
+ target = target[..., : recon_audio.shape[-1]]
662
+
663
+ loss = F.l1_loss(recon_audio, target)
664
+
665
+ multi_stft_resolution_loss = 0.0
666
+
667
+ for window_size in self.multi_stft_resolutions_window_sizes:
668
+ res_stft_kwargs = dict(
669
+ n_fft=max(window_size, self.multi_stft_n_fft),
670
+ win_length=window_size,
671
+ return_complex=True,
672
+ window=self.multi_stft_window_fn(window_size, device=device),
673
+ **self.multi_stft_kwargs,
674
+ )
675
+
676
+ recon_Y = torch.stft(
677
+ rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs
678
+ )
679
+ target_Y = torch.stft(
680
+ rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs
681
+ )
682
+
683
+ multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(
684
+ recon_Y, target_Y
685
+ )
686
+
687
+ weighted_multi_resolution_loss = (
688
+ multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
689
+ )
690
+
691
+ total_loss = loss + weighted_multi_resolution_loss
692
+
693
+ if not return_loss_breakdown:
694
+ return total_loss
695
+
696
+ return total_loss, (loss, multi_stft_resolution_loss)
mvsepless/models/bs_roformer/bs_roformer_fno.py CHANGED
@@ -1,704 +1,704 @@
1
- from functools import partial
2
-
3
- import torch
4
- from torch import nn, einsum, Tensor
5
- from torch.nn import Module, ModuleList
6
- import torch.nn.functional as F
7
- from neuralop.models import FNO1d
8
-
9
- from .attend import Attend
10
-
11
- try:
12
- from .attend_sage import Attend as AttendSage
13
- except:
14
- pass
15
- from torch.utils.checkpoint import checkpoint
16
-
17
- from beartype.typing import Tuple, Optional, List, Callable
18
- from beartype import beartype
19
-
20
- from rotary_embedding_torch import RotaryEmbedding
21
-
22
- from einops import rearrange, pack, unpack
23
- from einops.layers.torch import Rearrange
24
-
25
-
26
- def exists(val):
27
- return val is not None
28
-
29
-
30
- def default(v, d):
31
- return v if exists(v) else d
32
-
33
-
34
- def pack_one(t, pattern):
35
- return pack([t], pattern)
36
-
37
-
38
- def unpack_one(t, ps, pattern):
39
- return unpack(t, ps, pattern)[0]
40
-
41
-
42
- def l2norm(t):
43
- return F.normalize(t, dim=-1, p=2)
44
-
45
-
46
- class RMSNorm(Module):
47
- def __init__(self, dim):
48
- super().__init__()
49
- self.scale = dim**0.5
50
- self.gamma = nn.Parameter(torch.ones(dim))
51
-
52
- def forward(self, x):
53
- return F.normalize(x, dim=-1) * self.scale * self.gamma
54
-
55
-
56
- class FeedForward(Module):
57
- def __init__(self, dim, mult=4, dropout=0.0):
58
- super().__init__()
59
- dim_inner = int(dim * mult)
60
- self.net = nn.Sequential(
61
- RMSNorm(dim),
62
- nn.Linear(dim, dim_inner),
63
- nn.GELU(),
64
- nn.Dropout(dropout),
65
- nn.Linear(dim_inner, dim),
66
- nn.Dropout(dropout),
67
- )
68
-
69
- def forward(self, x):
70
- return self.net(x)
71
-
72
-
73
- class Attention(Module):
74
- def __init__(
75
- self,
76
- dim,
77
- heads=8,
78
- dim_head=64,
79
- dropout=0.0,
80
- rotary_embed=None,
81
- flash=True,
82
- sage_attention=False,
83
- ):
84
- super().__init__()
85
- self.heads = heads
86
- self.scale = dim_head**-0.5
87
- dim_inner = heads * dim_head
88
-
89
- self.rotary_embed = rotary_embed
90
-
91
- if sage_attention:
92
- self.attend = AttendSage(flash=flash, dropout=dropout)
93
- else:
94
- self.attend = Attend(flash=flash, dropout=dropout)
95
-
96
- self.norm = RMSNorm(dim)
97
- self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
98
-
99
- self.to_gates = nn.Linear(dim, heads)
100
-
101
- self.to_out = nn.Sequential(
102
- nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout)
103
- )
104
-
105
- def forward(self, x):
106
- x = self.norm(x)
107
-
108
- q, k, v = rearrange(
109
- self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads
110
- )
111
-
112
- if exists(self.rotary_embed):
113
- q = self.rotary_embed.rotate_queries_or_keys(q)
114
- k = self.rotary_embed.rotate_queries_or_keys(k)
115
-
116
- out = self.attend(q, k, v)
117
-
118
- gates = self.to_gates(x)
119
- out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
120
-
121
- out = rearrange(out, "b h n d -> b n (h d)")
122
- return self.to_out(out)
123
-
124
-
125
- class LinearAttention(Module):
126
-
127
- @beartype
128
- def __init__(
129
- self,
130
- *,
131
- dim,
132
- dim_head=32,
133
- heads=8,
134
- scale=8,
135
- flash=False,
136
- dropout=0.0,
137
- sage_attention=False,
138
- ):
139
- super().__init__()
140
- dim_inner = dim_head * heads
141
- self.norm = RMSNorm(dim)
142
-
143
- self.to_qkv = nn.Sequential(
144
- nn.Linear(dim, dim_inner * 3, bias=False),
145
- Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads),
146
- )
147
-
148
- self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
149
-
150
- if sage_attention:
151
- self.attend = AttendSage(scale=scale, dropout=dropout, flash=flash)
152
- else:
153
- self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
154
-
155
- self.to_out = nn.Sequential(
156
- Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)
157
- )
158
-
159
- def forward(self, x):
160
- x = self.norm(x)
161
-
162
- q, k, v = self.to_qkv(x)
163
-
164
- q, k = map(l2norm, (q, k))
165
- q = q * self.temperature.exp()
166
-
167
- out = self.attend(q, k, v)
168
-
169
- return self.to_out(out)
170
-
171
-
172
- class Transformer(Module):
173
- def __init__(
174
- self,
175
- *,
176
- dim,
177
- depth,
178
- dim_head=64,
179
- heads=8,
180
- attn_dropout=0.0,
181
- ff_dropout=0.0,
182
- ff_mult=4,
183
- norm_output=True,
184
- rotary_embed=None,
185
- flash_attn=True,
186
- linear_attn=False,
187
- sage_attention=False,
188
- ):
189
- super().__init__()
190
- self.layers = ModuleList([])
191
-
192
- for _ in range(depth):
193
- if linear_attn:
194
- attn = LinearAttention(
195
- dim=dim,
196
- dim_head=dim_head,
197
- heads=heads,
198
- dropout=attn_dropout,
199
- flash=flash_attn,
200
- sage_attention=sage_attention,
201
- )
202
- else:
203
- attn = Attention(
204
- dim=dim,
205
- dim_head=dim_head,
206
- heads=heads,
207
- dropout=attn_dropout,
208
- rotary_embed=rotary_embed,
209
- flash=flash_attn,
210
- sage_attention=sage_attention,
211
- )
212
-
213
- self.layers.append(
214
- ModuleList(
215
- [attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]
216
- )
217
- )
218
-
219
- self.norm = RMSNorm(dim) if norm_output else nn.Identity()
220
-
221
- def forward(self, x):
222
-
223
- for attn, ff in self.layers:
224
- x = attn(x) + x
225
- x = ff(x) + x
226
-
227
- return self.norm(x)
228
-
229
-
230
- class BandSplit(Module):
231
- @beartype
232
- def __init__(self, dim, dim_inputs: Tuple[int, ...]):
233
- super().__init__()
234
- self.dim_inputs = dim_inputs
235
- self.to_features = ModuleList([])
236
-
237
- for dim_in in dim_inputs:
238
- net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
239
-
240
- self.to_features.append(net)
241
-
242
- def forward(self, x):
243
- x = x.split(self.dim_inputs, dim=-1)
244
-
245
- outs = []
246
- for split_input, to_feature in zip(x, self.to_features):
247
- split_output = to_feature(split_input)
248
- outs.append(split_output)
249
-
250
- return torch.stack(outs, dim=-2)
251
-
252
-
253
- def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
254
- dim_hidden = default(dim_hidden, dim_in)
255
-
256
- net = []
257
- dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
258
-
259
- for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
260
- is_last = ind == (len(dims) - 2)
261
-
262
- net.append(nn.Linear(layer_dim_in, layer_dim_out))
263
-
264
- if is_last:
265
- continue
266
-
267
- net.append(activation())
268
-
269
- return nn.Sequential(*net)
270
-
271
-
272
- class MaskEstimator(Module):
273
- @beartype
274
- def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
275
- super().__init__()
276
- self.dim_inputs = dim_inputs
277
- self.to_freqs = ModuleList([])
278
- dim_hidden = dim * mlp_expansion_factor
279
-
280
- for dim_in in dim_inputs:
281
- net = []
282
-
283
- mlp = nn.Sequential(
284
- FNO1d(
285
- n_modes_height=64,
286
- hidden_channels=dim,
287
- in_channels=dim,
288
- out_channels=dim_in * 2,
289
- lifting_channels=dim,
290
- projection_channels=dim,
291
- n_layers=3,
292
- separable=True,
293
- ),
294
- nn.GLU(dim=-2),
295
- )
296
-
297
- self.to_freqs.append(mlp)
298
-
299
- def forward(self, x):
300
- x = x.unbind(dim=-2)
301
-
302
- outs = []
303
-
304
- for band_features, mlp in zip(x, self.to_freqs):
305
- band_features = rearrange(band_features, "b t c -> b c t")
306
- with torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32):
307
- freq_out = mlp(band_features).float()
308
- freq_out = rearrange(freq_out, "b c t -> b t c")
309
- outs.append(freq_out)
310
-
311
- return torch.cat(outs, dim=-1)
312
-
313
-
314
- DEFAULT_FREQS_PER_BANDS = (
315
- 2,
316
- 2,
317
- 2,
318
- 2,
319
- 2,
320
- 2,
321
- 2,
322
- 2,
323
- 2,
324
- 2,
325
- 2,
326
- 2,
327
- 2,
328
- 2,
329
- 2,
330
- 2,
331
- 2,
332
- 2,
333
- 2,
334
- 2,
335
- 2,
336
- 2,
337
- 2,
338
- 2,
339
- 4,
340
- 4,
341
- 4,
342
- 4,
343
- 4,
344
- 4,
345
- 4,
346
- 4,
347
- 4,
348
- 4,
349
- 4,
350
- 4,
351
- 12,
352
- 12,
353
- 12,
354
- 12,
355
- 12,
356
- 12,
357
- 12,
358
- 12,
359
- 24,
360
- 24,
361
- 24,
362
- 24,
363
- 24,
364
- 24,
365
- 24,
366
- 24,
367
- 48,
368
- 48,
369
- 48,
370
- 48,
371
- 48,
372
- 48,
373
- 48,
374
- 48,
375
- 128,
376
- 129,
377
- )
378
-
379
-
380
- class BSRoformer_FNO(Module):
381
-
382
- @beartype
383
- def __init__(
384
- self,
385
- dim,
386
- *,
387
- depth,
388
- stereo=False,
389
- num_stems=1,
390
- time_transformer_depth=2,
391
- freq_transformer_depth=2,
392
- linear_transformer_depth=0,
393
- freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
394
- dim_head=64,
395
- heads=8,
396
- attn_dropout=0.0,
397
- ff_dropout=0.0,
398
- flash_attn=True,
399
- dim_freqs_in=1025,
400
- stft_n_fft=2048,
401
- stft_hop_length=512,
402
- stft_win_length=2048,
403
- stft_normalized=False,
404
- stft_window_fn: Optional[Callable] = None,
405
- mask_estimator_depth=2,
406
- multi_stft_resolution_loss_weight=1.0,
407
- multi_stft_resolutions_window_sizes: Tuple[int, ...] = (
408
- 4096,
409
- 2048,
410
- 1024,
411
- 512,
412
- 256,
413
- ),
414
- multi_stft_hop_size=147,
415
- multi_stft_normalized=False,
416
- multi_stft_window_fn: Callable = torch.hann_window,
417
- mlp_expansion_factor=4,
418
- use_torch_checkpoint=False,
419
- skip_connection=False,
420
- sage_attention=False,
421
- ):
422
- super().__init__()
423
-
424
- self.stereo = stereo
425
- self.audio_channels = 2 if stereo else 1
426
- self.num_stems = num_stems
427
- self.use_torch_checkpoint = use_torch_checkpoint
428
- self.skip_connection = skip_connection
429
-
430
- self.layers = ModuleList([])
431
-
432
- if sage_attention:
433
- print("Use Sage Attention")
434
-
435
- transformer_kwargs = dict(
436
- dim=dim,
437
- heads=heads,
438
- dim_head=dim_head,
439
- attn_dropout=attn_dropout,
440
- ff_dropout=ff_dropout,
441
- flash_attn=flash_attn,
442
- norm_output=False,
443
- sage_attention=sage_attention,
444
- )
445
-
446
- time_rotary_embed = RotaryEmbedding(dim=dim_head)
447
- freq_rotary_embed = RotaryEmbedding(dim=dim_head)
448
-
449
- for _ in range(depth):
450
- tran_modules = []
451
- if linear_transformer_depth > 0:
452
- tran_modules.append(
453
- Transformer(
454
- depth=linear_transformer_depth,
455
- linear_attn=True,
456
- **transformer_kwargs,
457
- )
458
- )
459
- tran_modules.append(
460
- Transformer(
461
- depth=time_transformer_depth,
462
- rotary_embed=time_rotary_embed,
463
- **transformer_kwargs,
464
- )
465
- )
466
- tran_modules.append(
467
- Transformer(
468
- depth=freq_transformer_depth,
469
- rotary_embed=freq_rotary_embed,
470
- **transformer_kwargs,
471
- )
472
- )
473
- self.layers.append(nn.ModuleList(tran_modules))
474
-
475
- self.final_norm = RMSNorm(dim)
476
-
477
- self.stft_kwargs = dict(
478
- n_fft=stft_n_fft,
479
- hop_length=stft_hop_length,
480
- win_length=stft_win_length,
481
- normalized=stft_normalized,
482
- )
483
-
484
- self.stft_window_fn = partial(
485
- default(stft_window_fn, torch.hann_window), stft_win_length
486
- )
487
-
488
- freqs = torch.stft(
489
- torch.randn(1, 4096),
490
- **self.stft_kwargs,
491
- window=torch.ones(stft_win_length),
492
- return_complex=True,
493
- ).shape[1]
494
-
495
- assert len(freqs_per_bands) > 1
496
- assert (
497
- sum(freqs_per_bands) == freqs
498
- ), f"the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}"
499
-
500
- freqs_per_bands_with_complex = tuple(
501
- 2 * f * self.audio_channels for f in freqs_per_bands
502
- )
503
-
504
- self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
505
-
506
- self.mask_estimators = nn.ModuleList([])
507
-
508
- for _ in range(num_stems):
509
- mask_estimator = MaskEstimator(
510
- dim=dim,
511
- dim_inputs=freqs_per_bands_with_complex,
512
- depth=mask_estimator_depth,
513
- mlp_expansion_factor=mlp_expansion_factor,
514
- )
515
-
516
- self.mask_estimators.append(mask_estimator)
517
-
518
- self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
519
- self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
520
- self.multi_stft_n_fft = stft_n_fft
521
- self.multi_stft_window_fn = multi_stft_window_fn
522
-
523
- self.multi_stft_kwargs = dict(
524
- hop_length=multi_stft_hop_size, normalized=multi_stft_normalized
525
- )
526
-
527
- def forward(self, raw_audio, target=None, return_loss_breakdown=False):
528
-
529
- device = raw_audio.device
530
-
531
- x_is_mps = True if device.type == "mps" else False
532
-
533
- if raw_audio.ndim == 2:
534
- raw_audio = rearrange(raw_audio, "b t -> b 1 t")
535
-
536
- channels = raw_audio.shape[1]
537
- assert (not self.stereo and channels == 1) or (
538
- self.stereo and channels == 2
539
- ), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
540
-
541
- raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
542
-
543
- stft_window = self.stft_window_fn(device=device)
544
-
545
- try:
546
- stft_repr = torch.stft(
547
- raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True
548
- )
549
- except:
550
- stft_repr = torch.stft(
551
- raw_audio.cpu() if x_is_mps else raw_audio,
552
- **self.stft_kwargs,
553
- window=stft_window.cpu() if x_is_mps else stft_window,
554
- return_complex=True,
555
- ).to(device)
556
- stft_repr = torch.view_as_real(stft_repr)
557
-
558
- stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
559
-
560
- stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c")
561
-
562
- x = rearrange(stft_repr, "b f t c -> b t (f c)")
563
-
564
- if self.use_torch_checkpoint:
565
- x = checkpoint(self.band_split, x, use_reentrant=False)
566
- else:
567
- x = self.band_split(x)
568
-
569
- store = [None] * len(self.layers)
570
- for i, transformer_block in enumerate(self.layers):
571
-
572
- if len(transformer_block) == 3:
573
- linear_transformer, time_transformer, freq_transformer = (
574
- transformer_block
575
- )
576
-
577
- x, ft_ps = pack([x], "b * d")
578
- if self.use_torch_checkpoint:
579
- x = checkpoint(linear_transformer, x, use_reentrant=False)
580
- else:
581
- x = linear_transformer(x)
582
- (x,) = unpack(x, ft_ps, "b * d")
583
- else:
584
- time_transformer, freq_transformer = transformer_block
585
-
586
- if self.skip_connection:
587
- for j in range(i):
588
- x = x + store[j]
589
-
590
- x = rearrange(x, "b t f d -> b f t d")
591
- x, ps = pack([x], "* t d")
592
-
593
- if self.use_torch_checkpoint:
594
- x = checkpoint(time_transformer, x, use_reentrant=False)
595
- else:
596
- x = time_transformer(x)
597
-
598
- (x,) = unpack(x, ps, "* t d")
599
- x = rearrange(x, "b f t d -> b t f d")
600
- x, ps = pack([x], "* f d")
601
-
602
- if self.use_torch_checkpoint:
603
- x = checkpoint(freq_transformer, x, use_reentrant=False)
604
- else:
605
- x = freq_transformer(x)
606
-
607
- (x,) = unpack(x, ps, "* f d")
608
-
609
- if self.skip_connection:
610
- store[i] = x
611
-
612
- x = self.final_norm(x)
613
-
614
- num_stems = len(self.mask_estimators)
615
-
616
- if self.use_torch_checkpoint:
617
- mask = torch.stack(
618
- [checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators],
619
- dim=1,
620
- )
621
- else:
622
- mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
623
- mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2)
624
-
625
- stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
626
-
627
- stft_repr = torch.view_as_complex(stft_repr)
628
- mask = torch.view_as_complex(mask)
629
-
630
- stft_repr = stft_repr * mask
631
-
632
- stft_repr = rearrange(
633
- stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels
634
- )
635
-
636
- try:
637
- recon_audio = torch.istft(
638
- stft_repr,
639
- **self.stft_kwargs,
640
- window=stft_window,
641
- return_complex=False,
642
- length=raw_audio.shape[-1],
643
- )
644
- except:
645
- recon_audio = torch.istft(
646
- stft_repr.cpu() if x_is_mps else stft_repr,
647
- **self.stft_kwargs,
648
- window=stft_window.cpu() if x_is_mps else stft_window,
649
- return_complex=False,
650
- length=raw_audio.shape[-1],
651
- ).to(device)
652
-
653
- recon_audio = rearrange(
654
- recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=num_stems
655
- )
656
-
657
- if num_stems == 1:
658
- recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
659
-
660
- if not exists(target):
661
- return recon_audio
662
-
663
- if self.num_stems > 1:
664
- assert target.ndim == 4 and target.shape[1] == self.num_stems
665
-
666
- if target.ndim == 2:
667
- target = rearrange(target, "... t -> ... 1 t")
668
-
669
- target = target[..., : recon_audio.shape[-1]]
670
-
671
- loss = F.l1_loss(recon_audio, target)
672
-
673
- multi_stft_resolution_loss = 0.0
674
-
675
- for window_size in self.multi_stft_resolutions_window_sizes:
676
- res_stft_kwargs = dict(
677
- n_fft=max(window_size, self.multi_stft_n_fft),
678
- win_length=window_size,
679
- return_complex=True,
680
- window=self.multi_stft_window_fn(window_size, device=device),
681
- **self.multi_stft_kwargs,
682
- )
683
-
684
- recon_Y = torch.stft(
685
- rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs
686
- )
687
- target_Y = torch.stft(
688
- rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs
689
- )
690
-
691
- multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(
692
- recon_Y, target_Y
693
- )
694
-
695
- weighted_multi_resolution_loss = (
696
- multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
697
- )
698
-
699
- total_loss = loss + weighted_multi_resolution_loss
700
-
701
- if not return_loss_breakdown:
702
- return total_loss
703
-
704
- return total_loss, (loss, multi_stft_resolution_loss)
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ from torch import nn, einsum, Tensor
5
+ from torch.nn import Module, ModuleList
6
+ import torch.nn.functional as F
7
+ from neuralop.models import FNO1d
8
+
9
+ from .attend import Attend
10
+
11
+ try:
12
+ from .attend_sage import Attend as AttendSage
13
+ except:
14
+ pass
15
+ from torch.utils.checkpoint import checkpoint
16
+
17
+ from beartype.typing import Tuple, Optional, List, Callable
18
+ from beartype import beartype
19
+
20
+ from rotary_embedding_torch import RotaryEmbedding
21
+
22
+ from einops import rearrange, pack, unpack
23
+ from einops.layers.torch import Rearrange
24
+
25
+
26
+ def exists(val):
27
+ return val is not None
28
+
29
+
30
+ def default(v, d):
31
+ return v if exists(v) else d
32
+
33
+
34
+ def pack_one(t, pattern):
35
+ return pack([t], pattern)
36
+
37
+
38
+ def unpack_one(t, ps, pattern):
39
+ return unpack(t, ps, pattern)[0]
40
+
41
+
42
+ def l2norm(t):
43
+ return F.normalize(t, dim=-1, p=2)
44
+
45
+
46
+ class RMSNorm(Module):
47
+ def __init__(self, dim):
48
+ super().__init__()
49
+ self.scale = dim**0.5
50
+ self.gamma = nn.Parameter(torch.ones(dim))
51
+
52
+ def forward(self, x):
53
+ return F.normalize(x, dim=-1) * self.scale * self.gamma
54
+
55
+
56
+ class FeedForward(Module):
57
+ def __init__(self, dim, mult=4, dropout=0.0):
58
+ super().__init__()
59
+ dim_inner = int(dim * mult)
60
+ self.net = nn.Sequential(
61
+ RMSNorm(dim),
62
+ nn.Linear(dim, dim_inner),
63
+ nn.GELU(),
64
+ nn.Dropout(dropout),
65
+ nn.Linear(dim_inner, dim),
66
+ nn.Dropout(dropout),
67
+ )
68
+
69
+ def forward(self, x):
70
+ return self.net(x)
71
+
72
+
73
+ class Attention(Module):
74
+ def __init__(
75
+ self,
76
+ dim,
77
+ heads=8,
78
+ dim_head=64,
79
+ dropout=0.0,
80
+ rotary_embed=None,
81
+ flash=True,
82
+ sage_attention=False,
83
+ ):
84
+ super().__init__()
85
+ self.heads = heads
86
+ self.scale = dim_head**-0.5
87
+ dim_inner = heads * dim_head
88
+
89
+ self.rotary_embed = rotary_embed
90
+
91
+ if sage_attention:
92
+ self.attend = AttendSage(flash=flash, dropout=dropout)
93
+ else:
94
+ self.attend = Attend(flash=flash, dropout=dropout)
95
+
96
+ self.norm = RMSNorm(dim)
97
+ self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
98
+
99
+ self.to_gates = nn.Linear(dim, heads)
100
+
101
+ self.to_out = nn.Sequential(
102
+ nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout)
103
+ )
104
+
105
+ def forward(self, x):
106
+ x = self.norm(x)
107
+
108
+ q, k, v = rearrange(
109
+ self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads
110
+ )
111
+
112
+ if exists(self.rotary_embed):
113
+ q = self.rotary_embed.rotate_queries_or_keys(q)
114
+ k = self.rotary_embed.rotate_queries_or_keys(k)
115
+
116
+ out = self.attend(q, k, v)
117
+
118
+ gates = self.to_gates(x)
119
+ out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
120
+
121
+ out = rearrange(out, "b h n d -> b n (h d)")
122
+ return self.to_out(out)
123
+
124
+
125
+ class LinearAttention(Module):
126
+
127
+ @beartype
128
+ def __init__(
129
+ self,
130
+ *,
131
+ dim,
132
+ dim_head=32,
133
+ heads=8,
134
+ scale=8,
135
+ flash=False,
136
+ dropout=0.0,
137
+ sage_attention=False,
138
+ ):
139
+ super().__init__()
140
+ dim_inner = dim_head * heads
141
+ self.norm = RMSNorm(dim)
142
+
143
+ self.to_qkv = nn.Sequential(
144
+ nn.Linear(dim, dim_inner * 3, bias=False),
145
+ Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads),
146
+ )
147
+
148
+ self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
149
+
150
+ if sage_attention:
151
+ self.attend = AttendSage(scale=scale, dropout=dropout, flash=flash)
152
+ else:
153
+ self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
154
+
155
+ self.to_out = nn.Sequential(
156
+ Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)
157
+ )
158
+
159
+ def forward(self, x):
160
+ x = self.norm(x)
161
+
162
+ q, k, v = self.to_qkv(x)
163
+
164
+ q, k = map(l2norm, (q, k))
165
+ q = q * self.temperature.exp()
166
+
167
+ out = self.attend(q, k, v)
168
+
169
+ return self.to_out(out)
170
+
171
+
172
+ class Transformer(Module):
173
+ def __init__(
174
+ self,
175
+ *,
176
+ dim,
177
+ depth,
178
+ dim_head=64,
179
+ heads=8,
180
+ attn_dropout=0.0,
181
+ ff_dropout=0.0,
182
+ ff_mult=4,
183
+ norm_output=True,
184
+ rotary_embed=None,
185
+ flash_attn=True,
186
+ linear_attn=False,
187
+ sage_attention=False,
188
+ ):
189
+ super().__init__()
190
+ self.layers = ModuleList([])
191
+
192
+ for _ in range(depth):
193
+ if linear_attn:
194
+ attn = LinearAttention(
195
+ dim=dim,
196
+ dim_head=dim_head,
197
+ heads=heads,
198
+ dropout=attn_dropout,
199
+ flash=flash_attn,
200
+ sage_attention=sage_attention,
201
+ )
202
+ else:
203
+ attn = Attention(
204
+ dim=dim,
205
+ dim_head=dim_head,
206
+ heads=heads,
207
+ dropout=attn_dropout,
208
+ rotary_embed=rotary_embed,
209
+ flash=flash_attn,
210
+ sage_attention=sage_attention,
211
+ )
212
+
213
+ self.layers.append(
214
+ ModuleList(
215
+ [attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]
216
+ )
217
+ )
218
+
219
+ self.norm = RMSNorm(dim) if norm_output else nn.Identity()
220
+
221
+ def forward(self, x):
222
+
223
+ for attn, ff in self.layers:
224
+ x = attn(x) + x
225
+ x = ff(x) + x
226
+
227
+ return self.norm(x)
228
+
229
+
230
+ class BandSplit(Module):
231
+ @beartype
232
+ def __init__(self, dim, dim_inputs: Tuple[int, ...]):
233
+ super().__init__()
234
+ self.dim_inputs = dim_inputs
235
+ self.to_features = ModuleList([])
236
+
237
+ for dim_in in dim_inputs:
238
+ net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
239
+
240
+ self.to_features.append(net)
241
+
242
+ def forward(self, x):
243
+ x = x.split(self.dim_inputs, dim=-1)
244
+
245
+ outs = []
246
+ for split_input, to_feature in zip(x, self.to_features):
247
+ split_output = to_feature(split_input)
248
+ outs.append(split_output)
249
+
250
+ return torch.stack(outs, dim=-2)
251
+
252
+
253
+ def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
254
+ dim_hidden = default(dim_hidden, dim_in)
255
+
256
+ net = []
257
+ dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
258
+
259
+ for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
260
+ is_last = ind == (len(dims) - 2)
261
+
262
+ net.append(nn.Linear(layer_dim_in, layer_dim_out))
263
+
264
+ if is_last:
265
+ continue
266
+
267
+ net.append(activation())
268
+
269
+ return nn.Sequential(*net)
270
+
271
+
272
+ class MaskEstimator(Module):
273
+ @beartype
274
+ def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
275
+ super().__init__()
276
+ self.dim_inputs = dim_inputs
277
+ self.to_freqs = ModuleList([])
278
+ dim_hidden = dim * mlp_expansion_factor
279
+
280
+ for dim_in in dim_inputs:
281
+ net = []
282
+
283
+ mlp = nn.Sequential(
284
+ FNO1d(
285
+ n_modes_height=64,
286
+ hidden_channels=dim,
287
+ in_channels=dim,
288
+ out_channels=dim_in * 2,
289
+ lifting_channels=dim,
290
+ projection_channels=dim,
291
+ n_layers=3,
292
+ separable=True,
293
+ ),
294
+ nn.GLU(dim=-2),
295
+ )
296
+
297
+ self.to_freqs.append(mlp)
298
+
299
+ def forward(self, x):
300
+ x = x.unbind(dim=-2)
301
+
302
+ outs = []
303
+
304
+ for band_features, mlp in zip(x, self.to_freqs):
305
+ band_features = rearrange(band_features, "b t c -> b c t")
306
+ with torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32):
307
+ freq_out = mlp(band_features).float()
308
+ freq_out = rearrange(freq_out, "b c t -> b t c")
309
+ outs.append(freq_out)
310
+
311
+ return torch.cat(outs, dim=-1)
312
+
313
+
314
+ DEFAULT_FREQS_PER_BANDS = (
315
+ 2,
316
+ 2,
317
+ 2,
318
+ 2,
319
+ 2,
320
+ 2,
321
+ 2,
322
+ 2,
323
+ 2,
324
+ 2,
325
+ 2,
326
+ 2,
327
+ 2,
328
+ 2,
329
+ 2,
330
+ 2,
331
+ 2,
332
+ 2,
333
+ 2,
334
+ 2,
335
+ 2,
336
+ 2,
337
+ 2,
338
+ 2,
339
+ 4,
340
+ 4,
341
+ 4,
342
+ 4,
343
+ 4,
344
+ 4,
345
+ 4,
346
+ 4,
347
+ 4,
348
+ 4,
349
+ 4,
350
+ 4,
351
+ 12,
352
+ 12,
353
+ 12,
354
+ 12,
355
+ 12,
356
+ 12,
357
+ 12,
358
+ 12,
359
+ 24,
360
+ 24,
361
+ 24,
362
+ 24,
363
+ 24,
364
+ 24,
365
+ 24,
366
+ 24,
367
+ 48,
368
+ 48,
369
+ 48,
370
+ 48,
371
+ 48,
372
+ 48,
373
+ 48,
374
+ 48,
375
+ 128,
376
+ 129,
377
+ )
378
+
379
+
380
+ class BSRoformer_FNO(Module):
381
+
382
+ @beartype
383
+ def __init__(
384
+ self,
385
+ dim,
386
+ *,
387
+ depth,
388
+ stereo=False,
389
+ num_stems=1,
390
+ time_transformer_depth=2,
391
+ freq_transformer_depth=2,
392
+ linear_transformer_depth=0,
393
+ freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
394
+ dim_head=64,
395
+ heads=8,
396
+ attn_dropout=0.0,
397
+ ff_dropout=0.0,
398
+ flash_attn=True,
399
+ dim_freqs_in=1025,
400
+ stft_n_fft=2048,
401
+ stft_hop_length=512,
402
+ stft_win_length=2048,
403
+ stft_normalized=False,
404
+ stft_window_fn: Optional[Callable] = None,
405
+ mask_estimator_depth=2,
406
+ multi_stft_resolution_loss_weight=1.0,
407
+ multi_stft_resolutions_window_sizes: Tuple[int, ...] = (
408
+ 4096,
409
+ 2048,
410
+ 1024,
411
+ 512,
412
+ 256,
413
+ ),
414
+ multi_stft_hop_size=147,
415
+ multi_stft_normalized=False,
416
+ multi_stft_window_fn: Callable = torch.hann_window,
417
+ mlp_expansion_factor=4,
418
+ use_torch_checkpoint=False,
419
+ skip_connection=False,
420
+ sage_attention=False,
421
+ ):
422
+ super().__init__()
423
+
424
+ self.stereo = stereo
425
+ self.audio_channels = 2 if stereo else 1
426
+ self.num_stems = num_stems
427
+ self.use_torch_checkpoint = use_torch_checkpoint
428
+ self.skip_connection = skip_connection
429
+
430
+ self.layers = ModuleList([])
431
+
432
+ if sage_attention:
433
+ print("Use Sage Attention")
434
+
435
+ transformer_kwargs = dict(
436
+ dim=dim,
437
+ heads=heads,
438
+ dim_head=dim_head,
439
+ attn_dropout=attn_dropout,
440
+ ff_dropout=ff_dropout,
441
+ flash_attn=flash_attn,
442
+ norm_output=False,
443
+ sage_attention=sage_attention,
444
+ )
445
+
446
+ time_rotary_embed = RotaryEmbedding(dim=dim_head)
447
+ freq_rotary_embed = RotaryEmbedding(dim=dim_head)
448
+
449
+ for _ in range(depth):
450
+ tran_modules = []
451
+ if linear_transformer_depth > 0:
452
+ tran_modules.append(
453
+ Transformer(
454
+ depth=linear_transformer_depth,
455
+ linear_attn=True,
456
+ **transformer_kwargs,
457
+ )
458
+ )
459
+ tran_modules.append(
460
+ Transformer(
461
+ depth=time_transformer_depth,
462
+ rotary_embed=time_rotary_embed,
463
+ **transformer_kwargs,
464
+ )
465
+ )
466
+ tran_modules.append(
467
+ Transformer(
468
+ depth=freq_transformer_depth,
469
+ rotary_embed=freq_rotary_embed,
470
+ **transformer_kwargs,
471
+ )
472
+ )
473
+ self.layers.append(nn.ModuleList(tran_modules))
474
+
475
+ self.final_norm = RMSNorm(dim)
476
+
477
+ self.stft_kwargs = dict(
478
+ n_fft=stft_n_fft,
479
+ hop_length=stft_hop_length,
480
+ win_length=stft_win_length,
481
+ normalized=stft_normalized,
482
+ )
483
+
484
+ self.stft_window_fn = partial(
485
+ default(stft_window_fn, torch.hann_window), stft_win_length
486
+ )
487
+
488
+ freqs = torch.stft(
489
+ torch.randn(1, 4096),
490
+ **self.stft_kwargs,
491
+ window=torch.ones(stft_win_length),
492
+ return_complex=True,
493
+ ).shape[1]
494
+
495
+ assert len(freqs_per_bands) > 1
496
+ assert (
497
+ sum(freqs_per_bands) == freqs
498
+ ), f"the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}"
499
+
500
+ freqs_per_bands_with_complex = tuple(
501
+ 2 * f * self.audio_channels for f in freqs_per_bands
502
+ )
503
+
504
+ self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
505
+
506
+ self.mask_estimators = nn.ModuleList([])
507
+
508
+ for _ in range(num_stems):
509
+ mask_estimator = MaskEstimator(
510
+ dim=dim,
511
+ dim_inputs=freqs_per_bands_with_complex,
512
+ depth=mask_estimator_depth,
513
+ mlp_expansion_factor=mlp_expansion_factor,
514
+ )
515
+
516
+ self.mask_estimators.append(mask_estimator)
517
+
518
+ self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
519
+ self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
520
+ self.multi_stft_n_fft = stft_n_fft
521
+ self.multi_stft_window_fn = multi_stft_window_fn
522
+
523
+ self.multi_stft_kwargs = dict(
524
+ hop_length=multi_stft_hop_size, normalized=multi_stft_normalized
525
+ )
526
+
527
+ def forward(self, raw_audio, target=None, return_loss_breakdown=False):
528
+
529
+ device = raw_audio.device
530
+
531
+ x_is_mps = True if device.type == "mps" else False
532
+
533
+ if raw_audio.ndim == 2:
534
+ raw_audio = rearrange(raw_audio, "b t -> b 1 t")
535
+
536
+ channels = raw_audio.shape[1]
537
+ assert (not self.stereo and channels == 1) or (
538
+ self.stereo and channels == 2
539
+ ), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
540
+
541
+ raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
542
+
543
+ stft_window = self.stft_window_fn(device=device)
544
+
545
+ try:
546
+ stft_repr = torch.stft(
547
+ raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True
548
+ )
549
+ except:
550
+ stft_repr = torch.stft(
551
+ raw_audio.cpu() if x_is_mps else raw_audio,
552
+ **self.stft_kwargs,
553
+ window=stft_window.cpu() if x_is_mps else stft_window,
554
+ return_complex=True,
555
+ ).to(device)
556
+ stft_repr = torch.view_as_real(stft_repr)
557
+
558
+ stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
559
+
560
+ stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c")
561
+
562
+ x = rearrange(stft_repr, "b f t c -> b t (f c)")
563
+
564
+ if self.use_torch_checkpoint:
565
+ x = checkpoint(self.band_split, x, use_reentrant=False)
566
+ else:
567
+ x = self.band_split(x)
568
+
569
+ store = [None] * len(self.layers)
570
+ for i, transformer_block in enumerate(self.layers):
571
+
572
+ if len(transformer_block) == 3:
573
+ linear_transformer, time_transformer, freq_transformer = (
574
+ transformer_block
575
+ )
576
+
577
+ x, ft_ps = pack([x], "b * d")
578
+ if self.use_torch_checkpoint:
579
+ x = checkpoint(linear_transformer, x, use_reentrant=False)
580
+ else:
581
+ x = linear_transformer(x)
582
+ (x,) = unpack(x, ft_ps, "b * d")
583
+ else:
584
+ time_transformer, freq_transformer = transformer_block
585
+
586
+ if self.skip_connection:
587
+ for j in range(i):
588
+ x = x + store[j]
589
+
590
+ x = rearrange(x, "b t f d -> b f t d")
591
+ x, ps = pack([x], "* t d")
592
+
593
+ if self.use_torch_checkpoint:
594
+ x = checkpoint(time_transformer, x, use_reentrant=False)
595
+ else:
596
+ x = time_transformer(x)
597
+
598
+ (x,) = unpack(x, ps, "* t d")
599
+ x = rearrange(x, "b f t d -> b t f d")
600
+ x, ps = pack([x], "* f d")
601
+
602
+ if self.use_torch_checkpoint:
603
+ x = checkpoint(freq_transformer, x, use_reentrant=False)
604
+ else:
605
+ x = freq_transformer(x)
606
+
607
+ (x,) = unpack(x, ps, "* f d")
608
+
609
+ if self.skip_connection:
610
+ store[i] = x
611
+
612
+ x = self.final_norm(x)
613
+
614
+ num_stems = len(self.mask_estimators)
615
+
616
+ if self.use_torch_checkpoint:
617
+ mask = torch.stack(
618
+ [checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators],
619
+ dim=1,
620
+ )
621
+ else:
622
+ mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
623
+ mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2)
624
+
625
+ stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
626
+
627
+ stft_repr = torch.view_as_complex(stft_repr)
628
+ mask = torch.view_as_complex(mask)
629
+
630
+ stft_repr = stft_repr * mask
631
+
632
+ stft_repr = rearrange(
633
+ stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels
634
+ )
635
+
636
+ try:
637
+ recon_audio = torch.istft(
638
+ stft_repr,
639
+ **self.stft_kwargs,
640
+ window=stft_window,
641
+ return_complex=False,
642
+ length=raw_audio.shape[-1],
643
+ )
644
+ except:
645
+ recon_audio = torch.istft(
646
+ stft_repr.cpu() if x_is_mps else stft_repr,
647
+ **self.stft_kwargs,
648
+ window=stft_window.cpu() if x_is_mps else stft_window,
649
+ return_complex=False,
650
+ length=raw_audio.shape[-1],
651
+ ).to(device)
652
+
653
+ recon_audio = rearrange(
654
+ recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=num_stems
655
+ )
656
+
657
+ if num_stems == 1:
658
+ recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
659
+
660
+ if not exists(target):
661
+ return recon_audio
662
+
663
+ if self.num_stems > 1:
664
+ assert target.ndim == 4 and target.shape[1] == self.num_stems
665
+
666
+ if target.ndim == 2:
667
+ target = rearrange(target, "... t -> ... 1 t")
668
+
669
+ target = target[..., : recon_audio.shape[-1]]
670
+
671
+ loss = F.l1_loss(recon_audio, target)
672
+
673
+ multi_stft_resolution_loss = 0.0
674
+
675
+ for window_size in self.multi_stft_resolutions_window_sizes:
676
+ res_stft_kwargs = dict(
677
+ n_fft=max(window_size, self.multi_stft_n_fft),
678
+ win_length=window_size,
679
+ return_complex=True,
680
+ window=self.multi_stft_window_fn(window_size, device=device),
681
+ **self.multi_stft_kwargs,
682
+ )
683
+
684
+ recon_Y = torch.stft(
685
+ rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs
686
+ )
687
+ target_Y = torch.stft(
688
+ rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs
689
+ )
690
+
691
+ multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(
692
+ recon_Y, target_Y
693
+ )
694
+
695
+ weighted_multi_resolution_loss = (
696
+ multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
697
+ )
698
+
699
+ total_loss = loss + weighted_multi_resolution_loss
700
+
701
+ if not return_loss_breakdown:
702
+ return total_loss
703
+
704
+ return total_loss, (loss, multi_stft_resolution_loss)
mvsepless/models/bs_roformer/bs_roformer_hyperace.py CHANGED
@@ -1,1122 +1,1122 @@
1
- from functools import partial
2
-
3
- import torch
4
- from torch import nn, einsum, Tensor
5
- from torch.nn import Module, ModuleList
6
- import torch.nn.functional as F
7
-
8
- from .attend import Attend
9
-
10
- try:
11
- from .attend_sage import Attend as AttendSage
12
- except:
13
- pass
14
- from torch.utils.checkpoint import checkpoint
15
-
16
- from beartype.typing import Tuple, Optional, List, Callable
17
- from beartype import beartype
18
-
19
- from rotary_embedding_torch import RotaryEmbedding
20
-
21
- from einops import rearrange, pack, unpack
22
- from einops.layers.torch import Rearrange
23
- import torchaudio
24
-
25
-
26
- def exists(val):
27
- return val is not None
28
-
29
-
30
- def default(v, d):
31
- return v if exists(v) else d
32
-
33
-
34
- def pack_one(t, pattern):
35
- return pack([t], pattern)
36
-
37
-
38
- def unpack_one(t, ps, pattern):
39
- return unpack(t, ps, pattern)[0]
40
-
41
-
42
- def l2norm(t):
43
- return F.normalize(t, dim=-1, p=2)
44
-
45
-
46
- class RMSNorm(Module):
47
- def __init__(self, dim):
48
- super().__init__()
49
- self.scale = dim**0.5
50
- self.gamma = nn.Parameter(torch.ones(dim))
51
-
52
- def forward(self, x):
53
- return F.normalize(x, dim=-1) * self.scale * self.gamma
54
-
55
-
56
- class FeedForward(Module):
57
- def __init__(self, dim, mult=4, dropout=0.0):
58
- super().__init__()
59
- dim_inner = int(dim * mult)
60
- self.net = nn.Sequential(
61
- RMSNorm(dim),
62
- nn.Linear(dim, dim_inner),
63
- nn.GELU(),
64
- nn.Dropout(dropout),
65
- nn.Linear(dim_inner, dim),
66
- nn.Dropout(dropout),
67
- )
68
-
69
- def forward(self, x):
70
- return self.net(x)
71
-
72
-
73
- class Attention(Module):
74
- def __init__(
75
- self,
76
- dim,
77
- heads=8,
78
- dim_head=64,
79
- dropout=0.0,
80
- rotary_embed=None,
81
- flash=True,
82
- sage_attention=False,
83
- ):
84
- super().__init__()
85
- self.heads = heads
86
- self.scale = dim_head**-0.5
87
- dim_inner = heads * dim_head
88
-
89
- self.rotary_embed = rotary_embed
90
-
91
- if sage_attention:
92
- self.attend = AttendSage(flash=flash, dropout=dropout)
93
- else:
94
- self.attend = Attend(flash=flash, dropout=dropout)
95
-
96
- self.norm = RMSNorm(dim)
97
- self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
98
-
99
- self.to_gates = nn.Linear(dim, heads)
100
-
101
- self.to_out = nn.Sequential(
102
- nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout)
103
- )
104
-
105
- def forward(self, x):
106
- x = self.norm(x)
107
-
108
- q, k, v = rearrange(
109
- self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads
110
- )
111
-
112
- if exists(self.rotary_embed):
113
- q = self.rotary_embed.rotate_queries_or_keys(q)
114
- k = self.rotary_embed.rotate_queries_or_keys(k)
115
-
116
- out = self.attend(q, k, v)
117
-
118
- gates = self.to_gates(x)
119
- out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
120
-
121
- out = rearrange(out, "b h n d -> b n (h d)")
122
- return self.to_out(out)
123
-
124
-
125
- class LinearAttention(Module):
126
-
127
- @beartype
128
- def __init__(
129
- self,
130
- *,
131
- dim,
132
- dim_head=32,
133
- heads=8,
134
- scale=8,
135
- flash=True,
136
- dropout=0.0,
137
- sage_attention=False,
138
- ):
139
- super().__init__()
140
- dim_inner = dim_head * heads
141
- self.norm = RMSNorm(dim)
142
-
143
- self.to_qkv = nn.Sequential(
144
- nn.Linear(dim, dim_inner * 3, bias=False),
145
- Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads),
146
- )
147
-
148
- self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
149
-
150
- if sage_attention:
151
- self.attend = AttendSage(scale=scale, dropout=dropout, flash=flash)
152
- else:
153
- self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
154
-
155
- self.to_out = nn.Sequential(
156
- Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)
157
- )
158
-
159
- def forward(self, x):
160
- x = self.norm(x)
161
-
162
- q, k, v = self.to_qkv(x)
163
-
164
- q, k = map(l2norm, (q, k))
165
- q = q * self.temperature.exp()
166
-
167
- out = self.attend(q, k, v)
168
-
169
- return self.to_out(out)
170
-
171
-
172
- class Transformer(Module):
173
- def __init__(
174
- self,
175
- *,
176
- dim,
177
- depth,
178
- dim_head=64,
179
- heads=8,
180
- attn_dropout=0.0,
181
- ff_dropout=0.0,
182
- ff_mult=4,
183
- norm_output=True,
184
- rotary_embed=None,
185
- flash_attn=True,
186
- linear_attn=False,
187
- sage_attention=False,
188
- ):
189
- super().__init__()
190
- self.layers = ModuleList([])
191
-
192
- for _ in range(depth):
193
- if linear_attn:
194
- attn = LinearAttention(
195
- dim=dim,
196
- dim_head=dim_head,
197
- heads=heads,
198
- dropout=attn_dropout,
199
- flash=flash_attn,
200
- sage_attention=sage_attention,
201
- )
202
- else:
203
- attn = Attention(
204
- dim=dim,
205
- dim_head=dim_head,
206
- heads=heads,
207
- dropout=attn_dropout,
208
- rotary_embed=rotary_embed,
209
- flash=flash_attn,
210
- sage_attention=sage_attention,
211
- )
212
-
213
- self.layers.append(
214
- ModuleList(
215
- [attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]
216
- )
217
- )
218
-
219
- self.norm = RMSNorm(dim) if norm_output else nn.Identity()
220
-
221
- def forward(self, x):
222
-
223
- for attn, ff in self.layers:
224
- x = attn(x) + x
225
- x = ff(x) + x
226
-
227
- return self.norm(x)
228
-
229
-
230
- class BandSplit(Module):
231
- @beartype
232
- def __init__(self, dim, dim_inputs: Tuple[int, ...]):
233
- super().__init__()
234
- self.dim_inputs = dim_inputs
235
- self.to_features = ModuleList([])
236
-
237
- for dim_in in dim_inputs:
238
- net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
239
-
240
- self.to_features.append(net)
241
-
242
- def forward(self, x):
243
-
244
- x = x.split(self.dim_inputs, dim=-1)
245
-
246
- outs = []
247
- for split_input, to_feature in zip(x, self.to_features):
248
- split_output = to_feature(split_input)
249
- outs.append(split_output)
250
-
251
- x = torch.stack(outs, dim=-2)
252
-
253
- return x
254
-
255
-
256
- class Conv(nn.Module):
257
- def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):
258
- super().__init__()
259
- self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
260
- self.bn = nn.InstanceNorm2d(c2, affine=True, eps=1e-8)
261
- self.act = nn.SiLU() if act else nn.Identity()
262
-
263
- def forward(self, x):
264
- return self.act(self.bn(self.conv(x)))
265
-
266
-
267
- def autopad(k, p=None):
268
- if p is None:
269
- p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
270
- return p
271
-
272
-
273
- class DSConv(nn.Module):
274
- def __init__(self, c1, c2, k=3, s=1, p=None, act=True):
275
- super().__init__()
276
- self.dwconv = nn.Conv2d(c1, c1, k, s, autopad(k, p), groups=c1, bias=False)
277
- self.pwconv = nn.Conv2d(c1, c2, 1, 1, 0, bias=False)
278
- self.bn = nn.InstanceNorm2d(c2, affine=True, eps=1e-8)
279
- self.act = nn.SiLU() if act else nn.Identity()
280
-
281
- def forward(self, x):
282
- return self.act(self.bn(self.pwconv(self.dwconv(x))))
283
-
284
-
285
- class DS_Bottleneck(nn.Module):
286
- def __init__(self, c1, c2, k=3, shortcut=True):
287
- super().__init__()
288
- c_ = c1
289
- self.dsconv1 = DSConv(c1, c_, k=3, s=1)
290
- self.dsconv2 = DSConv(c_, c2, k=k, s=1)
291
- self.shortcut = shortcut and c1 == c2
292
-
293
- def forward(self, x):
294
- return (
295
- x + self.dsconv2(self.dsconv1(x))
296
- if self.shortcut
297
- else self.dsconv2(self.dsconv1(x))
298
- )
299
-
300
-
301
- class DS_C3k(nn.Module):
302
- def __init__(self, c1, c2, n=1, k=3, e=0.5):
303
- super().__init__()
304
- c_ = int(c2 * e)
305
- self.cv1 = Conv(c1, c_, 1, 1)
306
- self.cv2 = Conv(c1, c_, 1, 1)
307
- self.cv3 = Conv(2 * c_, c2, 1, 1)
308
- self.m = nn.Sequential(
309
- *[DS_Bottleneck(c_, c_, k=k, shortcut=True) for _ in range(n)]
310
- )
311
-
312
- def forward(self, x):
313
- return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
314
-
315
-
316
- class DS_C3k2(nn.Module):
317
- def __init__(self, c1, c2, n=1, k=3, e=0.5):
318
- super().__init__()
319
- c_ = int(c2 * e)
320
- self.cv1 = Conv(c1, c_, 1, 1)
321
- self.m = DS_C3k(c_, c_, n=n, k=k, e=1.0)
322
- self.cv2 = Conv(c_, c2, 1, 1)
323
-
324
- def forward(self, x):
325
- x_ = self.cv1(x)
326
- x_ = self.m(x_)
327
- return self.cv2(x_)
328
-
329
-
330
- class AdaptiveHyperedgeGeneration(nn.Module):
331
- def __init__(self, in_channels, num_hyperedges, num_heads=8):
332
- super().__init__()
333
- self.num_hyperedges = num_hyperedges
334
- self.num_heads = num_heads
335
- self.head_dim = in_channels // num_heads
336
-
337
- self.global_proto = nn.Parameter(torch.randn(num_hyperedges, in_channels))
338
-
339
- self.context_mapper = nn.Linear(
340
- 2 * in_channels, num_hyperedges * in_channels, bias=False
341
- )
342
-
343
- self.query_proj = nn.Linear(in_channels, in_channels, bias=False)
344
-
345
- self.scale = self.head_dim**-0.5
346
-
347
- def forward(self, x):
348
- B, N, C = x.shape
349
-
350
- f_avg = F.adaptive_avg_pool1d(x.permute(0, 2, 1), 1).squeeze(-1)
351
- f_max = F.adaptive_max_pool1d(x.permute(0, 2, 1), 1).squeeze(-1)
352
- f_ctx = torch.cat((f_avg, f_max), dim=1)
353
-
354
- delta_P = self.context_mapper(f_ctx).view(B, self.num_hyperedges, C)
355
- P = self.global_proto.unsqueeze(0) + delta_P
356
-
357
- z = self.query_proj(x)
358
-
359
- z = z.view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
360
-
361
- P = P.view(B, self.num_hyperedges, self.num_heads, self.head_dim).permute(
362
- 0, 2, 3, 1
363
- )
364
-
365
- sim = (z @ P) * self.scale
366
-
367
- s_bar = sim.mean(dim=1)
368
-
369
- A = F.softmax(s_bar.permute(0, 2, 1), dim=-1)
370
-
371
- return A
372
-
373
-
374
- class HypergraphConvolution(nn.Module):
375
- def __init__(self, in_channels, out_channels):
376
- super().__init__()
377
- self.W_e = nn.Linear(in_channels, in_channels, bias=False)
378
- self.W_v = nn.Linear(in_channels, out_channels, bias=False)
379
- self.act = nn.SiLU()
380
-
381
- def forward(self, x, A):
382
- f_m = torch.bmm(A, x)
383
- f_m = self.act(self.W_e(f_m))
384
-
385
- x_out = torch.bmm(A.transpose(1, 2), f_m)
386
- x_out = self.act(self.W_v(x_out))
387
-
388
- return x + x_out
389
-
390
-
391
- class AdaptiveHypergraphComputation(nn.Module):
392
- def __init__(self, in_channels, out_channels, num_hyperedges=8, num_heads=8):
393
- super().__init__()
394
- self.adaptive_hyperedge_gen = AdaptiveHyperedgeGeneration(
395
- in_channels, num_hyperedges, num_heads
396
- )
397
- self.hypergraph_conv = HypergraphConvolution(in_channels, out_channels)
398
-
399
- def forward(self, x):
400
- B, C, H, W = x.shape
401
- x_flat = x.flatten(2).permute(0, 2, 1)
402
-
403
- A = self.adaptive_hyperedge_gen(x_flat)
404
-
405
- x_out_flat = self.hypergraph_conv(x_flat, A)
406
-
407
- x_out = x_out_flat.permute(0, 2, 1).view(B, -1, H, W)
408
- return x_out
409
-
410
-
411
- class C3AH(nn.Module):
412
- def __init__(self, c1, c2, num_hyperedges=8, num_heads=8, e=0.5):
413
- super().__init__()
414
- c_ = int(c1 * e)
415
- self.cv1 = Conv(c1, c_, 1, 1)
416
- self.cv2 = Conv(c1, c_, 1, 1)
417
- self.ahc = AdaptiveHypergraphComputation(c_, c_, num_hyperedges, num_heads)
418
- self.cv3 = Conv(2 * c_, c2, 1, 1)
419
-
420
- def forward(self, x):
421
- x_lateral = self.cv1(x)
422
- x_ahc = self.ahc(self.cv2(x))
423
- return self.cv3(torch.cat((x_ahc, x_lateral), dim=1))
424
-
425
-
426
- class HyperACE(nn.Module):
427
- def __init__(
428
- self,
429
- in_channels: List[int],
430
- out_channels: int,
431
- num_hyperedges=8,
432
- num_heads=8,
433
- k=2,
434
- l=1,
435
- c_h=0.5,
436
- c_l=0.25,
437
- ):
438
- super().__init__()
439
-
440
- c2, c3, c4, c5 = in_channels
441
- c_mid = c4
442
-
443
- self.fuse_conv = Conv(c2 + c3 + c4 + c5, c_mid, 1, 1)
444
-
445
- self.c_h = int(c_mid * c_h)
446
- self.c_l = int(c_mid * c_l)
447
- self.c_s = c_mid - self.c_h - self.c_l
448
- assert self.c_s > 0, "Channel split error"
449
-
450
- self.high_order_branch = nn.ModuleList(
451
- [
452
- C3AH(self.c_h, self.c_h, num_hyperedges, num_heads, e=1.0)
453
- for _ in range(k)
454
- ]
455
- )
456
- self.high_order_fuse = Conv(self.c_h * k, self.c_h, 1, 1)
457
-
458
- self.low_order_branch = nn.Sequential(
459
- *[DS_C3k(self.c_l, self.c_l, n=1, k=3, e=1.0) for _ in range(l)]
460
- )
461
-
462
- self.final_fuse = Conv(self.c_h + self.c_l + self.c_s, out_channels, 1, 1)
463
-
464
- def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
465
- B2, B3, B4, B5 = x
466
-
467
- B, _, H4, W4 = B4.shape
468
-
469
- B2_resized = F.interpolate(
470
- B2, size=(H4, W4), mode="bilinear", align_corners=False
471
- )
472
- B3_resized = F.interpolate(
473
- B3, size=(H4, W4), mode="bilinear", align_corners=False
474
- )
475
- B5_resized = F.interpolate(
476
- B5, size=(H4, W4), mode="bilinear", align_corners=False
477
- )
478
-
479
- x_b = self.fuse_conv(torch.cat((B2_resized, B3_resized, B4, B5_resized), dim=1))
480
-
481
- x_h, x_l, x_s = torch.split(x_b, [self.c_h, self.c_l, self.c_s], dim=1)
482
-
483
- x_h_outs = [m(x_h) for m in self.high_order_branch]
484
- x_h_fused = self.high_order_fuse(torch.cat(x_h_outs, dim=1))
485
-
486
- x_l_out = self.low_order_branch(x_l)
487
-
488
- y = self.final_fuse(torch.cat((x_h_fused, x_l_out, x_s), dim=1))
489
-
490
- return y
491
-
492
-
493
- class GatedFusion(nn.Module):
494
- def __init__(self, in_channels):
495
- super().__init__()
496
- self.gamma = nn.Parameter(torch.zeros(1, in_channels, 1, 1))
497
-
498
- def forward(self, f_in, h):
499
- if f_in.shape[1] != h.shape[1]:
500
- raise ValueError(f"Channel mismatch: f_in={f_in.shape}, h={h.shape}")
501
- return f_in + self.gamma * h
502
-
503
-
504
- class Backbone(nn.Module):
505
- def __init__(self, in_channels=256, base_channels=64, base_depth=3):
506
- super().__init__()
507
- c = base_channels
508
- c2 = base_channels
509
- c3 = 256
510
- c4 = 384
511
- c5 = 512
512
- c6 = 768
513
-
514
- self.stem = DSConv(in_channels, c2, k=3, s=(2, 1), p=1)
515
-
516
- self.p2 = nn.Sequential(
517
- DSConv(c2, c3, k=3, s=(2, 1), p=1), DS_C3k2(c3, c3, n=base_depth)
518
- )
519
-
520
- self.p3 = nn.Sequential(
521
- DSConv(c3, c4, k=3, s=(2, 1), p=1), DS_C3k2(c4, c4, n=base_depth * 2)
522
- )
523
-
524
- self.p4 = nn.Sequential(
525
- DSConv(c4, c5, k=3, s=(2, 1), p=1), DS_C3k2(c5, c5, n=base_depth * 2)
526
- )
527
-
528
- self.p5 = nn.Sequential(
529
- DSConv(c5, c6, k=3, s=(2, 1), p=1), DS_C3k2(c6, c6, n=base_depth)
530
- )
531
-
532
- self.out_channels = [c3, c4, c5, c6]
533
-
534
- def forward(self, x):
535
- x = self.stem(x)
536
- x2 = self.p2(x)
537
- x3 = self.p3(x2)
538
- x4 = self.p4(x3)
539
- x5 = self.p5(x4)
540
- return [x2, x3, x4, x5]
541
-
542
-
543
- class Decoder(nn.Module):
544
- def __init__(
545
- self,
546
- encoder_channels: List[int],
547
- hyperace_out_c: int,
548
- decoder_channels: List[int],
549
- ):
550
- super().__init__()
551
- c_p2, c_p3, c_p4, c_p5 = encoder_channels
552
- c_d2, c_d3, c_d4, c_d5 = decoder_channels
553
-
554
- self.h_to_d5 = Conv(hyperace_out_c, c_d5, 1, 1)
555
- self.h_to_d4 = Conv(hyperace_out_c, c_d4, 1, 1)
556
- self.h_to_d3 = Conv(hyperace_out_c, c_d3, 1, 1)
557
- self.h_to_d2 = Conv(hyperace_out_c, c_d2, 1, 1)
558
-
559
- self.fusion_d5 = GatedFusion(c_d5)
560
- self.fusion_d4 = GatedFusion(c_d4)
561
- self.fusion_d3 = GatedFusion(c_d3)
562
- self.fusion_d2 = GatedFusion(c_d2)
563
-
564
- self.skip_p5 = Conv(c_p5, c_d5, 1, 1)
565
- self.skip_p4 = Conv(c_p4, c_d4, 1, 1)
566
- self.skip_p3 = Conv(c_p3, c_d3, 1, 1)
567
- self.skip_p2 = Conv(c_p2, c_d2, 1, 1)
568
-
569
- self.up_d5 = DS_C3k2(c_d5, c_d4, n=1)
570
- self.up_d4 = DS_C3k2(c_d4, c_d3, n=1)
571
- self.up_d3 = DS_C3k2(c_d3, c_d2, n=1)
572
-
573
- self.final_d2 = DS_C3k2(c_d2, c_d2, n=1)
574
-
575
- def forward(self, enc_feats: List[torch.Tensor], h_ace: torch.Tensor):
576
- p2, p3, p4, p5 = enc_feats
577
-
578
- d5 = self.skip_p5(p5)
579
- h_d5 = self.h_to_d5(F.interpolate(h_ace, size=d5.shape[2:], mode="bilinear"))
580
- d5 = self.fusion_d5(d5, h_d5)
581
-
582
- d5_up = F.interpolate(d5, size=p4.shape[2:], mode="bilinear")
583
- d4_skip = self.skip_p4(p4)
584
- d4 = self.up_d5(d5_up) + d4_skip
585
-
586
- h_d4 = self.h_to_d4(F.interpolate(h_ace, size=d4.shape[2:], mode="bilinear"))
587
- d4 = self.fusion_d4(d4, h_d4)
588
-
589
- d4_up = F.interpolate(d4, size=p3.shape[2:], mode="bilinear")
590
- d3_skip = self.skip_p3(p3)
591
- d3 = self.up_d4(d4_up) + d3_skip
592
-
593
- h_d3 = self.h_to_d3(F.interpolate(h_ace, size=d3.shape[2:], mode="bilinear"))
594
- d3 = self.fusion_d3(d3, h_d3)
595
-
596
- d3_up = F.interpolate(d3, size=p2.shape[2:], mode="bilinear")
597
- d2_skip = self.skip_p2(p2)
598
- d2 = self.up_d3(d3_up) + d2_skip
599
-
600
- h_d2 = self.h_to_d2(F.interpolate(h_ace, size=d2.shape[2:], mode="bilinear"))
601
- d2 = self.fusion_d2(d2, h_d2)
602
-
603
- d2_final = self.final_d2(d2)
604
-
605
- return d2_final
606
-
607
-
608
- class FreqPixelShuffle(nn.Module):
609
- def __init__(self, in_channels, out_channels, scale=2):
610
- super().__init__()
611
- self.scale = scale
612
- self.conv = DSConv(in_channels, out_channels * scale, k=3, s=1, p=1)
613
- self.act = nn.SiLU()
614
-
615
- def forward(self, x):
616
- x = self.conv(x)
617
- B, C_r, H, W = x.shape
618
- out_c = C_r // self.scale
619
-
620
- x = x.view(B, out_c, self.scale, H, W)
621
-
622
- x = x.permute(0, 1, 3, 4, 2).contiguous()
623
- x = x.view(B, out_c, H, W * self.scale)
624
-
625
- return x
626
-
627
-
628
- class ProgressiveUpsampleHead(nn.Module):
629
- def __init__(self, in_channels, out_channels, target_bins=1025):
630
- super().__init__()
631
- self.target_bins = target_bins
632
-
633
- c = in_channels
634
-
635
- self.block1 = FreqPixelShuffle(c, c, scale=2)
636
- self.block2 = FreqPixelShuffle(c, c // 2, scale=2)
637
- self.block3 = FreqPixelShuffle(c // 2, c // 2, scale=2)
638
- self.block4 = FreqPixelShuffle(c // 2, c // 4, scale=2)
639
-
640
- self.final_conv = nn.Conv2d(c // 4, out_channels, kernel_size=1, bias=False)
641
-
642
- def forward(self, x):
643
-
644
- x = self.block1(x)
645
- x = self.block2(x)
646
- x = self.block3(x)
647
- x = self.block4(x)
648
-
649
- if x.shape[-1] != self.target_bins:
650
- x = F.interpolate(
651
- x,
652
- size=(x.shape[2], self.target_bins),
653
- mode="bilinear",
654
- align_corners=False,
655
- )
656
-
657
- x = self.final_conv(x)
658
- return x
659
-
660
-
661
- class SegmModel(nn.Module):
662
- def __init__(
663
- self,
664
- in_bands=62,
665
- in_dim=256,
666
- out_bins=1025,
667
- out_channels=4,
668
- base_channels=64,
669
- base_depth=2,
670
- num_hyperedges=16,
671
- num_heads=8,
672
- ):
673
- super().__init__()
674
-
675
- self.backbone = Backbone(
676
- in_channels=in_dim, base_channels=base_channels, base_depth=base_depth
677
- )
678
- enc_channels = self.backbone.out_channels
679
- c2, c3, c4, c5 = enc_channels
680
-
681
- hyperace_in_channels = enc_channels
682
- hyperace_out_channels = c4
683
- self.hyperace = HyperACE(
684
- hyperace_in_channels,
685
- hyperace_out_channels,
686
- num_hyperedges,
687
- num_heads,
688
- k=3,
689
- l=2,
690
- )
691
-
692
- decoder_channels = [c2, c3, c4, c5]
693
- self.decoder = Decoder(enc_channels, hyperace_out_channels, decoder_channels)
694
-
695
- self.upsample_head = ProgressiveUpsampleHead(
696
- in_channels=decoder_channels[0],
697
- out_channels=out_channels,
698
- target_bins=out_bins,
699
- )
700
-
701
- def forward(self, x):
702
- H, W = x.shape[2:]
703
-
704
- enc_feats = self.backbone(x)
705
-
706
- h_ace_feats = self.hyperace(enc_feats)
707
-
708
- dec_feat = self.decoder(enc_feats, h_ace_feats)
709
-
710
- feat_time_restored = F.interpolate(
711
- dec_feat, size=(H, dec_feat.shape[-1]), mode="bilinear", align_corners=False
712
- )
713
-
714
- out = self.upsample_head(feat_time_restored)
715
-
716
- return out
717
-
718
-
719
- def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
720
- dim_hidden = default(dim_hidden, dim_in)
721
-
722
- net = []
723
- dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
724
-
725
- for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
726
- is_last = ind == (len(dims) - 2)
727
-
728
- net.append(nn.Linear(layer_dim_in, layer_dim_out))
729
-
730
- if is_last:
731
- continue
732
-
733
- net.append(activation())
734
-
735
- return nn.Sequential(*net)
736
-
737
-
738
- class MaskEstimator(Module):
739
- @beartype
740
- def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
741
- super().__init__()
742
- self.dim_inputs = dim_inputs
743
- self.to_freqs = ModuleList([])
744
- dim_hidden = dim * mlp_expansion_factor
745
-
746
- for dim_in in dim_inputs:
747
- net = []
748
-
749
- mlp = nn.Sequential(
750
- MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1)
751
- )
752
-
753
- self.to_freqs.append(mlp)
754
-
755
- self.segm = SegmModel(
756
- in_bands=len(dim_inputs), in_dim=dim, out_bins=sum(dim_inputs) // 4
757
- )
758
-
759
- def forward(self, x):
760
- y = rearrange(x, "b t f c -> b c t f")
761
- y = self.segm(y)
762
- y = rearrange(y, "b c t f -> b t (f c)")
763
-
764
- x = x.unbind(dim=-2)
765
-
766
- outs = []
767
-
768
- for band_features, mlp in zip(x, self.to_freqs):
769
- freq_out = mlp(band_features)
770
- outs.append(freq_out)
771
-
772
- return torch.cat(outs, dim=-1) + y
773
-
774
-
775
- DEFAULT_FREQS_PER_BANDS = (
776
- 2,
777
- 2,
778
- 2,
779
- 2,
780
- 2,
781
- 2,
782
- 2,
783
- 2,
784
- 2,
785
- 2,
786
- 2,
787
- 2,
788
- 2,
789
- 2,
790
- 2,
791
- 2,
792
- 2,
793
- 2,
794
- 2,
795
- 2,
796
- 2,
797
- 2,
798
- 2,
799
- 2,
800
- 4,
801
- 4,
802
- 4,
803
- 4,
804
- 4,
805
- 4,
806
- 4,
807
- 4,
808
- 4,
809
- 4,
810
- 4,
811
- 4,
812
- 12,
813
- 12,
814
- 12,
815
- 12,
816
- 12,
817
- 12,
818
- 12,
819
- 12,
820
- 24,
821
- 24,
822
- 24,
823
- 24,
824
- 24,
825
- 24,
826
- 24,
827
- 24,
828
- 48,
829
- 48,
830
- 48,
831
- 48,
832
- 48,
833
- 48,
834
- 48,
835
- 48,
836
- 128,
837
- 129,
838
- )
839
-
840
-
841
- class BSRoformerHyperACE(Module):
842
-
843
- @beartype
844
- def __init__(
845
- self,
846
- dim,
847
- *,
848
- depth,
849
- stereo=False,
850
- num_stems=1,
851
- time_transformer_depth=2,
852
- freq_transformer_depth=2,
853
- linear_transformer_depth=0,
854
- freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
855
- dim_head=64,
856
- heads=8,
857
- attn_dropout=0.0,
858
- ff_dropout=0.0,
859
- flash_attn=True,
860
- dim_freqs_in=1025,
861
- stft_n_fft=2048,
862
- stft_hop_length=512,
863
- stft_win_length=2048,
864
- stft_normalized=False,
865
- stft_window_fn: Optional[Callable] = None,
866
- mask_estimator_depth=2,
867
- multi_stft_resolution_loss_weight=1.0,
868
- multi_stft_resolutions_window_sizes: Tuple[int, ...] = (
869
- 4096,
870
- 2048,
871
- 1024,
872
- 512,
873
- 256,
874
- ),
875
- multi_stft_hop_size=147,
876
- multi_stft_normalized=False,
877
- multi_stft_window_fn: Callable = torch.hann_window,
878
- mlp_expansion_factor=4,
879
- use_torch_checkpoint=False,
880
- skip_connection=False,
881
- sage_attention=False,
882
- ):
883
- super().__init__()
884
-
885
- self.stereo = stereo
886
- self.audio_channels = 2 if stereo else 1
887
- self.num_stems = num_stems
888
- self.use_torch_checkpoint = use_torch_checkpoint
889
- self.skip_connection = skip_connection
890
-
891
- self.layers = ModuleList([])
892
-
893
- if sage_attention:
894
- print("Use Sage Attention")
895
-
896
- transformer_kwargs = dict(
897
- dim=dim,
898
- heads=heads,
899
- dim_head=dim_head,
900
- attn_dropout=attn_dropout,
901
- ff_dropout=ff_dropout,
902
- flash_attn=flash_attn,
903
- norm_output=False,
904
- sage_attention=sage_attention,
905
- )
906
-
907
- time_rotary_embed = RotaryEmbedding(dim=dim_head)
908
- freq_rotary_embed = RotaryEmbedding(dim=dim_head)
909
-
910
- for _ in range(depth):
911
- tran_modules = []
912
- tran_modules.append(
913
- Transformer(
914
- depth=time_transformer_depth,
915
- rotary_embed=time_rotary_embed,
916
- **transformer_kwargs,
917
- )
918
- )
919
- tran_modules.append(
920
- Transformer(
921
- depth=freq_transformer_depth,
922
- rotary_embed=freq_rotary_embed,
923
- **transformer_kwargs,
924
- )
925
- )
926
- self.layers.append(nn.ModuleList(tran_modules))
927
-
928
- self.final_norm = RMSNorm(dim)
929
-
930
- self.stft_kwargs = dict(
931
- n_fft=stft_n_fft,
932
- hop_length=stft_hop_length,
933
- win_length=stft_win_length,
934
- normalized=stft_normalized,
935
- )
936
-
937
- self.stft_window_fn = partial(
938
- default(stft_window_fn, torch.hann_window), stft_win_length
939
- )
940
-
941
- freqs = torch.stft(
942
- torch.randn(1, 4096),
943
- **self.stft_kwargs,
944
- window=torch.ones(stft_win_length),
945
- return_complex=True,
946
- ).shape[1]
947
-
948
- assert len(freqs_per_bands) > 1
949
- assert (
950
- sum(freqs_per_bands) == freqs
951
- ), f"the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}"
952
-
953
- freqs_per_bands_with_complex = tuple(
954
- 2 * f * self.audio_channels for f in freqs_per_bands
955
- )
956
-
957
- self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
958
-
959
- self.mask_estimators = nn.ModuleList([])
960
-
961
- for _ in range(num_stems):
962
- mask_estimator = MaskEstimator(
963
- dim=dim,
964
- dim_inputs=freqs_per_bands_with_complex,
965
- depth=mask_estimator_depth,
966
- mlp_expansion_factor=mlp_expansion_factor,
967
- )
968
-
969
- self.mask_estimators.append(mask_estimator)
970
-
971
- self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
972
- self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
973
- self.multi_stft_n_fft = stft_n_fft
974
- self.multi_stft_window_fn = multi_stft_window_fn
975
-
976
- self.multi_stft_kwargs = dict(
977
- hop_length=multi_stft_hop_size, normalized=multi_stft_normalized
978
- )
979
-
980
- def forward(self, raw_audio, target=None, return_loss_breakdown=False):
981
-
982
- device = raw_audio.device
983
-
984
- x_is_mps = True if device.type == "mps" else False
985
-
986
- if raw_audio.ndim == 2:
987
- raw_audio = rearrange(raw_audio, "b t -> b 1 t")
988
-
989
- channels = raw_audio.shape[1]
990
- assert (not self.stereo and channels == 1) or (
991
- self.stereo and channels == 2
992
- ), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
993
-
994
- raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
995
-
996
- stft_window = self.stft_window_fn(device=device)
997
-
998
- try:
999
- stft_repr = torch.stft(
1000
- raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True
1001
- )
1002
- except:
1003
- stft_repr = torch.stft(
1004
- raw_audio.cpu() if x_is_mps else raw_audio,
1005
- **self.stft_kwargs,
1006
- window=stft_window.cpu() if x_is_mps else stft_window,
1007
- return_complex=True,
1008
- ).to(device)
1009
- stft_repr = torch.view_as_real(stft_repr)
1010
-
1011
- stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
1012
-
1013
- stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c")
1014
-
1015
- x = rearrange(stft_repr, "b f t c -> b t (f c)")
1016
-
1017
- x = self.band_split(x)
1018
-
1019
- for i, transformer_block in enumerate(self.layers):
1020
-
1021
- time_transformer, freq_transformer = transformer_block
1022
-
1023
- x = rearrange(x, "b t f d -> b f t d")
1024
- x, ps = pack([x], "* t d")
1025
-
1026
- x = time_transformer(x)
1027
-
1028
- (x,) = unpack(x, ps, "* t d")
1029
- x = rearrange(x, "b f t d -> b t f d")
1030
- x, ps = pack([x], "* f d")
1031
-
1032
- x = freq_transformer(x)
1033
-
1034
- (x,) = unpack(x, ps, "* f d")
1035
-
1036
- x = self.final_norm(x)
1037
-
1038
- num_stems = len(self.mask_estimators)
1039
-
1040
- mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
1041
- mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2)
1042
-
1043
- stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
1044
-
1045
- stft_repr = torch.view_as_complex(stft_repr)
1046
- mask = torch.view_as_complex(mask)
1047
-
1048
- stft_repr = stft_repr * mask
1049
-
1050
- stft_repr = rearrange(
1051
- stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels
1052
- )
1053
-
1054
- try:
1055
- recon_audio = torch.istft(
1056
- stft_repr,
1057
- **self.stft_kwargs,
1058
- window=stft_window,
1059
- return_complex=False,
1060
- length=raw_audio.shape[-1],
1061
- )
1062
- except:
1063
- recon_audio = torch.istft(
1064
- stft_repr.cpu() if x_is_mps else stft_repr,
1065
- **self.stft_kwargs,
1066
- window=stft_window.cpu() if x_is_mps else stft_window,
1067
- return_complex=False,
1068
- length=raw_audio.shape[-1],
1069
- ).to(device)
1070
-
1071
- recon_audio = rearrange(
1072
- recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=num_stems
1073
- )
1074
-
1075
- if num_stems == 1:
1076
- recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
1077
-
1078
- if not exists(target):
1079
- return recon_audio
1080
-
1081
- if self.num_stems > 1:
1082
- assert target.ndim == 4 and target.shape[1] == self.num_stems
1083
-
1084
- if target.ndim == 2:
1085
- target = rearrange(target, "... t -> ... 1 t")
1086
-
1087
- target = target[..., : recon_audio.shape[-1]]
1088
-
1089
- loss = F.l1_loss(recon_audio, target)
1090
-
1091
- multi_stft_resolution_loss = 0.0
1092
-
1093
- for window_size in self.multi_stft_resolutions_window_sizes:
1094
- res_stft_kwargs = dict(
1095
- n_fft=max(window_size, self.multi_stft_n_fft),
1096
- win_length=window_size,
1097
- return_complex=True,
1098
- window=self.multi_stft_window_fn(window_size, device=device),
1099
- **self.multi_stft_kwargs,
1100
- )
1101
-
1102
- recon_Y = torch.stft(
1103
- rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs
1104
- )
1105
- target_Y = torch.stft(
1106
- rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs
1107
- )
1108
-
1109
- multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(
1110
- recon_Y, target_Y
1111
- )
1112
-
1113
- weighted_multi_resolution_loss = (
1114
- multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
1115
- )
1116
-
1117
- total_loss = loss + weighted_multi_resolution_loss
1118
-
1119
- if not return_loss_breakdown:
1120
- return total_loss
1121
-
1122
- return total_loss, (loss, multi_stft_resolution_loss)
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ from torch import nn, einsum, Tensor
5
+ from torch.nn import Module, ModuleList
6
+ import torch.nn.functional as F
7
+
8
+ from .attend import Attend
9
+
10
+ try:
11
+ from .attend_sage import Attend as AttendSage
12
+ except:
13
+ pass
14
+ from torch.utils.checkpoint import checkpoint
15
+
16
+ from beartype.typing import Tuple, Optional, List, Callable
17
+ from beartype import beartype
18
+
19
+ from rotary_embedding_torch import RotaryEmbedding
20
+
21
+ from einops import rearrange, pack, unpack
22
+ from einops.layers.torch import Rearrange
23
+ import torchaudio
24
+
25
+
26
+ def exists(val):
27
+ return val is not None
28
+
29
+
30
+ def default(v, d):
31
+ return v if exists(v) else d
32
+
33
+
34
+ def pack_one(t, pattern):
35
+ return pack([t], pattern)
36
+
37
+
38
+ def unpack_one(t, ps, pattern):
39
+ return unpack(t, ps, pattern)[0]
40
+
41
+
42
+ def l2norm(t):
43
+ return F.normalize(t, dim=-1, p=2)
44
+
45
+
46
+ class RMSNorm(Module):
47
+ def __init__(self, dim):
48
+ super().__init__()
49
+ self.scale = dim**0.5
50
+ self.gamma = nn.Parameter(torch.ones(dim))
51
+
52
+ def forward(self, x):
53
+ return F.normalize(x, dim=-1) * self.scale * self.gamma
54
+
55
+
56
+ class FeedForward(Module):
57
+ def __init__(self, dim, mult=4, dropout=0.0):
58
+ super().__init__()
59
+ dim_inner = int(dim * mult)
60
+ self.net = nn.Sequential(
61
+ RMSNorm(dim),
62
+ nn.Linear(dim, dim_inner),
63
+ nn.GELU(),
64
+ nn.Dropout(dropout),
65
+ nn.Linear(dim_inner, dim),
66
+ nn.Dropout(dropout),
67
+ )
68
+
69
+ def forward(self, x):
70
+ return self.net(x)
71
+
72
+
73
+ class Attention(Module):
74
+ def __init__(
75
+ self,
76
+ dim,
77
+ heads=8,
78
+ dim_head=64,
79
+ dropout=0.0,
80
+ rotary_embed=None,
81
+ flash=True,
82
+ sage_attention=False,
83
+ ):
84
+ super().__init__()
85
+ self.heads = heads
86
+ self.scale = dim_head**-0.5
87
+ dim_inner = heads * dim_head
88
+
89
+ self.rotary_embed = rotary_embed
90
+
91
+ if sage_attention:
92
+ self.attend = AttendSage(flash=flash, dropout=dropout)
93
+ else:
94
+ self.attend = Attend(flash=flash, dropout=dropout)
95
+
96
+ self.norm = RMSNorm(dim)
97
+ self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
98
+
99
+ self.to_gates = nn.Linear(dim, heads)
100
+
101
+ self.to_out = nn.Sequential(
102
+ nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout)
103
+ )
104
+
105
+ def forward(self, x):
106
+ x = self.norm(x)
107
+
108
+ q, k, v = rearrange(
109
+ self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads
110
+ )
111
+
112
+ if exists(self.rotary_embed):
113
+ q = self.rotary_embed.rotate_queries_or_keys(q)
114
+ k = self.rotary_embed.rotate_queries_or_keys(k)
115
+
116
+ out = self.attend(q, k, v)
117
+
118
+ gates = self.to_gates(x)
119
+ out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
120
+
121
+ out = rearrange(out, "b h n d -> b n (h d)")
122
+ return self.to_out(out)
123
+
124
+
125
+ class LinearAttention(Module):
126
+
127
+ @beartype
128
+ def __init__(
129
+ self,
130
+ *,
131
+ dim,
132
+ dim_head=32,
133
+ heads=8,
134
+ scale=8,
135
+ flash=True,
136
+ dropout=0.0,
137
+ sage_attention=False,
138
+ ):
139
+ super().__init__()
140
+ dim_inner = dim_head * heads
141
+ self.norm = RMSNorm(dim)
142
+
143
+ self.to_qkv = nn.Sequential(
144
+ nn.Linear(dim, dim_inner * 3, bias=False),
145
+ Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads),
146
+ )
147
+
148
+ self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
149
+
150
+ if sage_attention:
151
+ self.attend = AttendSage(scale=scale, dropout=dropout, flash=flash)
152
+ else:
153
+ self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
154
+
155
+ self.to_out = nn.Sequential(
156
+ Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)
157
+ )
158
+
159
+ def forward(self, x):
160
+ x = self.norm(x)
161
+
162
+ q, k, v = self.to_qkv(x)
163
+
164
+ q, k = map(l2norm, (q, k))
165
+ q = q * self.temperature.exp()
166
+
167
+ out = self.attend(q, k, v)
168
+
169
+ return self.to_out(out)
170
+
171
+
172
+ class Transformer(Module):
173
+ def __init__(
174
+ self,
175
+ *,
176
+ dim,
177
+ depth,
178
+ dim_head=64,
179
+ heads=8,
180
+ attn_dropout=0.0,
181
+ ff_dropout=0.0,
182
+ ff_mult=4,
183
+ norm_output=True,
184
+ rotary_embed=None,
185
+ flash_attn=True,
186
+ linear_attn=False,
187
+ sage_attention=False,
188
+ ):
189
+ super().__init__()
190
+ self.layers = ModuleList([])
191
+
192
+ for _ in range(depth):
193
+ if linear_attn:
194
+ attn = LinearAttention(
195
+ dim=dim,
196
+ dim_head=dim_head,
197
+ heads=heads,
198
+ dropout=attn_dropout,
199
+ flash=flash_attn,
200
+ sage_attention=sage_attention,
201
+ )
202
+ else:
203
+ attn = Attention(
204
+ dim=dim,
205
+ dim_head=dim_head,
206
+ heads=heads,
207
+ dropout=attn_dropout,
208
+ rotary_embed=rotary_embed,
209
+ flash=flash_attn,
210
+ sage_attention=sage_attention,
211
+ )
212
+
213
+ self.layers.append(
214
+ ModuleList(
215
+ [attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]
216
+ )
217
+ )
218
+
219
+ self.norm = RMSNorm(dim) if norm_output else nn.Identity()
220
+
221
+ def forward(self, x):
222
+
223
+ for attn, ff in self.layers:
224
+ x = attn(x) + x
225
+ x = ff(x) + x
226
+
227
+ return self.norm(x)
228
+
229
+
230
+ class BandSplit(Module):
231
+ @beartype
232
+ def __init__(self, dim, dim_inputs: Tuple[int, ...]):
233
+ super().__init__()
234
+ self.dim_inputs = dim_inputs
235
+ self.to_features = ModuleList([])
236
+
237
+ for dim_in in dim_inputs:
238
+ net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
239
+
240
+ self.to_features.append(net)
241
+
242
+ def forward(self, x):
243
+
244
+ x = x.split(self.dim_inputs, dim=-1)
245
+
246
+ outs = []
247
+ for split_input, to_feature in zip(x, self.to_features):
248
+ split_output = to_feature(split_input)
249
+ outs.append(split_output)
250
+
251
+ x = torch.stack(outs, dim=-2)
252
+
253
+ return x
254
+
255
+
256
+ class Conv(nn.Module):
257
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):
258
+ super().__init__()
259
+ self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
260
+ self.bn = nn.InstanceNorm2d(c2, affine=True, eps=1e-8)
261
+ self.act = nn.SiLU() if act else nn.Identity()
262
+
263
+ def forward(self, x):
264
+ return self.act(self.bn(self.conv(x)))
265
+
266
+
267
+ def autopad(k, p=None):
268
+ if p is None:
269
+ p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
270
+ return p
271
+
272
+
273
+ class DSConv(nn.Module):
274
+ def __init__(self, c1, c2, k=3, s=1, p=None, act=True):
275
+ super().__init__()
276
+ self.dwconv = nn.Conv2d(c1, c1, k, s, autopad(k, p), groups=c1, bias=False)
277
+ self.pwconv = nn.Conv2d(c1, c2, 1, 1, 0, bias=False)
278
+ self.bn = nn.InstanceNorm2d(c2, affine=True, eps=1e-8)
279
+ self.act = nn.SiLU() if act else nn.Identity()
280
+
281
+ def forward(self, x):
282
+ return self.act(self.bn(self.pwconv(self.dwconv(x))))
283
+
284
+
285
+ class DS_Bottleneck(nn.Module):
286
+ def __init__(self, c1, c2, k=3, shortcut=True):
287
+ super().__init__()
288
+ c_ = c1
289
+ self.dsconv1 = DSConv(c1, c_, k=3, s=1)
290
+ self.dsconv2 = DSConv(c_, c2, k=k, s=1)
291
+ self.shortcut = shortcut and c1 == c2
292
+
293
+ def forward(self, x):
294
+ return (
295
+ x + self.dsconv2(self.dsconv1(x))
296
+ if self.shortcut
297
+ else self.dsconv2(self.dsconv1(x))
298
+ )
299
+
300
+
301
+ class DS_C3k(nn.Module):
302
+ def __init__(self, c1, c2, n=1, k=3, e=0.5):
303
+ super().__init__()
304
+ c_ = int(c2 * e)
305
+ self.cv1 = Conv(c1, c_, 1, 1)
306
+ self.cv2 = Conv(c1, c_, 1, 1)
307
+ self.cv3 = Conv(2 * c_, c2, 1, 1)
308
+ self.m = nn.Sequential(
309
+ *[DS_Bottleneck(c_, c_, k=k, shortcut=True) for _ in range(n)]
310
+ )
311
+
312
+ def forward(self, x):
313
+ return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
314
+
315
+
316
+ class DS_C3k2(nn.Module):
317
+ def __init__(self, c1, c2, n=1, k=3, e=0.5):
318
+ super().__init__()
319
+ c_ = int(c2 * e)
320
+ self.cv1 = Conv(c1, c_, 1, 1)
321
+ self.m = DS_C3k(c_, c_, n=n, k=k, e=1.0)
322
+ self.cv2 = Conv(c_, c2, 1, 1)
323
+
324
+ def forward(self, x):
325
+ x_ = self.cv1(x)
326
+ x_ = self.m(x_)
327
+ return self.cv2(x_)
328
+
329
+
330
+ class AdaptiveHyperedgeGeneration(nn.Module):
331
+ def __init__(self, in_channels, num_hyperedges, num_heads=8):
332
+ super().__init__()
333
+ self.num_hyperedges = num_hyperedges
334
+ self.num_heads = num_heads
335
+ self.head_dim = in_channels // num_heads
336
+
337
+ self.global_proto = nn.Parameter(torch.randn(num_hyperedges, in_channels))
338
+
339
+ self.context_mapper = nn.Linear(
340
+ 2 * in_channels, num_hyperedges * in_channels, bias=False
341
+ )
342
+
343
+ self.query_proj = nn.Linear(in_channels, in_channels, bias=False)
344
+
345
+ self.scale = self.head_dim**-0.5
346
+
347
+ def forward(self, x):
348
+ B, N, C = x.shape
349
+
350
+ f_avg = F.adaptive_avg_pool1d(x.permute(0, 2, 1), 1).squeeze(-1)
351
+ f_max = F.adaptive_max_pool1d(x.permute(0, 2, 1), 1).squeeze(-1)
352
+ f_ctx = torch.cat((f_avg, f_max), dim=1)
353
+
354
+ delta_P = self.context_mapper(f_ctx).view(B, self.num_hyperedges, C)
355
+ P = self.global_proto.unsqueeze(0) + delta_P
356
+
357
+ z = self.query_proj(x)
358
+
359
+ z = z.view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
360
+
361
+ P = P.view(B, self.num_hyperedges, self.num_heads, self.head_dim).permute(
362
+ 0, 2, 3, 1
363
+ )
364
+
365
+ sim = (z @ P) * self.scale
366
+
367
+ s_bar = sim.mean(dim=1)
368
+
369
+ A = F.softmax(s_bar.permute(0, 2, 1), dim=-1)
370
+
371
+ return A
372
+
373
+
374
+ class HypergraphConvolution(nn.Module):
375
+ def __init__(self, in_channels, out_channels):
376
+ super().__init__()
377
+ self.W_e = nn.Linear(in_channels, in_channels, bias=False)
378
+ self.W_v = nn.Linear(in_channels, out_channels, bias=False)
379
+ self.act = nn.SiLU()
380
+
381
+ def forward(self, x, A):
382
+ f_m = torch.bmm(A, x)
383
+ f_m = self.act(self.W_e(f_m))
384
+
385
+ x_out = torch.bmm(A.transpose(1, 2), f_m)
386
+ x_out = self.act(self.W_v(x_out))
387
+
388
+ return x + x_out
389
+
390
+
391
+ class AdaptiveHypergraphComputation(nn.Module):
392
+ def __init__(self, in_channels, out_channels, num_hyperedges=8, num_heads=8):
393
+ super().__init__()
394
+ self.adaptive_hyperedge_gen = AdaptiveHyperedgeGeneration(
395
+ in_channels, num_hyperedges, num_heads
396
+ )
397
+ self.hypergraph_conv = HypergraphConvolution(in_channels, out_channels)
398
+
399
+ def forward(self, x):
400
+ B, C, H, W = x.shape
401
+ x_flat = x.flatten(2).permute(0, 2, 1)
402
+
403
+ A = self.adaptive_hyperedge_gen(x_flat)
404
+
405
+ x_out_flat = self.hypergraph_conv(x_flat, A)
406
+
407
+ x_out = x_out_flat.permute(0, 2, 1).view(B, -1, H, W)
408
+ return x_out
409
+
410
+
411
+ class C3AH(nn.Module):
412
+ def __init__(self, c1, c2, num_hyperedges=8, num_heads=8, e=0.5):
413
+ super().__init__()
414
+ c_ = int(c1 * e)
415
+ self.cv1 = Conv(c1, c_, 1, 1)
416
+ self.cv2 = Conv(c1, c_, 1, 1)
417
+ self.ahc = AdaptiveHypergraphComputation(c_, c_, num_hyperedges, num_heads)
418
+ self.cv3 = Conv(2 * c_, c2, 1, 1)
419
+
420
+ def forward(self, x):
421
+ x_lateral = self.cv1(x)
422
+ x_ahc = self.ahc(self.cv2(x))
423
+ return self.cv3(torch.cat((x_ahc, x_lateral), dim=1))
424
+
425
+
426
+ class HyperACE(nn.Module):
427
+ def __init__(
428
+ self,
429
+ in_channels: List[int],
430
+ out_channels: int,
431
+ num_hyperedges=8,
432
+ num_heads=8,
433
+ k=2,
434
+ l=1,
435
+ c_h=0.5,
436
+ c_l=0.25,
437
+ ):
438
+ super().__init__()
439
+
440
+ c2, c3, c4, c5 = in_channels
441
+ c_mid = c4
442
+
443
+ self.fuse_conv = Conv(c2 + c3 + c4 + c5, c_mid, 1, 1)
444
+
445
+ self.c_h = int(c_mid * c_h)
446
+ self.c_l = int(c_mid * c_l)
447
+ self.c_s = c_mid - self.c_h - self.c_l
448
+ assert self.c_s > 0, "Channel split error"
449
+
450
+ self.high_order_branch = nn.ModuleList(
451
+ [
452
+ C3AH(self.c_h, self.c_h, num_hyperedges, num_heads, e=1.0)
453
+ for _ in range(k)
454
+ ]
455
+ )
456
+ self.high_order_fuse = Conv(self.c_h * k, self.c_h, 1, 1)
457
+
458
+ self.low_order_branch = nn.Sequential(
459
+ *[DS_C3k(self.c_l, self.c_l, n=1, k=3, e=1.0) for _ in range(l)]
460
+ )
461
+
462
+ self.final_fuse = Conv(self.c_h + self.c_l + self.c_s, out_channels, 1, 1)
463
+
464
+ def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
465
+ B2, B3, B4, B5 = x
466
+
467
+ B, _, H4, W4 = B4.shape
468
+
469
+ B2_resized = F.interpolate(
470
+ B2, size=(H4, W4), mode="bilinear", align_corners=False
471
+ )
472
+ B3_resized = F.interpolate(
473
+ B3, size=(H4, W4), mode="bilinear", align_corners=False
474
+ )
475
+ B5_resized = F.interpolate(
476
+ B5, size=(H4, W4), mode="bilinear", align_corners=False
477
+ )
478
+
479
+ x_b = self.fuse_conv(torch.cat((B2_resized, B3_resized, B4, B5_resized), dim=1))
480
+
481
+ x_h, x_l, x_s = torch.split(x_b, [self.c_h, self.c_l, self.c_s], dim=1)
482
+
483
+ x_h_outs = [m(x_h) for m in self.high_order_branch]
484
+ x_h_fused = self.high_order_fuse(torch.cat(x_h_outs, dim=1))
485
+
486
+ x_l_out = self.low_order_branch(x_l)
487
+
488
+ y = self.final_fuse(torch.cat((x_h_fused, x_l_out, x_s), dim=1))
489
+
490
+ return y
491
+
492
+
493
+ class GatedFusion(nn.Module):
494
+ def __init__(self, in_channels):
495
+ super().__init__()
496
+ self.gamma = nn.Parameter(torch.zeros(1, in_channels, 1, 1))
497
+
498
+ def forward(self, f_in, h):
499
+ if f_in.shape[1] != h.shape[1]:
500
+ raise ValueError(f"Channel mismatch: f_in={f_in.shape}, h={h.shape}")
501
+ return f_in + self.gamma * h
502
+
503
+
504
+ class Backbone(nn.Module):
505
+ def __init__(self, in_channels=256, base_channels=64, base_depth=3):
506
+ super().__init__()
507
+ c = base_channels
508
+ c2 = base_channels
509
+ c3 = 256
510
+ c4 = 384
511
+ c5 = 512
512
+ c6 = 768
513
+
514
+ self.stem = DSConv(in_channels, c2, k=3, s=(2, 1), p=1)
515
+
516
+ self.p2 = nn.Sequential(
517
+ DSConv(c2, c3, k=3, s=(2, 1), p=1), DS_C3k2(c3, c3, n=base_depth)
518
+ )
519
+
520
+ self.p3 = nn.Sequential(
521
+ DSConv(c3, c4, k=3, s=(2, 1), p=1), DS_C3k2(c4, c4, n=base_depth * 2)
522
+ )
523
+
524
+ self.p4 = nn.Sequential(
525
+ DSConv(c4, c5, k=3, s=(2, 1), p=1), DS_C3k2(c5, c5, n=base_depth * 2)
526
+ )
527
+
528
+ self.p5 = nn.Sequential(
529
+ DSConv(c5, c6, k=3, s=(2, 1), p=1), DS_C3k2(c6, c6, n=base_depth)
530
+ )
531
+
532
+ self.out_channels = [c3, c4, c5, c6]
533
+
534
+ def forward(self, x):
535
+ x = self.stem(x)
536
+ x2 = self.p2(x)
537
+ x3 = self.p3(x2)
538
+ x4 = self.p4(x3)
539
+ x5 = self.p5(x4)
540
+ return [x2, x3, x4, x5]
541
+
542
+
543
+ class Decoder(nn.Module):
544
+ def __init__(
545
+ self,
546
+ encoder_channels: List[int],
547
+ hyperace_out_c: int,
548
+ decoder_channels: List[int],
549
+ ):
550
+ super().__init__()
551
+ c_p2, c_p3, c_p4, c_p5 = encoder_channels
552
+ c_d2, c_d3, c_d4, c_d5 = decoder_channels
553
+
554
+ self.h_to_d5 = Conv(hyperace_out_c, c_d5, 1, 1)
555
+ self.h_to_d4 = Conv(hyperace_out_c, c_d4, 1, 1)
556
+ self.h_to_d3 = Conv(hyperace_out_c, c_d3, 1, 1)
557
+ self.h_to_d2 = Conv(hyperace_out_c, c_d2, 1, 1)
558
+
559
+ self.fusion_d5 = GatedFusion(c_d5)
560
+ self.fusion_d4 = GatedFusion(c_d4)
561
+ self.fusion_d3 = GatedFusion(c_d3)
562
+ self.fusion_d2 = GatedFusion(c_d2)
563
+
564
+ self.skip_p5 = Conv(c_p5, c_d5, 1, 1)
565
+ self.skip_p4 = Conv(c_p4, c_d4, 1, 1)
566
+ self.skip_p3 = Conv(c_p3, c_d3, 1, 1)
567
+ self.skip_p2 = Conv(c_p2, c_d2, 1, 1)
568
+
569
+ self.up_d5 = DS_C3k2(c_d5, c_d4, n=1)
570
+ self.up_d4 = DS_C3k2(c_d4, c_d3, n=1)
571
+ self.up_d3 = DS_C3k2(c_d3, c_d2, n=1)
572
+
573
+ self.final_d2 = DS_C3k2(c_d2, c_d2, n=1)
574
+
575
+ def forward(self, enc_feats: List[torch.Tensor], h_ace: torch.Tensor):
576
+ p2, p3, p4, p5 = enc_feats
577
+
578
+ d5 = self.skip_p5(p5)
579
+ h_d5 = self.h_to_d5(F.interpolate(h_ace, size=d5.shape[2:], mode="bilinear"))
580
+ d5 = self.fusion_d5(d5, h_d5)
581
+
582
+ d5_up = F.interpolate(d5, size=p4.shape[2:], mode="bilinear")
583
+ d4_skip = self.skip_p4(p4)
584
+ d4 = self.up_d5(d5_up) + d4_skip
585
+
586
+ h_d4 = self.h_to_d4(F.interpolate(h_ace, size=d4.shape[2:], mode="bilinear"))
587
+ d4 = self.fusion_d4(d4, h_d4)
588
+
589
+ d4_up = F.interpolate(d4, size=p3.shape[2:], mode="bilinear")
590
+ d3_skip = self.skip_p3(p3)
591
+ d3 = self.up_d4(d4_up) + d3_skip
592
+
593
+ h_d3 = self.h_to_d3(F.interpolate(h_ace, size=d3.shape[2:], mode="bilinear"))
594
+ d3 = self.fusion_d3(d3, h_d3)
595
+
596
+ d3_up = F.interpolate(d3, size=p2.shape[2:], mode="bilinear")
597
+ d2_skip = self.skip_p2(p2)
598
+ d2 = self.up_d3(d3_up) + d2_skip
599
+
600
+ h_d2 = self.h_to_d2(F.interpolate(h_ace, size=d2.shape[2:], mode="bilinear"))
601
+ d2 = self.fusion_d2(d2, h_d2)
602
+
603
+ d2_final = self.final_d2(d2)
604
+
605
+ return d2_final
606
+
607
+
608
+ class FreqPixelShuffle(nn.Module):
609
+ def __init__(self, in_channels, out_channels, scale=2):
610
+ super().__init__()
611
+ self.scale = scale
612
+ self.conv = DSConv(in_channels, out_channels * scale, k=3, s=1, p=1)
613
+ self.act = nn.SiLU()
614
+
615
+ def forward(self, x):
616
+ x = self.conv(x)
617
+ B, C_r, H, W = x.shape
618
+ out_c = C_r // self.scale
619
+
620
+ x = x.view(B, out_c, self.scale, H, W)
621
+
622
+ x = x.permute(0, 1, 3, 4, 2).contiguous()
623
+ x = x.view(B, out_c, H, W * self.scale)
624
+
625
+ return x
626
+
627
+
628
+ class ProgressiveUpsampleHead(nn.Module):
629
+ def __init__(self, in_channels, out_channels, target_bins=1025):
630
+ super().__init__()
631
+ self.target_bins = target_bins
632
+
633
+ c = in_channels
634
+
635
+ self.block1 = FreqPixelShuffle(c, c, scale=2)
636
+ self.block2 = FreqPixelShuffle(c, c // 2, scale=2)
637
+ self.block3 = FreqPixelShuffle(c // 2, c // 2, scale=2)
638
+ self.block4 = FreqPixelShuffle(c // 2, c // 4, scale=2)
639
+
640
+ self.final_conv = nn.Conv2d(c // 4, out_channels, kernel_size=1, bias=False)
641
+
642
+ def forward(self, x):
643
+
644
+ x = self.block1(x)
645
+ x = self.block2(x)
646
+ x = self.block3(x)
647
+ x = self.block4(x)
648
+
649
+ if x.shape[-1] != self.target_bins:
650
+ x = F.interpolate(
651
+ x,
652
+ size=(x.shape[2], self.target_bins),
653
+ mode="bilinear",
654
+ align_corners=False,
655
+ )
656
+
657
+ x = self.final_conv(x)
658
+ return x
659
+
660
+
661
+ class SegmModel(nn.Module):
662
+ def __init__(
663
+ self,
664
+ in_bands=62,
665
+ in_dim=256,
666
+ out_bins=1025,
667
+ out_channels=4,
668
+ base_channels=64,
669
+ base_depth=2,
670
+ num_hyperedges=16,
671
+ num_heads=8,
672
+ ):
673
+ super().__init__()
674
+
675
+ self.backbone = Backbone(
676
+ in_channels=in_dim, base_channels=base_channels, base_depth=base_depth
677
+ )
678
+ enc_channels = self.backbone.out_channels
679
+ c2, c3, c4, c5 = enc_channels
680
+
681
+ hyperace_in_channels = enc_channels
682
+ hyperace_out_channels = c4
683
+ self.hyperace = HyperACE(
684
+ hyperace_in_channels,
685
+ hyperace_out_channels,
686
+ num_hyperedges,
687
+ num_heads,
688
+ k=3,
689
+ l=2,
690
+ )
691
+
692
+ decoder_channels = [c2, c3, c4, c5]
693
+ self.decoder = Decoder(enc_channels, hyperace_out_channels, decoder_channels)
694
+
695
+ self.upsample_head = ProgressiveUpsampleHead(
696
+ in_channels=decoder_channels[0],
697
+ out_channels=out_channels,
698
+ target_bins=out_bins,
699
+ )
700
+
701
+ def forward(self, x):
702
+ H, W = x.shape[2:]
703
+
704
+ enc_feats = self.backbone(x)
705
+
706
+ h_ace_feats = self.hyperace(enc_feats)
707
+
708
+ dec_feat = self.decoder(enc_feats, h_ace_feats)
709
+
710
+ feat_time_restored = F.interpolate(
711
+ dec_feat, size=(H, dec_feat.shape[-1]), mode="bilinear", align_corners=False
712
+ )
713
+
714
+ out = self.upsample_head(feat_time_restored)
715
+
716
+ return out
717
+
718
+
719
+ def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
720
+ dim_hidden = default(dim_hidden, dim_in)
721
+
722
+ net = []
723
+ dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
724
+
725
+ for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
726
+ is_last = ind == (len(dims) - 2)
727
+
728
+ net.append(nn.Linear(layer_dim_in, layer_dim_out))
729
+
730
+ if is_last:
731
+ continue
732
+
733
+ net.append(activation())
734
+
735
+ return nn.Sequential(*net)
736
+
737
+
738
+ class MaskEstimator(Module):
739
+ @beartype
740
+ def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
741
+ super().__init__()
742
+ self.dim_inputs = dim_inputs
743
+ self.to_freqs = ModuleList([])
744
+ dim_hidden = dim * mlp_expansion_factor
745
+
746
+ for dim_in in dim_inputs:
747
+ net = []
748
+
749
+ mlp = nn.Sequential(
750
+ MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1)
751
+ )
752
+
753
+ self.to_freqs.append(mlp)
754
+
755
+ self.segm = SegmModel(
756
+ in_bands=len(dim_inputs), in_dim=dim, out_bins=sum(dim_inputs) // 4
757
+ )
758
+
759
+ def forward(self, x):
760
+ y = rearrange(x, "b t f c -> b c t f")
761
+ y = self.segm(y)
762
+ y = rearrange(y, "b c t f -> b t (f c)")
763
+
764
+ x = x.unbind(dim=-2)
765
+
766
+ outs = []
767
+
768
+ for band_features, mlp in zip(x, self.to_freqs):
769
+ freq_out = mlp(band_features)
770
+ outs.append(freq_out)
771
+
772
+ return torch.cat(outs, dim=-1) + y
773
+
774
+
775
+ DEFAULT_FREQS_PER_BANDS = (
776
+ 2,
777
+ 2,
778
+ 2,
779
+ 2,
780
+ 2,
781
+ 2,
782
+ 2,
783
+ 2,
784
+ 2,
785
+ 2,
786
+ 2,
787
+ 2,
788
+ 2,
789
+ 2,
790
+ 2,
791
+ 2,
792
+ 2,
793
+ 2,
794
+ 2,
795
+ 2,
796
+ 2,
797
+ 2,
798
+ 2,
799
+ 2,
800
+ 4,
801
+ 4,
802
+ 4,
803
+ 4,
804
+ 4,
805
+ 4,
806
+ 4,
807
+ 4,
808
+ 4,
809
+ 4,
810
+ 4,
811
+ 4,
812
+ 12,
813
+ 12,
814
+ 12,
815
+ 12,
816
+ 12,
817
+ 12,
818
+ 12,
819
+ 12,
820
+ 24,
821
+ 24,
822
+ 24,
823
+ 24,
824
+ 24,
825
+ 24,
826
+ 24,
827
+ 24,
828
+ 48,
829
+ 48,
830
+ 48,
831
+ 48,
832
+ 48,
833
+ 48,
834
+ 48,
835
+ 48,
836
+ 128,
837
+ 129,
838
+ )
839
+
840
+
841
+ class BSRoformerHyperACE(Module):
842
+
843
+ @beartype
844
+ def __init__(
845
+ self,
846
+ dim,
847
+ *,
848
+ depth,
849
+ stereo=False,
850
+ num_stems=1,
851
+ time_transformer_depth=2,
852
+ freq_transformer_depth=2,
853
+ linear_transformer_depth=0,
854
+ freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
855
+ dim_head=64,
856
+ heads=8,
857
+ attn_dropout=0.0,
858
+ ff_dropout=0.0,
859
+ flash_attn=True,
860
+ dim_freqs_in=1025,
861
+ stft_n_fft=2048,
862
+ stft_hop_length=512,
863
+ stft_win_length=2048,
864
+ stft_normalized=False,
865
+ stft_window_fn: Optional[Callable] = None,
866
+ mask_estimator_depth=2,
867
+ multi_stft_resolution_loss_weight=1.0,
868
+ multi_stft_resolutions_window_sizes: Tuple[int, ...] = (
869
+ 4096,
870
+ 2048,
871
+ 1024,
872
+ 512,
873
+ 256,
874
+ ),
875
+ multi_stft_hop_size=147,
876
+ multi_stft_normalized=False,
877
+ multi_stft_window_fn: Callable = torch.hann_window,
878
+ mlp_expansion_factor=4,
879
+ use_torch_checkpoint=False,
880
+ skip_connection=False,
881
+ sage_attention=False,
882
+ ):
883
+ super().__init__()
884
+
885
+ self.stereo = stereo
886
+ self.audio_channels = 2 if stereo else 1
887
+ self.num_stems = num_stems
888
+ self.use_torch_checkpoint = use_torch_checkpoint
889
+ self.skip_connection = skip_connection
890
+
891
+ self.layers = ModuleList([])
892
+
893
+ if sage_attention:
894
+ print("Use Sage Attention")
895
+
896
+ transformer_kwargs = dict(
897
+ dim=dim,
898
+ heads=heads,
899
+ dim_head=dim_head,
900
+ attn_dropout=attn_dropout,
901
+ ff_dropout=ff_dropout,
902
+ flash_attn=flash_attn,
903
+ norm_output=False,
904
+ sage_attention=sage_attention,
905
+ )
906
+
907
+ time_rotary_embed = RotaryEmbedding(dim=dim_head)
908
+ freq_rotary_embed = RotaryEmbedding(dim=dim_head)
909
+
910
+ for _ in range(depth):
911
+ tran_modules = []
912
+ tran_modules.append(
913
+ Transformer(
914
+ depth=time_transformer_depth,
915
+ rotary_embed=time_rotary_embed,
916
+ **transformer_kwargs,
917
+ )
918
+ )
919
+ tran_modules.append(
920
+ Transformer(
921
+ depth=freq_transformer_depth,
922
+ rotary_embed=freq_rotary_embed,
923
+ **transformer_kwargs,
924
+ )
925
+ )
926
+ self.layers.append(nn.ModuleList(tran_modules))
927
+
928
+ self.final_norm = RMSNorm(dim)
929
+
930
+ self.stft_kwargs = dict(
931
+ n_fft=stft_n_fft,
932
+ hop_length=stft_hop_length,
933
+ win_length=stft_win_length,
934
+ normalized=stft_normalized,
935
+ )
936
+
937
+ self.stft_window_fn = partial(
938
+ default(stft_window_fn, torch.hann_window), stft_win_length
939
+ )
940
+
941
+ freqs = torch.stft(
942
+ torch.randn(1, 4096),
943
+ **self.stft_kwargs,
944
+ window=torch.ones(stft_win_length),
945
+ return_complex=True,
946
+ ).shape[1]
947
+
948
+ assert len(freqs_per_bands) > 1
949
+ assert (
950
+ sum(freqs_per_bands) == freqs
951
+ ), f"the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}"
952
+
953
+ freqs_per_bands_with_complex = tuple(
954
+ 2 * f * self.audio_channels for f in freqs_per_bands
955
+ )
956
+
957
+ self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
958
+
959
+ self.mask_estimators = nn.ModuleList([])
960
+
961
+ for _ in range(num_stems):
962
+ mask_estimator = MaskEstimator(
963
+ dim=dim,
964
+ dim_inputs=freqs_per_bands_with_complex,
965
+ depth=mask_estimator_depth,
966
+ mlp_expansion_factor=mlp_expansion_factor,
967
+ )
968
+
969
+ self.mask_estimators.append(mask_estimator)
970
+
971
+ self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
972
+ self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
973
+ self.multi_stft_n_fft = stft_n_fft
974
+ self.multi_stft_window_fn = multi_stft_window_fn
975
+
976
+ self.multi_stft_kwargs = dict(
977
+ hop_length=multi_stft_hop_size, normalized=multi_stft_normalized
978
+ )
979
+
980
+ def forward(self, raw_audio, target=None, return_loss_breakdown=False):
981
+
982
+ device = raw_audio.device
983
+
984
+ x_is_mps = True if device.type == "mps" else False
985
+
986
+ if raw_audio.ndim == 2:
987
+ raw_audio = rearrange(raw_audio, "b t -> b 1 t")
988
+
989
+ channels = raw_audio.shape[1]
990
+ assert (not self.stereo and channels == 1) or (
991
+ self.stereo and channels == 2
992
+ ), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
993
+
994
+ raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
995
+
996
+ stft_window = self.stft_window_fn(device=device)
997
+
998
+ try:
999
+ stft_repr = torch.stft(
1000
+ raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True
1001
+ )
1002
+ except:
1003
+ stft_repr = torch.stft(
1004
+ raw_audio.cpu() if x_is_mps else raw_audio,
1005
+ **self.stft_kwargs,
1006
+ window=stft_window.cpu() if x_is_mps else stft_window,
1007
+ return_complex=True,
1008
+ ).to(device)
1009
+ stft_repr = torch.view_as_real(stft_repr)
1010
+
1011
+ stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
1012
+
1013
+ stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c")
1014
+
1015
+ x = rearrange(stft_repr, "b f t c -> b t (f c)")
1016
+
1017
+ x = self.band_split(x)
1018
+
1019
+ for i, transformer_block in enumerate(self.layers):
1020
+
1021
+ time_transformer, freq_transformer = transformer_block
1022
+
1023
+ x = rearrange(x, "b t f d -> b f t d")
1024
+ x, ps = pack([x], "* t d")
1025
+
1026
+ x = time_transformer(x)
1027
+
1028
+ (x,) = unpack(x, ps, "* t d")
1029
+ x = rearrange(x, "b f t d -> b t f d")
1030
+ x, ps = pack([x], "* f d")
1031
+
1032
+ x = freq_transformer(x)
1033
+
1034
+ (x,) = unpack(x, ps, "* f d")
1035
+
1036
+ x = self.final_norm(x)
1037
+
1038
+ num_stems = len(self.mask_estimators)
1039
+
1040
+ mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
1041
+ mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2)
1042
+
1043
+ stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
1044
+
1045
+ stft_repr = torch.view_as_complex(stft_repr)
1046
+ mask = torch.view_as_complex(mask)
1047
+
1048
+ stft_repr = stft_repr * mask
1049
+
1050
+ stft_repr = rearrange(
1051
+ stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels
1052
+ )
1053
+
1054
+ try:
1055
+ recon_audio = torch.istft(
1056
+ stft_repr,
1057
+ **self.stft_kwargs,
1058
+ window=stft_window,
1059
+ return_complex=False,
1060
+ length=raw_audio.shape[-1],
1061
+ )
1062
+ except:
1063
+ recon_audio = torch.istft(
1064
+ stft_repr.cpu() if x_is_mps else stft_repr,
1065
+ **self.stft_kwargs,
1066
+ window=stft_window.cpu() if x_is_mps else stft_window,
1067
+ return_complex=False,
1068
+ length=raw_audio.shape[-1],
1069
+ ).to(device)
1070
+
1071
+ recon_audio = rearrange(
1072
+ recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=num_stems
1073
+ )
1074
+
1075
+ if num_stems == 1:
1076
+ recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
1077
+
1078
+ if not exists(target):
1079
+ return recon_audio
1080
+
1081
+ if self.num_stems > 1:
1082
+ assert target.ndim == 4 and target.shape[1] == self.num_stems
1083
+
1084
+ if target.ndim == 2:
1085
+ target = rearrange(target, "... t -> ... 1 t")
1086
+
1087
+ target = target[..., : recon_audio.shape[-1]]
1088
+
1089
+ loss = F.l1_loss(recon_audio, target)
1090
+
1091
+ multi_stft_resolution_loss = 0.0
1092
+
1093
+ for window_size in self.multi_stft_resolutions_window_sizes:
1094
+ res_stft_kwargs = dict(
1095
+ n_fft=max(window_size, self.multi_stft_n_fft),
1096
+ win_length=window_size,
1097
+ return_complex=True,
1098
+ window=self.multi_stft_window_fn(window_size, device=device),
1099
+ **self.multi_stft_kwargs,
1100
+ )
1101
+
1102
+ recon_Y = torch.stft(
1103
+ rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs
1104
+ )
1105
+ target_Y = torch.stft(
1106
+ rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs
1107
+ )
1108
+
1109
+ multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(
1110
+ recon_Y, target_Y
1111
+ )
1112
+
1113
+ weighted_multi_resolution_loss = (
1114
+ multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
1115
+ )
1116
+
1117
+ total_loss = loss + weighted_multi_resolution_loss
1118
+
1119
+ if not return_loss_breakdown:
1120
+ return total_loss
1121
+
1122
+ return total_loss, (loss, multi_stft_resolution_loss)