Commit ·
8a06b33
1
Parent(s): e4bb613
Code actualized
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- mvsepless/additional_app.py +15 -4
- mvsepless/app.py +11 -13
- mvsepless/i18n.py +4 -4
- mvsepless/infer_utils.py +18 -4
- mvsepless/install.py +6 -15
- mvsepless/models/bandit/core/__init__.py +669 -669
- mvsepless/models/bandit/core/data/__init__.py +2 -2
- mvsepless/models/bandit/core/data/_types.py +17 -17
- mvsepless/models/bandit/core/data/augmentation.py +102 -102
- mvsepless/models/bandit/core/data/augmented.py +34 -34
- mvsepless/models/bandit/core/data/base.py +60 -60
- mvsepless/models/bandit/core/data/dnr/datamodule.py +64 -64
- mvsepless/models/bandit/core/data/dnr/dataset.py +360 -360
- mvsepless/models/bandit/core/data/dnr/preprocess.py +51 -51
- mvsepless/models/bandit/core/data/musdb/datamodule.py +75 -75
- mvsepless/models/bandit/core/data/musdb/dataset.py +241 -241
- mvsepless/models/bandit/core/data/musdb/preprocess.py +223 -223
- mvsepless/models/bandit/core/data/musdb/validation.yaml +14 -14
- mvsepless/models/bandit/core/loss/__init__.py +8 -8
- mvsepless/models/bandit/core/loss/_complex.py +27 -27
- mvsepless/models/bandit/core/loss/_multistem.py +43 -43
- mvsepless/models/bandit/core/loss/_timefreq.py +94 -94
- mvsepless/models/bandit/core/loss/snr.py +131 -131
- mvsepless/models/bandit/core/metrics/__init__.py +7 -7
- mvsepless/models/bandit/core/metrics/_squim.py +350 -350
- mvsepless/models/bandit/core/metrics/snr.py +124 -124
- mvsepless/models/bandit/core/model/__init__.py +3 -3
- mvsepless/models/bandit/core/model/_spectral.py +54 -54
- mvsepless/models/bandit/core/model/bsrnn/__init__.py +23 -23
- mvsepless/models/bandit/core/model/bsrnn/bandsplit.py +119 -119
- mvsepless/models/bandit/core/model/bsrnn/core.py +619 -619
- mvsepless/models/bandit/core/model/bsrnn/maskestim.py +327 -327
- mvsepless/models/bandit/core/model/bsrnn/tfmodel.py +287 -287
- mvsepless/models/bandit/core/model/bsrnn/utils.py +518 -518
- mvsepless/models/bandit/core/model/bsrnn/wrapper.py +828 -828
- mvsepless/models/bandit/core/utils/audio.py +324 -324
- mvsepless/models/bandit/model_from_config.py +26 -26
- mvsepless/models/bandit_v2/bandit.py +360 -360
- mvsepless/models/bandit_v2/bandsplit.py +127 -127
- mvsepless/models/bandit_v2/film.py +23 -23
- mvsepless/models/bandit_v2/maskestim.py +269 -269
- mvsepless/models/bandit_v2/tfmodel.py +141 -141
- mvsepless/models/bandit_v2/utils.py +384 -384
- mvsepless/models/bs_roformer/__init__.py +14 -14
- mvsepless/models/bs_roformer/attend.py +127 -127
- mvsepless/models/bs_roformer/attend_sage.py +146 -146
- mvsepless/models/bs_roformer/attend_sw.py +88 -88
- mvsepless/models/bs_roformer/bs_roformer.py +696 -696
- mvsepless/models/bs_roformer/bs_roformer_fno.py +704 -704
- mvsepless/models/bs_roformer/bs_roformer_hyperace.py +1122 -1122
mvsepless/additional_app.py
CHANGED
|
@@ -279,6 +279,7 @@ class CustomSeparator(Separator, GradioHelper):
|
|
| 279 |
use_spec_invert: bool = False,
|
| 280 |
econom_mode: Optional[bool] = None,
|
| 281 |
chunk_duration: float = 300,
|
|
|
|
| 282 |
progress: gr.Progress = gr.Progress(track_tqdm=True),
|
| 283 |
) -> List:
|
| 284 |
"""
|
|
@@ -295,7 +296,7 @@ class CustomSeparator(Separator, GradioHelper):
|
|
| 295 |
vr_aggr: Агрессивность для VR
|
| 296 |
vr_post_process: Постобработка для VR
|
| 297 |
vr_high_end_process: Обработка высоких частот для VR
|
| 298 |
-
|
| 299 |
use_spec_invert: Использовать инверсию спектрограммы
|
| 300 |
econom_mode: Эконом-режим
|
| 301 |
chunk_duration: Длительность чанка
|
|
@@ -309,7 +310,8 @@ class CustomSeparator(Separator, GradioHelper):
|
|
| 309 |
|
| 310 |
add_settings: Dict[str, Any] = {
|
| 311 |
"add_single_sep_text_progress": None,
|
| 312 |
-
"single_mode": False
|
|
|
|
| 313 |
}
|
| 314 |
|
| 315 |
if econom_mode is not None:
|
|
@@ -412,6 +414,11 @@ class CustomSeparator(Separator, GradioHelper):
|
|
| 412 |
value=False,
|
| 413 |
interactive=True
|
| 414 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
|
| 416 |
gr.Markdown(f"<h4>{_i18n('economy_settings')}</h4>", container=True)
|
| 417 |
econom_mode = gr.Checkbox(
|
|
@@ -497,7 +504,8 @@ class CustomSeparator(Separator, GradioHelper):
|
|
| 497 |
stems,
|
| 498 |
use_spec_for_extract_instrumental,
|
| 499 |
econom_mode,
|
| 500 |
-
chunk_dur_slider
|
|
|
|
| 501 |
],
|
| 502 |
outputs=[sep_state, status],
|
| 503 |
show_progress="full",
|
|
@@ -514,6 +522,7 @@ class CustomSeparator(Separator, GradioHelper):
|
|
| 514 |
u_spec: bool,
|
| 515 |
ec_mode: bool,
|
| 516 |
ch_dur: float,
|
|
|
|
| 517 |
progress: gr.Progress = gr.Progress(track_tqdm=True),
|
| 518 |
) -> Tuple[gr.update, gr.update]:
|
| 519 |
results = self._separate_batch(
|
|
@@ -527,6 +536,7 @@ class CustomSeparator(Separator, GradioHelper):
|
|
| 527 |
u_spec,
|
| 528 |
ec_mode,
|
| 529 |
ch_dur * 60,
|
|
|
|
| 530 |
progress=progress,
|
| 531 |
)
|
| 532 |
return gr.update(value=str(results)), gr.update(visible=False)
|
|
@@ -801,10 +811,11 @@ class CustomSeparator(Separator, GradioHelper):
|
|
| 801 |
def refresh_all_models() -> Tuple[gr.update, gr.update, gr.update]:
|
| 802 |
models = self.get_mn()
|
| 803 |
first_model = models[0] if models else None
|
|
|
|
| 804 |
return (
|
| 805 |
gr.update(choices=models, value=first_model),
|
| 806 |
gr.update(choices=models, value=first_model),
|
| 807 |
-
gr.update(choices=models, value=
|
| 808 |
)
|
| 809 |
|
| 810 |
class AutoEnsembless(Separator, GradioHelper):
|
|
|
|
| 279 |
use_spec_invert: bool = False,
|
| 280 |
econom_mode: Optional[bool] = None,
|
| 281 |
chunk_duration: float = 300,
|
| 282 |
+
denoise: bool = False,
|
| 283 |
progress: gr.Progress = gr.Progress(track_tqdm=True),
|
| 284 |
) -> List:
|
| 285 |
"""
|
|
|
|
| 296 |
vr_aggr: Агрессивность для VR
|
| 297 |
vr_post_process: Постобработка для VR
|
| 298 |
vr_high_end_process: Обработка высоких частот для VR
|
| 299 |
+
denoise: Шумоподавление для MDX
|
| 300 |
use_spec_invert: Использовать инверсию спектрограммы
|
| 301 |
econom_mode: Эконом-режим
|
| 302 |
chunk_duration: Длительность чанка
|
|
|
|
| 310 |
|
| 311 |
add_settings: Dict[str, Any] = {
|
| 312 |
"add_single_sep_text_progress": None,
|
| 313 |
+
"single_mode": False,
|
| 314 |
+
"denoise": denoise
|
| 315 |
}
|
| 316 |
|
| 317 |
if econom_mode is not None:
|
|
|
|
| 414 |
value=False,
|
| 415 |
interactive=True
|
| 416 |
)
|
| 417 |
+
denoise = gr.Checkbox(
|
| 418 |
+
label=_i18n("denoise"),
|
| 419 |
+
value=False,
|
| 420 |
+
interactive=True,
|
| 421 |
+
)
|
| 422 |
|
| 423 |
gr.Markdown(f"<h4>{_i18n('economy_settings')}</h4>", container=True)
|
| 424 |
econom_mode = gr.Checkbox(
|
|
|
|
| 504 |
stems,
|
| 505 |
use_spec_for_extract_instrumental,
|
| 506 |
econom_mode,
|
| 507 |
+
chunk_dur_slider,
|
| 508 |
+
denoise
|
| 509 |
],
|
| 510 |
outputs=[sep_state, status],
|
| 511 |
show_progress="full",
|
|
|
|
| 522 |
u_spec: bool,
|
| 523 |
ec_mode: bool,
|
| 524 |
ch_dur: float,
|
| 525 |
+
den: bool,
|
| 526 |
progress: gr.Progress = gr.Progress(track_tqdm=True),
|
| 527 |
) -> Tuple[gr.update, gr.update]:
|
| 528 |
results = self._separate_batch(
|
|
|
|
| 536 |
u_spec,
|
| 537 |
ec_mode,
|
| 538 |
ch_dur * 60,
|
| 539 |
+
den,
|
| 540 |
progress=progress,
|
| 541 |
)
|
| 542 |
return gr.update(value=str(results)), gr.update(visible=False)
|
|
|
|
| 811 |
def refresh_all_models() -> Tuple[gr.update, gr.update, gr.update]:
|
| 812 |
models = self.get_mn()
|
| 813 |
first_model = models[0] if models else None
|
| 814 |
+
first_model2 = [models[0]] if models else []
|
| 815 |
return (
|
| 816 |
gr.update(choices=models, value=first_model),
|
| 817 |
gr.update(choices=models, value=first_model),
|
| 818 |
+
gr.update(choices=models, value=first_model2),
|
| 819 |
)
|
| 820 |
|
| 821 |
class AutoEnsembless(Separator, GradioHelper):
|
mvsepless/app.py
CHANGED
|
@@ -278,7 +278,7 @@ class SeparatorGradio(GradioHelper, DownloadModelManager):
|
|
| 278 |
vr_aggr: int = 5,
|
| 279 |
vr_post_process: bool = False,
|
| 280 |
vr_high_end_process: bool = False,
|
| 281 |
-
|
| 282 |
use_spec_invert: bool = False,
|
| 283 |
econom_mode: Optional[bool] = None,
|
| 284 |
chunk_duration: float = 300,
|
|
@@ -298,7 +298,7 @@ class SeparatorGradio(GradioHelper, DownloadModelManager):
|
|
| 298 |
vr_aggr: Агрессивность для VR
|
| 299 |
vr_post_process: Постобработка для VR
|
| 300 |
vr_high_end_process: Обработка высоких частот для VR
|
| 301 |
-
|
| 302 |
use_spec_invert: Использовать инверсию спектрограммы
|
| 303 |
econom_mode: Эконом-режим
|
| 304 |
chunk_duration: Длительность чанка
|
|
@@ -311,7 +311,7 @@ class SeparatorGradio(GradioHelper, DownloadModelManager):
|
|
| 311 |
self.chunk_duration = chunk_duration
|
| 312 |
|
| 313 |
add_settings: Dict[str, Any] = {
|
| 314 |
-
"
|
| 315 |
"vr_aggr": vr_aggr,
|
| 316 |
"vr_post_process": vr_post_process,
|
| 317 |
"vr_high_end_process": vr_high_end_process,
|
|
@@ -487,19 +487,17 @@ class SeparatorGradio(GradioHelper, DownloadModelManager):
|
|
| 487 |
interactive=True
|
| 488 |
)
|
| 489 |
|
| 490 |
-
gr.Markdown(f"<h4>{_i18n('mdx_settings')}</h4>", container=True)
|
| 491 |
-
mdx_denoise = gr.Checkbox(
|
| 492 |
-
label=_i18n("mdx_denoise"),
|
| 493 |
-
value=False,
|
| 494 |
-
interactive=True,
|
| 495 |
-
)
|
| 496 |
-
|
| 497 |
gr.Markdown(f"<h4>{_i18n('invert_settings')}</h4>", container=True)
|
| 498 |
use_spec_for_extract_instrumental = gr.Checkbox(
|
| 499 |
label=_i18n("use_spectrogram_invert"),
|
| 500 |
value=False,
|
| 501 |
interactive=True
|
| 502 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 503 |
|
| 504 |
gr.Markdown(f"<h4>{_i18n('economy_settings')}</h4>", container=True)
|
| 505 |
econom_mode = gr.Checkbox(
|
|
@@ -583,7 +581,7 @@ class SeparatorGradio(GradioHelper, DownloadModelManager):
|
|
| 583 |
output_bitrate,
|
| 584 |
template,
|
| 585 |
stems,
|
| 586 |
-
|
| 587 |
vr_aggr,
|
| 588 |
vr_enable_post_process,
|
| 589 |
vr_enable_high_end_process,
|
|
@@ -603,7 +601,7 @@ class SeparatorGradio(GradioHelper, DownloadModelManager):
|
|
| 603 |
output_bitrate: int,
|
| 604 |
template: str,
|
| 605 |
stems: List[str],
|
| 606 |
-
|
| 607 |
vr_aggr: int,
|
| 608 |
vr_pp: bool,
|
| 609 |
vr_hip: bool,
|
|
@@ -623,7 +621,7 @@ class SeparatorGradio(GradioHelper, DownloadModelManager):
|
|
| 623 |
vr_aggr,
|
| 624 |
vr_pp,
|
| 625 |
vr_hip,
|
| 626 |
-
|
| 627 |
u_spec,
|
| 628 |
ec_mode,
|
| 629 |
ch_dur * 60,
|
|
|
|
| 278 |
vr_aggr: int = 5,
|
| 279 |
vr_post_process: bool = False,
|
| 280 |
vr_high_end_process: bool = False,
|
| 281 |
+
denoise: bool = False,
|
| 282 |
use_spec_invert: bool = False,
|
| 283 |
econom_mode: Optional[bool] = None,
|
| 284 |
chunk_duration: float = 300,
|
|
|
|
| 298 |
vr_aggr: Агрессивность для VR
|
| 299 |
vr_post_process: Постобработка для VR
|
| 300 |
vr_high_end_process: Обработка высоких частот для VR
|
| 301 |
+
denoise: Шумоподавление для MDX
|
| 302 |
use_spec_invert: Использовать инверсию спектрограммы
|
| 303 |
econom_mode: Эконом-режим
|
| 304 |
chunk_duration: Длительность чанка
|
|
|
|
| 311 |
self.chunk_duration = chunk_duration
|
| 312 |
|
| 313 |
add_settings: Dict[str, Any] = {
|
| 314 |
+
"denoise": denoise,
|
| 315 |
"vr_aggr": vr_aggr,
|
| 316 |
"vr_post_process": vr_post_process,
|
| 317 |
"vr_high_end_process": vr_high_end_process,
|
|
|
|
| 487 |
interactive=True
|
| 488 |
)
|
| 489 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 490 |
gr.Markdown(f"<h4>{_i18n('invert_settings')}</h4>", container=True)
|
| 491 |
use_spec_for_extract_instrumental = gr.Checkbox(
|
| 492 |
label=_i18n("use_spectrogram_invert"),
|
| 493 |
value=False,
|
| 494 |
interactive=True
|
| 495 |
)
|
| 496 |
+
denoise = gr.Checkbox(
|
| 497 |
+
label=_i18n("denoise"),
|
| 498 |
+
value=False,
|
| 499 |
+
interactive=True,
|
| 500 |
+
)
|
| 501 |
|
| 502 |
gr.Markdown(f"<h4>{_i18n('economy_settings')}</h4>", container=True)
|
| 503 |
econom_mode = gr.Checkbox(
|
|
|
|
| 581 |
output_bitrate,
|
| 582 |
template,
|
| 583 |
stems,
|
| 584 |
+
denoise,
|
| 585 |
vr_aggr,
|
| 586 |
vr_enable_post_process,
|
| 587 |
vr_enable_high_end_process,
|
|
|
|
| 601 |
output_bitrate: int,
|
| 602 |
template: str,
|
| 603 |
stems: List[str],
|
| 604 |
+
denoise: bool,
|
| 605 |
vr_aggr: int,
|
| 606 |
vr_pp: bool,
|
| 607 |
vr_hip: bool,
|
|
|
|
| 621 |
vr_aggr,
|
| 622 |
vr_pp,
|
| 623 |
vr_hip,
|
| 624 |
+
denoise,
|
| 625 |
u_spec,
|
| 626 |
ec_mode,
|
| 627 |
ch_dur * 60,
|
mvsepless/i18n.py
CHANGED
|
@@ -128,7 +128,7 @@ TRANSLATIONS: Dict[Language, Dict[str, str]] = {
|
|
| 128 |
"vr_post_process": "Дополнительная обработка для улучшения качества разделения",
|
| 129 |
"vr_high_end_process": "Восстановление недостающих высоких частот",
|
| 130 |
"mdx_settings": "MDX-NET",
|
| 131 |
-
"
|
| 132 |
"invert_settings": "Инвертирование результата",
|
| 133 |
"use_spectrogram_invert": "При извлечении инструментала/второго стема/остатка использовать спектрограмму",
|
| 134 |
"economy_settings": "Экономия",
|
|
@@ -441,7 +441,7 @@ TRANSLATIONS: Dict[Language, Dict[str, str]] = {
|
|
| 441 |
"ext_inst_help": "Извлечь инструментал вычитанием",
|
| 442 |
"use_spec_invert_help": "Инверсия спектрограммы для вторичного стема",
|
| 443 |
"install_only_help": "Только установка модели",
|
| 444 |
-
"
|
| 445 |
"vr_aggression_help": "Агрессивность для VR моделей (по умолчанию: 5)",
|
| 446 |
"vr_high_end_help": "Восстановление недостающих высоких частот на VR моделях",
|
| 447 |
"vr_post_process_help": "Дополнительная обработка для улучшения качества разделения VR модели",
|
|
@@ -856,7 +856,7 @@ TRANSLATIONS: Dict[Language, Dict[str, str]] = {
|
|
| 856 |
"vr_post_process": "Additional post-processing to improve separation quality",
|
| 857 |
"vr_high_end_process": "Restore missing high frequencies",
|
| 858 |
"mdx_settings": "MDX-NET",
|
| 859 |
-
"
|
| 860 |
"invert_settings": "Result Inversion",
|
| 861 |
"use_spectrogram_invert": "Use spectrogram for extracting instrumental/second stem/remainder",
|
| 862 |
"economy_settings": "Economy",
|
|
@@ -1169,7 +1169,7 @@ TRANSLATIONS: Dict[Language, Dict[str, str]] = {
|
|
| 1169 |
"ext_inst_help": "Extract instrumental by subtraction",
|
| 1170 |
"use_spec_invert_help": "Use spectrogram inversion for secondary stem",
|
| 1171 |
"install_only_help": "Download model only",
|
| 1172 |
-
"
|
| 1173 |
"vr_aggression_help": "Aggression for VR models (default: 5)",
|
| 1174 |
"vr_high_end_help": "Restore missing high frequencies on VR models",
|
| 1175 |
"vr_post_process_help": "Additional post-processing for VR models",
|
|
|
|
| 128 |
"vr_post_process": "Дополнительная обработка для улучшения качества разделения",
|
| 129 |
"vr_high_end_process": "Восстановление недостающих высоких частот",
|
| 130 |
"mdx_settings": "MDX-NET",
|
| 131 |
+
"denoise": "Шумоподавление",
|
| 132 |
"invert_settings": "Инвертирование результата",
|
| 133 |
"use_spectrogram_invert": "При извлечении инструментала/второго стема/остатка использовать спектрограмму",
|
| 134 |
"economy_settings": "Экономия",
|
|
|
|
| 441 |
"ext_inst_help": "Извлечь инструментал вычитанием",
|
| 442 |
"use_spec_invert_help": "Инверсия спектрограммы для вторичного стема",
|
| 443 |
"install_only_help": "Только установка модели",
|
| 444 |
+
"denoise_help": "Удаление шума из вывода",
|
| 445 |
"vr_aggression_help": "Агрессивность для VR моделей (по умолчанию: 5)",
|
| 446 |
"vr_high_end_help": "Восстановление недостающих высоких частот на VR моделях",
|
| 447 |
"vr_post_process_help": "Дополнительная обработка для улучшения качества разделения VR модели",
|
|
|
|
| 856 |
"vr_post_process": "Additional post-processing to improve separation quality",
|
| 857 |
"vr_high_end_process": "Restore missing high frequencies",
|
| 858 |
"mdx_settings": "MDX-NET",
|
| 859 |
+
"denoise": "Denoise",
|
| 860 |
"invert_settings": "Result Inversion",
|
| 861 |
"use_spectrogram_invert": "Use spectrogram for extracting instrumental/second stem/remainder",
|
| 862 |
"economy_settings": "Economy",
|
|
|
|
| 1169 |
"ext_inst_help": "Extract instrumental by subtraction",
|
| 1170 |
"use_spec_invert_help": "Use spectrogram inversion for secondary stem",
|
| 1171 |
"install_only_help": "Download model only",
|
| 1172 |
+
"denoise_help": "Remove noise from output",
|
| 1173 |
"vr_aggression_help": "Aggression for VR models (default: 5)",
|
| 1174 |
"vr_high_end_help": "Restore missing high frequencies on VR models",
|
| 1175 |
"vr_post_process_help": "Additional post-processing for VR models",
|
mvsepless/infer_utils.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import sys
|
| 2 |
sys.stdout.reconfigure(encoding='utf-8')
|
| 3 |
sys.stderr.reconfigure(encoding='utf-8')
|
|
@@ -70,7 +71,9 @@ def get_model_from_config(model_type: str, config_path: str) -> Tuple[Any, Any]:
|
|
| 70 |
from models.vr_arch import VRNet
|
| 71 |
model = VRNet(**dict(config.model))
|
| 72 |
elif model_type == "htdemucs":
|
| 73 |
-
|
|
|
|
|
|
|
| 74 |
model = get_model(config)
|
| 75 |
elif model_type == "mel_band_roformer":
|
| 76 |
if hasattr(config, "windowed"):
|
|
@@ -419,6 +422,7 @@ def demix_demucs(
|
|
| 419 |
mix = torch.tensor(mix, dtype=torch.float32)
|
| 420 |
chunk_size = config.training.samplerate * config.training.segment
|
| 421 |
num_instruments = len(config.training.instruments)
|
|
|
|
| 422 |
num_overlap = config.inference.num_overlap
|
| 423 |
step = chunk_size // num_overlap
|
| 424 |
fade_size = chunk_size // 10
|
|
@@ -451,8 +455,12 @@ def demix_demucs(
|
|
| 451 |
|
| 452 |
if len(batch_data) >= batch_size or i >= mix.shape[1]:
|
| 453 |
arr = torch.stack(batch_data, dim=0)
|
| 454 |
-
|
| 455 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 456 |
window = windowing_array.clone()
|
| 457 |
if i - step == 0:
|
| 458 |
window[:fade_size] = 1
|
|
@@ -511,6 +519,7 @@ def demix_generic(
|
|
| 511 |
chunk_size = config.audio.chunk_size
|
| 512 |
instruments = prefer_target_instrument(config)
|
| 513 |
num_instruments = len(instruments)
|
|
|
|
| 514 |
num_overlap = config.inference.num_overlap
|
| 515 |
|
| 516 |
fade_size = chunk_size // 10
|
|
@@ -550,7 +559,12 @@ def demix_generic(
|
|
| 550 |
|
| 551 |
if len(batch_data) >= batch_size or i >= mix.shape[1]:
|
| 552 |
arr = torch.stack(batch_data, dim=0)
|
| 553 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 554 |
|
| 555 |
window = windowing_array.clone()
|
| 556 |
if i - step == 0:
|
|
|
|
| 1 |
+
import os
|
| 2 |
import sys
|
| 3 |
sys.stdout.reconfigure(encoding='utf-8')
|
| 4 |
sys.stderr.reconfigure(encoding='utf-8')
|
|
|
|
| 71 |
from models.vr_arch import VRNet
|
| 72 |
model = VRNet(**dict(config.model))
|
| 73 |
elif model_type == "htdemucs":
|
| 74 |
+
models_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models')
|
| 75 |
+
sys.path.append(models_path)
|
| 76 |
+
from demucs import get_model
|
| 77 |
model = get_model(config)
|
| 78 |
elif model_type == "mel_band_roformer":
|
| 79 |
if hasattr(config, "windowed"):
|
|
|
|
| 422 |
mix = torch.tensor(mix, dtype=torch.float32)
|
| 423 |
chunk_size = config.training.samplerate * config.training.segment
|
| 424 |
num_instruments = len(config.training.instruments)
|
| 425 |
+
denoise = config.inference.denoise
|
| 426 |
num_overlap = config.inference.num_overlap
|
| 427 |
step = chunk_size // num_overlap
|
| 428 |
fade_size = chunk_size // 10
|
|
|
|
| 455 |
|
| 456 |
if len(batch_data) >= batch_size or i >= mix.shape[1]:
|
| 457 |
arr = torch.stack(batch_data, dim=0)
|
| 458 |
+
if denoise:
|
| 459 |
+
x1 = model(arr)
|
| 460 |
+
x2 = model(-arr)
|
| 461 |
+
x = (x1 + -x2) * 0.5
|
| 462 |
+
else:
|
| 463 |
+
x = model(arr)
|
| 464 |
window = windowing_array.clone()
|
| 465 |
if i - step == 0:
|
| 466 |
window[:fade_size] = 1
|
|
|
|
| 519 |
chunk_size = config.audio.chunk_size
|
| 520 |
instruments = prefer_target_instrument(config)
|
| 521 |
num_instruments = len(instruments)
|
| 522 |
+
denoise = config.inference.denoise
|
| 523 |
num_overlap = config.inference.num_overlap
|
| 524 |
|
| 525 |
fade_size = chunk_size // 10
|
|
|
|
| 559 |
|
| 560 |
if len(batch_data) >= batch_size or i >= mix.shape[1]:
|
| 561 |
arr = torch.stack(batch_data, dim=0)
|
| 562 |
+
if denoise:
|
| 563 |
+
x1 = model(arr)
|
| 564 |
+
x2 = model(-arr)
|
| 565 |
+
x = (x1 + -x2) * 0.5
|
| 566 |
+
else:
|
| 567 |
+
x = model(arr)
|
| 568 |
|
| 569 |
window = windowing_array.clone()
|
| 570 |
if i - step == 0:
|
mvsepless/install.py
CHANGED
|
@@ -221,33 +221,28 @@ universal_requirements: List[str] = [
|
|
| 221 |
"matplotlib",
|
| 222 |
"tqdm",
|
| 223 |
"einops",
|
| 224 |
-
"protobuf",
|
| 225 |
"soundfile",
|
| 226 |
"pydub",
|
| 227 |
"webrtcvad",
|
| 228 |
"audiomentations",
|
| 229 |
"pedalboard",
|
| 230 |
"ml_collections",
|
| 231 |
-
"timm",
|
| 232 |
"wandb",
|
| 233 |
-
"accelerate",
|
| 234 |
"bitsandbytes",
|
| 235 |
"tokenizers",
|
| 236 |
"huggingface-hub",
|
| 237 |
"transformers",
|
| 238 |
-
"
|
| 239 |
-
"
|
|
|
|
| 240 |
"asteroid>=0.6.0",
|
| 241 |
"pyloudnorm",
|
| 242 |
-
"prodigyopt",
|
| 243 |
-
"torch_log_wmse",
|
| 244 |
"rotary_embedding_torch",
|
| 245 |
"gradio<6.0",
|
| 246 |
"omegaconf",
|
| 247 |
"beartype",
|
| 248 |
"spafe",
|
| 249 |
"torch_audiomentations",
|
| 250 |
-
"auraloss",
|
| 251 |
"onnx>=1.17",
|
| 252 |
"onnx2torch>=0.3.0",
|
| 253 |
"onnxruntime-gpu>=1.17" if cuda_available else "onnxruntime>=1.17",
|
|
@@ -279,32 +274,28 @@ old_requirements: List[str] = [
|
|
| 279 |
"matplotlib==3.10.8",
|
| 280 |
"tqdm==4.67.1",
|
| 281 |
"einops==0.8.1",
|
| 282 |
-
"protobuf==6.33.4",
|
| 283 |
"soundfile==0.13.1",
|
| 284 |
"pydub==0.25.1",
|
| 285 |
"webrtcvad==2.0.10",
|
| 286 |
"audiomentations==0.43.1",
|
| 287 |
"pedalboard==0.8.2",
|
| 288 |
"ml_collections==1.1.0",
|
| 289 |
-
"timm==1.0.24",
|
| 290 |
"wandb==0.24.0",
|
| 291 |
-
"accelerate==1.2.1",
|
| 292 |
"bitsandbytes==0.45.0",
|
| 293 |
"tokenizers==0.15.2",
|
| 294 |
"huggingface-hub==0.34.2",
|
| 295 |
"transformers==4.39.3",
|
| 296 |
-
"
|
| 297 |
-
"
|
|
|
|
| 298 |
"asteroid==0.6.0",
|
| 299 |
"pyloudnorm",
|
| 300 |
-
"prodigyopt==1.1.2",
|
| 301 |
"rotary_embedding_torch==0.3.6",
|
| 302 |
"gradio<6.0.0",
|
| 303 |
"omegaconf==2.3.0",
|
| 304 |
"beartype==0.22.9",
|
| 305 |
"spafe==0.3.3",
|
| 306 |
"torch_audiomentations==0.12.0",
|
| 307 |
-
"auraloss==0.4.0",
|
| 308 |
"onnx>=1.17",
|
| 309 |
"onnx2torch>=0.3.0",
|
| 310 |
"onnxruntime-gpu>=1.17" if cuda_available else "onnxruntime>=1.17",
|
|
|
|
| 221 |
"matplotlib",
|
| 222 |
"tqdm",
|
| 223 |
"einops",
|
|
|
|
| 224 |
"soundfile",
|
| 225 |
"pydub",
|
| 226 |
"webrtcvad",
|
| 227 |
"audiomentations",
|
| 228 |
"pedalboard",
|
| 229 |
"ml_collections",
|
|
|
|
| 230 |
"wandb",
|
|
|
|
| 231 |
"bitsandbytes",
|
| 232 |
"tokenizers",
|
| 233 |
"huggingface-hub",
|
| 234 |
"transformers",
|
| 235 |
+
"diffq>=0.2.1",
|
| 236 |
+
"julius>=0.2.3",
|
| 237 |
+
"openunmix",
|
| 238 |
"asteroid>=0.6.0",
|
| 239 |
"pyloudnorm",
|
|
|
|
|
|
|
| 240 |
"rotary_embedding_torch",
|
| 241 |
"gradio<6.0",
|
| 242 |
"omegaconf",
|
| 243 |
"beartype",
|
| 244 |
"spafe",
|
| 245 |
"torch_audiomentations",
|
|
|
|
| 246 |
"onnx>=1.17",
|
| 247 |
"onnx2torch>=0.3.0",
|
| 248 |
"onnxruntime-gpu>=1.17" if cuda_available else "onnxruntime>=1.17",
|
|
|
|
| 274 |
"matplotlib==3.10.8",
|
| 275 |
"tqdm==4.67.1",
|
| 276 |
"einops==0.8.1",
|
|
|
|
| 277 |
"soundfile==0.13.1",
|
| 278 |
"pydub==0.25.1",
|
| 279 |
"webrtcvad==2.0.10",
|
| 280 |
"audiomentations==0.43.1",
|
| 281 |
"pedalboard==0.8.2",
|
| 282 |
"ml_collections==1.1.0",
|
|
|
|
| 283 |
"wandb==0.24.0",
|
|
|
|
| 284 |
"bitsandbytes==0.45.0",
|
| 285 |
"tokenizers==0.15.2",
|
| 286 |
"huggingface-hub==0.34.2",
|
| 287 |
"transformers==4.39.3",
|
| 288 |
+
"diffq>=0.2.1",
|
| 289 |
+
"julius>=0.2.3",
|
| 290 |
+
"openunmix",
|
| 291 |
"asteroid==0.6.0",
|
| 292 |
"pyloudnorm",
|
|
|
|
| 293 |
"rotary_embedding_torch==0.3.6",
|
| 294 |
"gradio<6.0.0",
|
| 295 |
"omegaconf==2.3.0",
|
| 296 |
"beartype==0.22.9",
|
| 297 |
"spafe==0.3.3",
|
| 298 |
"torch_audiomentations==0.12.0",
|
|
|
|
| 299 |
"onnx>=1.17",
|
| 300 |
"onnx2torch>=0.3.0",
|
| 301 |
"onnxruntime-gpu>=1.17" if cuda_available else "onnxruntime>=1.17",
|
mvsepless/models/bandit/core/__init__.py
CHANGED
|
@@ -1,669 +1,669 @@
|
|
| 1 |
-
import os.path
|
| 2 |
-
from collections import defaultdict
|
| 3 |
-
from itertools import chain, combinations
|
| 4 |
-
from typing import Any, Dict, Iterator, Mapping, Optional, Tuple, Type, TypedDict
|
| 5 |
-
|
| 6 |
-
import pytorch_lightning as pl
|
| 7 |
-
import torch
|
| 8 |
-
import torchaudio as ta
|
| 9 |
-
import torchmetrics as tm
|
| 10 |
-
from asteroid import losses as asteroid_losses
|
| 11 |
-
|
| 12 |
-
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
| 13 |
-
from torch import nn, optim
|
| 14 |
-
from torch.optim import lr_scheduler
|
| 15 |
-
from torch.optim.lr_scheduler import LRScheduler
|
| 16 |
-
|
| 17 |
-
from . import loss, metrics as metrics_, model
|
| 18 |
-
from .data._types import BatchedDataDict
|
| 19 |
-
from .data.augmentation import BaseAugmentor, StemAugmentor
|
| 20 |
-
from .utils import audio as audio_
|
| 21 |
-
from .utils.audio import BaseFader
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
ConfigDict = TypedDict("ConfigDict", {"name": str, "kwargs": Dict[str, Any]})
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
class SchedulerConfigDict(ConfigDict):
|
| 28 |
-
monitor: str
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
OptimizerSchedulerConfigDict = TypedDict(
|
| 32 |
-
"OptimizerSchedulerConfigDict",
|
| 33 |
-
{"optimizer": ConfigDict, "scheduler": SchedulerConfigDict},
|
| 34 |
-
total=False,
|
| 35 |
-
)
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
class LRSchedulerReturnDict(TypedDict, total=False):
|
| 39 |
-
scheduler: LRScheduler
|
| 40 |
-
monitor: str
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
class ConfigureOptimizerReturnDict(TypedDict, total=False):
|
| 44 |
-
optimizer: torch.optim.Optimizer
|
| 45 |
-
lr_scheduler: LRSchedulerReturnDict
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
OutputType = Dict[str, Any]
|
| 49 |
-
MetricsType = Dict[str, torch.Tensor]
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
def get_optimizer_class(name: str) -> Type[optim.Optimizer]:
|
| 53 |
-
|
| 54 |
-
if name == "DeepSpeedCPUAdam":
|
| 55 |
-
return DeepSpeedCPUAdam
|
| 56 |
-
|
| 57 |
-
for module in [optim, gooptim]:
|
| 58 |
-
if name in module.__dict__:
|
| 59 |
-
return module.__dict__[name]
|
| 60 |
-
|
| 61 |
-
raise NameError
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
def parse_optimizer_config(
|
| 65 |
-
config: OptimizerSchedulerConfigDict, parameters: Iterator[nn.Parameter]
|
| 66 |
-
) -> ConfigureOptimizerReturnDict:
|
| 67 |
-
optim_class = get_optimizer_class(config["optimizer"]["name"])
|
| 68 |
-
optimizer = optim_class(parameters, **config["optimizer"]["kwargs"])
|
| 69 |
-
|
| 70 |
-
optim_dict: ConfigureOptimizerReturnDict = {
|
| 71 |
-
"optimizer": optimizer,
|
| 72 |
-
}
|
| 73 |
-
|
| 74 |
-
if "scheduler" in config:
|
| 75 |
-
|
| 76 |
-
lr_scheduler_class_ = config["scheduler"]["name"]
|
| 77 |
-
lr_scheduler_class = lr_scheduler.__dict__[lr_scheduler_class_]
|
| 78 |
-
lr_scheduler_dict: LRSchedulerReturnDict = {
|
| 79 |
-
"scheduler": lr_scheduler_class(optimizer, **config["scheduler"]["kwargs"])
|
| 80 |
-
}
|
| 81 |
-
|
| 82 |
-
if lr_scheduler_class_ == "ReduceLROnPlateau":
|
| 83 |
-
lr_scheduler_dict["monitor"] = config["scheduler"]["monitor"]
|
| 84 |
-
|
| 85 |
-
optim_dict["lr_scheduler"] = lr_scheduler_dict
|
| 86 |
-
|
| 87 |
-
return optim_dict
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
def parse_model_config(config: ConfigDict) -> Any:
|
| 91 |
-
name = config["name"]
|
| 92 |
-
|
| 93 |
-
for module in [model]:
|
| 94 |
-
if name in module.__dict__:
|
| 95 |
-
return module.__dict__[name](**config["kwargs"])
|
| 96 |
-
|
| 97 |
-
raise NameError
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
_LEGACY_LOSS_NAMES = ["HybridL1Loss"]
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
def _parse_legacy_loss_config(config: ConfigDict) -> nn.Module:
|
| 104 |
-
name = config["name"]
|
| 105 |
-
|
| 106 |
-
if name == "HybridL1Loss":
|
| 107 |
-
return loss.TimeFreqL1Loss(**config["kwargs"])
|
| 108 |
-
|
| 109 |
-
raise NameError
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
def parse_loss_config(config: ConfigDict) -> nn.Module:
|
| 113 |
-
name = config["name"]
|
| 114 |
-
|
| 115 |
-
if name in _LEGACY_LOSS_NAMES:
|
| 116 |
-
return _parse_legacy_loss_config(config)
|
| 117 |
-
|
| 118 |
-
for module in [loss, nn.modules.loss, asteroid_losses]:
|
| 119 |
-
if name in module.__dict__:
|
| 120 |
-
return module.__dict__[name](**config["kwargs"])
|
| 121 |
-
|
| 122 |
-
raise NameError
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
def get_metric(config: ConfigDict) -> tm.Metric:
|
| 126 |
-
name = config["name"]
|
| 127 |
-
|
| 128 |
-
for module in [tm, metrics_]:
|
| 129 |
-
if name in module.__dict__:
|
| 130 |
-
return module.__dict__[name](**config["kwargs"])
|
| 131 |
-
raise NameError
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
def parse_metric_config(config: Dict[str, ConfigDict]) -> tm.MetricCollection:
|
| 135 |
-
metrics = {}
|
| 136 |
-
|
| 137 |
-
for metric in config:
|
| 138 |
-
metrics[metric] = get_metric(config[metric])
|
| 139 |
-
|
| 140 |
-
return tm.MetricCollection(metrics)
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
def parse_fader_config(config: ConfigDict) -> BaseFader:
|
| 144 |
-
name = config["name"]
|
| 145 |
-
|
| 146 |
-
for module in [audio_]:
|
| 147 |
-
if name in module.__dict__:
|
| 148 |
-
return module.__dict__[name](**config["kwargs"])
|
| 149 |
-
|
| 150 |
-
raise NameError
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
class LightningSystem(pl.LightningModule):
|
| 154 |
-
_VOX_STEMS = ["speech", "vocals"]
|
| 155 |
-
_BG_STEMS = ["background", "effects", "mne"]
|
| 156 |
-
|
| 157 |
-
def __init__(
|
| 158 |
-
self, config: Dict, loss_adjustment: float = 1.0, attach_fader: bool = False
|
| 159 |
-
) -> None:
|
| 160 |
-
super().__init__()
|
| 161 |
-
self.optimizer_config = config["optimizer"]
|
| 162 |
-
self.model = parse_model_config(config["model"])
|
| 163 |
-
self.loss = parse_loss_config(config["loss"])
|
| 164 |
-
self.metrics = nn.ModuleDict(
|
| 165 |
-
{
|
| 166 |
-
stem: parse_metric_config(config["metrics"]["dev"])
|
| 167 |
-
for stem in self.model.stems
|
| 168 |
-
}
|
| 169 |
-
)
|
| 170 |
-
|
| 171 |
-
self.metrics.disallow_fsdp = True
|
| 172 |
-
|
| 173 |
-
self.test_metrics = nn.ModuleDict(
|
| 174 |
-
{
|
| 175 |
-
stem: parse_metric_config(config["metrics"]["test"])
|
| 176 |
-
for stem in self.model.stems
|
| 177 |
-
}
|
| 178 |
-
)
|
| 179 |
-
|
| 180 |
-
self.test_metrics.disallow_fsdp = True
|
| 181 |
-
|
| 182 |
-
self.fs = config["model"]["kwargs"]["fs"]
|
| 183 |
-
|
| 184 |
-
self.fader_config = config["inference"]["fader"]
|
| 185 |
-
if attach_fader:
|
| 186 |
-
self.fader = parse_fader_config(config["inference"]["fader"])
|
| 187 |
-
else:
|
| 188 |
-
self.fader = None
|
| 189 |
-
|
| 190 |
-
self.augmentation: Optional[BaseAugmentor]
|
| 191 |
-
if config.get("augmentation", None) is not None:
|
| 192 |
-
self.augmentation = StemAugmentor(**config["augmentation"])
|
| 193 |
-
else:
|
| 194 |
-
self.augmentation = None
|
| 195 |
-
|
| 196 |
-
self.predict_output_path: Optional[str] = None
|
| 197 |
-
self.loss_adjustment = loss_adjustment
|
| 198 |
-
|
| 199 |
-
self.val_prefix = None
|
| 200 |
-
self.test_prefix = None
|
| 201 |
-
|
| 202 |
-
def configure_optimizers(self) -> Any:
|
| 203 |
-
return parse_optimizer_config(
|
| 204 |
-
self.optimizer_config, self.trainer.model.parameters()
|
| 205 |
-
)
|
| 206 |
-
|
| 207 |
-
def compute_loss(
|
| 208 |
-
self, batch: BatchedDataDict, output: OutputType
|
| 209 |
-
) -> Dict[str, torch.Tensor]:
|
| 210 |
-
return {"loss": self.loss(output, batch)}
|
| 211 |
-
|
| 212 |
-
def update_metrics(
|
| 213 |
-
self, batch: BatchedDataDict, output: OutputType, mode: str
|
| 214 |
-
) -> None:
|
| 215 |
-
|
| 216 |
-
if mode == "test":
|
| 217 |
-
metrics = self.test_metrics
|
| 218 |
-
else:
|
| 219 |
-
metrics = self.metrics
|
| 220 |
-
|
| 221 |
-
for stem, metric in metrics.items():
|
| 222 |
-
|
| 223 |
-
if stem == "mne:+":
|
| 224 |
-
stem = "mne"
|
| 225 |
-
|
| 226 |
-
if mode == "train":
|
| 227 |
-
metric.update(
|
| 228 |
-
output["audio"][stem],
|
| 229 |
-
batch["audio"][stem],
|
| 230 |
-
)
|
| 231 |
-
else:
|
| 232 |
-
if stem not in batch["audio"]:
|
| 233 |
-
matched = False
|
| 234 |
-
if stem in self._VOX_STEMS:
|
| 235 |
-
for bstem in self._VOX_STEMS:
|
| 236 |
-
if bstem in batch["audio"]:
|
| 237 |
-
batch["audio"][stem] = batch["audio"][bstem]
|
| 238 |
-
matched = True
|
| 239 |
-
break
|
| 240 |
-
elif stem in self._BG_STEMS:
|
| 241 |
-
for bstem in self._BG_STEMS:
|
| 242 |
-
if bstem in batch["audio"]:
|
| 243 |
-
batch["audio"][stem] = batch["audio"][bstem]
|
| 244 |
-
matched = True
|
| 245 |
-
break
|
| 246 |
-
else:
|
| 247 |
-
matched = True
|
| 248 |
-
|
| 249 |
-
if matched:
|
| 250 |
-
if stem == "mne" and "mne" not in output["audio"]:
|
| 251 |
-
output["audio"]["mne"] = (
|
| 252 |
-
output["audio"]["music"] + output["audio"]["effects"]
|
| 253 |
-
)
|
| 254 |
-
|
| 255 |
-
metric.update(
|
| 256 |
-
output["audio"][stem],
|
| 257 |
-
batch["audio"][stem],
|
| 258 |
-
)
|
| 259 |
-
|
| 260 |
-
def compute_metrics(self, mode: str = "dev") -> Dict[str, torch.Tensor]:
|
| 261 |
-
|
| 262 |
-
if mode == "test":
|
| 263 |
-
metrics = self.test_metrics
|
| 264 |
-
else:
|
| 265 |
-
metrics = self.metrics
|
| 266 |
-
|
| 267 |
-
metric_dict = {}
|
| 268 |
-
|
| 269 |
-
for stem, metric in metrics.items():
|
| 270 |
-
md = metric.compute()
|
| 271 |
-
metric_dict.update({f"{stem}/{k}": v for k, v in md.items()})
|
| 272 |
-
|
| 273 |
-
self.log_dict(metric_dict, prog_bar=True, logger=False)
|
| 274 |
-
|
| 275 |
-
return metric_dict
|
| 276 |
-
|
| 277 |
-
def reset_metrics(self, test_mode: bool = False) -> None:
|
| 278 |
-
|
| 279 |
-
if test_mode:
|
| 280 |
-
metrics = self.test_metrics
|
| 281 |
-
else:
|
| 282 |
-
metrics = self.metrics
|
| 283 |
-
|
| 284 |
-
for _, metric in metrics.items():
|
| 285 |
-
metric.reset()
|
| 286 |
-
|
| 287 |
-
def forward(self, batch: BatchedDataDict) -> Any:
|
| 288 |
-
batch, output = self.model(batch)
|
| 289 |
-
|
| 290 |
-
return batch, output
|
| 291 |
-
|
| 292 |
-
def common_step(self, batch: BatchedDataDict, mode: str) -> Any:
|
| 293 |
-
batch, output = self.forward(batch)
|
| 294 |
-
loss_dict = self.compute_loss(batch, output)
|
| 295 |
-
|
| 296 |
-
with torch.no_grad():
|
| 297 |
-
self.update_metrics(batch, output, mode=mode)
|
| 298 |
-
|
| 299 |
-
if mode == "train":
|
| 300 |
-
self.log("loss", loss_dict["loss"], prog_bar=True)
|
| 301 |
-
|
| 302 |
-
return output, loss_dict
|
| 303 |
-
|
| 304 |
-
def training_step(self, batch: BatchedDataDict) -> Dict[str, Any]:
|
| 305 |
-
|
| 306 |
-
if self.augmentation is not None:
|
| 307 |
-
with torch.no_grad():
|
| 308 |
-
batch = self.augmentation(batch)
|
| 309 |
-
|
| 310 |
-
_, loss_dict = self.common_step(batch, mode="train")
|
| 311 |
-
|
| 312 |
-
with torch.inference_mode():
|
| 313 |
-
self.log_dict_with_prefix(
|
| 314 |
-
loss_dict, "train", batch_size=batch["audio"]["mixture"].shape[0]
|
| 315 |
-
)
|
| 316 |
-
|
| 317 |
-
loss_dict["loss"] *= self.loss_adjustment
|
| 318 |
-
|
| 319 |
-
return loss_dict
|
| 320 |
-
|
| 321 |
-
def on_train_batch_end(
|
| 322 |
-
self, outputs: STEP_OUTPUT, batch: BatchedDataDict, batch_idx: int
|
| 323 |
-
) -> None:
|
| 324 |
-
|
| 325 |
-
metric_dict = self.compute_metrics()
|
| 326 |
-
self.log_dict_with_prefix(metric_dict, "train")
|
| 327 |
-
self.reset_metrics()
|
| 328 |
-
|
| 329 |
-
def validation_step(
|
| 330 |
-
self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0
|
| 331 |
-
) -> Dict[str, Any]:
|
| 332 |
-
|
| 333 |
-
with torch.inference_mode():
|
| 334 |
-
curr_val_prefix = f"val{dataloader_idx}" if dataloader_idx > 0 else "val"
|
| 335 |
-
|
| 336 |
-
if curr_val_prefix != self.val_prefix:
|
| 337 |
-
if self.val_prefix is not None:
|
| 338 |
-
self._on_validation_epoch_end()
|
| 339 |
-
self.val_prefix = curr_val_prefix
|
| 340 |
-
_, loss_dict = self.common_step(batch, mode="val")
|
| 341 |
-
|
| 342 |
-
self.log_dict_with_prefix(
|
| 343 |
-
loss_dict,
|
| 344 |
-
self.val_prefix,
|
| 345 |
-
batch_size=batch["audio"]["mixture"].shape[0],
|
| 346 |
-
prog_bar=True,
|
| 347 |
-
add_dataloader_idx=False,
|
| 348 |
-
)
|
| 349 |
-
|
| 350 |
-
return loss_dict
|
| 351 |
-
|
| 352 |
-
def on_validation_epoch_end(self) -> None:
|
| 353 |
-
self._on_validation_epoch_end()
|
| 354 |
-
|
| 355 |
-
def _on_validation_epoch_end(self) -> None:
|
| 356 |
-
metric_dict = self.compute_metrics()
|
| 357 |
-
self.log_dict_with_prefix(
|
| 358 |
-
metric_dict, self.val_prefix, prog_bar=True, add_dataloader_idx=False
|
| 359 |
-
)
|
| 360 |
-
self.reset_metrics()
|
| 361 |
-
|
| 362 |
-
def old_predtest_step(
|
| 363 |
-
self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0
|
| 364 |
-
) -> Tuple[BatchedDataDict, OutputType]:
|
| 365 |
-
|
| 366 |
-
audio_batch = batch["audio"]["mixture"]
|
| 367 |
-
track_batch = batch.get("track", ["" for _ in range(len(audio_batch))])
|
| 368 |
-
|
| 369 |
-
output_list_of_dicts = [
|
| 370 |
-
self.fader(audio[None, ...], lambda a: self.test_forward(a, track))
|
| 371 |
-
for audio, track in zip(audio_batch, track_batch)
|
| 372 |
-
]
|
| 373 |
-
|
| 374 |
-
output_dict_of_lists = defaultdict(list)
|
| 375 |
-
|
| 376 |
-
for output_dict in output_list_of_dicts:
|
| 377 |
-
for stem, audio in output_dict.items():
|
| 378 |
-
output_dict_of_lists[stem].append(audio)
|
| 379 |
-
|
| 380 |
-
output = {
|
| 381 |
-
"audio": {
|
| 382 |
-
stem: torch.concat(output_list, dim=0)
|
| 383 |
-
for stem, output_list in output_dict_of_lists.items()
|
| 384 |
-
}
|
| 385 |
-
}
|
| 386 |
-
|
| 387 |
-
return batch, output
|
| 388 |
-
|
| 389 |
-
def predtest_step(
|
| 390 |
-
self, batch: BatchedDataDict, batch_idx: int = -1, dataloader_idx: int = 0
|
| 391 |
-
) -> Tuple[BatchedDataDict, OutputType]:
|
| 392 |
-
|
| 393 |
-
if getattr(self.model, "bypass_fader", False):
|
| 394 |
-
batch, output = self.model(batch)
|
| 395 |
-
else:
|
| 396 |
-
audio_batch = batch["audio"]["mixture"]
|
| 397 |
-
output = self.fader(
|
| 398 |
-
audio_batch, lambda a: self.test_forward(a, "", batch=batch)
|
| 399 |
-
)
|
| 400 |
-
|
| 401 |
-
return batch, output
|
| 402 |
-
|
| 403 |
-
def test_forward(
|
| 404 |
-
self, audio: torch.Tensor, track: str = "", batch: BatchedDataDict = None
|
| 405 |
-
) -> torch.Tensor:
|
| 406 |
-
|
| 407 |
-
if self.fader is None:
|
| 408 |
-
self.attach_fader()
|
| 409 |
-
|
| 410 |
-
cond = batch.get("condition", None)
|
| 411 |
-
|
| 412 |
-
if cond is not None and cond.shape[0] == 1:
|
| 413 |
-
cond = cond.repeat(audio.shape[0], 1)
|
| 414 |
-
|
| 415 |
-
_, output = self.forward(
|
| 416 |
-
{
|
| 417 |
-
"audio": {"mixture": audio},
|
| 418 |
-
"track": track,
|
| 419 |
-
"condition": cond,
|
| 420 |
-
}
|
| 421 |
-
)
|
| 422 |
-
|
| 423 |
-
return output["audio"]
|
| 424 |
-
|
| 425 |
-
def on_test_epoch_start(self) -> None:
|
| 426 |
-
self.attach_fader(force_reattach=True)
|
| 427 |
-
|
| 428 |
-
def test_step(
|
| 429 |
-
self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0
|
| 430 |
-
) -> Any:
|
| 431 |
-
curr_test_prefix = f"test{dataloader_idx}"
|
| 432 |
-
|
| 433 |
-
if curr_test_prefix != self.test_prefix:
|
| 434 |
-
if self.test_prefix is not None:
|
| 435 |
-
self._on_test_epoch_end()
|
| 436 |
-
self.test_prefix = curr_test_prefix
|
| 437 |
-
|
| 438 |
-
with torch.inference_mode():
|
| 439 |
-
_, output = self.predtest_step(batch, batch_idx, dataloader_idx)
|
| 440 |
-
self.update_metrics(batch, output, mode="test")
|
| 441 |
-
|
| 442 |
-
return output
|
| 443 |
-
|
| 444 |
-
def on_test_epoch_end(self) -> None:
|
| 445 |
-
self._on_test_epoch_end()
|
| 446 |
-
|
| 447 |
-
def _on_test_epoch_end(self) -> None:
|
| 448 |
-
metric_dict = self.compute_metrics(mode="test")
|
| 449 |
-
self.log_dict_with_prefix(
|
| 450 |
-
metric_dict, self.test_prefix, prog_bar=True, add_dataloader_idx=False
|
| 451 |
-
)
|
| 452 |
-
self.reset_metrics()
|
| 453 |
-
|
| 454 |
-
def predict_step(
|
| 455 |
-
self,
|
| 456 |
-
batch: BatchedDataDict,
|
| 457 |
-
batch_idx: int = 0,
|
| 458 |
-
dataloader_idx: int = 0,
|
| 459 |
-
include_track_name: Optional[bool] = None,
|
| 460 |
-
get_no_vox_combinations: bool = True,
|
| 461 |
-
get_residual: bool = False,
|
| 462 |
-
treat_batch_as_channels: bool = False,
|
| 463 |
-
fs: Optional[int] = None,
|
| 464 |
-
) -> Any:
|
| 465 |
-
assert self.predict_output_path is not None
|
| 466 |
-
|
| 467 |
-
batch_size = batch["audio"]["mixture"].shape[0]
|
| 468 |
-
|
| 469 |
-
if include_track_name is None:
|
| 470 |
-
include_track_name = batch_size > 1
|
| 471 |
-
|
| 472 |
-
with torch.inference_mode():
|
| 473 |
-
batch, output = self.predtest_step(batch, batch_idx, dataloader_idx)
|
| 474 |
-
print("Pred test finished...")
|
| 475 |
-
torch.cuda.empty_cache()
|
| 476 |
-
metric_dict = {}
|
| 477 |
-
|
| 478 |
-
if get_residual:
|
| 479 |
-
mixture = batch["audio"]["mixture"]
|
| 480 |
-
extracted = sum([output["audio"][stem] for stem in output["audio"]])
|
| 481 |
-
residual = mixture - extracted
|
| 482 |
-
print(extracted.shape, mixture.shape, residual.shape)
|
| 483 |
-
|
| 484 |
-
output["audio"]["residual"] = residual
|
| 485 |
-
|
| 486 |
-
if get_no_vox_combinations:
|
| 487 |
-
no_vox_stems = [
|
| 488 |
-
stem for stem in output["audio"] if stem not in self._VOX_STEMS
|
| 489 |
-
]
|
| 490 |
-
no_vox_combinations = chain.from_iterable(
|
| 491 |
-
combinations(no_vox_stems, r) for r in range(2, len(no_vox_stems) + 1)
|
| 492 |
-
)
|
| 493 |
-
|
| 494 |
-
for combination in no_vox_combinations:
|
| 495 |
-
combination_ = list(combination)
|
| 496 |
-
output["audio"]["+".join(combination_)] = sum(
|
| 497 |
-
[output["audio"][stem] for stem in combination_]
|
| 498 |
-
)
|
| 499 |
-
|
| 500 |
-
if treat_batch_as_channels:
|
| 501 |
-
for stem in output["audio"]:
|
| 502 |
-
output["audio"][stem] = output["audio"][stem].reshape(
|
| 503 |
-
1, -1, output["audio"][stem].shape[-1]
|
| 504 |
-
)
|
| 505 |
-
batch_size = 1
|
| 506 |
-
|
| 507 |
-
for b in range(batch_size):
|
| 508 |
-
print("!!", b)
|
| 509 |
-
for stem in output["audio"]:
|
| 510 |
-
print(f"Saving audio for {stem} to {self.predict_output_path}")
|
| 511 |
-
track_name = batch["track"][b].split("/")[-1]
|
| 512 |
-
|
| 513 |
-
if batch.get("audio", {}).get(stem, None) is not None:
|
| 514 |
-
self.test_metrics[stem].reset()
|
| 515 |
-
metrics = self.test_metrics[stem](
|
| 516 |
-
batch["audio"][stem][[b], ...], output["audio"][stem][[b], ...]
|
| 517 |
-
)
|
| 518 |
-
snr = metrics["snr"]
|
| 519 |
-
sisnr = metrics["sisnr"]
|
| 520 |
-
sdr = metrics["sdr"]
|
| 521 |
-
metric_dict[stem] = metrics
|
| 522 |
-
print(
|
| 523 |
-
track_name,
|
| 524 |
-
f"snr={snr:2.2f} dB",
|
| 525 |
-
f"sisnr={sisnr:2.2f}",
|
| 526 |
-
f"sdr={sdr:2.2f} dB",
|
| 527 |
-
)
|
| 528 |
-
filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav"
|
| 529 |
-
else:
|
| 530 |
-
filename = f"{stem}.wav"
|
| 531 |
-
|
| 532 |
-
if include_track_name:
|
| 533 |
-
output_dir = os.path.join(self.predict_output_path, track_name)
|
| 534 |
-
else:
|
| 535 |
-
output_dir = self.predict_output_path
|
| 536 |
-
|
| 537 |
-
os.makedirs(output_dir, exist_ok=True)
|
| 538 |
-
|
| 539 |
-
if fs is None:
|
| 540 |
-
fs = self.fs
|
| 541 |
-
|
| 542 |
-
ta.save(
|
| 543 |
-
os.path.join(output_dir, filename),
|
| 544 |
-
output["audio"][stem][b, ...].cpu(),
|
| 545 |
-
fs,
|
| 546 |
-
)
|
| 547 |
-
|
| 548 |
-
return metric_dict
|
| 549 |
-
|
| 550 |
-
def get_stems(
|
| 551 |
-
self,
|
| 552 |
-
batch: BatchedDataDict,
|
| 553 |
-
batch_idx: int = 0,
|
| 554 |
-
dataloader_idx: int = 0,
|
| 555 |
-
include_track_name: Optional[bool] = None,
|
| 556 |
-
get_no_vox_combinations: bool = True,
|
| 557 |
-
get_residual: bool = False,
|
| 558 |
-
treat_batch_as_channels: bool = False,
|
| 559 |
-
fs: Optional[int] = None,
|
| 560 |
-
) -> Any:
|
| 561 |
-
assert self.predict_output_path is not None
|
| 562 |
-
|
| 563 |
-
batch_size = batch["audio"]["mixture"].shape[0]
|
| 564 |
-
|
| 565 |
-
if include_track_name is None:
|
| 566 |
-
include_track_name = batch_size > 1
|
| 567 |
-
|
| 568 |
-
with torch.inference_mode():
|
| 569 |
-
batch, output = self.predtest_step(batch, batch_idx, dataloader_idx)
|
| 570 |
-
torch.cuda.empty_cache()
|
| 571 |
-
metric_dict = {}
|
| 572 |
-
|
| 573 |
-
if get_residual:
|
| 574 |
-
mixture = batch["audio"]["mixture"]
|
| 575 |
-
extracted = sum([output["audio"][stem] for stem in output["audio"]])
|
| 576 |
-
residual = mixture - extracted
|
| 577 |
-
|
| 578 |
-
output["audio"]["residual"] = residual
|
| 579 |
-
|
| 580 |
-
if get_no_vox_combinations:
|
| 581 |
-
no_vox_stems = [
|
| 582 |
-
stem for stem in output["audio"] if stem not in self._VOX_STEMS
|
| 583 |
-
]
|
| 584 |
-
no_vox_combinations = chain.from_iterable(
|
| 585 |
-
combinations(no_vox_stems, r) for r in range(2, len(no_vox_stems) + 1)
|
| 586 |
-
)
|
| 587 |
-
|
| 588 |
-
for combination in no_vox_combinations:
|
| 589 |
-
combination_ = list(combination)
|
| 590 |
-
output["audio"]["+".join(combination_)] = sum(
|
| 591 |
-
[output["audio"][stem] for stem in combination_]
|
| 592 |
-
)
|
| 593 |
-
|
| 594 |
-
if treat_batch_as_channels:
|
| 595 |
-
for stem in output["audio"]:
|
| 596 |
-
output["audio"][stem] = output["audio"][stem].reshape(
|
| 597 |
-
1, -1, output["audio"][stem].shape[-1]
|
| 598 |
-
)
|
| 599 |
-
batch_size = 1
|
| 600 |
-
|
| 601 |
-
result = {}
|
| 602 |
-
for b in range(batch_size):
|
| 603 |
-
for stem in output["audio"]:
|
| 604 |
-
track_name = batch["track"][b].split("/")[-1]
|
| 605 |
-
|
| 606 |
-
if batch.get("audio", {}).get(stem, None) is not None:
|
| 607 |
-
self.test_metrics[stem].reset()
|
| 608 |
-
metrics = self.test_metrics[stem](
|
| 609 |
-
batch["audio"][stem][[b], ...], output["audio"][stem][[b], ...]
|
| 610 |
-
)
|
| 611 |
-
snr = metrics["snr"]
|
| 612 |
-
sisnr = metrics["sisnr"]
|
| 613 |
-
sdr = metrics["sdr"]
|
| 614 |
-
metric_dict[stem] = metrics
|
| 615 |
-
print(
|
| 616 |
-
track_name,
|
| 617 |
-
f"snr={snr:2.2f} dB",
|
| 618 |
-
f"sisnr={sisnr:2.2f}",
|
| 619 |
-
f"sdr={sdr:2.2f} dB",
|
| 620 |
-
)
|
| 621 |
-
filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav"
|
| 622 |
-
else:
|
| 623 |
-
filename = f"{stem}.wav"
|
| 624 |
-
|
| 625 |
-
if include_track_name:
|
| 626 |
-
output_dir = os.path.join(self.predict_output_path, track_name)
|
| 627 |
-
else:
|
| 628 |
-
output_dir = self.predict_output_path
|
| 629 |
-
|
| 630 |
-
os.makedirs(output_dir, exist_ok=True)
|
| 631 |
-
|
| 632 |
-
if fs is None:
|
| 633 |
-
fs = self.fs
|
| 634 |
-
|
| 635 |
-
result[stem] = output["audio"][stem][b, ...].cpu().numpy()
|
| 636 |
-
|
| 637 |
-
return result
|
| 638 |
-
|
| 639 |
-
def load_state_dict(
|
| 640 |
-
self, state_dict: Mapping[str, Any], strict: bool = False
|
| 641 |
-
) -> Any:
|
| 642 |
-
|
| 643 |
-
return super().load_state_dict(state_dict, strict=False)
|
| 644 |
-
|
| 645 |
-
def set_predict_output_path(self, path: str) -> None:
|
| 646 |
-
self.predict_output_path = path
|
| 647 |
-
os.makedirs(self.predict_output_path, exist_ok=True)
|
| 648 |
-
|
| 649 |
-
self.attach_fader()
|
| 650 |
-
|
| 651 |
-
def attach_fader(self, force_reattach=False) -> None:
|
| 652 |
-
if self.fader is None or force_reattach:
|
| 653 |
-
self.fader = parse_fader_config(self.fader_config)
|
| 654 |
-
self.fader.to(self.device)
|
| 655 |
-
|
| 656 |
-
def log_dict_with_prefix(
|
| 657 |
-
self,
|
| 658 |
-
dict_: Dict[str, torch.Tensor],
|
| 659 |
-
prefix: str,
|
| 660 |
-
batch_size: Optional[int] = None,
|
| 661 |
-
**kwargs: Any,
|
| 662 |
-
) -> None:
|
| 663 |
-
self.log_dict(
|
| 664 |
-
{f"{prefix}/{k}": v for k, v in dict_.items()},
|
| 665 |
-
batch_size=batch_size,
|
| 666 |
-
logger=True,
|
| 667 |
-
sync_dist=True,
|
| 668 |
-
**kwargs,
|
| 669 |
-
)
|
|
|
|
| 1 |
+
import os.path
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
from itertools import chain, combinations
|
| 4 |
+
from typing import Any, Dict, Iterator, Mapping, Optional, Tuple, Type, TypedDict
|
| 5 |
+
|
| 6 |
+
import pytorch_lightning as pl
|
| 7 |
+
import torch
|
| 8 |
+
import torchaudio as ta
|
| 9 |
+
import torchmetrics as tm
|
| 10 |
+
from asteroid import losses as asteroid_losses
|
| 11 |
+
|
| 12 |
+
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
| 13 |
+
from torch import nn, optim
|
| 14 |
+
from torch.optim import lr_scheduler
|
| 15 |
+
from torch.optim.lr_scheduler import LRScheduler
|
| 16 |
+
|
| 17 |
+
from . import loss, metrics as metrics_, model
|
| 18 |
+
from .data._types import BatchedDataDict
|
| 19 |
+
from .data.augmentation import BaseAugmentor, StemAugmentor
|
| 20 |
+
from .utils import audio as audio_
|
| 21 |
+
from .utils.audio import BaseFader
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
ConfigDict = TypedDict("ConfigDict", {"name": str, "kwargs": Dict[str, Any]})
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class SchedulerConfigDict(ConfigDict):
|
| 28 |
+
monitor: str
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
OptimizerSchedulerConfigDict = TypedDict(
|
| 32 |
+
"OptimizerSchedulerConfigDict",
|
| 33 |
+
{"optimizer": ConfigDict, "scheduler": SchedulerConfigDict},
|
| 34 |
+
total=False,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class LRSchedulerReturnDict(TypedDict, total=False):
|
| 39 |
+
scheduler: LRScheduler
|
| 40 |
+
monitor: str
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class ConfigureOptimizerReturnDict(TypedDict, total=False):
|
| 44 |
+
optimizer: torch.optim.Optimizer
|
| 45 |
+
lr_scheduler: LRSchedulerReturnDict
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
OutputType = Dict[str, Any]
|
| 49 |
+
MetricsType = Dict[str, torch.Tensor]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def get_optimizer_class(name: str) -> Type[optim.Optimizer]:
|
| 53 |
+
|
| 54 |
+
if name == "DeepSpeedCPUAdam":
|
| 55 |
+
return DeepSpeedCPUAdam
|
| 56 |
+
|
| 57 |
+
for module in [optim, gooptim]:
|
| 58 |
+
if name in module.__dict__:
|
| 59 |
+
return module.__dict__[name]
|
| 60 |
+
|
| 61 |
+
raise NameError
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def parse_optimizer_config(
|
| 65 |
+
config: OptimizerSchedulerConfigDict, parameters: Iterator[nn.Parameter]
|
| 66 |
+
) -> ConfigureOptimizerReturnDict:
|
| 67 |
+
optim_class = get_optimizer_class(config["optimizer"]["name"])
|
| 68 |
+
optimizer = optim_class(parameters, **config["optimizer"]["kwargs"])
|
| 69 |
+
|
| 70 |
+
optim_dict: ConfigureOptimizerReturnDict = {
|
| 71 |
+
"optimizer": optimizer,
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
if "scheduler" in config:
|
| 75 |
+
|
| 76 |
+
lr_scheduler_class_ = config["scheduler"]["name"]
|
| 77 |
+
lr_scheduler_class = lr_scheduler.__dict__[lr_scheduler_class_]
|
| 78 |
+
lr_scheduler_dict: LRSchedulerReturnDict = {
|
| 79 |
+
"scheduler": lr_scheduler_class(optimizer, **config["scheduler"]["kwargs"])
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
if lr_scheduler_class_ == "ReduceLROnPlateau":
|
| 83 |
+
lr_scheduler_dict["monitor"] = config["scheduler"]["monitor"]
|
| 84 |
+
|
| 85 |
+
optim_dict["lr_scheduler"] = lr_scheduler_dict
|
| 86 |
+
|
| 87 |
+
return optim_dict
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def parse_model_config(config: ConfigDict) -> Any:
|
| 91 |
+
name = config["name"]
|
| 92 |
+
|
| 93 |
+
for module in [model]:
|
| 94 |
+
if name in module.__dict__:
|
| 95 |
+
return module.__dict__[name](**config["kwargs"])
|
| 96 |
+
|
| 97 |
+
raise NameError
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
_LEGACY_LOSS_NAMES = ["HybridL1Loss"]
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def _parse_legacy_loss_config(config: ConfigDict) -> nn.Module:
|
| 104 |
+
name = config["name"]
|
| 105 |
+
|
| 106 |
+
if name == "HybridL1Loss":
|
| 107 |
+
return loss.TimeFreqL1Loss(**config["kwargs"])
|
| 108 |
+
|
| 109 |
+
raise NameError
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def parse_loss_config(config: ConfigDict) -> nn.Module:
|
| 113 |
+
name = config["name"]
|
| 114 |
+
|
| 115 |
+
if name in _LEGACY_LOSS_NAMES:
|
| 116 |
+
return _parse_legacy_loss_config(config)
|
| 117 |
+
|
| 118 |
+
for module in [loss, nn.modules.loss, asteroid_losses]:
|
| 119 |
+
if name in module.__dict__:
|
| 120 |
+
return module.__dict__[name](**config["kwargs"])
|
| 121 |
+
|
| 122 |
+
raise NameError
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def get_metric(config: ConfigDict) -> tm.Metric:
|
| 126 |
+
name = config["name"]
|
| 127 |
+
|
| 128 |
+
for module in [tm, metrics_]:
|
| 129 |
+
if name in module.__dict__:
|
| 130 |
+
return module.__dict__[name](**config["kwargs"])
|
| 131 |
+
raise NameError
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def parse_metric_config(config: Dict[str, ConfigDict]) -> tm.MetricCollection:
|
| 135 |
+
metrics = {}
|
| 136 |
+
|
| 137 |
+
for metric in config:
|
| 138 |
+
metrics[metric] = get_metric(config[metric])
|
| 139 |
+
|
| 140 |
+
return tm.MetricCollection(metrics)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def parse_fader_config(config: ConfigDict) -> BaseFader:
|
| 144 |
+
name = config["name"]
|
| 145 |
+
|
| 146 |
+
for module in [audio_]:
|
| 147 |
+
if name in module.__dict__:
|
| 148 |
+
return module.__dict__[name](**config["kwargs"])
|
| 149 |
+
|
| 150 |
+
raise NameError
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class LightningSystem(pl.LightningModule):
|
| 154 |
+
_VOX_STEMS = ["speech", "vocals"]
|
| 155 |
+
_BG_STEMS = ["background", "effects", "mne"]
|
| 156 |
+
|
| 157 |
+
def __init__(
|
| 158 |
+
self, config: Dict, loss_adjustment: float = 1.0, attach_fader: bool = False
|
| 159 |
+
) -> None:
|
| 160 |
+
super().__init__()
|
| 161 |
+
self.optimizer_config = config["optimizer"]
|
| 162 |
+
self.model = parse_model_config(config["model"])
|
| 163 |
+
self.loss = parse_loss_config(config["loss"])
|
| 164 |
+
self.metrics = nn.ModuleDict(
|
| 165 |
+
{
|
| 166 |
+
stem: parse_metric_config(config["metrics"]["dev"])
|
| 167 |
+
for stem in self.model.stems
|
| 168 |
+
}
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
self.metrics.disallow_fsdp = True
|
| 172 |
+
|
| 173 |
+
self.test_metrics = nn.ModuleDict(
|
| 174 |
+
{
|
| 175 |
+
stem: parse_metric_config(config["metrics"]["test"])
|
| 176 |
+
for stem in self.model.stems
|
| 177 |
+
}
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
self.test_metrics.disallow_fsdp = True
|
| 181 |
+
|
| 182 |
+
self.fs = config["model"]["kwargs"]["fs"]
|
| 183 |
+
|
| 184 |
+
self.fader_config = config["inference"]["fader"]
|
| 185 |
+
if attach_fader:
|
| 186 |
+
self.fader = parse_fader_config(config["inference"]["fader"])
|
| 187 |
+
else:
|
| 188 |
+
self.fader = None
|
| 189 |
+
|
| 190 |
+
self.augmentation: Optional[BaseAugmentor]
|
| 191 |
+
if config.get("augmentation", None) is not None:
|
| 192 |
+
self.augmentation = StemAugmentor(**config["augmentation"])
|
| 193 |
+
else:
|
| 194 |
+
self.augmentation = None
|
| 195 |
+
|
| 196 |
+
self.predict_output_path: Optional[str] = None
|
| 197 |
+
self.loss_adjustment = loss_adjustment
|
| 198 |
+
|
| 199 |
+
self.val_prefix = None
|
| 200 |
+
self.test_prefix = None
|
| 201 |
+
|
| 202 |
+
def configure_optimizers(self) -> Any:
|
| 203 |
+
return parse_optimizer_config(
|
| 204 |
+
self.optimizer_config, self.trainer.model.parameters()
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
def compute_loss(
|
| 208 |
+
self, batch: BatchedDataDict, output: OutputType
|
| 209 |
+
) -> Dict[str, torch.Tensor]:
|
| 210 |
+
return {"loss": self.loss(output, batch)}
|
| 211 |
+
|
| 212 |
+
def update_metrics(
|
| 213 |
+
self, batch: BatchedDataDict, output: OutputType, mode: str
|
| 214 |
+
) -> None:
|
| 215 |
+
|
| 216 |
+
if mode == "test":
|
| 217 |
+
metrics = self.test_metrics
|
| 218 |
+
else:
|
| 219 |
+
metrics = self.metrics
|
| 220 |
+
|
| 221 |
+
for stem, metric in metrics.items():
|
| 222 |
+
|
| 223 |
+
if stem == "mne:+":
|
| 224 |
+
stem = "mne"
|
| 225 |
+
|
| 226 |
+
if mode == "train":
|
| 227 |
+
metric.update(
|
| 228 |
+
output["audio"][stem],
|
| 229 |
+
batch["audio"][stem],
|
| 230 |
+
)
|
| 231 |
+
else:
|
| 232 |
+
if stem not in batch["audio"]:
|
| 233 |
+
matched = False
|
| 234 |
+
if stem in self._VOX_STEMS:
|
| 235 |
+
for bstem in self._VOX_STEMS:
|
| 236 |
+
if bstem in batch["audio"]:
|
| 237 |
+
batch["audio"][stem] = batch["audio"][bstem]
|
| 238 |
+
matched = True
|
| 239 |
+
break
|
| 240 |
+
elif stem in self._BG_STEMS:
|
| 241 |
+
for bstem in self._BG_STEMS:
|
| 242 |
+
if bstem in batch["audio"]:
|
| 243 |
+
batch["audio"][stem] = batch["audio"][bstem]
|
| 244 |
+
matched = True
|
| 245 |
+
break
|
| 246 |
+
else:
|
| 247 |
+
matched = True
|
| 248 |
+
|
| 249 |
+
if matched:
|
| 250 |
+
if stem == "mne" and "mne" not in output["audio"]:
|
| 251 |
+
output["audio"]["mne"] = (
|
| 252 |
+
output["audio"]["music"] + output["audio"]["effects"]
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
metric.update(
|
| 256 |
+
output["audio"][stem],
|
| 257 |
+
batch["audio"][stem],
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
def compute_metrics(self, mode: str = "dev") -> Dict[str, torch.Tensor]:
|
| 261 |
+
|
| 262 |
+
if mode == "test":
|
| 263 |
+
metrics = self.test_metrics
|
| 264 |
+
else:
|
| 265 |
+
metrics = self.metrics
|
| 266 |
+
|
| 267 |
+
metric_dict = {}
|
| 268 |
+
|
| 269 |
+
for stem, metric in metrics.items():
|
| 270 |
+
md = metric.compute()
|
| 271 |
+
metric_dict.update({f"{stem}/{k}": v for k, v in md.items()})
|
| 272 |
+
|
| 273 |
+
self.log_dict(metric_dict, prog_bar=True, logger=False)
|
| 274 |
+
|
| 275 |
+
return metric_dict
|
| 276 |
+
|
| 277 |
+
def reset_metrics(self, test_mode: bool = False) -> None:
|
| 278 |
+
|
| 279 |
+
if test_mode:
|
| 280 |
+
metrics = self.test_metrics
|
| 281 |
+
else:
|
| 282 |
+
metrics = self.metrics
|
| 283 |
+
|
| 284 |
+
for _, metric in metrics.items():
|
| 285 |
+
metric.reset()
|
| 286 |
+
|
| 287 |
+
def forward(self, batch: BatchedDataDict) -> Any:
|
| 288 |
+
batch, output = self.model(batch)
|
| 289 |
+
|
| 290 |
+
return batch, output
|
| 291 |
+
|
| 292 |
+
def common_step(self, batch: BatchedDataDict, mode: str) -> Any:
|
| 293 |
+
batch, output = self.forward(batch)
|
| 294 |
+
loss_dict = self.compute_loss(batch, output)
|
| 295 |
+
|
| 296 |
+
with torch.no_grad():
|
| 297 |
+
self.update_metrics(batch, output, mode=mode)
|
| 298 |
+
|
| 299 |
+
if mode == "train":
|
| 300 |
+
self.log("loss", loss_dict["loss"], prog_bar=True)
|
| 301 |
+
|
| 302 |
+
return output, loss_dict
|
| 303 |
+
|
| 304 |
+
def training_step(self, batch: BatchedDataDict) -> Dict[str, Any]:
|
| 305 |
+
|
| 306 |
+
if self.augmentation is not None:
|
| 307 |
+
with torch.no_grad():
|
| 308 |
+
batch = self.augmentation(batch)
|
| 309 |
+
|
| 310 |
+
_, loss_dict = self.common_step(batch, mode="train")
|
| 311 |
+
|
| 312 |
+
with torch.inference_mode():
|
| 313 |
+
self.log_dict_with_prefix(
|
| 314 |
+
loss_dict, "train", batch_size=batch["audio"]["mixture"].shape[0]
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
loss_dict["loss"] *= self.loss_adjustment
|
| 318 |
+
|
| 319 |
+
return loss_dict
|
| 320 |
+
|
| 321 |
+
def on_train_batch_end(
|
| 322 |
+
self, outputs: STEP_OUTPUT, batch: BatchedDataDict, batch_idx: int
|
| 323 |
+
) -> None:
|
| 324 |
+
|
| 325 |
+
metric_dict = self.compute_metrics()
|
| 326 |
+
self.log_dict_with_prefix(metric_dict, "train")
|
| 327 |
+
self.reset_metrics()
|
| 328 |
+
|
| 329 |
+
def validation_step(
|
| 330 |
+
self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0
|
| 331 |
+
) -> Dict[str, Any]:
|
| 332 |
+
|
| 333 |
+
with torch.inference_mode():
|
| 334 |
+
curr_val_prefix = f"val{dataloader_idx}" if dataloader_idx > 0 else "val"
|
| 335 |
+
|
| 336 |
+
if curr_val_prefix != self.val_prefix:
|
| 337 |
+
if self.val_prefix is not None:
|
| 338 |
+
self._on_validation_epoch_end()
|
| 339 |
+
self.val_prefix = curr_val_prefix
|
| 340 |
+
_, loss_dict = self.common_step(batch, mode="val")
|
| 341 |
+
|
| 342 |
+
self.log_dict_with_prefix(
|
| 343 |
+
loss_dict,
|
| 344 |
+
self.val_prefix,
|
| 345 |
+
batch_size=batch["audio"]["mixture"].shape[0],
|
| 346 |
+
prog_bar=True,
|
| 347 |
+
add_dataloader_idx=False,
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
return loss_dict
|
| 351 |
+
|
| 352 |
+
def on_validation_epoch_end(self) -> None:
|
| 353 |
+
self._on_validation_epoch_end()
|
| 354 |
+
|
| 355 |
+
def _on_validation_epoch_end(self) -> None:
|
| 356 |
+
metric_dict = self.compute_metrics()
|
| 357 |
+
self.log_dict_with_prefix(
|
| 358 |
+
metric_dict, self.val_prefix, prog_bar=True, add_dataloader_idx=False
|
| 359 |
+
)
|
| 360 |
+
self.reset_metrics()
|
| 361 |
+
|
| 362 |
+
def old_predtest_step(
|
| 363 |
+
self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0
|
| 364 |
+
) -> Tuple[BatchedDataDict, OutputType]:
|
| 365 |
+
|
| 366 |
+
audio_batch = batch["audio"]["mixture"]
|
| 367 |
+
track_batch = batch.get("track", ["" for _ in range(len(audio_batch))])
|
| 368 |
+
|
| 369 |
+
output_list_of_dicts = [
|
| 370 |
+
self.fader(audio[None, ...], lambda a: self.test_forward(a, track))
|
| 371 |
+
for audio, track in zip(audio_batch, track_batch)
|
| 372 |
+
]
|
| 373 |
+
|
| 374 |
+
output_dict_of_lists = defaultdict(list)
|
| 375 |
+
|
| 376 |
+
for output_dict in output_list_of_dicts:
|
| 377 |
+
for stem, audio in output_dict.items():
|
| 378 |
+
output_dict_of_lists[stem].append(audio)
|
| 379 |
+
|
| 380 |
+
output = {
|
| 381 |
+
"audio": {
|
| 382 |
+
stem: torch.concat(output_list, dim=0)
|
| 383 |
+
for stem, output_list in output_dict_of_lists.items()
|
| 384 |
+
}
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
return batch, output
|
| 388 |
+
|
| 389 |
+
def predtest_step(
|
| 390 |
+
self, batch: BatchedDataDict, batch_idx: int = -1, dataloader_idx: int = 0
|
| 391 |
+
) -> Tuple[BatchedDataDict, OutputType]:
|
| 392 |
+
|
| 393 |
+
if getattr(self.model, "bypass_fader", False):
|
| 394 |
+
batch, output = self.model(batch)
|
| 395 |
+
else:
|
| 396 |
+
audio_batch = batch["audio"]["mixture"]
|
| 397 |
+
output = self.fader(
|
| 398 |
+
audio_batch, lambda a: self.test_forward(a, "", batch=batch)
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
return batch, output
|
| 402 |
+
|
| 403 |
+
def test_forward(
|
| 404 |
+
self, audio: torch.Tensor, track: str = "", batch: BatchedDataDict = None
|
| 405 |
+
) -> torch.Tensor:
|
| 406 |
+
|
| 407 |
+
if self.fader is None:
|
| 408 |
+
self.attach_fader()
|
| 409 |
+
|
| 410 |
+
cond = batch.get("condition", None)
|
| 411 |
+
|
| 412 |
+
if cond is not None and cond.shape[0] == 1:
|
| 413 |
+
cond = cond.repeat(audio.shape[0], 1)
|
| 414 |
+
|
| 415 |
+
_, output = self.forward(
|
| 416 |
+
{
|
| 417 |
+
"audio": {"mixture": audio},
|
| 418 |
+
"track": track,
|
| 419 |
+
"condition": cond,
|
| 420 |
+
}
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
return output["audio"]
|
| 424 |
+
|
| 425 |
+
def on_test_epoch_start(self) -> None:
|
| 426 |
+
self.attach_fader(force_reattach=True)
|
| 427 |
+
|
| 428 |
+
def test_step(
|
| 429 |
+
self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0
|
| 430 |
+
) -> Any:
|
| 431 |
+
curr_test_prefix = f"test{dataloader_idx}"
|
| 432 |
+
|
| 433 |
+
if curr_test_prefix != self.test_prefix:
|
| 434 |
+
if self.test_prefix is not None:
|
| 435 |
+
self._on_test_epoch_end()
|
| 436 |
+
self.test_prefix = curr_test_prefix
|
| 437 |
+
|
| 438 |
+
with torch.inference_mode():
|
| 439 |
+
_, output = self.predtest_step(batch, batch_idx, dataloader_idx)
|
| 440 |
+
self.update_metrics(batch, output, mode="test")
|
| 441 |
+
|
| 442 |
+
return output
|
| 443 |
+
|
| 444 |
+
def on_test_epoch_end(self) -> None:
|
| 445 |
+
self._on_test_epoch_end()
|
| 446 |
+
|
| 447 |
+
def _on_test_epoch_end(self) -> None:
|
| 448 |
+
metric_dict = self.compute_metrics(mode="test")
|
| 449 |
+
self.log_dict_with_prefix(
|
| 450 |
+
metric_dict, self.test_prefix, prog_bar=True, add_dataloader_idx=False
|
| 451 |
+
)
|
| 452 |
+
self.reset_metrics()
|
| 453 |
+
|
| 454 |
+
def predict_step(
|
| 455 |
+
self,
|
| 456 |
+
batch: BatchedDataDict,
|
| 457 |
+
batch_idx: int = 0,
|
| 458 |
+
dataloader_idx: int = 0,
|
| 459 |
+
include_track_name: Optional[bool] = None,
|
| 460 |
+
get_no_vox_combinations: bool = True,
|
| 461 |
+
get_residual: bool = False,
|
| 462 |
+
treat_batch_as_channels: bool = False,
|
| 463 |
+
fs: Optional[int] = None,
|
| 464 |
+
) -> Any:
|
| 465 |
+
assert self.predict_output_path is not None
|
| 466 |
+
|
| 467 |
+
batch_size = batch["audio"]["mixture"].shape[0]
|
| 468 |
+
|
| 469 |
+
if include_track_name is None:
|
| 470 |
+
include_track_name = batch_size > 1
|
| 471 |
+
|
| 472 |
+
with torch.inference_mode():
|
| 473 |
+
batch, output = self.predtest_step(batch, batch_idx, dataloader_idx)
|
| 474 |
+
print("Pred test finished...")
|
| 475 |
+
torch.cuda.empty_cache()
|
| 476 |
+
metric_dict = {}
|
| 477 |
+
|
| 478 |
+
if get_residual:
|
| 479 |
+
mixture = batch["audio"]["mixture"]
|
| 480 |
+
extracted = sum([output["audio"][stem] for stem in output["audio"]])
|
| 481 |
+
residual = mixture - extracted
|
| 482 |
+
print(extracted.shape, mixture.shape, residual.shape)
|
| 483 |
+
|
| 484 |
+
output["audio"]["residual"] = residual
|
| 485 |
+
|
| 486 |
+
if get_no_vox_combinations:
|
| 487 |
+
no_vox_stems = [
|
| 488 |
+
stem for stem in output["audio"] if stem not in self._VOX_STEMS
|
| 489 |
+
]
|
| 490 |
+
no_vox_combinations = chain.from_iterable(
|
| 491 |
+
combinations(no_vox_stems, r) for r in range(2, len(no_vox_stems) + 1)
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
for combination in no_vox_combinations:
|
| 495 |
+
combination_ = list(combination)
|
| 496 |
+
output["audio"]["+".join(combination_)] = sum(
|
| 497 |
+
[output["audio"][stem] for stem in combination_]
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
if treat_batch_as_channels:
|
| 501 |
+
for stem in output["audio"]:
|
| 502 |
+
output["audio"][stem] = output["audio"][stem].reshape(
|
| 503 |
+
1, -1, output["audio"][stem].shape[-1]
|
| 504 |
+
)
|
| 505 |
+
batch_size = 1
|
| 506 |
+
|
| 507 |
+
for b in range(batch_size):
|
| 508 |
+
print("!!", b)
|
| 509 |
+
for stem in output["audio"]:
|
| 510 |
+
print(f"Saving audio for {stem} to {self.predict_output_path}")
|
| 511 |
+
track_name = batch["track"][b].split("/")[-1]
|
| 512 |
+
|
| 513 |
+
if batch.get("audio", {}).get(stem, None) is not None:
|
| 514 |
+
self.test_metrics[stem].reset()
|
| 515 |
+
metrics = self.test_metrics[stem](
|
| 516 |
+
batch["audio"][stem][[b], ...], output["audio"][stem][[b], ...]
|
| 517 |
+
)
|
| 518 |
+
snr = metrics["snr"]
|
| 519 |
+
sisnr = metrics["sisnr"]
|
| 520 |
+
sdr = metrics["sdr"]
|
| 521 |
+
metric_dict[stem] = metrics
|
| 522 |
+
print(
|
| 523 |
+
track_name,
|
| 524 |
+
f"snr={snr:2.2f} dB",
|
| 525 |
+
f"sisnr={sisnr:2.2f}",
|
| 526 |
+
f"sdr={sdr:2.2f} dB",
|
| 527 |
+
)
|
| 528 |
+
filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav"
|
| 529 |
+
else:
|
| 530 |
+
filename = f"{stem}.wav"
|
| 531 |
+
|
| 532 |
+
if include_track_name:
|
| 533 |
+
output_dir = os.path.join(self.predict_output_path, track_name)
|
| 534 |
+
else:
|
| 535 |
+
output_dir = self.predict_output_path
|
| 536 |
+
|
| 537 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 538 |
+
|
| 539 |
+
if fs is None:
|
| 540 |
+
fs = self.fs
|
| 541 |
+
|
| 542 |
+
ta.save(
|
| 543 |
+
os.path.join(output_dir, filename),
|
| 544 |
+
output["audio"][stem][b, ...].cpu(),
|
| 545 |
+
fs,
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
return metric_dict
|
| 549 |
+
|
| 550 |
+
def get_stems(
|
| 551 |
+
self,
|
| 552 |
+
batch: BatchedDataDict,
|
| 553 |
+
batch_idx: int = 0,
|
| 554 |
+
dataloader_idx: int = 0,
|
| 555 |
+
include_track_name: Optional[bool] = None,
|
| 556 |
+
get_no_vox_combinations: bool = True,
|
| 557 |
+
get_residual: bool = False,
|
| 558 |
+
treat_batch_as_channels: bool = False,
|
| 559 |
+
fs: Optional[int] = None,
|
| 560 |
+
) -> Any:
|
| 561 |
+
assert self.predict_output_path is not None
|
| 562 |
+
|
| 563 |
+
batch_size = batch["audio"]["mixture"].shape[0]
|
| 564 |
+
|
| 565 |
+
if include_track_name is None:
|
| 566 |
+
include_track_name = batch_size > 1
|
| 567 |
+
|
| 568 |
+
with torch.inference_mode():
|
| 569 |
+
batch, output = self.predtest_step(batch, batch_idx, dataloader_idx)
|
| 570 |
+
torch.cuda.empty_cache()
|
| 571 |
+
metric_dict = {}
|
| 572 |
+
|
| 573 |
+
if get_residual:
|
| 574 |
+
mixture = batch["audio"]["mixture"]
|
| 575 |
+
extracted = sum([output["audio"][stem] for stem in output["audio"]])
|
| 576 |
+
residual = mixture - extracted
|
| 577 |
+
|
| 578 |
+
output["audio"]["residual"] = residual
|
| 579 |
+
|
| 580 |
+
if get_no_vox_combinations:
|
| 581 |
+
no_vox_stems = [
|
| 582 |
+
stem for stem in output["audio"] if stem not in self._VOX_STEMS
|
| 583 |
+
]
|
| 584 |
+
no_vox_combinations = chain.from_iterable(
|
| 585 |
+
combinations(no_vox_stems, r) for r in range(2, len(no_vox_stems) + 1)
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
for combination in no_vox_combinations:
|
| 589 |
+
combination_ = list(combination)
|
| 590 |
+
output["audio"]["+".join(combination_)] = sum(
|
| 591 |
+
[output["audio"][stem] for stem in combination_]
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
if treat_batch_as_channels:
|
| 595 |
+
for stem in output["audio"]:
|
| 596 |
+
output["audio"][stem] = output["audio"][stem].reshape(
|
| 597 |
+
1, -1, output["audio"][stem].shape[-1]
|
| 598 |
+
)
|
| 599 |
+
batch_size = 1
|
| 600 |
+
|
| 601 |
+
result = {}
|
| 602 |
+
for b in range(batch_size):
|
| 603 |
+
for stem in output["audio"]:
|
| 604 |
+
track_name = batch["track"][b].split("/")[-1]
|
| 605 |
+
|
| 606 |
+
if batch.get("audio", {}).get(stem, None) is not None:
|
| 607 |
+
self.test_metrics[stem].reset()
|
| 608 |
+
metrics = self.test_metrics[stem](
|
| 609 |
+
batch["audio"][stem][[b], ...], output["audio"][stem][[b], ...]
|
| 610 |
+
)
|
| 611 |
+
snr = metrics["snr"]
|
| 612 |
+
sisnr = metrics["sisnr"]
|
| 613 |
+
sdr = metrics["sdr"]
|
| 614 |
+
metric_dict[stem] = metrics
|
| 615 |
+
print(
|
| 616 |
+
track_name,
|
| 617 |
+
f"snr={snr:2.2f} dB",
|
| 618 |
+
f"sisnr={sisnr:2.2f}",
|
| 619 |
+
f"sdr={sdr:2.2f} dB",
|
| 620 |
+
)
|
| 621 |
+
filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav"
|
| 622 |
+
else:
|
| 623 |
+
filename = f"{stem}.wav"
|
| 624 |
+
|
| 625 |
+
if include_track_name:
|
| 626 |
+
output_dir = os.path.join(self.predict_output_path, track_name)
|
| 627 |
+
else:
|
| 628 |
+
output_dir = self.predict_output_path
|
| 629 |
+
|
| 630 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 631 |
+
|
| 632 |
+
if fs is None:
|
| 633 |
+
fs = self.fs
|
| 634 |
+
|
| 635 |
+
result[stem] = output["audio"][stem][b, ...].cpu().numpy()
|
| 636 |
+
|
| 637 |
+
return result
|
| 638 |
+
|
| 639 |
+
def load_state_dict(
|
| 640 |
+
self, state_dict: Mapping[str, Any], strict: bool = False
|
| 641 |
+
) -> Any:
|
| 642 |
+
|
| 643 |
+
return super().load_state_dict(state_dict, strict=False)
|
| 644 |
+
|
| 645 |
+
def set_predict_output_path(self, path: str) -> None:
|
| 646 |
+
self.predict_output_path = path
|
| 647 |
+
os.makedirs(self.predict_output_path, exist_ok=True)
|
| 648 |
+
|
| 649 |
+
self.attach_fader()
|
| 650 |
+
|
| 651 |
+
def attach_fader(self, force_reattach=False) -> None:
|
| 652 |
+
if self.fader is None or force_reattach:
|
| 653 |
+
self.fader = parse_fader_config(self.fader_config)
|
| 654 |
+
self.fader.to(self.device)
|
| 655 |
+
|
| 656 |
+
def log_dict_with_prefix(
|
| 657 |
+
self,
|
| 658 |
+
dict_: Dict[str, torch.Tensor],
|
| 659 |
+
prefix: str,
|
| 660 |
+
batch_size: Optional[int] = None,
|
| 661 |
+
**kwargs: Any,
|
| 662 |
+
) -> None:
|
| 663 |
+
self.log_dict(
|
| 664 |
+
{f"{prefix}/{k}": v for k, v in dict_.items()},
|
| 665 |
+
batch_size=batch_size,
|
| 666 |
+
logger=True,
|
| 667 |
+
sync_dist=True,
|
| 668 |
+
**kwargs,
|
| 669 |
+
)
|
mvsepless/models/bandit/core/data/__init__.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
| 1 |
-
from .dnr.datamodule import DivideAndRemasterDataModule
|
| 2 |
-
from .musdb.datamodule import MUSDB18DataModule
|
|
|
|
| 1 |
+
from .dnr.datamodule import DivideAndRemasterDataModule
|
| 2 |
+
from .musdb.datamodule import MUSDB18DataModule
|
mvsepless/models/bandit/core/data/_types.py
CHANGED
|
@@ -1,17 +1,17 @@
|
|
| 1 |
-
from typing import Dict, Sequence, TypedDict
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
|
| 5 |
-
AudioDict = Dict[str, torch.Tensor]
|
| 6 |
-
|
| 7 |
-
DataDict = TypedDict("DataDict", {"audio": AudioDict, "track": str})
|
| 8 |
-
|
| 9 |
-
BatchedDataDict = TypedDict(
|
| 10 |
-
"BatchedDataDict", {"audio": AudioDict, "track": Sequence[str]}
|
| 11 |
-
)
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
class DataDictWithLanguage(TypedDict):
|
| 15 |
-
audio: AudioDict
|
| 16 |
-
track: str
|
| 17 |
-
language: str
|
|
|
|
| 1 |
+
from typing import Dict, Sequence, TypedDict
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
AudioDict = Dict[str, torch.Tensor]
|
| 6 |
+
|
| 7 |
+
DataDict = TypedDict("DataDict", {"audio": AudioDict, "track": str})
|
| 8 |
+
|
| 9 |
+
BatchedDataDict = TypedDict(
|
| 10 |
+
"BatchedDataDict", {"audio": AudioDict, "track": Sequence[str]}
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class DataDictWithLanguage(TypedDict):
|
| 15 |
+
audio: AudioDict
|
| 16 |
+
track: str
|
| 17 |
+
language: str
|
mvsepless/models/bandit/core/data/augmentation.py
CHANGED
|
@@ -1,102 +1,102 @@
|
|
| 1 |
-
from abc import ABC
|
| 2 |
-
from typing import Any, Dict, Union
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
import torch_audiomentations as tam
|
| 6 |
-
from torch import nn
|
| 7 |
-
|
| 8 |
-
from ._types import BatchedDataDict, DataDict
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class BaseAugmentor(nn.Module, ABC):
|
| 12 |
-
def forward(
|
| 13 |
-
self, item: Union[DataDict, BatchedDataDict]
|
| 14 |
-
) -> Union[DataDict, BatchedDataDict]:
|
| 15 |
-
raise NotImplementedError
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
class StemAugmentor(BaseAugmentor):
|
| 19 |
-
def __init__(
|
| 20 |
-
self,
|
| 21 |
-
audiomentations: Dict[str, Dict[str, Any]],
|
| 22 |
-
fix_clipping: bool = True,
|
| 23 |
-
scaler_margin: float = 0.5,
|
| 24 |
-
apply_both_default_and_common: bool = False,
|
| 25 |
-
) -> None:
|
| 26 |
-
super().__init__()
|
| 27 |
-
|
| 28 |
-
augmentations = {}
|
| 29 |
-
|
| 30 |
-
self.has_default = "[default]" in audiomentations
|
| 31 |
-
self.has_common = "[common]" in audiomentations
|
| 32 |
-
self.apply_both_default_and_common = apply_both_default_and_common
|
| 33 |
-
|
| 34 |
-
for stem in audiomentations:
|
| 35 |
-
if audiomentations[stem]["name"] == "Compose":
|
| 36 |
-
augmentations[stem] = getattr(tam, audiomentations[stem]["name"])(
|
| 37 |
-
[
|
| 38 |
-
getattr(tam, aug["name"])(**aug["kwargs"])
|
| 39 |
-
for aug in audiomentations[stem]["kwargs"]["transforms"]
|
| 40 |
-
],
|
| 41 |
-
**audiomentations[stem]["kwargs"]["kwargs"],
|
| 42 |
-
)
|
| 43 |
-
else:
|
| 44 |
-
augmentations[stem] = getattr(tam, audiomentations[stem]["name"])(
|
| 45 |
-
**audiomentations[stem]["kwargs"]
|
| 46 |
-
)
|
| 47 |
-
|
| 48 |
-
self.augmentations = nn.ModuleDict(augmentations)
|
| 49 |
-
self.fix_clipping = fix_clipping
|
| 50 |
-
self.scaler_margin = scaler_margin
|
| 51 |
-
|
| 52 |
-
def check_and_fix_clipping(
|
| 53 |
-
self, item: Union[DataDict, BatchedDataDict]
|
| 54 |
-
) -> Union[DataDict, BatchedDataDict]:
|
| 55 |
-
max_abs = []
|
| 56 |
-
|
| 57 |
-
for stem in item["audio"]:
|
| 58 |
-
max_abs.append(item["audio"][stem].abs().max().item())
|
| 59 |
-
|
| 60 |
-
if max(max_abs) > 1.0:
|
| 61 |
-
scaler = 1.0 / (
|
| 62 |
-
max(max_abs)
|
| 63 |
-
+ torch.rand((1,), device=item["audio"]["mixture"].device)
|
| 64 |
-
* self.scaler_margin
|
| 65 |
-
)
|
| 66 |
-
|
| 67 |
-
for stem in item["audio"]:
|
| 68 |
-
item["audio"][stem] *= scaler
|
| 69 |
-
|
| 70 |
-
return item
|
| 71 |
-
|
| 72 |
-
def forward(
|
| 73 |
-
self, item: Union[DataDict, BatchedDataDict]
|
| 74 |
-
) -> Union[DataDict, BatchedDataDict]:
|
| 75 |
-
|
| 76 |
-
for stem in item["audio"]:
|
| 77 |
-
if stem == "mixture":
|
| 78 |
-
continue
|
| 79 |
-
|
| 80 |
-
if self.has_common:
|
| 81 |
-
item["audio"][stem] = self.augmentations["[common]"](
|
| 82 |
-
item["audio"][stem]
|
| 83 |
-
).samples
|
| 84 |
-
|
| 85 |
-
if stem in self.augmentations:
|
| 86 |
-
item["audio"][stem] = self.augmentations[stem](
|
| 87 |
-
item["audio"][stem]
|
| 88 |
-
).samples
|
| 89 |
-
elif self.has_default:
|
| 90 |
-
if not self.has_common or self.apply_both_default_and_common:
|
| 91 |
-
item["audio"][stem] = self.augmentations["[default]"](
|
| 92 |
-
item["audio"][stem]
|
| 93 |
-
).samples
|
| 94 |
-
|
| 95 |
-
item["audio"]["mixture"] = sum(
|
| 96 |
-
[item["audio"][stem] for stem in item["audio"] if stem != "mixture"]
|
| 97 |
-
) # type: ignore[call-overload, assignment]
|
| 98 |
-
|
| 99 |
-
if self.fix_clipping:
|
| 100 |
-
item = self.check_and_fix_clipping(item)
|
| 101 |
-
|
| 102 |
-
return item
|
|
|
|
| 1 |
+
from abc import ABC
|
| 2 |
+
from typing import Any, Dict, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch_audiomentations as tam
|
| 6 |
+
from torch import nn
|
| 7 |
+
|
| 8 |
+
from ._types import BatchedDataDict, DataDict
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class BaseAugmentor(nn.Module, ABC):
|
| 12 |
+
def forward(
|
| 13 |
+
self, item: Union[DataDict, BatchedDataDict]
|
| 14 |
+
) -> Union[DataDict, BatchedDataDict]:
|
| 15 |
+
raise NotImplementedError
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class StemAugmentor(BaseAugmentor):
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
audiomentations: Dict[str, Dict[str, Any]],
|
| 22 |
+
fix_clipping: bool = True,
|
| 23 |
+
scaler_margin: float = 0.5,
|
| 24 |
+
apply_both_default_and_common: bool = False,
|
| 25 |
+
) -> None:
|
| 26 |
+
super().__init__()
|
| 27 |
+
|
| 28 |
+
augmentations = {}
|
| 29 |
+
|
| 30 |
+
self.has_default = "[default]" in audiomentations
|
| 31 |
+
self.has_common = "[common]" in audiomentations
|
| 32 |
+
self.apply_both_default_and_common = apply_both_default_and_common
|
| 33 |
+
|
| 34 |
+
for stem in audiomentations:
|
| 35 |
+
if audiomentations[stem]["name"] == "Compose":
|
| 36 |
+
augmentations[stem] = getattr(tam, audiomentations[stem]["name"])(
|
| 37 |
+
[
|
| 38 |
+
getattr(tam, aug["name"])(**aug["kwargs"])
|
| 39 |
+
for aug in audiomentations[stem]["kwargs"]["transforms"]
|
| 40 |
+
],
|
| 41 |
+
**audiomentations[stem]["kwargs"]["kwargs"],
|
| 42 |
+
)
|
| 43 |
+
else:
|
| 44 |
+
augmentations[stem] = getattr(tam, audiomentations[stem]["name"])(
|
| 45 |
+
**audiomentations[stem]["kwargs"]
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
self.augmentations = nn.ModuleDict(augmentations)
|
| 49 |
+
self.fix_clipping = fix_clipping
|
| 50 |
+
self.scaler_margin = scaler_margin
|
| 51 |
+
|
| 52 |
+
def check_and_fix_clipping(
|
| 53 |
+
self, item: Union[DataDict, BatchedDataDict]
|
| 54 |
+
) -> Union[DataDict, BatchedDataDict]:
|
| 55 |
+
max_abs = []
|
| 56 |
+
|
| 57 |
+
for stem in item["audio"]:
|
| 58 |
+
max_abs.append(item["audio"][stem].abs().max().item())
|
| 59 |
+
|
| 60 |
+
if max(max_abs) > 1.0:
|
| 61 |
+
scaler = 1.0 / (
|
| 62 |
+
max(max_abs)
|
| 63 |
+
+ torch.rand((1,), device=item["audio"]["mixture"].device)
|
| 64 |
+
* self.scaler_margin
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
for stem in item["audio"]:
|
| 68 |
+
item["audio"][stem] *= scaler
|
| 69 |
+
|
| 70 |
+
return item
|
| 71 |
+
|
| 72 |
+
def forward(
|
| 73 |
+
self, item: Union[DataDict, BatchedDataDict]
|
| 74 |
+
) -> Union[DataDict, BatchedDataDict]:
|
| 75 |
+
|
| 76 |
+
for stem in item["audio"]:
|
| 77 |
+
if stem == "mixture":
|
| 78 |
+
continue
|
| 79 |
+
|
| 80 |
+
if self.has_common:
|
| 81 |
+
item["audio"][stem] = self.augmentations["[common]"](
|
| 82 |
+
item["audio"][stem]
|
| 83 |
+
).samples
|
| 84 |
+
|
| 85 |
+
if stem in self.augmentations:
|
| 86 |
+
item["audio"][stem] = self.augmentations[stem](
|
| 87 |
+
item["audio"][stem]
|
| 88 |
+
).samples
|
| 89 |
+
elif self.has_default:
|
| 90 |
+
if not self.has_common or self.apply_both_default_and_common:
|
| 91 |
+
item["audio"][stem] = self.augmentations["[default]"](
|
| 92 |
+
item["audio"][stem]
|
| 93 |
+
).samples
|
| 94 |
+
|
| 95 |
+
item["audio"]["mixture"] = sum(
|
| 96 |
+
[item["audio"][stem] for stem in item["audio"] if stem != "mixture"]
|
| 97 |
+
) # type: ignore[call-overload, assignment]
|
| 98 |
+
|
| 99 |
+
if self.fix_clipping:
|
| 100 |
+
item = self.check_and_fix_clipping(item)
|
| 101 |
+
|
| 102 |
+
return item
|
mvsepless/models/bandit/core/data/augmented.py
CHANGED
|
@@ -1,34 +1,34 @@
|
|
| 1 |
-
import warnings
|
| 2 |
-
from typing import Dict, Optional, Union
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
from torch import nn
|
| 6 |
-
from torch.utils import data
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
class AugmentedDataset(data.Dataset):
|
| 10 |
-
def __init__(
|
| 11 |
-
self,
|
| 12 |
-
dataset: data.Dataset,
|
| 13 |
-
augmentation: nn.Module = nn.Identity(),
|
| 14 |
-
target_length: Optional[int] = None,
|
| 15 |
-
) -> None:
|
| 16 |
-
warnings.warn(
|
| 17 |
-
"This class is no longer used. Attach augmentation to "
|
| 18 |
-
"the LightningSystem instead.",
|
| 19 |
-
DeprecationWarning,
|
| 20 |
-
)
|
| 21 |
-
|
| 22 |
-
self.dataset = dataset
|
| 23 |
-
self.augmentation = augmentation
|
| 24 |
-
|
| 25 |
-
self.ds_length: int = len(dataset) # type: ignore[arg-type]
|
| 26 |
-
self.length = target_length if target_length is not None else self.ds_length
|
| 27 |
-
|
| 28 |
-
def __getitem__(self, index: int) -> Dict[str, Union[str, Dict[str, torch.Tensor]]]:
|
| 29 |
-
item = self.dataset[index % self.ds_length]
|
| 30 |
-
item = self.augmentation(item)
|
| 31 |
-
return item
|
| 32 |
-
|
| 33 |
-
def __len__(self) -> int:
|
| 34 |
-
return self.length
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
from typing import Dict, Optional, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.utils import data
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class AugmentedDataset(data.Dataset):
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
dataset: data.Dataset,
|
| 13 |
+
augmentation: nn.Module = nn.Identity(),
|
| 14 |
+
target_length: Optional[int] = None,
|
| 15 |
+
) -> None:
|
| 16 |
+
warnings.warn(
|
| 17 |
+
"This class is no longer used. Attach augmentation to "
|
| 18 |
+
"the LightningSystem instead.",
|
| 19 |
+
DeprecationWarning,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
self.dataset = dataset
|
| 23 |
+
self.augmentation = augmentation
|
| 24 |
+
|
| 25 |
+
self.ds_length: int = len(dataset) # type: ignore[arg-type]
|
| 26 |
+
self.length = target_length if target_length is not None else self.ds_length
|
| 27 |
+
|
| 28 |
+
def __getitem__(self, index: int) -> Dict[str, Union[str, Dict[str, torch.Tensor]]]:
|
| 29 |
+
item = self.dataset[index % self.ds_length]
|
| 30 |
+
item = self.augmentation(item)
|
| 31 |
+
return item
|
| 32 |
+
|
| 33 |
+
def __len__(self) -> int:
|
| 34 |
+
return self.length
|
mvsepless/models/bandit/core/data/base.py
CHANGED
|
@@ -1,60 +1,60 @@
|
|
| 1 |
-
import os
|
| 2 |
-
from abc import ABC, abstractmethod
|
| 3 |
-
from typing import Any, Dict, List, Optional
|
| 4 |
-
|
| 5 |
-
import numpy as np
|
| 6 |
-
import pedalboard as pb
|
| 7 |
-
import torch
|
| 8 |
-
import torchaudio as ta
|
| 9 |
-
from torch.utils import data
|
| 10 |
-
|
| 11 |
-
from ._types import AudioDict, DataDict
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
class BaseSourceSeparationDataset(data.Dataset, ABC):
|
| 15 |
-
def __init__(
|
| 16 |
-
self,
|
| 17 |
-
split: str,
|
| 18 |
-
stems: List[str],
|
| 19 |
-
files: List[str],
|
| 20 |
-
data_path: str,
|
| 21 |
-
fs: int,
|
| 22 |
-
npy_memmap: bool,
|
| 23 |
-
recompute_mixture: bool,
|
| 24 |
-
):
|
| 25 |
-
self.split = split
|
| 26 |
-
self.stems = stems
|
| 27 |
-
self.stems_no_mixture = [s for s in stems if s != "mixture"]
|
| 28 |
-
self.files = files
|
| 29 |
-
self.data_path = data_path
|
| 30 |
-
self.fs = fs
|
| 31 |
-
self.npy_memmap = npy_memmap
|
| 32 |
-
self.recompute_mixture = recompute_mixture
|
| 33 |
-
|
| 34 |
-
@abstractmethod
|
| 35 |
-
def get_stem(self, *, stem: str, identifier: Dict[str, Any]) -> torch.Tensor:
|
| 36 |
-
raise NotImplementedError
|
| 37 |
-
|
| 38 |
-
def _get_audio(self, stems, identifier: Dict[str, Any]):
|
| 39 |
-
audio = {}
|
| 40 |
-
for stem in stems:
|
| 41 |
-
audio[stem] = self.get_stem(stem=stem, identifier=identifier)
|
| 42 |
-
|
| 43 |
-
return audio
|
| 44 |
-
|
| 45 |
-
def get_audio(self, identifier: Dict[str, Any]) -> AudioDict:
|
| 46 |
-
|
| 47 |
-
if self.recompute_mixture:
|
| 48 |
-
audio = self._get_audio(self.stems_no_mixture, identifier=identifier)
|
| 49 |
-
audio["mixture"] = self.compute_mixture(audio)
|
| 50 |
-
return audio
|
| 51 |
-
else:
|
| 52 |
-
return self._get_audio(self.stems, identifier=identifier)
|
| 53 |
-
|
| 54 |
-
@abstractmethod
|
| 55 |
-
def get_identifier(self, index: int) -> Dict[str, Any]:
|
| 56 |
-
pass
|
| 57 |
-
|
| 58 |
-
def compute_mixture(self, audio: AudioDict) -> torch.Tensor:
|
| 59 |
-
|
| 60 |
-
return sum(audio[stem] for stem in audio if stem != "mixture")
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from abc import ABC, abstractmethod
|
| 3 |
+
from typing import Any, Dict, List, Optional
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pedalboard as pb
|
| 7 |
+
import torch
|
| 8 |
+
import torchaudio as ta
|
| 9 |
+
from torch.utils import data
|
| 10 |
+
|
| 11 |
+
from ._types import AudioDict, DataDict
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class BaseSourceSeparationDataset(data.Dataset, ABC):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
split: str,
|
| 18 |
+
stems: List[str],
|
| 19 |
+
files: List[str],
|
| 20 |
+
data_path: str,
|
| 21 |
+
fs: int,
|
| 22 |
+
npy_memmap: bool,
|
| 23 |
+
recompute_mixture: bool,
|
| 24 |
+
):
|
| 25 |
+
self.split = split
|
| 26 |
+
self.stems = stems
|
| 27 |
+
self.stems_no_mixture = [s for s in stems if s != "mixture"]
|
| 28 |
+
self.files = files
|
| 29 |
+
self.data_path = data_path
|
| 30 |
+
self.fs = fs
|
| 31 |
+
self.npy_memmap = npy_memmap
|
| 32 |
+
self.recompute_mixture = recompute_mixture
|
| 33 |
+
|
| 34 |
+
@abstractmethod
|
| 35 |
+
def get_stem(self, *, stem: str, identifier: Dict[str, Any]) -> torch.Tensor:
|
| 36 |
+
raise NotImplementedError
|
| 37 |
+
|
| 38 |
+
def _get_audio(self, stems, identifier: Dict[str, Any]):
|
| 39 |
+
audio = {}
|
| 40 |
+
for stem in stems:
|
| 41 |
+
audio[stem] = self.get_stem(stem=stem, identifier=identifier)
|
| 42 |
+
|
| 43 |
+
return audio
|
| 44 |
+
|
| 45 |
+
def get_audio(self, identifier: Dict[str, Any]) -> AudioDict:
|
| 46 |
+
|
| 47 |
+
if self.recompute_mixture:
|
| 48 |
+
audio = self._get_audio(self.stems_no_mixture, identifier=identifier)
|
| 49 |
+
audio["mixture"] = self.compute_mixture(audio)
|
| 50 |
+
return audio
|
| 51 |
+
else:
|
| 52 |
+
return self._get_audio(self.stems, identifier=identifier)
|
| 53 |
+
|
| 54 |
+
@abstractmethod
|
| 55 |
+
def get_identifier(self, index: int) -> Dict[str, Any]:
|
| 56 |
+
pass
|
| 57 |
+
|
| 58 |
+
def compute_mixture(self, audio: AudioDict) -> torch.Tensor:
|
| 59 |
+
|
| 60 |
+
return sum(audio[stem] for stem in audio if stem != "mixture")
|
mvsepless/models/bandit/core/data/dnr/datamodule.py
CHANGED
|
@@ -1,64 +1,64 @@
|
|
| 1 |
-
import os
|
| 2 |
-
from typing import Mapping, Optional
|
| 3 |
-
|
| 4 |
-
import pytorch_lightning as pl
|
| 5 |
-
|
| 6 |
-
from .dataset import (
|
| 7 |
-
DivideAndRemasterDataset,
|
| 8 |
-
DivideAndRemasterDeterministicChunkDataset,
|
| 9 |
-
DivideAndRemasterRandomChunkDataset,
|
| 10 |
-
DivideAndRemasterRandomChunkDatasetWithSpeechReverb,
|
| 11 |
-
)
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def DivideAndRemasterDataModule(
|
| 15 |
-
data_root: str = "$DATA_ROOT/DnR/v2",
|
| 16 |
-
batch_size: int = 2,
|
| 17 |
-
num_workers: int = 8,
|
| 18 |
-
train_kwargs: Optional[Mapping] = None,
|
| 19 |
-
val_kwargs: Optional[Mapping] = None,
|
| 20 |
-
test_kwargs: Optional[Mapping] = None,
|
| 21 |
-
datamodule_kwargs: Optional[Mapping] = None,
|
| 22 |
-
use_speech_reverb: bool = False,
|
| 23 |
-
) -> pl.LightningDataModule:
|
| 24 |
-
if train_kwargs is None:
|
| 25 |
-
train_kwargs = {}
|
| 26 |
-
|
| 27 |
-
if val_kwargs is None:
|
| 28 |
-
val_kwargs = {}
|
| 29 |
-
|
| 30 |
-
if test_kwargs is None:
|
| 31 |
-
test_kwargs = {}
|
| 32 |
-
|
| 33 |
-
if datamodule_kwargs is None:
|
| 34 |
-
datamodule_kwargs = {}
|
| 35 |
-
|
| 36 |
-
if num_workers is None:
|
| 37 |
-
num_workers = os.cpu_count()
|
| 38 |
-
|
| 39 |
-
if num_workers is None:
|
| 40 |
-
num_workers = 32
|
| 41 |
-
|
| 42 |
-
num_workers = min(num_workers, 64)
|
| 43 |
-
|
| 44 |
-
if use_speech_reverb:
|
| 45 |
-
train_cls = DivideAndRemasterRandomChunkDatasetWithSpeechReverb
|
| 46 |
-
else:
|
| 47 |
-
train_cls = DivideAndRemasterRandomChunkDataset
|
| 48 |
-
|
| 49 |
-
train_dataset = train_cls(data_root, "train", **train_kwargs)
|
| 50 |
-
|
| 51 |
-
datamodule = pl.LightningDataModule.from_datasets(
|
| 52 |
-
train_dataset=train_dataset,
|
| 53 |
-
val_dataset=DivideAndRemasterDeterministicChunkDataset(
|
| 54 |
-
data_root, "val", **val_kwargs
|
| 55 |
-
),
|
| 56 |
-
test_dataset=DivideAndRemasterDataset(data_root, "test", **test_kwargs),
|
| 57 |
-
batch_size=batch_size,
|
| 58 |
-
num_workers=num_workers,
|
| 59 |
-
**datamodule_kwargs,
|
| 60 |
-
)
|
| 61 |
-
|
| 62 |
-
datamodule.predict_dataloader = datamodule.test_dataloader # type: ignore[method-assign]
|
| 63 |
-
|
| 64 |
-
return datamodule
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Mapping, Optional
|
| 3 |
+
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
|
| 6 |
+
from .dataset import (
|
| 7 |
+
DivideAndRemasterDataset,
|
| 8 |
+
DivideAndRemasterDeterministicChunkDataset,
|
| 9 |
+
DivideAndRemasterRandomChunkDataset,
|
| 10 |
+
DivideAndRemasterRandomChunkDatasetWithSpeechReverb,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def DivideAndRemasterDataModule(
|
| 15 |
+
data_root: str = "$DATA_ROOT/DnR/v2",
|
| 16 |
+
batch_size: int = 2,
|
| 17 |
+
num_workers: int = 8,
|
| 18 |
+
train_kwargs: Optional[Mapping] = None,
|
| 19 |
+
val_kwargs: Optional[Mapping] = None,
|
| 20 |
+
test_kwargs: Optional[Mapping] = None,
|
| 21 |
+
datamodule_kwargs: Optional[Mapping] = None,
|
| 22 |
+
use_speech_reverb: bool = False,
|
| 23 |
+
) -> pl.LightningDataModule:
|
| 24 |
+
if train_kwargs is None:
|
| 25 |
+
train_kwargs = {}
|
| 26 |
+
|
| 27 |
+
if val_kwargs is None:
|
| 28 |
+
val_kwargs = {}
|
| 29 |
+
|
| 30 |
+
if test_kwargs is None:
|
| 31 |
+
test_kwargs = {}
|
| 32 |
+
|
| 33 |
+
if datamodule_kwargs is None:
|
| 34 |
+
datamodule_kwargs = {}
|
| 35 |
+
|
| 36 |
+
if num_workers is None:
|
| 37 |
+
num_workers = os.cpu_count()
|
| 38 |
+
|
| 39 |
+
if num_workers is None:
|
| 40 |
+
num_workers = 32
|
| 41 |
+
|
| 42 |
+
num_workers = min(num_workers, 64)
|
| 43 |
+
|
| 44 |
+
if use_speech_reverb:
|
| 45 |
+
train_cls = DivideAndRemasterRandomChunkDatasetWithSpeechReverb
|
| 46 |
+
else:
|
| 47 |
+
train_cls = DivideAndRemasterRandomChunkDataset
|
| 48 |
+
|
| 49 |
+
train_dataset = train_cls(data_root, "train", **train_kwargs)
|
| 50 |
+
|
| 51 |
+
datamodule = pl.LightningDataModule.from_datasets(
|
| 52 |
+
train_dataset=train_dataset,
|
| 53 |
+
val_dataset=DivideAndRemasterDeterministicChunkDataset(
|
| 54 |
+
data_root, "val", **val_kwargs
|
| 55 |
+
),
|
| 56 |
+
test_dataset=DivideAndRemasterDataset(data_root, "test", **test_kwargs),
|
| 57 |
+
batch_size=batch_size,
|
| 58 |
+
num_workers=num_workers,
|
| 59 |
+
**datamodule_kwargs,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
datamodule.predict_dataloader = datamodule.test_dataloader # type: ignore[method-assign]
|
| 63 |
+
|
| 64 |
+
return datamodule
|
mvsepless/models/bandit/core/data/dnr/dataset.py
CHANGED
|
@@ -1,360 +1,360 @@
|
|
| 1 |
-
import os
|
| 2 |
-
from abc import ABC
|
| 3 |
-
from typing import Any, Dict, List, Optional
|
| 4 |
-
|
| 5 |
-
import numpy as np
|
| 6 |
-
import pedalboard as pb
|
| 7 |
-
import torch
|
| 8 |
-
import torchaudio as ta
|
| 9 |
-
from torch.utils import data
|
| 10 |
-
|
| 11 |
-
from .._types import AudioDict, DataDict
|
| 12 |
-
from ..base import BaseSourceSeparationDataset
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
class DivideAndRemasterBaseDataset(BaseSourceSeparationDataset, ABC):
|
| 16 |
-
ALLOWED_STEMS = ["mixture", "speech", "music", "effects", "mne"]
|
| 17 |
-
STEM_NAME_MAP = {
|
| 18 |
-
"mixture": "mix",
|
| 19 |
-
"speech": "speech",
|
| 20 |
-
"music": "music",
|
| 21 |
-
"effects": "sfx",
|
| 22 |
-
}
|
| 23 |
-
SPLIT_NAME_MAP = {"train": "tr", "val": "cv", "test": "tt"}
|
| 24 |
-
|
| 25 |
-
FULL_TRACK_LENGTH_SECOND = 60
|
| 26 |
-
FULL_TRACK_LENGTH_SAMPLES = FULL_TRACK_LENGTH_SECOND * 44100
|
| 27 |
-
|
| 28 |
-
def __init__(
|
| 29 |
-
self,
|
| 30 |
-
split: str,
|
| 31 |
-
stems: List[str],
|
| 32 |
-
files: List[str],
|
| 33 |
-
data_path: str,
|
| 34 |
-
fs: int = 44100,
|
| 35 |
-
npy_memmap: bool = True,
|
| 36 |
-
recompute_mixture: bool = False,
|
| 37 |
-
) -> None:
|
| 38 |
-
super().__init__(
|
| 39 |
-
split=split,
|
| 40 |
-
stems=stems,
|
| 41 |
-
files=files,
|
| 42 |
-
data_path=data_path,
|
| 43 |
-
fs=fs,
|
| 44 |
-
npy_memmap=npy_memmap,
|
| 45 |
-
recompute_mixture=recompute_mixture,
|
| 46 |
-
)
|
| 47 |
-
|
| 48 |
-
def get_stem(self, *, stem: str, identifier: Dict[str, Any]) -> torch.Tensor:
|
| 49 |
-
|
| 50 |
-
if stem == "mne":
|
| 51 |
-
return self.get_stem(stem="music", identifier=identifier) + self.get_stem(
|
| 52 |
-
stem="effects", identifier=identifier
|
| 53 |
-
)
|
| 54 |
-
|
| 55 |
-
track = identifier["track"]
|
| 56 |
-
path = os.path.join(self.data_path, track)
|
| 57 |
-
|
| 58 |
-
if self.npy_memmap:
|
| 59 |
-
audio = np.load(
|
| 60 |
-
os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.npy"), mmap_mode="r"
|
| 61 |
-
)
|
| 62 |
-
else:
|
| 63 |
-
audio, _ = ta.load(os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.wav"))
|
| 64 |
-
|
| 65 |
-
return audio
|
| 66 |
-
|
| 67 |
-
def get_identifier(self, index):
|
| 68 |
-
return dict(track=self.files[index])
|
| 69 |
-
|
| 70 |
-
def __getitem__(self, index: int) -> DataDict:
|
| 71 |
-
identifier = self.get_identifier(index)
|
| 72 |
-
audio = self.get_audio(identifier)
|
| 73 |
-
|
| 74 |
-
return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
class DivideAndRemasterDataset(DivideAndRemasterBaseDataset):
|
| 78 |
-
def __init__(
|
| 79 |
-
self,
|
| 80 |
-
data_root: str,
|
| 81 |
-
split: str,
|
| 82 |
-
stems: Optional[List[str]] = None,
|
| 83 |
-
fs: int = 44100,
|
| 84 |
-
npy_memmap: bool = True,
|
| 85 |
-
) -> None:
|
| 86 |
-
|
| 87 |
-
if stems is None:
|
| 88 |
-
stems = self.ALLOWED_STEMS
|
| 89 |
-
self.stems = stems
|
| 90 |
-
|
| 91 |
-
data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
|
| 92 |
-
|
| 93 |
-
files = sorted(os.listdir(data_path))
|
| 94 |
-
files = [
|
| 95 |
-
f
|
| 96 |
-
for f in files
|
| 97 |
-
if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f))
|
| 98 |
-
]
|
| 99 |
-
if split == "train":
|
| 100 |
-
assert len(files) == 3406, len(files)
|
| 101 |
-
elif split == "val":
|
| 102 |
-
assert len(files) == 487, len(files)
|
| 103 |
-
elif split == "test":
|
| 104 |
-
assert len(files) == 973, len(files)
|
| 105 |
-
|
| 106 |
-
self.n_tracks = len(files)
|
| 107 |
-
|
| 108 |
-
super().__init__(
|
| 109 |
-
data_path=data_path,
|
| 110 |
-
split=split,
|
| 111 |
-
stems=stems,
|
| 112 |
-
files=files,
|
| 113 |
-
fs=fs,
|
| 114 |
-
npy_memmap=npy_memmap,
|
| 115 |
-
)
|
| 116 |
-
|
| 117 |
-
def __len__(self) -> int:
|
| 118 |
-
return self.n_tracks
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
class DivideAndRemasterRandomChunkDataset(DivideAndRemasterBaseDataset):
|
| 122 |
-
def __init__(
|
| 123 |
-
self,
|
| 124 |
-
data_root: str,
|
| 125 |
-
split: str,
|
| 126 |
-
target_length: int,
|
| 127 |
-
chunk_size_second: float,
|
| 128 |
-
stems: Optional[List[str]] = None,
|
| 129 |
-
fs: int = 44100,
|
| 130 |
-
npy_memmap: bool = True,
|
| 131 |
-
) -> None:
|
| 132 |
-
|
| 133 |
-
if stems is None:
|
| 134 |
-
stems = self.ALLOWED_STEMS
|
| 135 |
-
self.stems = stems
|
| 136 |
-
|
| 137 |
-
data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
|
| 138 |
-
|
| 139 |
-
files = sorted(os.listdir(data_path))
|
| 140 |
-
files = [
|
| 141 |
-
f
|
| 142 |
-
for f in files
|
| 143 |
-
if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f))
|
| 144 |
-
]
|
| 145 |
-
|
| 146 |
-
if split == "train":
|
| 147 |
-
assert len(files) == 3406, len(files)
|
| 148 |
-
elif split == "val":
|
| 149 |
-
assert len(files) == 487, len(files)
|
| 150 |
-
elif split == "test":
|
| 151 |
-
assert len(files) == 973, len(files)
|
| 152 |
-
|
| 153 |
-
self.n_tracks = len(files)
|
| 154 |
-
|
| 155 |
-
self.target_length = target_length
|
| 156 |
-
self.chunk_size = int(chunk_size_second * fs)
|
| 157 |
-
|
| 158 |
-
super().__init__(
|
| 159 |
-
data_path=data_path,
|
| 160 |
-
split=split,
|
| 161 |
-
stems=stems,
|
| 162 |
-
files=files,
|
| 163 |
-
fs=fs,
|
| 164 |
-
npy_memmap=npy_memmap,
|
| 165 |
-
)
|
| 166 |
-
|
| 167 |
-
def __len__(self) -> int:
|
| 168 |
-
return self.target_length
|
| 169 |
-
|
| 170 |
-
def get_identifier(self, index):
|
| 171 |
-
return super().get_identifier(index % self.n_tracks)
|
| 172 |
-
|
| 173 |
-
def get_stem(
|
| 174 |
-
self,
|
| 175 |
-
*,
|
| 176 |
-
stem: str,
|
| 177 |
-
identifier: Dict[str, Any],
|
| 178 |
-
chunk_here: bool = False,
|
| 179 |
-
) -> torch.Tensor:
|
| 180 |
-
|
| 181 |
-
stem = super().get_stem(stem=stem, identifier=identifier)
|
| 182 |
-
|
| 183 |
-
if chunk_here:
|
| 184 |
-
start = np.random.randint(
|
| 185 |
-
0, self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size
|
| 186 |
-
)
|
| 187 |
-
end = start + self.chunk_size
|
| 188 |
-
|
| 189 |
-
stem = stem[:, start:end]
|
| 190 |
-
|
| 191 |
-
return stem
|
| 192 |
-
|
| 193 |
-
def __getitem__(self, index: int) -> DataDict:
|
| 194 |
-
identifier = self.get_identifier(index)
|
| 195 |
-
audio = self.get_audio(identifier)
|
| 196 |
-
|
| 197 |
-
start = np.random.randint(0, self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size)
|
| 198 |
-
end = start + self.chunk_size
|
| 199 |
-
|
| 200 |
-
audio = {k: v[:, start:end] for k, v in audio.items()}
|
| 201 |
-
|
| 202 |
-
return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
class DivideAndRemasterDeterministicChunkDataset(DivideAndRemasterBaseDataset):
|
| 206 |
-
def __init__(
|
| 207 |
-
self,
|
| 208 |
-
data_root: str,
|
| 209 |
-
split: str,
|
| 210 |
-
chunk_size_second: float,
|
| 211 |
-
hop_size_second: float,
|
| 212 |
-
stems: Optional[List[str]] = None,
|
| 213 |
-
fs: int = 44100,
|
| 214 |
-
npy_memmap: bool = True,
|
| 215 |
-
) -> None:
|
| 216 |
-
|
| 217 |
-
if stems is None:
|
| 218 |
-
stems = self.ALLOWED_STEMS
|
| 219 |
-
self.stems = stems
|
| 220 |
-
|
| 221 |
-
data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
|
| 222 |
-
|
| 223 |
-
files = sorted(os.listdir(data_path))
|
| 224 |
-
files = [
|
| 225 |
-
f
|
| 226 |
-
for f in files
|
| 227 |
-
if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f))
|
| 228 |
-
]
|
| 229 |
-
if split == "train":
|
| 230 |
-
assert len(files) == 3406, len(files)
|
| 231 |
-
elif split == "val":
|
| 232 |
-
assert len(files) == 487, len(files)
|
| 233 |
-
elif split == "test":
|
| 234 |
-
assert len(files) == 973, len(files)
|
| 235 |
-
|
| 236 |
-
self.n_tracks = len(files)
|
| 237 |
-
|
| 238 |
-
self.chunk_size = int(chunk_size_second * fs)
|
| 239 |
-
self.hop_size = int(hop_size_second * fs)
|
| 240 |
-
self.n_chunks_per_track = int(
|
| 241 |
-
(self.FULL_TRACK_LENGTH_SECOND - chunk_size_second) / hop_size_second
|
| 242 |
-
)
|
| 243 |
-
|
| 244 |
-
self.length = self.n_tracks * self.n_chunks_per_track
|
| 245 |
-
|
| 246 |
-
super().__init__(
|
| 247 |
-
data_path=data_path,
|
| 248 |
-
split=split,
|
| 249 |
-
stems=stems,
|
| 250 |
-
files=files,
|
| 251 |
-
fs=fs,
|
| 252 |
-
npy_memmap=npy_memmap,
|
| 253 |
-
)
|
| 254 |
-
|
| 255 |
-
def get_identifier(self, index):
|
| 256 |
-
return super().get_identifier(index % self.n_tracks)
|
| 257 |
-
|
| 258 |
-
def __len__(self) -> int:
|
| 259 |
-
return self.length
|
| 260 |
-
|
| 261 |
-
def __getitem__(self, item: int) -> DataDict:
|
| 262 |
-
|
| 263 |
-
index = item % self.n_tracks
|
| 264 |
-
chunk = item // self.n_tracks
|
| 265 |
-
|
| 266 |
-
data_ = super().__getitem__(index)
|
| 267 |
-
|
| 268 |
-
audio = data_["audio"]
|
| 269 |
-
|
| 270 |
-
start = chunk * self.hop_size
|
| 271 |
-
end = start + self.chunk_size
|
| 272 |
-
|
| 273 |
-
for stem in self.stems:
|
| 274 |
-
data_["audio"][stem] = audio[stem][:, start:end]
|
| 275 |
-
|
| 276 |
-
return data_
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
class DivideAndRemasterRandomChunkDatasetWithSpeechReverb(
|
| 280 |
-
DivideAndRemasterRandomChunkDataset
|
| 281 |
-
):
|
| 282 |
-
def __init__(
|
| 283 |
-
self,
|
| 284 |
-
data_root: str,
|
| 285 |
-
split: str,
|
| 286 |
-
target_length: int,
|
| 287 |
-
chunk_size_second: float,
|
| 288 |
-
stems: Optional[List[str]] = None,
|
| 289 |
-
fs: int = 44100,
|
| 290 |
-
npy_memmap: bool = True,
|
| 291 |
-
) -> None:
|
| 292 |
-
|
| 293 |
-
if stems is None:
|
| 294 |
-
stems = self.ALLOWED_STEMS
|
| 295 |
-
|
| 296 |
-
stems_no_mixture = [s for s in stems if s != "mixture"]
|
| 297 |
-
|
| 298 |
-
super().__init__(
|
| 299 |
-
data_root=data_root,
|
| 300 |
-
split=split,
|
| 301 |
-
target_length=target_length,
|
| 302 |
-
chunk_size_second=chunk_size_second,
|
| 303 |
-
stems=stems_no_mixture,
|
| 304 |
-
fs=fs,
|
| 305 |
-
npy_memmap=npy_memmap,
|
| 306 |
-
)
|
| 307 |
-
|
| 308 |
-
self.stems = stems
|
| 309 |
-
self.stems_no_mixture = stems_no_mixture
|
| 310 |
-
|
| 311 |
-
def __getitem__(self, index: int) -> DataDict:
|
| 312 |
-
|
| 313 |
-
data_ = super().__getitem__(index)
|
| 314 |
-
|
| 315 |
-
dry = data_["audio"]["speech"][:]
|
| 316 |
-
n_samples = dry.shape[-1]
|
| 317 |
-
|
| 318 |
-
wet_level = np.random.rand()
|
| 319 |
-
|
| 320 |
-
speech = pb.Reverb(
|
| 321 |
-
room_size=np.random.rand(),
|
| 322 |
-
damping=np.random.rand(),
|
| 323 |
-
wet_level=wet_level,
|
| 324 |
-
dry_level=(1 - wet_level),
|
| 325 |
-
width=np.random.rand(),
|
| 326 |
-
).process(dry, self.fs, buffer_size=8192 * 4)[..., :n_samples]
|
| 327 |
-
|
| 328 |
-
data_["audio"]["speech"] = speech
|
| 329 |
-
|
| 330 |
-
data_["audio"]["mixture"] = sum(
|
| 331 |
-
[data_["audio"][s] for s in self.stems_no_mixture]
|
| 332 |
-
)
|
| 333 |
-
|
| 334 |
-
return data_
|
| 335 |
-
|
| 336 |
-
def __len__(self) -> int:
|
| 337 |
-
return super().__len__()
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
if __name__ == "__main__":
|
| 341 |
-
|
| 342 |
-
from pprint import pprint
|
| 343 |
-
from tqdm.auto import tqdm
|
| 344 |
-
|
| 345 |
-
for split_ in ["train", "val", "test"]:
|
| 346 |
-
ds = DivideAndRemasterRandomChunkDatasetWithSpeechReverb(
|
| 347 |
-
data_root="$DATA_ROOT/DnR/v2np",
|
| 348 |
-
split=split_,
|
| 349 |
-
target_length=100,
|
| 350 |
-
chunk_size_second=6.0,
|
| 351 |
-
)
|
| 352 |
-
|
| 353 |
-
print(split_, len(ds))
|
| 354 |
-
|
| 355 |
-
for track_ in tqdm(ds): # type: ignore
|
| 356 |
-
pprint(track_)
|
| 357 |
-
track_["audio"] = {k: v.shape for k, v in track_["audio"].items()}
|
| 358 |
-
pprint(track_)
|
| 359 |
-
|
| 360 |
-
break
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from abc import ABC
|
| 3 |
+
from typing import Any, Dict, List, Optional
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pedalboard as pb
|
| 7 |
+
import torch
|
| 8 |
+
import torchaudio as ta
|
| 9 |
+
from torch.utils import data
|
| 10 |
+
|
| 11 |
+
from .._types import AudioDict, DataDict
|
| 12 |
+
from ..base import BaseSourceSeparationDataset
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class DivideAndRemasterBaseDataset(BaseSourceSeparationDataset, ABC):
|
| 16 |
+
ALLOWED_STEMS = ["mixture", "speech", "music", "effects", "mne"]
|
| 17 |
+
STEM_NAME_MAP = {
|
| 18 |
+
"mixture": "mix",
|
| 19 |
+
"speech": "speech",
|
| 20 |
+
"music": "music",
|
| 21 |
+
"effects": "sfx",
|
| 22 |
+
}
|
| 23 |
+
SPLIT_NAME_MAP = {"train": "tr", "val": "cv", "test": "tt"}
|
| 24 |
+
|
| 25 |
+
FULL_TRACK_LENGTH_SECOND = 60
|
| 26 |
+
FULL_TRACK_LENGTH_SAMPLES = FULL_TRACK_LENGTH_SECOND * 44100
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
split: str,
|
| 31 |
+
stems: List[str],
|
| 32 |
+
files: List[str],
|
| 33 |
+
data_path: str,
|
| 34 |
+
fs: int = 44100,
|
| 35 |
+
npy_memmap: bool = True,
|
| 36 |
+
recompute_mixture: bool = False,
|
| 37 |
+
) -> None:
|
| 38 |
+
super().__init__(
|
| 39 |
+
split=split,
|
| 40 |
+
stems=stems,
|
| 41 |
+
files=files,
|
| 42 |
+
data_path=data_path,
|
| 43 |
+
fs=fs,
|
| 44 |
+
npy_memmap=npy_memmap,
|
| 45 |
+
recompute_mixture=recompute_mixture,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
def get_stem(self, *, stem: str, identifier: Dict[str, Any]) -> torch.Tensor:
|
| 49 |
+
|
| 50 |
+
if stem == "mne":
|
| 51 |
+
return self.get_stem(stem="music", identifier=identifier) + self.get_stem(
|
| 52 |
+
stem="effects", identifier=identifier
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
track = identifier["track"]
|
| 56 |
+
path = os.path.join(self.data_path, track)
|
| 57 |
+
|
| 58 |
+
if self.npy_memmap:
|
| 59 |
+
audio = np.load(
|
| 60 |
+
os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.npy"), mmap_mode="r"
|
| 61 |
+
)
|
| 62 |
+
else:
|
| 63 |
+
audio, _ = ta.load(os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.wav"))
|
| 64 |
+
|
| 65 |
+
return audio
|
| 66 |
+
|
| 67 |
+
def get_identifier(self, index):
|
| 68 |
+
return dict(track=self.files[index])
|
| 69 |
+
|
| 70 |
+
def __getitem__(self, index: int) -> DataDict:
|
| 71 |
+
identifier = self.get_identifier(index)
|
| 72 |
+
audio = self.get_audio(identifier)
|
| 73 |
+
|
| 74 |
+
return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class DivideAndRemasterDataset(DivideAndRemasterBaseDataset):
|
| 78 |
+
def __init__(
|
| 79 |
+
self,
|
| 80 |
+
data_root: str,
|
| 81 |
+
split: str,
|
| 82 |
+
stems: Optional[List[str]] = None,
|
| 83 |
+
fs: int = 44100,
|
| 84 |
+
npy_memmap: bool = True,
|
| 85 |
+
) -> None:
|
| 86 |
+
|
| 87 |
+
if stems is None:
|
| 88 |
+
stems = self.ALLOWED_STEMS
|
| 89 |
+
self.stems = stems
|
| 90 |
+
|
| 91 |
+
data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
|
| 92 |
+
|
| 93 |
+
files = sorted(os.listdir(data_path))
|
| 94 |
+
files = [
|
| 95 |
+
f
|
| 96 |
+
for f in files
|
| 97 |
+
if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f))
|
| 98 |
+
]
|
| 99 |
+
if split == "train":
|
| 100 |
+
assert len(files) == 3406, len(files)
|
| 101 |
+
elif split == "val":
|
| 102 |
+
assert len(files) == 487, len(files)
|
| 103 |
+
elif split == "test":
|
| 104 |
+
assert len(files) == 973, len(files)
|
| 105 |
+
|
| 106 |
+
self.n_tracks = len(files)
|
| 107 |
+
|
| 108 |
+
super().__init__(
|
| 109 |
+
data_path=data_path,
|
| 110 |
+
split=split,
|
| 111 |
+
stems=stems,
|
| 112 |
+
files=files,
|
| 113 |
+
fs=fs,
|
| 114 |
+
npy_memmap=npy_memmap,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def __len__(self) -> int:
|
| 118 |
+
return self.n_tracks
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class DivideAndRemasterRandomChunkDataset(DivideAndRemasterBaseDataset):
|
| 122 |
+
def __init__(
|
| 123 |
+
self,
|
| 124 |
+
data_root: str,
|
| 125 |
+
split: str,
|
| 126 |
+
target_length: int,
|
| 127 |
+
chunk_size_second: float,
|
| 128 |
+
stems: Optional[List[str]] = None,
|
| 129 |
+
fs: int = 44100,
|
| 130 |
+
npy_memmap: bool = True,
|
| 131 |
+
) -> None:
|
| 132 |
+
|
| 133 |
+
if stems is None:
|
| 134 |
+
stems = self.ALLOWED_STEMS
|
| 135 |
+
self.stems = stems
|
| 136 |
+
|
| 137 |
+
data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
|
| 138 |
+
|
| 139 |
+
files = sorted(os.listdir(data_path))
|
| 140 |
+
files = [
|
| 141 |
+
f
|
| 142 |
+
for f in files
|
| 143 |
+
if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f))
|
| 144 |
+
]
|
| 145 |
+
|
| 146 |
+
if split == "train":
|
| 147 |
+
assert len(files) == 3406, len(files)
|
| 148 |
+
elif split == "val":
|
| 149 |
+
assert len(files) == 487, len(files)
|
| 150 |
+
elif split == "test":
|
| 151 |
+
assert len(files) == 973, len(files)
|
| 152 |
+
|
| 153 |
+
self.n_tracks = len(files)
|
| 154 |
+
|
| 155 |
+
self.target_length = target_length
|
| 156 |
+
self.chunk_size = int(chunk_size_second * fs)
|
| 157 |
+
|
| 158 |
+
super().__init__(
|
| 159 |
+
data_path=data_path,
|
| 160 |
+
split=split,
|
| 161 |
+
stems=stems,
|
| 162 |
+
files=files,
|
| 163 |
+
fs=fs,
|
| 164 |
+
npy_memmap=npy_memmap,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
def __len__(self) -> int:
|
| 168 |
+
return self.target_length
|
| 169 |
+
|
| 170 |
+
def get_identifier(self, index):
|
| 171 |
+
return super().get_identifier(index % self.n_tracks)
|
| 172 |
+
|
| 173 |
+
def get_stem(
|
| 174 |
+
self,
|
| 175 |
+
*,
|
| 176 |
+
stem: str,
|
| 177 |
+
identifier: Dict[str, Any],
|
| 178 |
+
chunk_here: bool = False,
|
| 179 |
+
) -> torch.Tensor:
|
| 180 |
+
|
| 181 |
+
stem = super().get_stem(stem=stem, identifier=identifier)
|
| 182 |
+
|
| 183 |
+
if chunk_here:
|
| 184 |
+
start = np.random.randint(
|
| 185 |
+
0, self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size
|
| 186 |
+
)
|
| 187 |
+
end = start + self.chunk_size
|
| 188 |
+
|
| 189 |
+
stem = stem[:, start:end]
|
| 190 |
+
|
| 191 |
+
return stem
|
| 192 |
+
|
| 193 |
+
def __getitem__(self, index: int) -> DataDict:
|
| 194 |
+
identifier = self.get_identifier(index)
|
| 195 |
+
audio = self.get_audio(identifier)
|
| 196 |
+
|
| 197 |
+
start = np.random.randint(0, self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size)
|
| 198 |
+
end = start + self.chunk_size
|
| 199 |
+
|
| 200 |
+
audio = {k: v[:, start:end] for k, v in audio.items()}
|
| 201 |
+
|
| 202 |
+
return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class DivideAndRemasterDeterministicChunkDataset(DivideAndRemasterBaseDataset):
|
| 206 |
+
def __init__(
|
| 207 |
+
self,
|
| 208 |
+
data_root: str,
|
| 209 |
+
split: str,
|
| 210 |
+
chunk_size_second: float,
|
| 211 |
+
hop_size_second: float,
|
| 212 |
+
stems: Optional[List[str]] = None,
|
| 213 |
+
fs: int = 44100,
|
| 214 |
+
npy_memmap: bool = True,
|
| 215 |
+
) -> None:
|
| 216 |
+
|
| 217 |
+
if stems is None:
|
| 218 |
+
stems = self.ALLOWED_STEMS
|
| 219 |
+
self.stems = stems
|
| 220 |
+
|
| 221 |
+
data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
|
| 222 |
+
|
| 223 |
+
files = sorted(os.listdir(data_path))
|
| 224 |
+
files = [
|
| 225 |
+
f
|
| 226 |
+
for f in files
|
| 227 |
+
if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f))
|
| 228 |
+
]
|
| 229 |
+
if split == "train":
|
| 230 |
+
assert len(files) == 3406, len(files)
|
| 231 |
+
elif split == "val":
|
| 232 |
+
assert len(files) == 487, len(files)
|
| 233 |
+
elif split == "test":
|
| 234 |
+
assert len(files) == 973, len(files)
|
| 235 |
+
|
| 236 |
+
self.n_tracks = len(files)
|
| 237 |
+
|
| 238 |
+
self.chunk_size = int(chunk_size_second * fs)
|
| 239 |
+
self.hop_size = int(hop_size_second * fs)
|
| 240 |
+
self.n_chunks_per_track = int(
|
| 241 |
+
(self.FULL_TRACK_LENGTH_SECOND - chunk_size_second) / hop_size_second
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
self.length = self.n_tracks * self.n_chunks_per_track
|
| 245 |
+
|
| 246 |
+
super().__init__(
|
| 247 |
+
data_path=data_path,
|
| 248 |
+
split=split,
|
| 249 |
+
stems=stems,
|
| 250 |
+
files=files,
|
| 251 |
+
fs=fs,
|
| 252 |
+
npy_memmap=npy_memmap,
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
def get_identifier(self, index):
|
| 256 |
+
return super().get_identifier(index % self.n_tracks)
|
| 257 |
+
|
| 258 |
+
def __len__(self) -> int:
|
| 259 |
+
return self.length
|
| 260 |
+
|
| 261 |
+
def __getitem__(self, item: int) -> DataDict:
|
| 262 |
+
|
| 263 |
+
index = item % self.n_tracks
|
| 264 |
+
chunk = item // self.n_tracks
|
| 265 |
+
|
| 266 |
+
data_ = super().__getitem__(index)
|
| 267 |
+
|
| 268 |
+
audio = data_["audio"]
|
| 269 |
+
|
| 270 |
+
start = chunk * self.hop_size
|
| 271 |
+
end = start + self.chunk_size
|
| 272 |
+
|
| 273 |
+
for stem in self.stems:
|
| 274 |
+
data_["audio"][stem] = audio[stem][:, start:end]
|
| 275 |
+
|
| 276 |
+
return data_
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
class DivideAndRemasterRandomChunkDatasetWithSpeechReverb(
|
| 280 |
+
DivideAndRemasterRandomChunkDataset
|
| 281 |
+
):
|
| 282 |
+
def __init__(
|
| 283 |
+
self,
|
| 284 |
+
data_root: str,
|
| 285 |
+
split: str,
|
| 286 |
+
target_length: int,
|
| 287 |
+
chunk_size_second: float,
|
| 288 |
+
stems: Optional[List[str]] = None,
|
| 289 |
+
fs: int = 44100,
|
| 290 |
+
npy_memmap: bool = True,
|
| 291 |
+
) -> None:
|
| 292 |
+
|
| 293 |
+
if stems is None:
|
| 294 |
+
stems = self.ALLOWED_STEMS
|
| 295 |
+
|
| 296 |
+
stems_no_mixture = [s for s in stems if s != "mixture"]
|
| 297 |
+
|
| 298 |
+
super().__init__(
|
| 299 |
+
data_root=data_root,
|
| 300 |
+
split=split,
|
| 301 |
+
target_length=target_length,
|
| 302 |
+
chunk_size_second=chunk_size_second,
|
| 303 |
+
stems=stems_no_mixture,
|
| 304 |
+
fs=fs,
|
| 305 |
+
npy_memmap=npy_memmap,
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
self.stems = stems
|
| 309 |
+
self.stems_no_mixture = stems_no_mixture
|
| 310 |
+
|
| 311 |
+
def __getitem__(self, index: int) -> DataDict:
|
| 312 |
+
|
| 313 |
+
data_ = super().__getitem__(index)
|
| 314 |
+
|
| 315 |
+
dry = data_["audio"]["speech"][:]
|
| 316 |
+
n_samples = dry.shape[-1]
|
| 317 |
+
|
| 318 |
+
wet_level = np.random.rand()
|
| 319 |
+
|
| 320 |
+
speech = pb.Reverb(
|
| 321 |
+
room_size=np.random.rand(),
|
| 322 |
+
damping=np.random.rand(),
|
| 323 |
+
wet_level=wet_level,
|
| 324 |
+
dry_level=(1 - wet_level),
|
| 325 |
+
width=np.random.rand(),
|
| 326 |
+
).process(dry, self.fs, buffer_size=8192 * 4)[..., :n_samples]
|
| 327 |
+
|
| 328 |
+
data_["audio"]["speech"] = speech
|
| 329 |
+
|
| 330 |
+
data_["audio"]["mixture"] = sum(
|
| 331 |
+
[data_["audio"][s] for s in self.stems_no_mixture]
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
return data_
|
| 335 |
+
|
| 336 |
+
def __len__(self) -> int:
|
| 337 |
+
return super().__len__()
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
if __name__ == "__main__":
|
| 341 |
+
|
| 342 |
+
from pprint import pprint
|
| 343 |
+
from tqdm.auto import tqdm
|
| 344 |
+
|
| 345 |
+
for split_ in ["train", "val", "test"]:
|
| 346 |
+
ds = DivideAndRemasterRandomChunkDatasetWithSpeechReverb(
|
| 347 |
+
data_root="$DATA_ROOT/DnR/v2np",
|
| 348 |
+
split=split_,
|
| 349 |
+
target_length=100,
|
| 350 |
+
chunk_size_second=6.0,
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
print(split_, len(ds))
|
| 354 |
+
|
| 355 |
+
for track_ in tqdm(ds): # type: ignore
|
| 356 |
+
pprint(track_)
|
| 357 |
+
track_["audio"] = {k: v.shape for k, v in track_["audio"].items()}
|
| 358 |
+
pprint(track_)
|
| 359 |
+
|
| 360 |
+
break
|
mvsepless/models/bandit/core/data/dnr/preprocess.py
CHANGED
|
@@ -1,51 +1,51 @@
|
|
| 1 |
-
import glob
|
| 2 |
-
import os
|
| 3 |
-
from typing import Tuple
|
| 4 |
-
|
| 5 |
-
import numpy as np
|
| 6 |
-
import torchaudio as ta
|
| 7 |
-
from tqdm.contrib.concurrent import process_map
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
def process_one(inputs: Tuple[str, str, int]) -> None:
|
| 11 |
-
infile, outfile, target_fs = inputs
|
| 12 |
-
|
| 13 |
-
dir = os.path.dirname(outfile)
|
| 14 |
-
os.makedirs(dir, exist_ok=True)
|
| 15 |
-
|
| 16 |
-
data, fs = ta.load(infile)
|
| 17 |
-
|
| 18 |
-
if fs != target_fs:
|
| 19 |
-
data = ta.functional.resample(
|
| 20 |
-
data, fs, target_fs, resampling_method="sinc_interp_kaiser"
|
| 21 |
-
)
|
| 22 |
-
fs = target_fs
|
| 23 |
-
|
| 24 |
-
data = data.numpy()
|
| 25 |
-
data = data.astype(np.float32)
|
| 26 |
-
|
| 27 |
-
if os.path.exists(outfile):
|
| 28 |
-
data_ = np.load(outfile)
|
| 29 |
-
if np.allclose(data, data_):
|
| 30 |
-
return
|
| 31 |
-
|
| 32 |
-
np.save(outfile, data)
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
def preprocess(data_path: str, output_path: str, fs: int) -> None:
|
| 36 |
-
files = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True)
|
| 37 |
-
print(files)
|
| 38 |
-
outfiles = [
|
| 39 |
-
f.replace(data_path, output_path).replace(".wav", ".npy") for f in files
|
| 40 |
-
]
|
| 41 |
-
|
| 42 |
-
os.makedirs(output_path, exist_ok=True)
|
| 43 |
-
inputs = list(zip(files, outfiles, [fs] * len(files)))
|
| 44 |
-
|
| 45 |
-
process_map(process_one, inputs, chunksize=32)
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
if __name__ == "__main__":
|
| 49 |
-
import fire
|
| 50 |
-
|
| 51 |
-
fire.Fire()
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torchaudio as ta
|
| 7 |
+
from tqdm.contrib.concurrent import process_map
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def process_one(inputs: Tuple[str, str, int]) -> None:
|
| 11 |
+
infile, outfile, target_fs = inputs
|
| 12 |
+
|
| 13 |
+
dir = os.path.dirname(outfile)
|
| 14 |
+
os.makedirs(dir, exist_ok=True)
|
| 15 |
+
|
| 16 |
+
data, fs = ta.load(infile)
|
| 17 |
+
|
| 18 |
+
if fs != target_fs:
|
| 19 |
+
data = ta.functional.resample(
|
| 20 |
+
data, fs, target_fs, resampling_method="sinc_interp_kaiser"
|
| 21 |
+
)
|
| 22 |
+
fs = target_fs
|
| 23 |
+
|
| 24 |
+
data = data.numpy()
|
| 25 |
+
data = data.astype(np.float32)
|
| 26 |
+
|
| 27 |
+
if os.path.exists(outfile):
|
| 28 |
+
data_ = np.load(outfile)
|
| 29 |
+
if np.allclose(data, data_):
|
| 30 |
+
return
|
| 31 |
+
|
| 32 |
+
np.save(outfile, data)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def preprocess(data_path: str, output_path: str, fs: int) -> None:
|
| 36 |
+
files = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True)
|
| 37 |
+
print(files)
|
| 38 |
+
outfiles = [
|
| 39 |
+
f.replace(data_path, output_path).replace(".wav", ".npy") for f in files
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
os.makedirs(output_path, exist_ok=True)
|
| 43 |
+
inputs = list(zip(files, outfiles, [fs] * len(files)))
|
| 44 |
+
|
| 45 |
+
process_map(process_one, inputs, chunksize=32)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
if __name__ == "__main__":
|
| 49 |
+
import fire
|
| 50 |
+
|
| 51 |
+
fire.Fire()
|
mvsepless/models/bandit/core/data/musdb/datamodule.py
CHANGED
|
@@ -1,75 +1,75 @@
|
|
| 1 |
-
import os.path
|
| 2 |
-
from typing import Mapping, Optional
|
| 3 |
-
|
| 4 |
-
import pytorch_lightning as pl
|
| 5 |
-
|
| 6 |
-
from .dataset import (
|
| 7 |
-
MUSDB18BaseDataset,
|
| 8 |
-
MUSDB18FullTrackDataset,
|
| 9 |
-
MUSDB18SadDataset,
|
| 10 |
-
MUSDB18SadOnTheFlyAugmentedDataset,
|
| 11 |
-
)
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def MUSDB18DataModule(
|
| 15 |
-
data_root: str = "$DATA_ROOT/MUSDB18/HQ",
|
| 16 |
-
target_stem: str = "vocals",
|
| 17 |
-
batch_size: int = 2,
|
| 18 |
-
num_workers: int = 8,
|
| 19 |
-
train_kwargs: Optional[Mapping] = None,
|
| 20 |
-
val_kwargs: Optional[Mapping] = None,
|
| 21 |
-
test_kwargs: Optional[Mapping] = None,
|
| 22 |
-
datamodule_kwargs: Optional[Mapping] = None,
|
| 23 |
-
use_on_the_fly: bool = True,
|
| 24 |
-
npy_memmap: bool = True,
|
| 25 |
-
) -> pl.LightningDataModule:
|
| 26 |
-
if train_kwargs is None:
|
| 27 |
-
train_kwargs = {}
|
| 28 |
-
|
| 29 |
-
if val_kwargs is None:
|
| 30 |
-
val_kwargs = {}
|
| 31 |
-
|
| 32 |
-
if test_kwargs is None:
|
| 33 |
-
test_kwargs = {}
|
| 34 |
-
|
| 35 |
-
if datamodule_kwargs is None:
|
| 36 |
-
datamodule_kwargs = {}
|
| 37 |
-
|
| 38 |
-
train_dataset: MUSDB18BaseDataset
|
| 39 |
-
|
| 40 |
-
if use_on_the_fly:
|
| 41 |
-
train_dataset = MUSDB18SadOnTheFlyAugmentedDataset(
|
| 42 |
-
data_root=os.path.join(data_root, "saded-np"),
|
| 43 |
-
split="train",
|
| 44 |
-
target_stem=target_stem,
|
| 45 |
-
**train_kwargs,
|
| 46 |
-
)
|
| 47 |
-
else:
|
| 48 |
-
train_dataset = MUSDB18SadDataset(
|
| 49 |
-
data_root=os.path.join(data_root, "saded-np"),
|
| 50 |
-
split="train",
|
| 51 |
-
target_stem=target_stem,
|
| 52 |
-
**train_kwargs,
|
| 53 |
-
)
|
| 54 |
-
|
| 55 |
-
datamodule = pl.LightningDataModule.from_datasets(
|
| 56 |
-
train_dataset=train_dataset,
|
| 57 |
-
val_dataset=MUSDB18SadDataset(
|
| 58 |
-
data_root=os.path.join(data_root, "saded-np"),
|
| 59 |
-
split="val",
|
| 60 |
-
target_stem=target_stem,
|
| 61 |
-
**val_kwargs,
|
| 62 |
-
),
|
| 63 |
-
test_dataset=MUSDB18FullTrackDataset(
|
| 64 |
-
data_root=os.path.join(data_root, "canonical"), split="test", **test_kwargs
|
| 65 |
-
),
|
| 66 |
-
batch_size=batch_size,
|
| 67 |
-
num_workers=num_workers,
|
| 68 |
-
**datamodule_kwargs,
|
| 69 |
-
)
|
| 70 |
-
|
| 71 |
-
datamodule.predict_dataloader = ( # type: ignore[method-assign]
|
| 72 |
-
datamodule.test_dataloader
|
| 73 |
-
)
|
| 74 |
-
|
| 75 |
-
return datamodule
|
|
|
|
| 1 |
+
import os.path
|
| 2 |
+
from typing import Mapping, Optional
|
| 3 |
+
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
|
| 6 |
+
from .dataset import (
|
| 7 |
+
MUSDB18BaseDataset,
|
| 8 |
+
MUSDB18FullTrackDataset,
|
| 9 |
+
MUSDB18SadDataset,
|
| 10 |
+
MUSDB18SadOnTheFlyAugmentedDataset,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def MUSDB18DataModule(
|
| 15 |
+
data_root: str = "$DATA_ROOT/MUSDB18/HQ",
|
| 16 |
+
target_stem: str = "vocals",
|
| 17 |
+
batch_size: int = 2,
|
| 18 |
+
num_workers: int = 8,
|
| 19 |
+
train_kwargs: Optional[Mapping] = None,
|
| 20 |
+
val_kwargs: Optional[Mapping] = None,
|
| 21 |
+
test_kwargs: Optional[Mapping] = None,
|
| 22 |
+
datamodule_kwargs: Optional[Mapping] = None,
|
| 23 |
+
use_on_the_fly: bool = True,
|
| 24 |
+
npy_memmap: bool = True,
|
| 25 |
+
) -> pl.LightningDataModule:
|
| 26 |
+
if train_kwargs is None:
|
| 27 |
+
train_kwargs = {}
|
| 28 |
+
|
| 29 |
+
if val_kwargs is None:
|
| 30 |
+
val_kwargs = {}
|
| 31 |
+
|
| 32 |
+
if test_kwargs is None:
|
| 33 |
+
test_kwargs = {}
|
| 34 |
+
|
| 35 |
+
if datamodule_kwargs is None:
|
| 36 |
+
datamodule_kwargs = {}
|
| 37 |
+
|
| 38 |
+
train_dataset: MUSDB18BaseDataset
|
| 39 |
+
|
| 40 |
+
if use_on_the_fly:
|
| 41 |
+
train_dataset = MUSDB18SadOnTheFlyAugmentedDataset(
|
| 42 |
+
data_root=os.path.join(data_root, "saded-np"),
|
| 43 |
+
split="train",
|
| 44 |
+
target_stem=target_stem,
|
| 45 |
+
**train_kwargs,
|
| 46 |
+
)
|
| 47 |
+
else:
|
| 48 |
+
train_dataset = MUSDB18SadDataset(
|
| 49 |
+
data_root=os.path.join(data_root, "saded-np"),
|
| 50 |
+
split="train",
|
| 51 |
+
target_stem=target_stem,
|
| 52 |
+
**train_kwargs,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
datamodule = pl.LightningDataModule.from_datasets(
|
| 56 |
+
train_dataset=train_dataset,
|
| 57 |
+
val_dataset=MUSDB18SadDataset(
|
| 58 |
+
data_root=os.path.join(data_root, "saded-np"),
|
| 59 |
+
split="val",
|
| 60 |
+
target_stem=target_stem,
|
| 61 |
+
**val_kwargs,
|
| 62 |
+
),
|
| 63 |
+
test_dataset=MUSDB18FullTrackDataset(
|
| 64 |
+
data_root=os.path.join(data_root, "canonical"), split="test", **test_kwargs
|
| 65 |
+
),
|
| 66 |
+
batch_size=batch_size,
|
| 67 |
+
num_workers=num_workers,
|
| 68 |
+
**datamodule_kwargs,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
datamodule.predict_dataloader = ( # type: ignore[method-assign]
|
| 72 |
+
datamodule.test_dataloader
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
return datamodule
|
mvsepless/models/bandit/core/data/musdb/dataset.py
CHANGED
|
@@ -1,241 +1,241 @@
|
|
| 1 |
-
import os
|
| 2 |
-
from abc import ABC
|
| 3 |
-
from typing import List, Optional, Tuple
|
| 4 |
-
|
| 5 |
-
import numpy as np
|
| 6 |
-
import torch
|
| 7 |
-
import torchaudio as ta
|
| 8 |
-
from torch.utils import data
|
| 9 |
-
|
| 10 |
-
from .._types import AudioDict, DataDict
|
| 11 |
-
from ..base import BaseSourceSeparationDataset
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
class MUSDB18BaseDataset(BaseSourceSeparationDataset, ABC):
|
| 15 |
-
|
| 16 |
-
ALLOWED_STEMS = ["mixture", "vocals", "bass", "drums", "other"]
|
| 17 |
-
|
| 18 |
-
def __init__(
|
| 19 |
-
self,
|
| 20 |
-
split: str,
|
| 21 |
-
stems: List[str],
|
| 22 |
-
files: List[str],
|
| 23 |
-
data_path: str,
|
| 24 |
-
fs: int = 44100,
|
| 25 |
-
npy_memmap=False,
|
| 26 |
-
) -> None:
|
| 27 |
-
super().__init__(
|
| 28 |
-
split=split,
|
| 29 |
-
stems=stems,
|
| 30 |
-
files=files,
|
| 31 |
-
data_path=data_path,
|
| 32 |
-
fs=fs,
|
| 33 |
-
npy_memmap=npy_memmap,
|
| 34 |
-
recompute_mixture=False,
|
| 35 |
-
)
|
| 36 |
-
|
| 37 |
-
def get_stem(self, *, stem: str, identifier) -> torch.Tensor:
|
| 38 |
-
track = identifier["track"]
|
| 39 |
-
path = os.path.join(self.data_path, track)
|
| 40 |
-
|
| 41 |
-
if self.npy_memmap:
|
| 42 |
-
audio = np.load(os.path.join(path, f"{stem}.wav.npy"), mmap_mode="r")
|
| 43 |
-
else:
|
| 44 |
-
audio, _ = ta.load(os.path.join(path, f"{stem}.wav"))
|
| 45 |
-
|
| 46 |
-
return audio
|
| 47 |
-
|
| 48 |
-
def get_identifier(self, index):
|
| 49 |
-
return dict(track=self.files[index])
|
| 50 |
-
|
| 51 |
-
def __getitem__(self, index: int) -> DataDict:
|
| 52 |
-
identifier = self.get_identifier(index)
|
| 53 |
-
audio = self.get_audio(identifier)
|
| 54 |
-
|
| 55 |
-
return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
class MUSDB18FullTrackDataset(MUSDB18BaseDataset):
|
| 59 |
-
|
| 60 |
-
N_TRAIN_TRACKS = 100
|
| 61 |
-
N_TEST_TRACKS = 50
|
| 62 |
-
VALIDATION_FILES = [
|
| 63 |
-
"Actions - One Minute Smile",
|
| 64 |
-
"Clara Berry And Wooldog - Waltz For My Victims",
|
| 65 |
-
"Johnny Lokke - Promises & Lies",
|
| 66 |
-
"Patrick Talbot - A Reason To Leave",
|
| 67 |
-
"Triviul - Angelsaint",
|
| 68 |
-
"Alexander Ross - Goodbye Bolero",
|
| 69 |
-
"Fergessen - Nos Palpitants",
|
| 70 |
-
"Leaf - Summerghost",
|
| 71 |
-
"Skelpolu - Human Mistakes",
|
| 72 |
-
"Young Griffo - Pennies",
|
| 73 |
-
"ANiMAL - Rockshow",
|
| 74 |
-
"James May - On The Line",
|
| 75 |
-
"Meaxic - Take A Step",
|
| 76 |
-
"Traffic Experiment - Sirens",
|
| 77 |
-
]
|
| 78 |
-
|
| 79 |
-
def __init__(
|
| 80 |
-
self, data_root: str, split: str, stems: Optional[List[str]] = None
|
| 81 |
-
) -> None:
|
| 82 |
-
|
| 83 |
-
if stems is None:
|
| 84 |
-
stems = self.ALLOWED_STEMS
|
| 85 |
-
self.stems = stems
|
| 86 |
-
|
| 87 |
-
if split == "test":
|
| 88 |
-
subset = "test"
|
| 89 |
-
elif split in ["train", "val"]:
|
| 90 |
-
subset = "train"
|
| 91 |
-
else:
|
| 92 |
-
raise NameError
|
| 93 |
-
|
| 94 |
-
data_path = os.path.join(data_root, subset)
|
| 95 |
-
|
| 96 |
-
files = sorted(os.listdir(data_path))
|
| 97 |
-
files = [f for f in files if not f.startswith(".")]
|
| 98 |
-
if subset == "train":
|
| 99 |
-
assert len(files) == 100, len(files)
|
| 100 |
-
if split == "train":
|
| 101 |
-
files = [f for f in files if f not in self.VALIDATION_FILES]
|
| 102 |
-
assert len(files) == 100 - len(self.VALIDATION_FILES)
|
| 103 |
-
else:
|
| 104 |
-
files = [f for f in files if f in self.VALIDATION_FILES]
|
| 105 |
-
assert len(files) == len(self.VALIDATION_FILES)
|
| 106 |
-
else:
|
| 107 |
-
split = "test"
|
| 108 |
-
assert len(files) == 50
|
| 109 |
-
|
| 110 |
-
self.n_tracks = len(files)
|
| 111 |
-
|
| 112 |
-
super().__init__(data_path=data_path, split=split, stems=stems, files=files)
|
| 113 |
-
|
| 114 |
-
def __len__(self) -> int:
|
| 115 |
-
return self.n_tracks
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
class MUSDB18SadDataset(MUSDB18BaseDataset):
|
| 119 |
-
def __init__(
|
| 120 |
-
self,
|
| 121 |
-
data_root: str,
|
| 122 |
-
split: str,
|
| 123 |
-
target_stem: str,
|
| 124 |
-
stems: Optional[List[str]] = None,
|
| 125 |
-
target_length: Optional[int] = None,
|
| 126 |
-
npy_memmap=False,
|
| 127 |
-
) -> None:
|
| 128 |
-
|
| 129 |
-
if stems is None:
|
| 130 |
-
stems = self.ALLOWED_STEMS
|
| 131 |
-
|
| 132 |
-
data_path = os.path.join(data_root, target_stem, split)
|
| 133 |
-
|
| 134 |
-
files = sorted(os.listdir(data_path))
|
| 135 |
-
files = [f for f in files if not f.startswith(".")]
|
| 136 |
-
|
| 137 |
-
super().__init__(
|
| 138 |
-
data_path=data_path,
|
| 139 |
-
split=split,
|
| 140 |
-
stems=stems,
|
| 141 |
-
files=files,
|
| 142 |
-
npy_memmap=npy_memmap,
|
| 143 |
-
)
|
| 144 |
-
self.n_segments = len(files)
|
| 145 |
-
self.target_stem = target_stem
|
| 146 |
-
self.target_length = (
|
| 147 |
-
target_length if target_length is not None else self.n_segments
|
| 148 |
-
)
|
| 149 |
-
|
| 150 |
-
def __len__(self) -> int:
|
| 151 |
-
return self.target_length
|
| 152 |
-
|
| 153 |
-
def __getitem__(self, index: int) -> DataDict:
|
| 154 |
-
|
| 155 |
-
index = index % self.n_segments
|
| 156 |
-
|
| 157 |
-
return super().__getitem__(index)
|
| 158 |
-
|
| 159 |
-
def get_identifier(self, index):
|
| 160 |
-
return super().get_identifier(index % self.n_segments)
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
class MUSDB18SadOnTheFlyAugmentedDataset(MUSDB18SadDataset):
|
| 164 |
-
def __init__(
|
| 165 |
-
self,
|
| 166 |
-
data_root: str,
|
| 167 |
-
split: str,
|
| 168 |
-
target_stem: str,
|
| 169 |
-
stems: Optional[List[str]] = None,
|
| 170 |
-
target_length: int = 20000,
|
| 171 |
-
apply_probability: Optional[float] = None,
|
| 172 |
-
chunk_size_second: float = 3.0,
|
| 173 |
-
random_scale_range_db: Tuple[float, float] = (-10, 10),
|
| 174 |
-
drop_probability: float = 0.1,
|
| 175 |
-
rescale: bool = True,
|
| 176 |
-
) -> None:
|
| 177 |
-
super().__init__(data_root, split, target_stem, stems)
|
| 178 |
-
|
| 179 |
-
if apply_probability is None:
|
| 180 |
-
apply_probability = (target_length - self.n_segments) / target_length
|
| 181 |
-
|
| 182 |
-
self.apply_probability = apply_probability
|
| 183 |
-
self.drop_probability = drop_probability
|
| 184 |
-
self.chunk_size_second = chunk_size_second
|
| 185 |
-
self.random_scale_range_db = random_scale_range_db
|
| 186 |
-
self.rescale = rescale
|
| 187 |
-
|
| 188 |
-
self.chunk_size_sample = int(self.chunk_size_second * self.fs)
|
| 189 |
-
self.target_length = target_length
|
| 190 |
-
|
| 191 |
-
def __len__(self) -> int:
|
| 192 |
-
return self.target_length
|
| 193 |
-
|
| 194 |
-
def __getitem__(self, index: int) -> DataDict:
|
| 195 |
-
|
| 196 |
-
index = index % self.n_segments
|
| 197 |
-
|
| 198 |
-
audio = {}
|
| 199 |
-
identifier = self.get_identifier(index)
|
| 200 |
-
|
| 201 |
-
for stem in self.stems_no_mixture:
|
| 202 |
-
if stem == self.target_stem:
|
| 203 |
-
identifier_ = identifier
|
| 204 |
-
else:
|
| 205 |
-
if np.random.rand() < self.apply_probability:
|
| 206 |
-
index_ = np.random.randint(self.n_segments)
|
| 207 |
-
identifier_ = self.get_identifier(index_)
|
| 208 |
-
else:
|
| 209 |
-
identifier_ = identifier
|
| 210 |
-
|
| 211 |
-
audio[stem] = self.get_stem(stem=stem, identifier=identifier_)
|
| 212 |
-
|
| 213 |
-
if self.chunk_size_sample < audio[stem].shape[-1]:
|
| 214 |
-
chunk_start = np.random.randint(
|
| 215 |
-
audio[stem].shape[-1] - self.chunk_size_sample
|
| 216 |
-
)
|
| 217 |
-
else:
|
| 218 |
-
chunk_start = 0
|
| 219 |
-
|
| 220 |
-
if np.random.rand() < self.drop_probability:
|
| 221 |
-
linear_scale = 0.0
|
| 222 |
-
else:
|
| 223 |
-
db_scale = np.random.uniform(*self.random_scale_range_db)
|
| 224 |
-
linear_scale = np.power(10, db_scale / 20)
|
| 225 |
-
audio[stem][..., chunk_start : chunk_start + self.chunk_size_sample] = (
|
| 226 |
-
linear_scale
|
| 227 |
-
* audio[stem][..., chunk_start : chunk_start + self.chunk_size_sample]
|
| 228 |
-
)
|
| 229 |
-
|
| 230 |
-
audio["mixture"] = self.compute_mixture(audio)
|
| 231 |
-
|
| 232 |
-
if self.rescale:
|
| 233 |
-
max_abs_val = max(
|
| 234 |
-
[torch.max(torch.abs(audio[stem])) for stem in self.stems]
|
| 235 |
-
) # type: ignore[type-var]
|
| 236 |
-
if max_abs_val > 1:
|
| 237 |
-
audio = {k: v / max_abs_val for k, v in audio.items()}
|
| 238 |
-
|
| 239 |
-
track = identifier["track"]
|
| 240 |
-
|
| 241 |
-
return {"audio": audio, "track": f"{self.split}/{track}"}
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from abc import ABC
|
| 3 |
+
from typing import List, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torchaudio as ta
|
| 8 |
+
from torch.utils import data
|
| 9 |
+
|
| 10 |
+
from .._types import AudioDict, DataDict
|
| 11 |
+
from ..base import BaseSourceSeparationDataset
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class MUSDB18BaseDataset(BaseSourceSeparationDataset, ABC):
|
| 15 |
+
|
| 16 |
+
ALLOWED_STEMS = ["mixture", "vocals", "bass", "drums", "other"]
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
split: str,
|
| 21 |
+
stems: List[str],
|
| 22 |
+
files: List[str],
|
| 23 |
+
data_path: str,
|
| 24 |
+
fs: int = 44100,
|
| 25 |
+
npy_memmap=False,
|
| 26 |
+
) -> None:
|
| 27 |
+
super().__init__(
|
| 28 |
+
split=split,
|
| 29 |
+
stems=stems,
|
| 30 |
+
files=files,
|
| 31 |
+
data_path=data_path,
|
| 32 |
+
fs=fs,
|
| 33 |
+
npy_memmap=npy_memmap,
|
| 34 |
+
recompute_mixture=False,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
def get_stem(self, *, stem: str, identifier) -> torch.Tensor:
|
| 38 |
+
track = identifier["track"]
|
| 39 |
+
path = os.path.join(self.data_path, track)
|
| 40 |
+
|
| 41 |
+
if self.npy_memmap:
|
| 42 |
+
audio = np.load(os.path.join(path, f"{stem}.wav.npy"), mmap_mode="r")
|
| 43 |
+
else:
|
| 44 |
+
audio, _ = ta.load(os.path.join(path, f"{stem}.wav"))
|
| 45 |
+
|
| 46 |
+
return audio
|
| 47 |
+
|
| 48 |
+
def get_identifier(self, index):
|
| 49 |
+
return dict(track=self.files[index])
|
| 50 |
+
|
| 51 |
+
def __getitem__(self, index: int) -> DataDict:
|
| 52 |
+
identifier = self.get_identifier(index)
|
| 53 |
+
audio = self.get_audio(identifier)
|
| 54 |
+
|
| 55 |
+
return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class MUSDB18FullTrackDataset(MUSDB18BaseDataset):
|
| 59 |
+
|
| 60 |
+
N_TRAIN_TRACKS = 100
|
| 61 |
+
N_TEST_TRACKS = 50
|
| 62 |
+
VALIDATION_FILES = [
|
| 63 |
+
"Actions - One Minute Smile",
|
| 64 |
+
"Clara Berry And Wooldog - Waltz For My Victims",
|
| 65 |
+
"Johnny Lokke - Promises & Lies",
|
| 66 |
+
"Patrick Talbot - A Reason To Leave",
|
| 67 |
+
"Triviul - Angelsaint",
|
| 68 |
+
"Alexander Ross - Goodbye Bolero",
|
| 69 |
+
"Fergessen - Nos Palpitants",
|
| 70 |
+
"Leaf - Summerghost",
|
| 71 |
+
"Skelpolu - Human Mistakes",
|
| 72 |
+
"Young Griffo - Pennies",
|
| 73 |
+
"ANiMAL - Rockshow",
|
| 74 |
+
"James May - On The Line",
|
| 75 |
+
"Meaxic - Take A Step",
|
| 76 |
+
"Traffic Experiment - Sirens",
|
| 77 |
+
]
|
| 78 |
+
|
| 79 |
+
def __init__(
|
| 80 |
+
self, data_root: str, split: str, stems: Optional[List[str]] = None
|
| 81 |
+
) -> None:
|
| 82 |
+
|
| 83 |
+
if stems is None:
|
| 84 |
+
stems = self.ALLOWED_STEMS
|
| 85 |
+
self.stems = stems
|
| 86 |
+
|
| 87 |
+
if split == "test":
|
| 88 |
+
subset = "test"
|
| 89 |
+
elif split in ["train", "val"]:
|
| 90 |
+
subset = "train"
|
| 91 |
+
else:
|
| 92 |
+
raise NameError
|
| 93 |
+
|
| 94 |
+
data_path = os.path.join(data_root, subset)
|
| 95 |
+
|
| 96 |
+
files = sorted(os.listdir(data_path))
|
| 97 |
+
files = [f for f in files if not f.startswith(".")]
|
| 98 |
+
if subset == "train":
|
| 99 |
+
assert len(files) == 100, len(files)
|
| 100 |
+
if split == "train":
|
| 101 |
+
files = [f for f in files if f not in self.VALIDATION_FILES]
|
| 102 |
+
assert len(files) == 100 - len(self.VALIDATION_FILES)
|
| 103 |
+
else:
|
| 104 |
+
files = [f for f in files if f in self.VALIDATION_FILES]
|
| 105 |
+
assert len(files) == len(self.VALIDATION_FILES)
|
| 106 |
+
else:
|
| 107 |
+
split = "test"
|
| 108 |
+
assert len(files) == 50
|
| 109 |
+
|
| 110 |
+
self.n_tracks = len(files)
|
| 111 |
+
|
| 112 |
+
super().__init__(data_path=data_path, split=split, stems=stems, files=files)
|
| 113 |
+
|
| 114 |
+
def __len__(self) -> int:
|
| 115 |
+
return self.n_tracks
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class MUSDB18SadDataset(MUSDB18BaseDataset):
|
| 119 |
+
def __init__(
|
| 120 |
+
self,
|
| 121 |
+
data_root: str,
|
| 122 |
+
split: str,
|
| 123 |
+
target_stem: str,
|
| 124 |
+
stems: Optional[List[str]] = None,
|
| 125 |
+
target_length: Optional[int] = None,
|
| 126 |
+
npy_memmap=False,
|
| 127 |
+
) -> None:
|
| 128 |
+
|
| 129 |
+
if stems is None:
|
| 130 |
+
stems = self.ALLOWED_STEMS
|
| 131 |
+
|
| 132 |
+
data_path = os.path.join(data_root, target_stem, split)
|
| 133 |
+
|
| 134 |
+
files = sorted(os.listdir(data_path))
|
| 135 |
+
files = [f for f in files if not f.startswith(".")]
|
| 136 |
+
|
| 137 |
+
super().__init__(
|
| 138 |
+
data_path=data_path,
|
| 139 |
+
split=split,
|
| 140 |
+
stems=stems,
|
| 141 |
+
files=files,
|
| 142 |
+
npy_memmap=npy_memmap,
|
| 143 |
+
)
|
| 144 |
+
self.n_segments = len(files)
|
| 145 |
+
self.target_stem = target_stem
|
| 146 |
+
self.target_length = (
|
| 147 |
+
target_length if target_length is not None else self.n_segments
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
def __len__(self) -> int:
|
| 151 |
+
return self.target_length
|
| 152 |
+
|
| 153 |
+
def __getitem__(self, index: int) -> DataDict:
|
| 154 |
+
|
| 155 |
+
index = index % self.n_segments
|
| 156 |
+
|
| 157 |
+
return super().__getitem__(index)
|
| 158 |
+
|
| 159 |
+
def get_identifier(self, index):
|
| 160 |
+
return super().get_identifier(index % self.n_segments)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class MUSDB18SadOnTheFlyAugmentedDataset(MUSDB18SadDataset):
|
| 164 |
+
def __init__(
|
| 165 |
+
self,
|
| 166 |
+
data_root: str,
|
| 167 |
+
split: str,
|
| 168 |
+
target_stem: str,
|
| 169 |
+
stems: Optional[List[str]] = None,
|
| 170 |
+
target_length: int = 20000,
|
| 171 |
+
apply_probability: Optional[float] = None,
|
| 172 |
+
chunk_size_second: float = 3.0,
|
| 173 |
+
random_scale_range_db: Tuple[float, float] = (-10, 10),
|
| 174 |
+
drop_probability: float = 0.1,
|
| 175 |
+
rescale: bool = True,
|
| 176 |
+
) -> None:
|
| 177 |
+
super().__init__(data_root, split, target_stem, stems)
|
| 178 |
+
|
| 179 |
+
if apply_probability is None:
|
| 180 |
+
apply_probability = (target_length - self.n_segments) / target_length
|
| 181 |
+
|
| 182 |
+
self.apply_probability = apply_probability
|
| 183 |
+
self.drop_probability = drop_probability
|
| 184 |
+
self.chunk_size_second = chunk_size_second
|
| 185 |
+
self.random_scale_range_db = random_scale_range_db
|
| 186 |
+
self.rescale = rescale
|
| 187 |
+
|
| 188 |
+
self.chunk_size_sample = int(self.chunk_size_second * self.fs)
|
| 189 |
+
self.target_length = target_length
|
| 190 |
+
|
| 191 |
+
def __len__(self) -> int:
|
| 192 |
+
return self.target_length
|
| 193 |
+
|
| 194 |
+
def __getitem__(self, index: int) -> DataDict:
|
| 195 |
+
|
| 196 |
+
index = index % self.n_segments
|
| 197 |
+
|
| 198 |
+
audio = {}
|
| 199 |
+
identifier = self.get_identifier(index)
|
| 200 |
+
|
| 201 |
+
for stem in self.stems_no_mixture:
|
| 202 |
+
if stem == self.target_stem:
|
| 203 |
+
identifier_ = identifier
|
| 204 |
+
else:
|
| 205 |
+
if np.random.rand() < self.apply_probability:
|
| 206 |
+
index_ = np.random.randint(self.n_segments)
|
| 207 |
+
identifier_ = self.get_identifier(index_)
|
| 208 |
+
else:
|
| 209 |
+
identifier_ = identifier
|
| 210 |
+
|
| 211 |
+
audio[stem] = self.get_stem(stem=stem, identifier=identifier_)
|
| 212 |
+
|
| 213 |
+
if self.chunk_size_sample < audio[stem].shape[-1]:
|
| 214 |
+
chunk_start = np.random.randint(
|
| 215 |
+
audio[stem].shape[-1] - self.chunk_size_sample
|
| 216 |
+
)
|
| 217 |
+
else:
|
| 218 |
+
chunk_start = 0
|
| 219 |
+
|
| 220 |
+
if np.random.rand() < self.drop_probability:
|
| 221 |
+
linear_scale = 0.0
|
| 222 |
+
else:
|
| 223 |
+
db_scale = np.random.uniform(*self.random_scale_range_db)
|
| 224 |
+
linear_scale = np.power(10, db_scale / 20)
|
| 225 |
+
audio[stem][..., chunk_start : chunk_start + self.chunk_size_sample] = (
|
| 226 |
+
linear_scale
|
| 227 |
+
* audio[stem][..., chunk_start : chunk_start + self.chunk_size_sample]
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
audio["mixture"] = self.compute_mixture(audio)
|
| 231 |
+
|
| 232 |
+
if self.rescale:
|
| 233 |
+
max_abs_val = max(
|
| 234 |
+
[torch.max(torch.abs(audio[stem])) for stem in self.stems]
|
| 235 |
+
) # type: ignore[type-var]
|
| 236 |
+
if max_abs_val > 1:
|
| 237 |
+
audio = {k: v / max_abs_val for k, v in audio.items()}
|
| 238 |
+
|
| 239 |
+
track = identifier["track"]
|
| 240 |
+
|
| 241 |
+
return {"audio": audio, "track": f"{self.split}/{track}"}
|
mvsepless/models/bandit/core/data/musdb/preprocess.py
CHANGED
|
@@ -1,223 +1,223 @@
|
|
| 1 |
-
import glob
|
| 2 |
-
import os
|
| 3 |
-
|
| 4 |
-
import numpy as np
|
| 5 |
-
import torch
|
| 6 |
-
import torchaudio as ta
|
| 7 |
-
from torch import nn
|
| 8 |
-
from torch.nn import functional as F
|
| 9 |
-
from tqdm.contrib.concurrent import process_map
|
| 10 |
-
|
| 11 |
-
from .._types import DataDict
|
| 12 |
-
from .dataset import MUSDB18FullTrackDataset
|
| 13 |
-
import pyloudnorm as pyln
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
class SourceActivityDetector(nn.Module):
|
| 17 |
-
def __init__(
|
| 18 |
-
self,
|
| 19 |
-
analysis_stem: str,
|
| 20 |
-
output_path: str,
|
| 21 |
-
fs: int = 44100,
|
| 22 |
-
segment_length_second: float = 6.0,
|
| 23 |
-
hop_length_second: float = 3.0,
|
| 24 |
-
n_chunks: int = 10,
|
| 25 |
-
chunk_epsilon: float = 1e-5,
|
| 26 |
-
energy_threshold_quantile: float = 0.15,
|
| 27 |
-
segment_epsilon: float = 1e-3,
|
| 28 |
-
salient_proportion_threshold: float = 0.5,
|
| 29 |
-
target_lufs: float = -24,
|
| 30 |
-
) -> None:
|
| 31 |
-
super().__init__()
|
| 32 |
-
|
| 33 |
-
self.fs = fs
|
| 34 |
-
self.segment_length = int(segment_length_second * self.fs)
|
| 35 |
-
self.hop_length = int(hop_length_second * self.fs)
|
| 36 |
-
self.n_chunks = n_chunks
|
| 37 |
-
assert self.segment_length % self.n_chunks == 0
|
| 38 |
-
self.chunk_size = self.segment_length // self.n_chunks
|
| 39 |
-
self.chunk_epsilon = chunk_epsilon
|
| 40 |
-
self.energy_threshold_quantile = energy_threshold_quantile
|
| 41 |
-
self.segment_epsilon = segment_epsilon
|
| 42 |
-
self.salient_proportion_threshold = salient_proportion_threshold
|
| 43 |
-
self.analysis_stem = analysis_stem
|
| 44 |
-
|
| 45 |
-
self.meter = pyln.Meter(self.fs)
|
| 46 |
-
self.target_lufs = target_lufs
|
| 47 |
-
|
| 48 |
-
self.output_path = output_path
|
| 49 |
-
|
| 50 |
-
def forward(self, data: DataDict) -> None:
|
| 51 |
-
|
| 52 |
-
stem_ = self.analysis_stem if (self.analysis_stem != "none") else "mixture"
|
| 53 |
-
|
| 54 |
-
x = data["audio"][stem_]
|
| 55 |
-
|
| 56 |
-
xnp = x.numpy()
|
| 57 |
-
loudness = self.meter.integrated_loudness(xnp.T)
|
| 58 |
-
|
| 59 |
-
for stem in data["audio"]:
|
| 60 |
-
s = data["audio"][stem]
|
| 61 |
-
s = pyln.normalize.loudness(s.numpy().T, loudness, self.target_lufs).T
|
| 62 |
-
s = torch.as_tensor(s)
|
| 63 |
-
data["audio"][stem] = s
|
| 64 |
-
|
| 65 |
-
if x.ndim == 3:
|
| 66 |
-
assert x.shape[0] == 1
|
| 67 |
-
x = x[0]
|
| 68 |
-
|
| 69 |
-
n_chan, n_samples = x.shape
|
| 70 |
-
|
| 71 |
-
n_segments = (
|
| 72 |
-
int(np.ceil((n_samples - self.segment_length) / self.hop_length)) + 1
|
| 73 |
-
)
|
| 74 |
-
|
| 75 |
-
segments = torch.zeros((n_segments, n_chan, self.segment_length))
|
| 76 |
-
for i in range(n_segments):
|
| 77 |
-
start = i * self.hop_length
|
| 78 |
-
end = start + self.segment_length
|
| 79 |
-
end = min(end, n_samples)
|
| 80 |
-
|
| 81 |
-
xseg = x[:, start:end]
|
| 82 |
-
|
| 83 |
-
if end - start < self.segment_length:
|
| 84 |
-
xseg = F.pad(
|
| 85 |
-
xseg, pad=(0, self.segment_length - (end - start)), value=torch.nan
|
| 86 |
-
)
|
| 87 |
-
|
| 88 |
-
segments[i, :, :] = xseg
|
| 89 |
-
|
| 90 |
-
chunks = segments.reshape((n_segments, n_chan, self.n_chunks, self.chunk_size))
|
| 91 |
-
|
| 92 |
-
if self.analysis_stem != "none":
|
| 93 |
-
chunk_energies = torch.mean(torch.square(chunks), dim=(1, 3))
|
| 94 |
-
chunk_energies = torch.nan_to_num(chunk_energies, nan=0)
|
| 95 |
-
chunk_energies[chunk_energies == 0] = self.chunk_epsilon
|
| 96 |
-
|
| 97 |
-
energy_threshold = torch.nanquantile(
|
| 98 |
-
chunk_energies, q=self.energy_threshold_quantile
|
| 99 |
-
)
|
| 100 |
-
|
| 101 |
-
if energy_threshold < self.segment_epsilon:
|
| 102 |
-
energy_threshold = self.segment_epsilon # type: ignore[assignment]
|
| 103 |
-
|
| 104 |
-
chunks_above_threshold = chunk_energies > energy_threshold
|
| 105 |
-
n_chunks_above_threshold = torch.mean(
|
| 106 |
-
chunks_above_threshold.to(torch.float), dim=-1
|
| 107 |
-
)
|
| 108 |
-
|
| 109 |
-
segment_above_threshold = (
|
| 110 |
-
n_chunks_above_threshold > self.salient_proportion_threshold
|
| 111 |
-
)
|
| 112 |
-
|
| 113 |
-
if torch.sum(segment_above_threshold) == 0:
|
| 114 |
-
return
|
| 115 |
-
|
| 116 |
-
else:
|
| 117 |
-
segment_above_threshold = torch.ones((n_segments,))
|
| 118 |
-
|
| 119 |
-
for i in range(n_segments):
|
| 120 |
-
if not segment_above_threshold[i]:
|
| 121 |
-
continue
|
| 122 |
-
|
| 123 |
-
outpath = os.path.join(
|
| 124 |
-
self.output_path,
|
| 125 |
-
self.analysis_stem,
|
| 126 |
-
f"{data['track']} - {self.analysis_stem}{i:03d}",
|
| 127 |
-
)
|
| 128 |
-
os.makedirs(outpath, exist_ok=True)
|
| 129 |
-
|
| 130 |
-
for stem in data["audio"]:
|
| 131 |
-
if stem == self.analysis_stem:
|
| 132 |
-
segment = torch.nan_to_num(segments[i, :, :], nan=0)
|
| 133 |
-
else:
|
| 134 |
-
start = i * self.hop_length
|
| 135 |
-
end = start + self.segment_length
|
| 136 |
-
end = min(n_samples, end)
|
| 137 |
-
|
| 138 |
-
segment = data["audio"][stem][:, start:end]
|
| 139 |
-
|
| 140 |
-
if end - start < self.segment_length:
|
| 141 |
-
segment = F.pad(
|
| 142 |
-
segment, (0, self.segment_length - (end - start))
|
| 143 |
-
)
|
| 144 |
-
|
| 145 |
-
assert segment.shape[-1] == self.segment_length, segment.shape
|
| 146 |
-
|
| 147 |
-
np.save(os.path.join(outpath, f"{stem}.wav"), segment)
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
def preprocess(
|
| 151 |
-
analysis_stem: str,
|
| 152 |
-
output_path: str = "/data/MUSDB18/HQ/saded-np",
|
| 153 |
-
fs: int = 44100,
|
| 154 |
-
segment_length_second: float = 6.0,
|
| 155 |
-
hop_length_second: float = 3.0,
|
| 156 |
-
n_chunks: int = 10,
|
| 157 |
-
chunk_epsilon: float = 1e-5,
|
| 158 |
-
energy_threshold_quantile: float = 0.15,
|
| 159 |
-
segment_epsilon: float = 1e-3,
|
| 160 |
-
salient_proportion_threshold: float = 0.5,
|
| 161 |
-
) -> None:
|
| 162 |
-
|
| 163 |
-
sad = SourceActivityDetector(
|
| 164 |
-
analysis_stem=analysis_stem,
|
| 165 |
-
output_path=output_path,
|
| 166 |
-
fs=fs,
|
| 167 |
-
segment_length_second=segment_length_second,
|
| 168 |
-
hop_length_second=hop_length_second,
|
| 169 |
-
n_chunks=n_chunks,
|
| 170 |
-
chunk_epsilon=chunk_epsilon,
|
| 171 |
-
energy_threshold_quantile=energy_threshold_quantile,
|
| 172 |
-
segment_epsilon=segment_epsilon,
|
| 173 |
-
salient_proportion_threshold=salient_proportion_threshold,
|
| 174 |
-
)
|
| 175 |
-
|
| 176 |
-
for split in ["train", "val", "test"]:
|
| 177 |
-
ds = MUSDB18FullTrackDataset(
|
| 178 |
-
data_root="/data/MUSDB18/HQ/canonical",
|
| 179 |
-
split=split,
|
| 180 |
-
)
|
| 181 |
-
|
| 182 |
-
tracks = []
|
| 183 |
-
for i, track in enumerate(tqdm(ds, total=len(ds))):
|
| 184 |
-
if i % 32 == 0 and tracks:
|
| 185 |
-
process_map(sad, tracks, max_workers=8)
|
| 186 |
-
tracks = []
|
| 187 |
-
tracks.append(track)
|
| 188 |
-
process_map(sad, tracks, max_workers=8)
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
def loudness_norm_one(inputs):
|
| 192 |
-
infile, outfile, target_lufs = inputs
|
| 193 |
-
|
| 194 |
-
audio, fs = ta.load(infile)
|
| 195 |
-
audio = audio.mean(dim=0, keepdim=True).numpy().T
|
| 196 |
-
|
| 197 |
-
meter = pyln.Meter(fs)
|
| 198 |
-
loudness = meter.integrated_loudness(audio)
|
| 199 |
-
audio = pyln.normalize.loudness(audio, loudness, target_lufs)
|
| 200 |
-
|
| 201 |
-
os.makedirs(os.path.dirname(outfile), exist_ok=True)
|
| 202 |
-
np.save(outfile, audio.T)
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
def loudness_norm(
|
| 206 |
-
data_path: str,
|
| 207 |
-
target_lufs=-17.0,
|
| 208 |
-
):
|
| 209 |
-
files = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True)
|
| 210 |
-
|
| 211 |
-
outfiles = [f.replace(".wav", ".npy").replace("saded", "saded-np") for f in files]
|
| 212 |
-
|
| 213 |
-
files = [(f, o, target_lufs) for f, o in zip(files, outfiles)]
|
| 214 |
-
|
| 215 |
-
process_map(loudness_norm_one, files, chunksize=2)
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
if __name__ == "__main__":
|
| 219 |
-
|
| 220 |
-
from tqdm.auto import tqdm
|
| 221 |
-
import fire
|
| 222 |
-
|
| 223 |
-
fire.Fire()
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torchaudio as ta
|
| 7 |
+
from torch import nn
|
| 8 |
+
from torch.nn import functional as F
|
| 9 |
+
from tqdm.contrib.concurrent import process_map
|
| 10 |
+
|
| 11 |
+
from .._types import DataDict
|
| 12 |
+
from .dataset import MUSDB18FullTrackDataset
|
| 13 |
+
import pyloudnorm as pyln
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SourceActivityDetector(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
analysis_stem: str,
|
| 20 |
+
output_path: str,
|
| 21 |
+
fs: int = 44100,
|
| 22 |
+
segment_length_second: float = 6.0,
|
| 23 |
+
hop_length_second: float = 3.0,
|
| 24 |
+
n_chunks: int = 10,
|
| 25 |
+
chunk_epsilon: float = 1e-5,
|
| 26 |
+
energy_threshold_quantile: float = 0.15,
|
| 27 |
+
segment_epsilon: float = 1e-3,
|
| 28 |
+
salient_proportion_threshold: float = 0.5,
|
| 29 |
+
target_lufs: float = -24,
|
| 30 |
+
) -> None:
|
| 31 |
+
super().__init__()
|
| 32 |
+
|
| 33 |
+
self.fs = fs
|
| 34 |
+
self.segment_length = int(segment_length_second * self.fs)
|
| 35 |
+
self.hop_length = int(hop_length_second * self.fs)
|
| 36 |
+
self.n_chunks = n_chunks
|
| 37 |
+
assert self.segment_length % self.n_chunks == 0
|
| 38 |
+
self.chunk_size = self.segment_length // self.n_chunks
|
| 39 |
+
self.chunk_epsilon = chunk_epsilon
|
| 40 |
+
self.energy_threshold_quantile = energy_threshold_quantile
|
| 41 |
+
self.segment_epsilon = segment_epsilon
|
| 42 |
+
self.salient_proportion_threshold = salient_proportion_threshold
|
| 43 |
+
self.analysis_stem = analysis_stem
|
| 44 |
+
|
| 45 |
+
self.meter = pyln.Meter(self.fs)
|
| 46 |
+
self.target_lufs = target_lufs
|
| 47 |
+
|
| 48 |
+
self.output_path = output_path
|
| 49 |
+
|
| 50 |
+
def forward(self, data: DataDict) -> None:
|
| 51 |
+
|
| 52 |
+
stem_ = self.analysis_stem if (self.analysis_stem != "none") else "mixture"
|
| 53 |
+
|
| 54 |
+
x = data["audio"][stem_]
|
| 55 |
+
|
| 56 |
+
xnp = x.numpy()
|
| 57 |
+
loudness = self.meter.integrated_loudness(xnp.T)
|
| 58 |
+
|
| 59 |
+
for stem in data["audio"]:
|
| 60 |
+
s = data["audio"][stem]
|
| 61 |
+
s = pyln.normalize.loudness(s.numpy().T, loudness, self.target_lufs).T
|
| 62 |
+
s = torch.as_tensor(s)
|
| 63 |
+
data["audio"][stem] = s
|
| 64 |
+
|
| 65 |
+
if x.ndim == 3:
|
| 66 |
+
assert x.shape[0] == 1
|
| 67 |
+
x = x[0]
|
| 68 |
+
|
| 69 |
+
n_chan, n_samples = x.shape
|
| 70 |
+
|
| 71 |
+
n_segments = (
|
| 72 |
+
int(np.ceil((n_samples - self.segment_length) / self.hop_length)) + 1
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
segments = torch.zeros((n_segments, n_chan, self.segment_length))
|
| 76 |
+
for i in range(n_segments):
|
| 77 |
+
start = i * self.hop_length
|
| 78 |
+
end = start + self.segment_length
|
| 79 |
+
end = min(end, n_samples)
|
| 80 |
+
|
| 81 |
+
xseg = x[:, start:end]
|
| 82 |
+
|
| 83 |
+
if end - start < self.segment_length:
|
| 84 |
+
xseg = F.pad(
|
| 85 |
+
xseg, pad=(0, self.segment_length - (end - start)), value=torch.nan
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
segments[i, :, :] = xseg
|
| 89 |
+
|
| 90 |
+
chunks = segments.reshape((n_segments, n_chan, self.n_chunks, self.chunk_size))
|
| 91 |
+
|
| 92 |
+
if self.analysis_stem != "none":
|
| 93 |
+
chunk_energies = torch.mean(torch.square(chunks), dim=(1, 3))
|
| 94 |
+
chunk_energies = torch.nan_to_num(chunk_energies, nan=0)
|
| 95 |
+
chunk_energies[chunk_energies == 0] = self.chunk_epsilon
|
| 96 |
+
|
| 97 |
+
energy_threshold = torch.nanquantile(
|
| 98 |
+
chunk_energies, q=self.energy_threshold_quantile
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
if energy_threshold < self.segment_epsilon:
|
| 102 |
+
energy_threshold = self.segment_epsilon # type: ignore[assignment]
|
| 103 |
+
|
| 104 |
+
chunks_above_threshold = chunk_energies > energy_threshold
|
| 105 |
+
n_chunks_above_threshold = torch.mean(
|
| 106 |
+
chunks_above_threshold.to(torch.float), dim=-1
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
segment_above_threshold = (
|
| 110 |
+
n_chunks_above_threshold > self.salient_proportion_threshold
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
if torch.sum(segment_above_threshold) == 0:
|
| 114 |
+
return
|
| 115 |
+
|
| 116 |
+
else:
|
| 117 |
+
segment_above_threshold = torch.ones((n_segments,))
|
| 118 |
+
|
| 119 |
+
for i in range(n_segments):
|
| 120 |
+
if not segment_above_threshold[i]:
|
| 121 |
+
continue
|
| 122 |
+
|
| 123 |
+
outpath = os.path.join(
|
| 124 |
+
self.output_path,
|
| 125 |
+
self.analysis_stem,
|
| 126 |
+
f"{data['track']} - {self.analysis_stem}{i:03d}",
|
| 127 |
+
)
|
| 128 |
+
os.makedirs(outpath, exist_ok=True)
|
| 129 |
+
|
| 130 |
+
for stem in data["audio"]:
|
| 131 |
+
if stem == self.analysis_stem:
|
| 132 |
+
segment = torch.nan_to_num(segments[i, :, :], nan=0)
|
| 133 |
+
else:
|
| 134 |
+
start = i * self.hop_length
|
| 135 |
+
end = start + self.segment_length
|
| 136 |
+
end = min(n_samples, end)
|
| 137 |
+
|
| 138 |
+
segment = data["audio"][stem][:, start:end]
|
| 139 |
+
|
| 140 |
+
if end - start < self.segment_length:
|
| 141 |
+
segment = F.pad(
|
| 142 |
+
segment, (0, self.segment_length - (end - start))
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
assert segment.shape[-1] == self.segment_length, segment.shape
|
| 146 |
+
|
| 147 |
+
np.save(os.path.join(outpath, f"{stem}.wav"), segment)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def preprocess(
|
| 151 |
+
analysis_stem: str,
|
| 152 |
+
output_path: str = "/data/MUSDB18/HQ/saded-np",
|
| 153 |
+
fs: int = 44100,
|
| 154 |
+
segment_length_second: float = 6.0,
|
| 155 |
+
hop_length_second: float = 3.0,
|
| 156 |
+
n_chunks: int = 10,
|
| 157 |
+
chunk_epsilon: float = 1e-5,
|
| 158 |
+
energy_threshold_quantile: float = 0.15,
|
| 159 |
+
segment_epsilon: float = 1e-3,
|
| 160 |
+
salient_proportion_threshold: float = 0.5,
|
| 161 |
+
) -> None:
|
| 162 |
+
|
| 163 |
+
sad = SourceActivityDetector(
|
| 164 |
+
analysis_stem=analysis_stem,
|
| 165 |
+
output_path=output_path,
|
| 166 |
+
fs=fs,
|
| 167 |
+
segment_length_second=segment_length_second,
|
| 168 |
+
hop_length_second=hop_length_second,
|
| 169 |
+
n_chunks=n_chunks,
|
| 170 |
+
chunk_epsilon=chunk_epsilon,
|
| 171 |
+
energy_threshold_quantile=energy_threshold_quantile,
|
| 172 |
+
segment_epsilon=segment_epsilon,
|
| 173 |
+
salient_proportion_threshold=salient_proportion_threshold,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
for split in ["train", "val", "test"]:
|
| 177 |
+
ds = MUSDB18FullTrackDataset(
|
| 178 |
+
data_root="/data/MUSDB18/HQ/canonical",
|
| 179 |
+
split=split,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
tracks = []
|
| 183 |
+
for i, track in enumerate(tqdm(ds, total=len(ds))):
|
| 184 |
+
if i % 32 == 0 and tracks:
|
| 185 |
+
process_map(sad, tracks, max_workers=8)
|
| 186 |
+
tracks = []
|
| 187 |
+
tracks.append(track)
|
| 188 |
+
process_map(sad, tracks, max_workers=8)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def loudness_norm_one(inputs):
|
| 192 |
+
infile, outfile, target_lufs = inputs
|
| 193 |
+
|
| 194 |
+
audio, fs = ta.load(infile)
|
| 195 |
+
audio = audio.mean(dim=0, keepdim=True).numpy().T
|
| 196 |
+
|
| 197 |
+
meter = pyln.Meter(fs)
|
| 198 |
+
loudness = meter.integrated_loudness(audio)
|
| 199 |
+
audio = pyln.normalize.loudness(audio, loudness, target_lufs)
|
| 200 |
+
|
| 201 |
+
os.makedirs(os.path.dirname(outfile), exist_ok=True)
|
| 202 |
+
np.save(outfile, audio.T)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def loudness_norm(
|
| 206 |
+
data_path: str,
|
| 207 |
+
target_lufs=-17.0,
|
| 208 |
+
):
|
| 209 |
+
files = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True)
|
| 210 |
+
|
| 211 |
+
outfiles = [f.replace(".wav", ".npy").replace("saded", "saded-np") for f in files]
|
| 212 |
+
|
| 213 |
+
files = [(f, o, target_lufs) for f, o in zip(files, outfiles)]
|
| 214 |
+
|
| 215 |
+
process_map(loudness_norm_one, files, chunksize=2)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
if __name__ == "__main__":
|
| 219 |
+
|
| 220 |
+
from tqdm.auto import tqdm
|
| 221 |
+
import fire
|
| 222 |
+
|
| 223 |
+
fire.Fire()
|
mvsepless/models/bandit/core/data/musdb/validation.yaml
CHANGED
|
@@ -1,15 +1,15 @@
|
|
| 1 |
-
validation:
|
| 2 |
-
- 'Actions - One Minute Smile'
|
| 3 |
-
- 'Clara Berry And Wooldog - Waltz For My Victims'
|
| 4 |
-
- 'Johnny Lokke - Promises & Lies'
|
| 5 |
-
- 'Patrick Talbot - A Reason To Leave'
|
| 6 |
-
- 'Triviul - Angelsaint'
|
| 7 |
-
- 'Alexander Ross - Goodbye Bolero'
|
| 8 |
-
- 'Fergessen - Nos Palpitants'
|
| 9 |
-
- 'Leaf - Summerghost'
|
| 10 |
-
- 'Skelpolu - Human Mistakes'
|
| 11 |
-
- 'Young Griffo - Pennies'
|
| 12 |
-
- 'ANiMAL - Rockshow'
|
| 13 |
-
- 'James May - On The Line'
|
| 14 |
-
- 'Meaxic - Take A Step'
|
| 15 |
- 'Traffic Experiment - Sirens'
|
|
|
|
| 1 |
+
validation:
|
| 2 |
+
- 'Actions - One Minute Smile'
|
| 3 |
+
- 'Clara Berry And Wooldog - Waltz For My Victims'
|
| 4 |
+
- 'Johnny Lokke - Promises & Lies'
|
| 5 |
+
- 'Patrick Talbot - A Reason To Leave'
|
| 6 |
+
- 'Triviul - Angelsaint'
|
| 7 |
+
- 'Alexander Ross - Goodbye Bolero'
|
| 8 |
+
- 'Fergessen - Nos Palpitants'
|
| 9 |
+
- 'Leaf - Summerghost'
|
| 10 |
+
- 'Skelpolu - Human Mistakes'
|
| 11 |
+
- 'Young Griffo - Pennies'
|
| 12 |
+
- 'ANiMAL - Rockshow'
|
| 13 |
+
- 'James May - On The Line'
|
| 14 |
+
- 'Meaxic - Take A Step'
|
| 15 |
- 'Traffic Experiment - Sirens'
|
mvsepless/models/bandit/core/loss/__init__.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
-
from ._multistem import MultiStemWrapperFromConfig
|
| 2 |
-
from ._timefreq import (
|
| 3 |
-
ReImL1Loss,
|
| 4 |
-
ReImL2Loss,
|
| 5 |
-
TimeFreqL1Loss,
|
| 6 |
-
TimeFreqL2Loss,
|
| 7 |
-
TimeFreqSignalNoisePNormRatioLoss,
|
| 8 |
-
)
|
|
|
|
| 1 |
+
from ._multistem import MultiStemWrapperFromConfig
|
| 2 |
+
from ._timefreq import (
|
| 3 |
+
ReImL1Loss,
|
| 4 |
+
ReImL2Loss,
|
| 5 |
+
TimeFreqL1Loss,
|
| 6 |
+
TimeFreqL2Loss,
|
| 7 |
+
TimeFreqSignalNoisePNormRatioLoss,
|
| 8 |
+
)
|
mvsepless/models/bandit/core/loss/_complex.py
CHANGED
|
@@ -1,27 +1,27 @@
|
|
| 1 |
-
from typing import Any
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
from torch import nn
|
| 5 |
-
from torch.nn.modules import loss as _loss
|
| 6 |
-
from torch.nn.modules.loss import _Loss
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
class ReImLossWrapper(_Loss):
|
| 10 |
-
def __init__(self, module: _Loss) -> None:
|
| 11 |
-
super().__init__()
|
| 12 |
-
self.module = module
|
| 13 |
-
|
| 14 |
-
def forward(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 15 |
-
return self.module(torch.view_as_real(preds), torch.view_as_real(target))
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
class ReImL1Loss(ReImLossWrapper):
|
| 19 |
-
def __init__(self, **kwargs: Any) -> None:
|
| 20 |
-
l1_loss = _loss.L1Loss(**kwargs)
|
| 21 |
-
super().__init__(module=(l1_loss))
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
class ReImL2Loss(ReImLossWrapper):
|
| 25 |
-
def __init__(self, **kwargs: Any) -> None:
|
| 26 |
-
l2_loss = _loss.MSELoss(**kwargs)
|
| 27 |
-
super().__init__(module=(l2_loss))
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn.modules import loss as _loss
|
| 6 |
+
from torch.nn.modules.loss import _Loss
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ReImLossWrapper(_Loss):
|
| 10 |
+
def __init__(self, module: _Loss) -> None:
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.module = module
|
| 13 |
+
|
| 14 |
+
def forward(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 15 |
+
return self.module(torch.view_as_real(preds), torch.view_as_real(target))
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ReImL1Loss(ReImLossWrapper):
|
| 19 |
+
def __init__(self, **kwargs: Any) -> None:
|
| 20 |
+
l1_loss = _loss.L1Loss(**kwargs)
|
| 21 |
+
super().__init__(module=(l1_loss))
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ReImL2Loss(ReImLossWrapper):
|
| 25 |
+
def __init__(self, **kwargs: Any) -> None:
|
| 26 |
+
l2_loss = _loss.MSELoss(**kwargs)
|
| 27 |
+
super().__init__(module=(l2_loss))
|
mvsepless/models/bandit/core/loss/_multistem.py
CHANGED
|
@@ -1,43 +1,43 @@
|
|
| 1 |
-
from typing import Any, Dict
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
from asteroid import losses as asteroid_losses
|
| 5 |
-
from torch import nn
|
| 6 |
-
from torch.nn.modules.loss import _Loss
|
| 7 |
-
|
| 8 |
-
from . import snr
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def parse_loss(name: str, kwargs: Dict[str, Any]) -> _Loss:
|
| 12 |
-
|
| 13 |
-
for module in [nn.modules.loss, snr, asteroid_losses, asteroid_losses.sdr]:
|
| 14 |
-
if name in module.__dict__:
|
| 15 |
-
return module.__dict__[name](**kwargs)
|
| 16 |
-
|
| 17 |
-
raise NameError
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
class MultiStemWrapper(_Loss):
|
| 21 |
-
def __init__(self, module: _Loss, modality: str = "audio") -> None:
|
| 22 |
-
super().__init__()
|
| 23 |
-
self.loss = module
|
| 24 |
-
self.modality = modality
|
| 25 |
-
|
| 26 |
-
def forward(
|
| 27 |
-
self,
|
| 28 |
-
preds: Dict[str, Dict[str, torch.Tensor]],
|
| 29 |
-
target: Dict[str, Dict[str, torch.Tensor]],
|
| 30 |
-
) -> torch.Tensor:
|
| 31 |
-
loss = {
|
| 32 |
-
stem: self.loss(preds[self.modality][stem], target[self.modality][stem])
|
| 33 |
-
for stem in preds[self.modality]
|
| 34 |
-
if stem in target[self.modality]
|
| 35 |
-
}
|
| 36 |
-
|
| 37 |
-
return sum(list(loss.values()))
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
class MultiStemWrapperFromConfig(MultiStemWrapper):
|
| 41 |
-
def __init__(self, name: str, kwargs: Any, modality: str = "audio") -> None:
|
| 42 |
-
loss = parse_loss(name, kwargs)
|
| 43 |
-
super().__init__(module=loss, modality=modality)
|
|
|
|
| 1 |
+
from typing import Any, Dict
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from asteroid import losses as asteroid_losses
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.nn.modules.loss import _Loss
|
| 7 |
+
|
| 8 |
+
from . import snr
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def parse_loss(name: str, kwargs: Dict[str, Any]) -> _Loss:
|
| 12 |
+
|
| 13 |
+
for module in [nn.modules.loss, snr, asteroid_losses, asteroid_losses.sdr]:
|
| 14 |
+
if name in module.__dict__:
|
| 15 |
+
return module.__dict__[name](**kwargs)
|
| 16 |
+
|
| 17 |
+
raise NameError
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class MultiStemWrapper(_Loss):
|
| 21 |
+
def __init__(self, module: _Loss, modality: str = "audio") -> None:
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.loss = module
|
| 24 |
+
self.modality = modality
|
| 25 |
+
|
| 26 |
+
def forward(
|
| 27 |
+
self,
|
| 28 |
+
preds: Dict[str, Dict[str, torch.Tensor]],
|
| 29 |
+
target: Dict[str, Dict[str, torch.Tensor]],
|
| 30 |
+
) -> torch.Tensor:
|
| 31 |
+
loss = {
|
| 32 |
+
stem: self.loss(preds[self.modality][stem], target[self.modality][stem])
|
| 33 |
+
for stem in preds[self.modality]
|
| 34 |
+
if stem in target[self.modality]
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
return sum(list(loss.values()))
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class MultiStemWrapperFromConfig(MultiStemWrapper):
|
| 41 |
+
def __init__(self, name: str, kwargs: Any, modality: str = "audio") -> None:
|
| 42 |
+
loss = parse_loss(name, kwargs)
|
| 43 |
+
super().__init__(module=loss, modality=modality)
|
mvsepless/models/bandit/core/loss/_timefreq.py
CHANGED
|
@@ -1,94 +1,94 @@
|
|
| 1 |
-
from typing import Any, Dict, Optional
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
from torch import nn
|
| 5 |
-
from torch.nn.modules.loss import _Loss
|
| 6 |
-
|
| 7 |
-
from ._multistem import MultiStemWrapper
|
| 8 |
-
from ._complex import ReImL1Loss, ReImL2Loss, ReImLossWrapper
|
| 9 |
-
from .snr import SignalNoisePNormRatio
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class TimeFreqWrapper(_Loss):
|
| 13 |
-
def __init__(
|
| 14 |
-
self,
|
| 15 |
-
time_module: _Loss,
|
| 16 |
-
freq_module: Optional[_Loss] = None,
|
| 17 |
-
time_weight: float = 1.0,
|
| 18 |
-
freq_weight: float = 1.0,
|
| 19 |
-
multistem: bool = True,
|
| 20 |
-
) -> None:
|
| 21 |
-
super().__init__()
|
| 22 |
-
|
| 23 |
-
if freq_module is None:
|
| 24 |
-
freq_module = time_module
|
| 25 |
-
|
| 26 |
-
if multistem:
|
| 27 |
-
time_module = MultiStemWrapper(time_module, modality="audio")
|
| 28 |
-
freq_module = MultiStemWrapper(freq_module, modality="spectrogram")
|
| 29 |
-
|
| 30 |
-
self.time_module = time_module
|
| 31 |
-
self.freq_module = freq_module
|
| 32 |
-
|
| 33 |
-
self.time_weight = time_weight
|
| 34 |
-
self.freq_weight = freq_weight
|
| 35 |
-
|
| 36 |
-
def forward(self, preds: Any, target: Any) -> torch.Tensor:
|
| 37 |
-
|
| 38 |
-
return self.time_weight * self.time_module(
|
| 39 |
-
preds, target
|
| 40 |
-
) + self.freq_weight * self.freq_module(preds, target)
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
class TimeFreqL1Loss(TimeFreqWrapper):
|
| 44 |
-
def __init__(
|
| 45 |
-
self,
|
| 46 |
-
time_weight: float = 1.0,
|
| 47 |
-
freq_weight: float = 1.0,
|
| 48 |
-
tkwargs: Optional[Dict[str, Any]] = None,
|
| 49 |
-
fkwargs: Optional[Dict[str, Any]] = None,
|
| 50 |
-
multistem: bool = True,
|
| 51 |
-
) -> None:
|
| 52 |
-
if tkwargs is None:
|
| 53 |
-
tkwargs = {}
|
| 54 |
-
if fkwargs is None:
|
| 55 |
-
fkwargs = {}
|
| 56 |
-
time_module = nn.L1Loss(**tkwargs)
|
| 57 |
-
freq_module = ReImL1Loss(**fkwargs)
|
| 58 |
-
super().__init__(time_module, freq_module, time_weight, freq_weight, multistem)
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
class TimeFreqL2Loss(TimeFreqWrapper):
|
| 62 |
-
def __init__(
|
| 63 |
-
self,
|
| 64 |
-
time_weight: float = 1.0,
|
| 65 |
-
freq_weight: float = 1.0,
|
| 66 |
-
tkwargs: Optional[Dict[str, Any]] = None,
|
| 67 |
-
fkwargs: Optional[Dict[str, Any]] = None,
|
| 68 |
-
multistem: bool = True,
|
| 69 |
-
) -> None:
|
| 70 |
-
if tkwargs is None:
|
| 71 |
-
tkwargs = {}
|
| 72 |
-
if fkwargs is None:
|
| 73 |
-
fkwargs = {}
|
| 74 |
-
time_module = nn.MSELoss(**tkwargs)
|
| 75 |
-
freq_module = ReImL2Loss(**fkwargs)
|
| 76 |
-
super().__init__(time_module, freq_module, time_weight, freq_weight, multistem)
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
class TimeFreqSignalNoisePNormRatioLoss(TimeFreqWrapper):
|
| 80 |
-
def __init__(
|
| 81 |
-
self,
|
| 82 |
-
time_weight: float = 1.0,
|
| 83 |
-
freq_weight: float = 1.0,
|
| 84 |
-
tkwargs: Optional[Dict[str, Any]] = None,
|
| 85 |
-
fkwargs: Optional[Dict[str, Any]] = None,
|
| 86 |
-
multistem: bool = True,
|
| 87 |
-
) -> None:
|
| 88 |
-
if tkwargs is None:
|
| 89 |
-
tkwargs = {}
|
| 90 |
-
if fkwargs is None:
|
| 91 |
-
fkwargs = {}
|
| 92 |
-
time_module = SignalNoisePNormRatio(**tkwargs)
|
| 93 |
-
freq_module = SignalNoisePNormRatio(**fkwargs)
|
| 94 |
-
super().__init__(time_module, freq_module, time_weight, freq_weight, multistem)
|
|
|
|
| 1 |
+
from typing import Any, Dict, Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn.modules.loss import _Loss
|
| 6 |
+
|
| 7 |
+
from ._multistem import MultiStemWrapper
|
| 8 |
+
from ._complex import ReImL1Loss, ReImL2Loss, ReImLossWrapper
|
| 9 |
+
from .snr import SignalNoisePNormRatio
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TimeFreqWrapper(_Loss):
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
time_module: _Loss,
|
| 16 |
+
freq_module: Optional[_Loss] = None,
|
| 17 |
+
time_weight: float = 1.0,
|
| 18 |
+
freq_weight: float = 1.0,
|
| 19 |
+
multistem: bool = True,
|
| 20 |
+
) -> None:
|
| 21 |
+
super().__init__()
|
| 22 |
+
|
| 23 |
+
if freq_module is None:
|
| 24 |
+
freq_module = time_module
|
| 25 |
+
|
| 26 |
+
if multistem:
|
| 27 |
+
time_module = MultiStemWrapper(time_module, modality="audio")
|
| 28 |
+
freq_module = MultiStemWrapper(freq_module, modality="spectrogram")
|
| 29 |
+
|
| 30 |
+
self.time_module = time_module
|
| 31 |
+
self.freq_module = freq_module
|
| 32 |
+
|
| 33 |
+
self.time_weight = time_weight
|
| 34 |
+
self.freq_weight = freq_weight
|
| 35 |
+
|
| 36 |
+
def forward(self, preds: Any, target: Any) -> torch.Tensor:
|
| 37 |
+
|
| 38 |
+
return self.time_weight * self.time_module(
|
| 39 |
+
preds, target
|
| 40 |
+
) + self.freq_weight * self.freq_module(preds, target)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class TimeFreqL1Loss(TimeFreqWrapper):
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
time_weight: float = 1.0,
|
| 47 |
+
freq_weight: float = 1.0,
|
| 48 |
+
tkwargs: Optional[Dict[str, Any]] = None,
|
| 49 |
+
fkwargs: Optional[Dict[str, Any]] = None,
|
| 50 |
+
multistem: bool = True,
|
| 51 |
+
) -> None:
|
| 52 |
+
if tkwargs is None:
|
| 53 |
+
tkwargs = {}
|
| 54 |
+
if fkwargs is None:
|
| 55 |
+
fkwargs = {}
|
| 56 |
+
time_module = nn.L1Loss(**tkwargs)
|
| 57 |
+
freq_module = ReImL1Loss(**fkwargs)
|
| 58 |
+
super().__init__(time_module, freq_module, time_weight, freq_weight, multistem)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class TimeFreqL2Loss(TimeFreqWrapper):
|
| 62 |
+
def __init__(
|
| 63 |
+
self,
|
| 64 |
+
time_weight: float = 1.0,
|
| 65 |
+
freq_weight: float = 1.0,
|
| 66 |
+
tkwargs: Optional[Dict[str, Any]] = None,
|
| 67 |
+
fkwargs: Optional[Dict[str, Any]] = None,
|
| 68 |
+
multistem: bool = True,
|
| 69 |
+
) -> None:
|
| 70 |
+
if tkwargs is None:
|
| 71 |
+
tkwargs = {}
|
| 72 |
+
if fkwargs is None:
|
| 73 |
+
fkwargs = {}
|
| 74 |
+
time_module = nn.MSELoss(**tkwargs)
|
| 75 |
+
freq_module = ReImL2Loss(**fkwargs)
|
| 76 |
+
super().__init__(time_module, freq_module, time_weight, freq_weight, multistem)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class TimeFreqSignalNoisePNormRatioLoss(TimeFreqWrapper):
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
time_weight: float = 1.0,
|
| 83 |
+
freq_weight: float = 1.0,
|
| 84 |
+
tkwargs: Optional[Dict[str, Any]] = None,
|
| 85 |
+
fkwargs: Optional[Dict[str, Any]] = None,
|
| 86 |
+
multistem: bool = True,
|
| 87 |
+
) -> None:
|
| 88 |
+
if tkwargs is None:
|
| 89 |
+
tkwargs = {}
|
| 90 |
+
if fkwargs is None:
|
| 91 |
+
fkwargs = {}
|
| 92 |
+
time_module = SignalNoisePNormRatio(**tkwargs)
|
| 93 |
+
freq_module = SignalNoisePNormRatio(**fkwargs)
|
| 94 |
+
super().__init__(time_module, freq_module, time_weight, freq_weight, multistem)
|
mvsepless/models/bandit/core/loss/snr.py
CHANGED
|
@@ -1,131 +1,131 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from torch.nn.modules.loss import _Loss
|
| 3 |
-
from torch.nn import functional as F
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
class SignalNoisePNormRatio(_Loss):
|
| 7 |
-
def __init__(
|
| 8 |
-
self,
|
| 9 |
-
p: float = 1.0,
|
| 10 |
-
scale_invariant: bool = False,
|
| 11 |
-
zero_mean: bool = False,
|
| 12 |
-
take_log: bool = True,
|
| 13 |
-
reduction: str = "mean",
|
| 14 |
-
EPS: float = 1e-3,
|
| 15 |
-
) -> None:
|
| 16 |
-
assert reduction != "sum", NotImplementedError
|
| 17 |
-
super().__init__(reduction=reduction)
|
| 18 |
-
assert not zero_mean
|
| 19 |
-
|
| 20 |
-
self.p = p
|
| 21 |
-
|
| 22 |
-
self.EPS = EPS
|
| 23 |
-
self.take_log = take_log
|
| 24 |
-
|
| 25 |
-
self.scale_invariant = scale_invariant
|
| 26 |
-
|
| 27 |
-
def forward(self, est_target: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 28 |
-
|
| 29 |
-
target_ = target
|
| 30 |
-
if self.scale_invariant:
|
| 31 |
-
ndim = target.ndim
|
| 32 |
-
dot = torch.sum(est_target * torch.conj(target), dim=-1, keepdim=True)
|
| 33 |
-
s_target_energy = torch.sum(
|
| 34 |
-
target * torch.conj(target), dim=-1, keepdim=True
|
| 35 |
-
)
|
| 36 |
-
|
| 37 |
-
if ndim > 2:
|
| 38 |
-
dot = torch.sum(dot, dim=list(range(1, ndim)), keepdim=True)
|
| 39 |
-
s_target_energy = torch.sum(
|
| 40 |
-
s_target_energy, dim=list(range(1, ndim)), keepdim=True
|
| 41 |
-
)
|
| 42 |
-
|
| 43 |
-
target_scaler = (dot + 1e-8) / (s_target_energy + 1e-8)
|
| 44 |
-
target = target_ * target_scaler
|
| 45 |
-
|
| 46 |
-
if torch.is_complex(est_target):
|
| 47 |
-
est_target = torch.view_as_real(est_target)
|
| 48 |
-
target = torch.view_as_real(target)
|
| 49 |
-
|
| 50 |
-
batch_size = est_target.shape[0]
|
| 51 |
-
est_target = est_target.reshape(batch_size, -1)
|
| 52 |
-
target = target.reshape(batch_size, -1)
|
| 53 |
-
|
| 54 |
-
if self.p == 1:
|
| 55 |
-
e_error = torch.abs(est_target - target).mean(dim=-1)
|
| 56 |
-
e_target = torch.abs(target).mean(dim=-1)
|
| 57 |
-
elif self.p == 2:
|
| 58 |
-
e_error = torch.square(est_target - target).mean(dim=-1)
|
| 59 |
-
e_target = torch.square(target).mean(dim=-1)
|
| 60 |
-
else:
|
| 61 |
-
raise NotImplementedError
|
| 62 |
-
|
| 63 |
-
if self.take_log:
|
| 64 |
-
loss = 10 * (
|
| 65 |
-
torch.log10(e_error + self.EPS) - torch.log10(e_target + self.EPS)
|
| 66 |
-
)
|
| 67 |
-
else:
|
| 68 |
-
loss = (e_error + self.EPS) / (e_target + self.EPS)
|
| 69 |
-
|
| 70 |
-
if self.reduction == "mean":
|
| 71 |
-
loss = loss.mean()
|
| 72 |
-
elif self.reduction == "sum":
|
| 73 |
-
loss = loss.sum()
|
| 74 |
-
|
| 75 |
-
return loss
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
class MultichannelSingleSrcNegSDR(_Loss):
|
| 79 |
-
def __init__(
|
| 80 |
-
self,
|
| 81 |
-
sdr_type: str,
|
| 82 |
-
p: float = 2.0,
|
| 83 |
-
zero_mean: bool = True,
|
| 84 |
-
take_log: bool = True,
|
| 85 |
-
reduction: str = "mean",
|
| 86 |
-
EPS: float = 1e-8,
|
| 87 |
-
) -> None:
|
| 88 |
-
assert reduction != "sum", NotImplementedError
|
| 89 |
-
super().__init__(reduction=reduction)
|
| 90 |
-
|
| 91 |
-
assert sdr_type in ["snr", "sisdr", "sdsdr"]
|
| 92 |
-
self.sdr_type = sdr_type
|
| 93 |
-
self.zero_mean = zero_mean
|
| 94 |
-
self.take_log = take_log
|
| 95 |
-
self.EPS = 1e-8
|
| 96 |
-
|
| 97 |
-
self.p = p
|
| 98 |
-
|
| 99 |
-
def forward(self, est_target: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 100 |
-
if target.size() != est_target.size() or target.ndim != 3:
|
| 101 |
-
raise TypeError(
|
| 102 |
-
f"Inputs must be of shape [batch, time], got {target.size()} and {est_target.size()} instead"
|
| 103 |
-
)
|
| 104 |
-
if self.zero_mean:
|
| 105 |
-
mean_source = torch.mean(target, dim=[1, 2], keepdim=True)
|
| 106 |
-
mean_estimate = torch.mean(est_target, dim=[1, 2], keepdim=True)
|
| 107 |
-
target = target - mean_source
|
| 108 |
-
est_target = est_target - mean_estimate
|
| 109 |
-
if self.sdr_type in ["sisdr", "sdsdr"]:
|
| 110 |
-
dot = torch.sum(est_target * target, dim=[1, 2], keepdim=True)
|
| 111 |
-
s_target_energy = torch.sum(target**2, dim=[1, 2], keepdim=True) + self.EPS
|
| 112 |
-
scaled_target = dot * target / s_target_energy
|
| 113 |
-
else:
|
| 114 |
-
scaled_target = target
|
| 115 |
-
if self.sdr_type in ["sdsdr", "snr"]:
|
| 116 |
-
e_noise = est_target - target
|
| 117 |
-
else:
|
| 118 |
-
e_noise = est_target - scaled_target
|
| 119 |
-
|
| 120 |
-
if self.p == 2.0:
|
| 121 |
-
losses = torch.sum(scaled_target**2, dim=[1, 2]) / (
|
| 122 |
-
torch.sum(e_noise**2, dim=[1, 2]) + self.EPS
|
| 123 |
-
)
|
| 124 |
-
else:
|
| 125 |
-
losses = torch.norm(scaled_target, p=self.p, dim=[1, 2]) / (
|
| 126 |
-
torch.linalg.vector_norm(e_noise, p=self.p, dim=[1, 2]) + self.EPS
|
| 127 |
-
)
|
| 128 |
-
if self.take_log:
|
| 129 |
-
losses = 10 * torch.log10(losses + self.EPS)
|
| 130 |
-
losses = losses.mean() if self.reduction == "mean" else losses
|
| 131 |
-
return -losses
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.nn.modules.loss import _Loss
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class SignalNoisePNormRatio(_Loss):
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
p: float = 1.0,
|
| 10 |
+
scale_invariant: bool = False,
|
| 11 |
+
zero_mean: bool = False,
|
| 12 |
+
take_log: bool = True,
|
| 13 |
+
reduction: str = "mean",
|
| 14 |
+
EPS: float = 1e-3,
|
| 15 |
+
) -> None:
|
| 16 |
+
assert reduction != "sum", NotImplementedError
|
| 17 |
+
super().__init__(reduction=reduction)
|
| 18 |
+
assert not zero_mean
|
| 19 |
+
|
| 20 |
+
self.p = p
|
| 21 |
+
|
| 22 |
+
self.EPS = EPS
|
| 23 |
+
self.take_log = take_log
|
| 24 |
+
|
| 25 |
+
self.scale_invariant = scale_invariant
|
| 26 |
+
|
| 27 |
+
def forward(self, est_target: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 28 |
+
|
| 29 |
+
target_ = target
|
| 30 |
+
if self.scale_invariant:
|
| 31 |
+
ndim = target.ndim
|
| 32 |
+
dot = torch.sum(est_target * torch.conj(target), dim=-1, keepdim=True)
|
| 33 |
+
s_target_energy = torch.sum(
|
| 34 |
+
target * torch.conj(target), dim=-1, keepdim=True
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
if ndim > 2:
|
| 38 |
+
dot = torch.sum(dot, dim=list(range(1, ndim)), keepdim=True)
|
| 39 |
+
s_target_energy = torch.sum(
|
| 40 |
+
s_target_energy, dim=list(range(1, ndim)), keepdim=True
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
target_scaler = (dot + 1e-8) / (s_target_energy + 1e-8)
|
| 44 |
+
target = target_ * target_scaler
|
| 45 |
+
|
| 46 |
+
if torch.is_complex(est_target):
|
| 47 |
+
est_target = torch.view_as_real(est_target)
|
| 48 |
+
target = torch.view_as_real(target)
|
| 49 |
+
|
| 50 |
+
batch_size = est_target.shape[0]
|
| 51 |
+
est_target = est_target.reshape(batch_size, -1)
|
| 52 |
+
target = target.reshape(batch_size, -1)
|
| 53 |
+
|
| 54 |
+
if self.p == 1:
|
| 55 |
+
e_error = torch.abs(est_target - target).mean(dim=-1)
|
| 56 |
+
e_target = torch.abs(target).mean(dim=-1)
|
| 57 |
+
elif self.p == 2:
|
| 58 |
+
e_error = torch.square(est_target - target).mean(dim=-1)
|
| 59 |
+
e_target = torch.square(target).mean(dim=-1)
|
| 60 |
+
else:
|
| 61 |
+
raise NotImplementedError
|
| 62 |
+
|
| 63 |
+
if self.take_log:
|
| 64 |
+
loss = 10 * (
|
| 65 |
+
torch.log10(e_error + self.EPS) - torch.log10(e_target + self.EPS)
|
| 66 |
+
)
|
| 67 |
+
else:
|
| 68 |
+
loss = (e_error + self.EPS) / (e_target + self.EPS)
|
| 69 |
+
|
| 70 |
+
if self.reduction == "mean":
|
| 71 |
+
loss = loss.mean()
|
| 72 |
+
elif self.reduction == "sum":
|
| 73 |
+
loss = loss.sum()
|
| 74 |
+
|
| 75 |
+
return loss
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class MultichannelSingleSrcNegSDR(_Loss):
|
| 79 |
+
def __init__(
|
| 80 |
+
self,
|
| 81 |
+
sdr_type: str,
|
| 82 |
+
p: float = 2.0,
|
| 83 |
+
zero_mean: bool = True,
|
| 84 |
+
take_log: bool = True,
|
| 85 |
+
reduction: str = "mean",
|
| 86 |
+
EPS: float = 1e-8,
|
| 87 |
+
) -> None:
|
| 88 |
+
assert reduction != "sum", NotImplementedError
|
| 89 |
+
super().__init__(reduction=reduction)
|
| 90 |
+
|
| 91 |
+
assert sdr_type in ["snr", "sisdr", "sdsdr"]
|
| 92 |
+
self.sdr_type = sdr_type
|
| 93 |
+
self.zero_mean = zero_mean
|
| 94 |
+
self.take_log = take_log
|
| 95 |
+
self.EPS = 1e-8
|
| 96 |
+
|
| 97 |
+
self.p = p
|
| 98 |
+
|
| 99 |
+
def forward(self, est_target: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 100 |
+
if target.size() != est_target.size() or target.ndim != 3:
|
| 101 |
+
raise TypeError(
|
| 102 |
+
f"Inputs must be of shape [batch, time], got {target.size()} and {est_target.size()} instead"
|
| 103 |
+
)
|
| 104 |
+
if self.zero_mean:
|
| 105 |
+
mean_source = torch.mean(target, dim=[1, 2], keepdim=True)
|
| 106 |
+
mean_estimate = torch.mean(est_target, dim=[1, 2], keepdim=True)
|
| 107 |
+
target = target - mean_source
|
| 108 |
+
est_target = est_target - mean_estimate
|
| 109 |
+
if self.sdr_type in ["sisdr", "sdsdr"]:
|
| 110 |
+
dot = torch.sum(est_target * target, dim=[1, 2], keepdim=True)
|
| 111 |
+
s_target_energy = torch.sum(target**2, dim=[1, 2], keepdim=True) + self.EPS
|
| 112 |
+
scaled_target = dot * target / s_target_energy
|
| 113 |
+
else:
|
| 114 |
+
scaled_target = target
|
| 115 |
+
if self.sdr_type in ["sdsdr", "snr"]:
|
| 116 |
+
e_noise = est_target - target
|
| 117 |
+
else:
|
| 118 |
+
e_noise = est_target - scaled_target
|
| 119 |
+
|
| 120 |
+
if self.p == 2.0:
|
| 121 |
+
losses = torch.sum(scaled_target**2, dim=[1, 2]) / (
|
| 122 |
+
torch.sum(e_noise**2, dim=[1, 2]) + self.EPS
|
| 123 |
+
)
|
| 124 |
+
else:
|
| 125 |
+
losses = torch.norm(scaled_target, p=self.p, dim=[1, 2]) / (
|
| 126 |
+
torch.linalg.vector_norm(e_noise, p=self.p, dim=[1, 2]) + self.EPS
|
| 127 |
+
)
|
| 128 |
+
if self.take_log:
|
| 129 |
+
losses = 10 * torch.log10(losses + self.EPS)
|
| 130 |
+
losses = losses.mean() if self.reduction == "mean" else losses
|
| 131 |
+
return -losses
|
mvsepless/models/bandit/core/metrics/__init__.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
-
from .snr import (
|
| 2 |
-
ChunkMedianScaleInvariantSignalDistortionRatio,
|
| 3 |
-
ChunkMedianScaleInvariantSignalNoiseRatio,
|
| 4 |
-
ChunkMedianSignalDistortionRatio,
|
| 5 |
-
ChunkMedianSignalNoiseRatio,
|
| 6 |
-
SafeSignalDistortionRatio,
|
| 7 |
-
)
|
|
|
|
| 1 |
+
from .snr import (
|
| 2 |
+
ChunkMedianScaleInvariantSignalDistortionRatio,
|
| 3 |
+
ChunkMedianScaleInvariantSignalNoiseRatio,
|
| 4 |
+
ChunkMedianSignalDistortionRatio,
|
| 5 |
+
ChunkMedianSignalNoiseRatio,
|
| 6 |
+
SafeSignalDistortionRatio,
|
| 7 |
+
)
|
mvsepless/models/bandit/core/metrics/_squim.py
CHANGED
|
@@ -1,350 +1,350 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
-
|
| 3 |
-
from torchaudio._internal import load_state_dict_from_url
|
| 4 |
-
|
| 5 |
-
import math
|
| 6 |
-
from typing import List, Optional, Tuple
|
| 7 |
-
|
| 8 |
-
import torch
|
| 9 |
-
import torch.nn as nn
|
| 10 |
-
import torch.nn.functional as F
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
def transform_wb_pesq_range(x: float) -> float:
|
| 14 |
-
return 0.999 + (4.999 - 0.999) / (1 + math.exp(-1.3669 * x + 3.8224))
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
PESQRange: Tuple[float, float] = (
|
| 18 |
-
1.0,
|
| 19 |
-
transform_wb_pesq_range(4.5),
|
| 20 |
-
)
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
class RangeSigmoid(nn.Module):
|
| 24 |
-
def __init__(self, val_range: Tuple[float, float] = (0.0, 1.0)) -> None:
|
| 25 |
-
super(RangeSigmoid, self).__init__()
|
| 26 |
-
assert isinstance(val_range, tuple) and len(val_range) == 2
|
| 27 |
-
self.val_range: Tuple[float, float] = val_range
|
| 28 |
-
self.sigmoid: nn.modules.Module = nn.Sigmoid()
|
| 29 |
-
|
| 30 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 31 |
-
out = (
|
| 32 |
-
self.sigmoid(x) * (self.val_range[1] - self.val_range[0])
|
| 33 |
-
+ self.val_range[0]
|
| 34 |
-
)
|
| 35 |
-
return out
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
class Encoder(nn.Module):
|
| 39 |
-
|
| 40 |
-
def __init__(self, feat_dim: int = 512, win_len: int = 32) -> None:
|
| 41 |
-
super(Encoder, self).__init__()
|
| 42 |
-
|
| 43 |
-
self.conv1d = nn.Conv1d(1, feat_dim, win_len, stride=win_len // 2, bias=False)
|
| 44 |
-
|
| 45 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 46 |
-
out = x.unsqueeze(dim=1)
|
| 47 |
-
out = F.relu(self.conv1d(out))
|
| 48 |
-
return out
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
class SingleRNN(nn.Module):
|
| 52 |
-
def __init__(
|
| 53 |
-
self, rnn_type: str, input_size: int, hidden_size: int, dropout: float = 0.0
|
| 54 |
-
) -> None:
|
| 55 |
-
super(SingleRNN, self).__init__()
|
| 56 |
-
|
| 57 |
-
self.rnn_type = rnn_type
|
| 58 |
-
self.input_size = input_size
|
| 59 |
-
self.hidden_size = hidden_size
|
| 60 |
-
|
| 61 |
-
self.rnn: nn.modules.Module = getattr(nn, rnn_type)(
|
| 62 |
-
input_size,
|
| 63 |
-
hidden_size,
|
| 64 |
-
1,
|
| 65 |
-
dropout=dropout,
|
| 66 |
-
batch_first=True,
|
| 67 |
-
bidirectional=True,
|
| 68 |
-
)
|
| 69 |
-
|
| 70 |
-
self.proj = nn.Linear(hidden_size * 2, input_size)
|
| 71 |
-
|
| 72 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 73 |
-
out, _ = self.rnn(x)
|
| 74 |
-
out = self.proj(out)
|
| 75 |
-
return out
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
class DPRNN(nn.Module):
|
| 79 |
-
|
| 80 |
-
def __init__(
|
| 81 |
-
self,
|
| 82 |
-
feat_dim: int = 64,
|
| 83 |
-
hidden_dim: int = 128,
|
| 84 |
-
num_blocks: int = 6,
|
| 85 |
-
rnn_type: str = "LSTM",
|
| 86 |
-
d_model: int = 256,
|
| 87 |
-
chunk_size: int = 100,
|
| 88 |
-
chunk_stride: int = 50,
|
| 89 |
-
) -> None:
|
| 90 |
-
super(DPRNN, self).__init__()
|
| 91 |
-
|
| 92 |
-
self.num_blocks = num_blocks
|
| 93 |
-
|
| 94 |
-
self.row_rnn = nn.ModuleList([])
|
| 95 |
-
self.col_rnn = nn.ModuleList([])
|
| 96 |
-
self.row_norm = nn.ModuleList([])
|
| 97 |
-
self.col_norm = nn.ModuleList([])
|
| 98 |
-
for _ in range(num_blocks):
|
| 99 |
-
self.row_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
|
| 100 |
-
self.col_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
|
| 101 |
-
self.row_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
|
| 102 |
-
self.col_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
|
| 103 |
-
self.conv = nn.Sequential(
|
| 104 |
-
nn.Conv2d(feat_dim, d_model, 1),
|
| 105 |
-
nn.PReLU(),
|
| 106 |
-
)
|
| 107 |
-
self.chunk_size = chunk_size
|
| 108 |
-
self.chunk_stride = chunk_stride
|
| 109 |
-
|
| 110 |
-
def pad_chunk(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
|
| 111 |
-
seq_len = x.shape[-1]
|
| 112 |
-
|
| 113 |
-
rest = (
|
| 114 |
-
self.chunk_size
|
| 115 |
-
- (self.chunk_stride + seq_len % self.chunk_size) % self.chunk_size
|
| 116 |
-
)
|
| 117 |
-
out = F.pad(x, [self.chunk_stride, rest + self.chunk_stride])
|
| 118 |
-
|
| 119 |
-
return out, rest
|
| 120 |
-
|
| 121 |
-
def chunking(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
|
| 122 |
-
out, rest = self.pad_chunk(x)
|
| 123 |
-
batch_size, feat_dim, seq_len = out.shape
|
| 124 |
-
|
| 125 |
-
segments1 = (
|
| 126 |
-
out[:, :, : -self.chunk_stride]
|
| 127 |
-
.contiguous()
|
| 128 |
-
.view(batch_size, feat_dim, -1, self.chunk_size)
|
| 129 |
-
)
|
| 130 |
-
segments2 = (
|
| 131 |
-
out[:, :, self.chunk_stride :]
|
| 132 |
-
.contiguous()
|
| 133 |
-
.view(batch_size, feat_dim, -1, self.chunk_size)
|
| 134 |
-
)
|
| 135 |
-
out = torch.cat([segments1, segments2], dim=3)
|
| 136 |
-
out = (
|
| 137 |
-
out.view(batch_size, feat_dim, -1, self.chunk_size)
|
| 138 |
-
.transpose(2, 3)
|
| 139 |
-
.contiguous()
|
| 140 |
-
)
|
| 141 |
-
|
| 142 |
-
return out, rest
|
| 143 |
-
|
| 144 |
-
def merging(self, x: torch.Tensor, rest: int) -> torch.Tensor:
|
| 145 |
-
batch_size, dim, _, _ = x.shape
|
| 146 |
-
out = (
|
| 147 |
-
x.transpose(2, 3)
|
| 148 |
-
.contiguous()
|
| 149 |
-
.view(batch_size, dim, -1, self.chunk_size * 2)
|
| 150 |
-
)
|
| 151 |
-
out1 = (
|
| 152 |
-
out[:, :, :, : self.chunk_size]
|
| 153 |
-
.contiguous()
|
| 154 |
-
.view(batch_size, dim, -1)[:, :, self.chunk_stride :]
|
| 155 |
-
)
|
| 156 |
-
out2 = (
|
| 157 |
-
out[:, :, :, self.chunk_size :]
|
| 158 |
-
.contiguous()
|
| 159 |
-
.view(batch_size, dim, -1)[:, :, : -self.chunk_stride]
|
| 160 |
-
)
|
| 161 |
-
out = out1 + out2
|
| 162 |
-
if rest > 0:
|
| 163 |
-
out = out[:, :, :-rest]
|
| 164 |
-
out = out.contiguous()
|
| 165 |
-
return out
|
| 166 |
-
|
| 167 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 168 |
-
x, rest = self.chunking(x)
|
| 169 |
-
batch_size, _, dim1, dim2 = x.shape
|
| 170 |
-
out = x
|
| 171 |
-
for row_rnn, row_norm, col_rnn, col_norm in zip(
|
| 172 |
-
self.row_rnn, self.row_norm, self.col_rnn, self.col_norm
|
| 173 |
-
):
|
| 174 |
-
row_in = (
|
| 175 |
-
out.permute(0, 3, 2, 1)
|
| 176 |
-
.contiguous()
|
| 177 |
-
.view(batch_size * dim2, dim1, -1)
|
| 178 |
-
.contiguous()
|
| 179 |
-
)
|
| 180 |
-
row_out = row_rnn(row_in)
|
| 181 |
-
row_out = (
|
| 182 |
-
row_out.view(batch_size, dim2, dim1, -1)
|
| 183 |
-
.permute(0, 3, 2, 1)
|
| 184 |
-
.contiguous()
|
| 185 |
-
)
|
| 186 |
-
row_out = row_norm(row_out)
|
| 187 |
-
out = out + row_out
|
| 188 |
-
|
| 189 |
-
col_in = (
|
| 190 |
-
out.permute(0, 2, 3, 1)
|
| 191 |
-
.contiguous()
|
| 192 |
-
.view(batch_size * dim1, dim2, -1)
|
| 193 |
-
.contiguous()
|
| 194 |
-
)
|
| 195 |
-
col_out = col_rnn(col_in)
|
| 196 |
-
col_out = (
|
| 197 |
-
col_out.view(batch_size, dim1, dim2, -1)
|
| 198 |
-
.permute(0, 3, 1, 2)
|
| 199 |
-
.contiguous()
|
| 200 |
-
)
|
| 201 |
-
col_out = col_norm(col_out)
|
| 202 |
-
out = out + col_out
|
| 203 |
-
out = self.conv(out)
|
| 204 |
-
out = self.merging(out, rest)
|
| 205 |
-
out = out.transpose(1, 2).contiguous()
|
| 206 |
-
return out
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
class AutoPool(nn.Module):
|
| 210 |
-
def __init__(self, pool_dim: int = 1) -> None:
|
| 211 |
-
super(AutoPool, self).__init__()
|
| 212 |
-
self.pool_dim: int = pool_dim
|
| 213 |
-
self.softmax: nn.modules.Module = nn.Softmax(dim=pool_dim)
|
| 214 |
-
self.register_parameter("alpha", nn.Parameter(torch.ones(1)))
|
| 215 |
-
|
| 216 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 217 |
-
weight = self.softmax(torch.mul(x, self.alpha))
|
| 218 |
-
out = torch.sum(torch.mul(x, weight), dim=self.pool_dim)
|
| 219 |
-
return out
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
class SquimObjective(nn.Module):
|
| 223 |
-
|
| 224 |
-
def __init__(
|
| 225 |
-
self,
|
| 226 |
-
encoder: nn.Module,
|
| 227 |
-
dprnn: nn.Module,
|
| 228 |
-
branches: nn.ModuleList,
|
| 229 |
-
):
|
| 230 |
-
super(SquimObjective, self).__init__()
|
| 231 |
-
self.encoder = encoder
|
| 232 |
-
self.dprnn = dprnn
|
| 233 |
-
self.branches = branches
|
| 234 |
-
|
| 235 |
-
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
| 236 |
-
if x.ndim != 2:
|
| 237 |
-
raise ValueError(
|
| 238 |
-
f"The input must be a 2D Tensor. Found dimension {x.ndim}."
|
| 239 |
-
)
|
| 240 |
-
x = x / (torch.mean(x**2, dim=1, keepdim=True) ** 0.5 * 20)
|
| 241 |
-
out = self.encoder(x)
|
| 242 |
-
out = self.dprnn(out)
|
| 243 |
-
scores = []
|
| 244 |
-
for branch in self.branches:
|
| 245 |
-
scores.append(branch(out).squeeze(dim=1))
|
| 246 |
-
return scores
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
def _create_branch(d_model: int, nhead: int, metric: str) -> nn.modules.Module:
|
| 250 |
-
layer1 = nn.TransformerEncoderLayer(
|
| 251 |
-
d_model, nhead, d_model * 4, dropout=0.0, batch_first=True
|
| 252 |
-
)
|
| 253 |
-
layer2 = AutoPool()
|
| 254 |
-
if metric == "stoi":
|
| 255 |
-
layer3 = nn.Sequential(
|
| 256 |
-
nn.Linear(d_model, d_model),
|
| 257 |
-
nn.PReLU(),
|
| 258 |
-
nn.Linear(d_model, 1),
|
| 259 |
-
RangeSigmoid(),
|
| 260 |
-
)
|
| 261 |
-
elif metric == "pesq":
|
| 262 |
-
layer3 = nn.Sequential(
|
| 263 |
-
nn.Linear(d_model, d_model),
|
| 264 |
-
nn.PReLU(),
|
| 265 |
-
nn.Linear(d_model, 1),
|
| 266 |
-
RangeSigmoid(val_range=PESQRange),
|
| 267 |
-
)
|
| 268 |
-
else:
|
| 269 |
-
layer3: nn.modules.Module = nn.Sequential(
|
| 270 |
-
nn.Linear(d_model, d_model), nn.PReLU(), nn.Linear(d_model, 1)
|
| 271 |
-
)
|
| 272 |
-
return nn.Sequential(layer1, layer2, layer3)
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
def squim_objective_model(
|
| 276 |
-
feat_dim: int,
|
| 277 |
-
win_len: int,
|
| 278 |
-
d_model: int,
|
| 279 |
-
nhead: int,
|
| 280 |
-
hidden_dim: int,
|
| 281 |
-
num_blocks: int,
|
| 282 |
-
rnn_type: str,
|
| 283 |
-
chunk_size: int,
|
| 284 |
-
chunk_stride: Optional[int] = None,
|
| 285 |
-
) -> SquimObjective:
|
| 286 |
-
if chunk_stride is None:
|
| 287 |
-
chunk_stride = chunk_size // 2
|
| 288 |
-
encoder = Encoder(feat_dim, win_len)
|
| 289 |
-
dprnn = DPRNN(
|
| 290 |
-
feat_dim, hidden_dim, num_blocks, rnn_type, d_model, chunk_size, chunk_stride
|
| 291 |
-
)
|
| 292 |
-
branches = nn.ModuleList(
|
| 293 |
-
[
|
| 294 |
-
_create_branch(d_model, nhead, "stoi"),
|
| 295 |
-
_create_branch(d_model, nhead, "pesq"),
|
| 296 |
-
_create_branch(d_model, nhead, "sisdr"),
|
| 297 |
-
]
|
| 298 |
-
)
|
| 299 |
-
return SquimObjective(encoder, dprnn, branches)
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
def squim_objective_base() -> SquimObjective:
|
| 303 |
-
return squim_objective_model(
|
| 304 |
-
feat_dim=256,
|
| 305 |
-
win_len=64,
|
| 306 |
-
d_model=256,
|
| 307 |
-
nhead=4,
|
| 308 |
-
hidden_dim=256,
|
| 309 |
-
num_blocks=2,
|
| 310 |
-
rnn_type="LSTM",
|
| 311 |
-
chunk_size=71,
|
| 312 |
-
)
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
@dataclass
|
| 316 |
-
class SquimObjectiveBundle:
|
| 317 |
-
|
| 318 |
-
_path: str
|
| 319 |
-
_sample_rate: float
|
| 320 |
-
|
| 321 |
-
def _get_state_dict(self, dl_kwargs):
|
| 322 |
-
url = f"https://download.pytorch.org/torchaudio/models/{self._path}"
|
| 323 |
-
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
|
| 324 |
-
state_dict = load_state_dict_from_url(url, **dl_kwargs)
|
| 325 |
-
return state_dict
|
| 326 |
-
|
| 327 |
-
def get_model(self, *, dl_kwargs=None) -> SquimObjective:
|
| 328 |
-
model = squim_objective_base()
|
| 329 |
-
model.load_state_dict(self._get_state_dict(dl_kwargs))
|
| 330 |
-
model.eval()
|
| 331 |
-
return model
|
| 332 |
-
|
| 333 |
-
@property
|
| 334 |
-
def sample_rate(self):
|
| 335 |
-
return self._sample_rate
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
SQUIM_OBJECTIVE = SquimObjectiveBundle(
|
| 339 |
-
"squim_objective_dns2020.pth",
|
| 340 |
-
_sample_rate=16000,
|
| 341 |
-
)
|
| 342 |
-
SQUIM_OBJECTIVE.__doc__ = """SquimObjective pipeline trained using approach described in
|
| 343 |
-
:cite:`kumar2023torchaudio` on the *DNS 2020 Dataset* :cite:`reddy2020interspeech`.
|
| 344 |
-
|
| 345 |
-
The underlying model is constructed by :py:func:`torchaudio.models.squim_objective_base`.
|
| 346 |
-
The weights are under `Creative Commons Attribution 4.0 International License
|
| 347 |
-
<https://github.com/microsoft/DNS-Challenge/blob/interspeech2020/master/LICENSE>`__.
|
| 348 |
-
|
| 349 |
-
Please refer to :py:class:`SquimObjectiveBundle` for usage instructions.
|
| 350 |
-
"""
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
from torchaudio._internal import load_state_dict_from_url
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
from typing import List, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def transform_wb_pesq_range(x: float) -> float:
|
| 14 |
+
return 0.999 + (4.999 - 0.999) / (1 + math.exp(-1.3669 * x + 3.8224))
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
PESQRange: Tuple[float, float] = (
|
| 18 |
+
1.0,
|
| 19 |
+
transform_wb_pesq_range(4.5),
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class RangeSigmoid(nn.Module):
|
| 24 |
+
def __init__(self, val_range: Tuple[float, float] = (0.0, 1.0)) -> None:
|
| 25 |
+
super(RangeSigmoid, self).__init__()
|
| 26 |
+
assert isinstance(val_range, tuple) and len(val_range) == 2
|
| 27 |
+
self.val_range: Tuple[float, float] = val_range
|
| 28 |
+
self.sigmoid: nn.modules.Module = nn.Sigmoid()
|
| 29 |
+
|
| 30 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 31 |
+
out = (
|
| 32 |
+
self.sigmoid(x) * (self.val_range[1] - self.val_range[0])
|
| 33 |
+
+ self.val_range[0]
|
| 34 |
+
)
|
| 35 |
+
return out
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class Encoder(nn.Module):
|
| 39 |
+
|
| 40 |
+
def __init__(self, feat_dim: int = 512, win_len: int = 32) -> None:
|
| 41 |
+
super(Encoder, self).__init__()
|
| 42 |
+
|
| 43 |
+
self.conv1d = nn.Conv1d(1, feat_dim, win_len, stride=win_len // 2, bias=False)
|
| 44 |
+
|
| 45 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 46 |
+
out = x.unsqueeze(dim=1)
|
| 47 |
+
out = F.relu(self.conv1d(out))
|
| 48 |
+
return out
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class SingleRNN(nn.Module):
|
| 52 |
+
def __init__(
|
| 53 |
+
self, rnn_type: str, input_size: int, hidden_size: int, dropout: float = 0.0
|
| 54 |
+
) -> None:
|
| 55 |
+
super(SingleRNN, self).__init__()
|
| 56 |
+
|
| 57 |
+
self.rnn_type = rnn_type
|
| 58 |
+
self.input_size = input_size
|
| 59 |
+
self.hidden_size = hidden_size
|
| 60 |
+
|
| 61 |
+
self.rnn: nn.modules.Module = getattr(nn, rnn_type)(
|
| 62 |
+
input_size,
|
| 63 |
+
hidden_size,
|
| 64 |
+
1,
|
| 65 |
+
dropout=dropout,
|
| 66 |
+
batch_first=True,
|
| 67 |
+
bidirectional=True,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
self.proj = nn.Linear(hidden_size * 2, input_size)
|
| 71 |
+
|
| 72 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 73 |
+
out, _ = self.rnn(x)
|
| 74 |
+
out = self.proj(out)
|
| 75 |
+
return out
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class DPRNN(nn.Module):
|
| 79 |
+
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
feat_dim: int = 64,
|
| 83 |
+
hidden_dim: int = 128,
|
| 84 |
+
num_blocks: int = 6,
|
| 85 |
+
rnn_type: str = "LSTM",
|
| 86 |
+
d_model: int = 256,
|
| 87 |
+
chunk_size: int = 100,
|
| 88 |
+
chunk_stride: int = 50,
|
| 89 |
+
) -> None:
|
| 90 |
+
super(DPRNN, self).__init__()
|
| 91 |
+
|
| 92 |
+
self.num_blocks = num_blocks
|
| 93 |
+
|
| 94 |
+
self.row_rnn = nn.ModuleList([])
|
| 95 |
+
self.col_rnn = nn.ModuleList([])
|
| 96 |
+
self.row_norm = nn.ModuleList([])
|
| 97 |
+
self.col_norm = nn.ModuleList([])
|
| 98 |
+
for _ in range(num_blocks):
|
| 99 |
+
self.row_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
|
| 100 |
+
self.col_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
|
| 101 |
+
self.row_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
|
| 102 |
+
self.col_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
|
| 103 |
+
self.conv = nn.Sequential(
|
| 104 |
+
nn.Conv2d(feat_dim, d_model, 1),
|
| 105 |
+
nn.PReLU(),
|
| 106 |
+
)
|
| 107 |
+
self.chunk_size = chunk_size
|
| 108 |
+
self.chunk_stride = chunk_stride
|
| 109 |
+
|
| 110 |
+
def pad_chunk(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
|
| 111 |
+
seq_len = x.shape[-1]
|
| 112 |
+
|
| 113 |
+
rest = (
|
| 114 |
+
self.chunk_size
|
| 115 |
+
- (self.chunk_stride + seq_len % self.chunk_size) % self.chunk_size
|
| 116 |
+
)
|
| 117 |
+
out = F.pad(x, [self.chunk_stride, rest + self.chunk_stride])
|
| 118 |
+
|
| 119 |
+
return out, rest
|
| 120 |
+
|
| 121 |
+
def chunking(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
|
| 122 |
+
out, rest = self.pad_chunk(x)
|
| 123 |
+
batch_size, feat_dim, seq_len = out.shape
|
| 124 |
+
|
| 125 |
+
segments1 = (
|
| 126 |
+
out[:, :, : -self.chunk_stride]
|
| 127 |
+
.contiguous()
|
| 128 |
+
.view(batch_size, feat_dim, -1, self.chunk_size)
|
| 129 |
+
)
|
| 130 |
+
segments2 = (
|
| 131 |
+
out[:, :, self.chunk_stride :]
|
| 132 |
+
.contiguous()
|
| 133 |
+
.view(batch_size, feat_dim, -1, self.chunk_size)
|
| 134 |
+
)
|
| 135 |
+
out = torch.cat([segments1, segments2], dim=3)
|
| 136 |
+
out = (
|
| 137 |
+
out.view(batch_size, feat_dim, -1, self.chunk_size)
|
| 138 |
+
.transpose(2, 3)
|
| 139 |
+
.contiguous()
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
return out, rest
|
| 143 |
+
|
| 144 |
+
def merging(self, x: torch.Tensor, rest: int) -> torch.Tensor:
|
| 145 |
+
batch_size, dim, _, _ = x.shape
|
| 146 |
+
out = (
|
| 147 |
+
x.transpose(2, 3)
|
| 148 |
+
.contiguous()
|
| 149 |
+
.view(batch_size, dim, -1, self.chunk_size * 2)
|
| 150 |
+
)
|
| 151 |
+
out1 = (
|
| 152 |
+
out[:, :, :, : self.chunk_size]
|
| 153 |
+
.contiguous()
|
| 154 |
+
.view(batch_size, dim, -1)[:, :, self.chunk_stride :]
|
| 155 |
+
)
|
| 156 |
+
out2 = (
|
| 157 |
+
out[:, :, :, self.chunk_size :]
|
| 158 |
+
.contiguous()
|
| 159 |
+
.view(batch_size, dim, -1)[:, :, : -self.chunk_stride]
|
| 160 |
+
)
|
| 161 |
+
out = out1 + out2
|
| 162 |
+
if rest > 0:
|
| 163 |
+
out = out[:, :, :-rest]
|
| 164 |
+
out = out.contiguous()
|
| 165 |
+
return out
|
| 166 |
+
|
| 167 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 168 |
+
x, rest = self.chunking(x)
|
| 169 |
+
batch_size, _, dim1, dim2 = x.shape
|
| 170 |
+
out = x
|
| 171 |
+
for row_rnn, row_norm, col_rnn, col_norm in zip(
|
| 172 |
+
self.row_rnn, self.row_norm, self.col_rnn, self.col_norm
|
| 173 |
+
):
|
| 174 |
+
row_in = (
|
| 175 |
+
out.permute(0, 3, 2, 1)
|
| 176 |
+
.contiguous()
|
| 177 |
+
.view(batch_size * dim2, dim1, -1)
|
| 178 |
+
.contiguous()
|
| 179 |
+
)
|
| 180 |
+
row_out = row_rnn(row_in)
|
| 181 |
+
row_out = (
|
| 182 |
+
row_out.view(batch_size, dim2, dim1, -1)
|
| 183 |
+
.permute(0, 3, 2, 1)
|
| 184 |
+
.contiguous()
|
| 185 |
+
)
|
| 186 |
+
row_out = row_norm(row_out)
|
| 187 |
+
out = out + row_out
|
| 188 |
+
|
| 189 |
+
col_in = (
|
| 190 |
+
out.permute(0, 2, 3, 1)
|
| 191 |
+
.contiguous()
|
| 192 |
+
.view(batch_size * dim1, dim2, -1)
|
| 193 |
+
.contiguous()
|
| 194 |
+
)
|
| 195 |
+
col_out = col_rnn(col_in)
|
| 196 |
+
col_out = (
|
| 197 |
+
col_out.view(batch_size, dim1, dim2, -1)
|
| 198 |
+
.permute(0, 3, 1, 2)
|
| 199 |
+
.contiguous()
|
| 200 |
+
)
|
| 201 |
+
col_out = col_norm(col_out)
|
| 202 |
+
out = out + col_out
|
| 203 |
+
out = self.conv(out)
|
| 204 |
+
out = self.merging(out, rest)
|
| 205 |
+
out = out.transpose(1, 2).contiguous()
|
| 206 |
+
return out
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class AutoPool(nn.Module):
|
| 210 |
+
def __init__(self, pool_dim: int = 1) -> None:
|
| 211 |
+
super(AutoPool, self).__init__()
|
| 212 |
+
self.pool_dim: int = pool_dim
|
| 213 |
+
self.softmax: nn.modules.Module = nn.Softmax(dim=pool_dim)
|
| 214 |
+
self.register_parameter("alpha", nn.Parameter(torch.ones(1)))
|
| 215 |
+
|
| 216 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 217 |
+
weight = self.softmax(torch.mul(x, self.alpha))
|
| 218 |
+
out = torch.sum(torch.mul(x, weight), dim=self.pool_dim)
|
| 219 |
+
return out
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class SquimObjective(nn.Module):
|
| 223 |
+
|
| 224 |
+
def __init__(
|
| 225 |
+
self,
|
| 226 |
+
encoder: nn.Module,
|
| 227 |
+
dprnn: nn.Module,
|
| 228 |
+
branches: nn.ModuleList,
|
| 229 |
+
):
|
| 230 |
+
super(SquimObjective, self).__init__()
|
| 231 |
+
self.encoder = encoder
|
| 232 |
+
self.dprnn = dprnn
|
| 233 |
+
self.branches = branches
|
| 234 |
+
|
| 235 |
+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
| 236 |
+
if x.ndim != 2:
|
| 237 |
+
raise ValueError(
|
| 238 |
+
f"The input must be a 2D Tensor. Found dimension {x.ndim}."
|
| 239 |
+
)
|
| 240 |
+
x = x / (torch.mean(x**2, dim=1, keepdim=True) ** 0.5 * 20)
|
| 241 |
+
out = self.encoder(x)
|
| 242 |
+
out = self.dprnn(out)
|
| 243 |
+
scores = []
|
| 244 |
+
for branch in self.branches:
|
| 245 |
+
scores.append(branch(out).squeeze(dim=1))
|
| 246 |
+
return scores
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def _create_branch(d_model: int, nhead: int, metric: str) -> nn.modules.Module:
|
| 250 |
+
layer1 = nn.TransformerEncoderLayer(
|
| 251 |
+
d_model, nhead, d_model * 4, dropout=0.0, batch_first=True
|
| 252 |
+
)
|
| 253 |
+
layer2 = AutoPool()
|
| 254 |
+
if metric == "stoi":
|
| 255 |
+
layer3 = nn.Sequential(
|
| 256 |
+
nn.Linear(d_model, d_model),
|
| 257 |
+
nn.PReLU(),
|
| 258 |
+
nn.Linear(d_model, 1),
|
| 259 |
+
RangeSigmoid(),
|
| 260 |
+
)
|
| 261 |
+
elif metric == "pesq":
|
| 262 |
+
layer3 = nn.Sequential(
|
| 263 |
+
nn.Linear(d_model, d_model),
|
| 264 |
+
nn.PReLU(),
|
| 265 |
+
nn.Linear(d_model, 1),
|
| 266 |
+
RangeSigmoid(val_range=PESQRange),
|
| 267 |
+
)
|
| 268 |
+
else:
|
| 269 |
+
layer3: nn.modules.Module = nn.Sequential(
|
| 270 |
+
nn.Linear(d_model, d_model), nn.PReLU(), nn.Linear(d_model, 1)
|
| 271 |
+
)
|
| 272 |
+
return nn.Sequential(layer1, layer2, layer3)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def squim_objective_model(
|
| 276 |
+
feat_dim: int,
|
| 277 |
+
win_len: int,
|
| 278 |
+
d_model: int,
|
| 279 |
+
nhead: int,
|
| 280 |
+
hidden_dim: int,
|
| 281 |
+
num_blocks: int,
|
| 282 |
+
rnn_type: str,
|
| 283 |
+
chunk_size: int,
|
| 284 |
+
chunk_stride: Optional[int] = None,
|
| 285 |
+
) -> SquimObjective:
|
| 286 |
+
if chunk_stride is None:
|
| 287 |
+
chunk_stride = chunk_size // 2
|
| 288 |
+
encoder = Encoder(feat_dim, win_len)
|
| 289 |
+
dprnn = DPRNN(
|
| 290 |
+
feat_dim, hidden_dim, num_blocks, rnn_type, d_model, chunk_size, chunk_stride
|
| 291 |
+
)
|
| 292 |
+
branches = nn.ModuleList(
|
| 293 |
+
[
|
| 294 |
+
_create_branch(d_model, nhead, "stoi"),
|
| 295 |
+
_create_branch(d_model, nhead, "pesq"),
|
| 296 |
+
_create_branch(d_model, nhead, "sisdr"),
|
| 297 |
+
]
|
| 298 |
+
)
|
| 299 |
+
return SquimObjective(encoder, dprnn, branches)
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def squim_objective_base() -> SquimObjective:
|
| 303 |
+
return squim_objective_model(
|
| 304 |
+
feat_dim=256,
|
| 305 |
+
win_len=64,
|
| 306 |
+
d_model=256,
|
| 307 |
+
nhead=4,
|
| 308 |
+
hidden_dim=256,
|
| 309 |
+
num_blocks=2,
|
| 310 |
+
rnn_type="LSTM",
|
| 311 |
+
chunk_size=71,
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
@dataclass
|
| 316 |
+
class SquimObjectiveBundle:
|
| 317 |
+
|
| 318 |
+
_path: str
|
| 319 |
+
_sample_rate: float
|
| 320 |
+
|
| 321 |
+
def _get_state_dict(self, dl_kwargs):
|
| 322 |
+
url = f"https://download.pytorch.org/torchaudio/models/{self._path}"
|
| 323 |
+
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
|
| 324 |
+
state_dict = load_state_dict_from_url(url, **dl_kwargs)
|
| 325 |
+
return state_dict
|
| 326 |
+
|
| 327 |
+
def get_model(self, *, dl_kwargs=None) -> SquimObjective:
|
| 328 |
+
model = squim_objective_base()
|
| 329 |
+
model.load_state_dict(self._get_state_dict(dl_kwargs))
|
| 330 |
+
model.eval()
|
| 331 |
+
return model
|
| 332 |
+
|
| 333 |
+
@property
|
| 334 |
+
def sample_rate(self):
|
| 335 |
+
return self._sample_rate
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
SQUIM_OBJECTIVE = SquimObjectiveBundle(
|
| 339 |
+
"squim_objective_dns2020.pth",
|
| 340 |
+
_sample_rate=16000,
|
| 341 |
+
)
|
| 342 |
+
SQUIM_OBJECTIVE.__doc__ = """SquimObjective pipeline trained using approach described in
|
| 343 |
+
:cite:`kumar2023torchaudio` on the *DNS 2020 Dataset* :cite:`reddy2020interspeech`.
|
| 344 |
+
|
| 345 |
+
The underlying model is constructed by :py:func:`torchaudio.models.squim_objective_base`.
|
| 346 |
+
The weights are under `Creative Commons Attribution 4.0 International License
|
| 347 |
+
<https://github.com/microsoft/DNS-Challenge/blob/interspeech2020/master/LICENSE>`__.
|
| 348 |
+
|
| 349 |
+
Please refer to :py:class:`SquimObjectiveBundle` for usage instructions.
|
| 350 |
+
"""
|
mvsepless/models/bandit/core/metrics/snr.py
CHANGED
|
@@ -1,124 +1,124 @@
|
|
| 1 |
-
from typing import Any, Callable
|
| 2 |
-
|
| 3 |
-
import numpy as np
|
| 4 |
-
import torch
|
| 5 |
-
import torchmetrics as tm
|
| 6 |
-
from torch._C import _LinAlgError
|
| 7 |
-
from torchmetrics import functional as tmF
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class SafeSignalDistortionRatio(tm.SignalDistortionRatio):
|
| 11 |
-
def __init__(self, **kwargs) -> None:
|
| 12 |
-
super().__init__(**kwargs)
|
| 13 |
-
|
| 14 |
-
def update(self, *args, **kwargs) -> Any:
|
| 15 |
-
try:
|
| 16 |
-
super().update(*args, **kwargs)
|
| 17 |
-
except:
|
| 18 |
-
pass
|
| 19 |
-
|
| 20 |
-
def compute(self) -> Any:
|
| 21 |
-
if self.total == 0:
|
| 22 |
-
return torch.tensor(torch.nan)
|
| 23 |
-
return super().compute()
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
class BaseChunkMedianSignalRatio(tm.Metric):
|
| 27 |
-
def __init__(
|
| 28 |
-
self,
|
| 29 |
-
func: Callable,
|
| 30 |
-
window_size: int,
|
| 31 |
-
hop_size: int = None,
|
| 32 |
-
zero_mean: bool = False,
|
| 33 |
-
) -> None:
|
| 34 |
-
super().__init__()
|
| 35 |
-
|
| 36 |
-
self.func = func
|
| 37 |
-
self.window_size = window_size
|
| 38 |
-
if hop_size is None:
|
| 39 |
-
hop_size = window_size
|
| 40 |
-
self.hop_size = hop_size
|
| 41 |
-
|
| 42 |
-
self.add_state("sum_snr", default=torch.tensor(0.0), dist_reduce_fx="sum")
|
| 43 |
-
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
|
| 44 |
-
|
| 45 |
-
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
|
| 46 |
-
|
| 47 |
-
n_samples = target.shape[-1]
|
| 48 |
-
|
| 49 |
-
n_chunks = int(np.ceil((n_samples - self.window_size) / self.hop_size) + 1)
|
| 50 |
-
|
| 51 |
-
snr_chunk = []
|
| 52 |
-
|
| 53 |
-
for i in range(n_chunks):
|
| 54 |
-
start = i * self.hop_size
|
| 55 |
-
|
| 56 |
-
if n_samples - start < self.window_size:
|
| 57 |
-
continue
|
| 58 |
-
|
| 59 |
-
end = start + self.window_size
|
| 60 |
-
|
| 61 |
-
try:
|
| 62 |
-
chunk_snr = self.func(preds[..., start:end], target[..., start:end])
|
| 63 |
-
|
| 64 |
-
if torch.all(torch.isfinite(chunk_snr)):
|
| 65 |
-
snr_chunk.append(chunk_snr)
|
| 66 |
-
except _LinAlgError:
|
| 67 |
-
pass
|
| 68 |
-
|
| 69 |
-
snr_chunk = torch.stack(snr_chunk, dim=-1)
|
| 70 |
-
snr_batch, _ = torch.nanmedian(snr_chunk, dim=-1)
|
| 71 |
-
|
| 72 |
-
self.sum_snr += snr_batch.sum()
|
| 73 |
-
self.total += snr_batch.numel()
|
| 74 |
-
|
| 75 |
-
def compute(self) -> Any:
|
| 76 |
-
return self.sum_snr / self.total
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
class ChunkMedianSignalNoiseRatio(BaseChunkMedianSignalRatio):
|
| 80 |
-
def __init__(
|
| 81 |
-
self, window_size: int, hop_size: int = None, zero_mean: bool = False
|
| 82 |
-
) -> None:
|
| 83 |
-
super().__init__(
|
| 84 |
-
func=tmF.signal_noise_ratio,
|
| 85 |
-
window_size=window_size,
|
| 86 |
-
hop_size=hop_size,
|
| 87 |
-
zero_mean=zero_mean,
|
| 88 |
-
)
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
class ChunkMedianScaleInvariantSignalNoiseRatio(BaseChunkMedianSignalRatio):
|
| 92 |
-
def __init__(
|
| 93 |
-
self, window_size: int, hop_size: int = None, zero_mean: bool = False
|
| 94 |
-
) -> None:
|
| 95 |
-
super().__init__(
|
| 96 |
-
func=tmF.scale_invariant_signal_noise_ratio,
|
| 97 |
-
window_size=window_size,
|
| 98 |
-
hop_size=hop_size,
|
| 99 |
-
zero_mean=zero_mean,
|
| 100 |
-
)
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
class ChunkMedianSignalDistortionRatio(BaseChunkMedianSignalRatio):
|
| 104 |
-
def __init__(
|
| 105 |
-
self, window_size: int, hop_size: int = None, zero_mean: bool = False
|
| 106 |
-
) -> None:
|
| 107 |
-
super().__init__(
|
| 108 |
-
func=tmF.signal_distortion_ratio,
|
| 109 |
-
window_size=window_size,
|
| 110 |
-
hop_size=hop_size,
|
| 111 |
-
zero_mean=zero_mean,
|
| 112 |
-
)
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
class ChunkMedianScaleInvariantSignalDistortionRatio(BaseChunkMedianSignalRatio):
|
| 116 |
-
def __init__(
|
| 117 |
-
self, window_size: int, hop_size: int = None, zero_mean: bool = False
|
| 118 |
-
) -> None:
|
| 119 |
-
super().__init__(
|
| 120 |
-
func=tmF.scale_invariant_signal_distortion_ratio,
|
| 121 |
-
window_size=window_size,
|
| 122 |
-
hop_size=hop_size,
|
| 123 |
-
zero_mean=zero_mean,
|
| 124 |
-
)
|
|
|
|
| 1 |
+
from typing import Any, Callable
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torchmetrics as tm
|
| 6 |
+
from torch._C import _LinAlgError
|
| 7 |
+
from torchmetrics import functional as tmF
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SafeSignalDistortionRatio(tm.SignalDistortionRatio):
|
| 11 |
+
def __init__(self, **kwargs) -> None:
|
| 12 |
+
super().__init__(**kwargs)
|
| 13 |
+
|
| 14 |
+
def update(self, *args, **kwargs) -> Any:
|
| 15 |
+
try:
|
| 16 |
+
super().update(*args, **kwargs)
|
| 17 |
+
except:
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
def compute(self) -> Any:
|
| 21 |
+
if self.total == 0:
|
| 22 |
+
return torch.tensor(torch.nan)
|
| 23 |
+
return super().compute()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class BaseChunkMedianSignalRatio(tm.Metric):
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
func: Callable,
|
| 30 |
+
window_size: int,
|
| 31 |
+
hop_size: int = None,
|
| 32 |
+
zero_mean: bool = False,
|
| 33 |
+
) -> None:
|
| 34 |
+
super().__init__()
|
| 35 |
+
|
| 36 |
+
self.func = func
|
| 37 |
+
self.window_size = window_size
|
| 38 |
+
if hop_size is None:
|
| 39 |
+
hop_size = window_size
|
| 40 |
+
self.hop_size = hop_size
|
| 41 |
+
|
| 42 |
+
self.add_state("sum_snr", default=torch.tensor(0.0), dist_reduce_fx="sum")
|
| 43 |
+
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
|
| 44 |
+
|
| 45 |
+
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
|
| 46 |
+
|
| 47 |
+
n_samples = target.shape[-1]
|
| 48 |
+
|
| 49 |
+
n_chunks = int(np.ceil((n_samples - self.window_size) / self.hop_size) + 1)
|
| 50 |
+
|
| 51 |
+
snr_chunk = []
|
| 52 |
+
|
| 53 |
+
for i in range(n_chunks):
|
| 54 |
+
start = i * self.hop_size
|
| 55 |
+
|
| 56 |
+
if n_samples - start < self.window_size:
|
| 57 |
+
continue
|
| 58 |
+
|
| 59 |
+
end = start + self.window_size
|
| 60 |
+
|
| 61 |
+
try:
|
| 62 |
+
chunk_snr = self.func(preds[..., start:end], target[..., start:end])
|
| 63 |
+
|
| 64 |
+
if torch.all(torch.isfinite(chunk_snr)):
|
| 65 |
+
snr_chunk.append(chunk_snr)
|
| 66 |
+
except _LinAlgError:
|
| 67 |
+
pass
|
| 68 |
+
|
| 69 |
+
snr_chunk = torch.stack(snr_chunk, dim=-1)
|
| 70 |
+
snr_batch, _ = torch.nanmedian(snr_chunk, dim=-1)
|
| 71 |
+
|
| 72 |
+
self.sum_snr += snr_batch.sum()
|
| 73 |
+
self.total += snr_batch.numel()
|
| 74 |
+
|
| 75 |
+
def compute(self) -> Any:
|
| 76 |
+
return self.sum_snr / self.total
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class ChunkMedianSignalNoiseRatio(BaseChunkMedianSignalRatio):
|
| 80 |
+
def __init__(
|
| 81 |
+
self, window_size: int, hop_size: int = None, zero_mean: bool = False
|
| 82 |
+
) -> None:
|
| 83 |
+
super().__init__(
|
| 84 |
+
func=tmF.signal_noise_ratio,
|
| 85 |
+
window_size=window_size,
|
| 86 |
+
hop_size=hop_size,
|
| 87 |
+
zero_mean=zero_mean,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class ChunkMedianScaleInvariantSignalNoiseRatio(BaseChunkMedianSignalRatio):
|
| 92 |
+
def __init__(
|
| 93 |
+
self, window_size: int, hop_size: int = None, zero_mean: bool = False
|
| 94 |
+
) -> None:
|
| 95 |
+
super().__init__(
|
| 96 |
+
func=tmF.scale_invariant_signal_noise_ratio,
|
| 97 |
+
window_size=window_size,
|
| 98 |
+
hop_size=hop_size,
|
| 99 |
+
zero_mean=zero_mean,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class ChunkMedianSignalDistortionRatio(BaseChunkMedianSignalRatio):
|
| 104 |
+
def __init__(
|
| 105 |
+
self, window_size: int, hop_size: int = None, zero_mean: bool = False
|
| 106 |
+
) -> None:
|
| 107 |
+
super().__init__(
|
| 108 |
+
func=tmF.signal_distortion_ratio,
|
| 109 |
+
window_size=window_size,
|
| 110 |
+
hop_size=hop_size,
|
| 111 |
+
zero_mean=zero_mean,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class ChunkMedianScaleInvariantSignalDistortionRatio(BaseChunkMedianSignalRatio):
|
| 116 |
+
def __init__(
|
| 117 |
+
self, window_size: int, hop_size: int = None, zero_mean: bool = False
|
| 118 |
+
) -> None:
|
| 119 |
+
super().__init__(
|
| 120 |
+
func=tmF.scale_invariant_signal_distortion_ratio,
|
| 121 |
+
window_size=window_size,
|
| 122 |
+
hop_size=hop_size,
|
| 123 |
+
zero_mean=zero_mean,
|
| 124 |
+
)
|
mvsepless/models/bandit/core/model/__init__.py
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
-
from .bsrnn.wrapper import (
|
| 2 |
-
MultiMaskMultiSourceBandSplitRNNSimple,
|
| 3 |
-
)
|
|
|
|
| 1 |
+
from .bsrnn.wrapper import (
|
| 2 |
+
MultiMaskMultiSourceBandSplitRNNSimple,
|
| 3 |
+
)
|
mvsepless/models/bandit/core/model/_spectral.py
CHANGED
|
@@ -1,54 +1,54 @@
|
|
| 1 |
-
from typing import Dict, Optional
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
import torchaudio as ta
|
| 5 |
-
from torch import nn
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
class _SpectralComponent(nn.Module):
|
| 9 |
-
def __init__(
|
| 10 |
-
self,
|
| 11 |
-
n_fft: int = 2048,
|
| 12 |
-
win_length: Optional[int] = 2048,
|
| 13 |
-
hop_length: int = 512,
|
| 14 |
-
window_fn: str = "hann_window",
|
| 15 |
-
wkwargs: Optional[Dict] = None,
|
| 16 |
-
power: Optional[int] = None,
|
| 17 |
-
center: bool = True,
|
| 18 |
-
normalized: bool = True,
|
| 19 |
-
pad_mode: str = "constant",
|
| 20 |
-
onesided: bool = True,
|
| 21 |
-
**kwargs,
|
| 22 |
-
) -> None:
|
| 23 |
-
super().__init__()
|
| 24 |
-
|
| 25 |
-
assert power is None
|
| 26 |
-
|
| 27 |
-
window_fn = torch.__dict__[window_fn]
|
| 28 |
-
|
| 29 |
-
self.stft = ta.transforms.Spectrogram(
|
| 30 |
-
n_fft=n_fft,
|
| 31 |
-
win_length=win_length,
|
| 32 |
-
hop_length=hop_length,
|
| 33 |
-
pad_mode=pad_mode,
|
| 34 |
-
pad=0,
|
| 35 |
-
window_fn=window_fn,
|
| 36 |
-
wkwargs=wkwargs,
|
| 37 |
-
power=power,
|
| 38 |
-
normalized=normalized,
|
| 39 |
-
center=center,
|
| 40 |
-
onesided=onesided,
|
| 41 |
-
)
|
| 42 |
-
|
| 43 |
-
self.istft = ta.transforms.InverseSpectrogram(
|
| 44 |
-
n_fft=n_fft,
|
| 45 |
-
win_length=win_length,
|
| 46 |
-
hop_length=hop_length,
|
| 47 |
-
pad_mode=pad_mode,
|
| 48 |
-
pad=0,
|
| 49 |
-
window_fn=window_fn,
|
| 50 |
-
wkwargs=wkwargs,
|
| 51 |
-
normalized=normalized,
|
| 52 |
-
center=center,
|
| 53 |
-
onesided=onesided,
|
| 54 |
-
)
|
|
|
|
| 1 |
+
from typing import Dict, Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torchaudio as ta
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class _SpectralComponent(nn.Module):
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
n_fft: int = 2048,
|
| 12 |
+
win_length: Optional[int] = 2048,
|
| 13 |
+
hop_length: int = 512,
|
| 14 |
+
window_fn: str = "hann_window",
|
| 15 |
+
wkwargs: Optional[Dict] = None,
|
| 16 |
+
power: Optional[int] = None,
|
| 17 |
+
center: bool = True,
|
| 18 |
+
normalized: bool = True,
|
| 19 |
+
pad_mode: str = "constant",
|
| 20 |
+
onesided: bool = True,
|
| 21 |
+
**kwargs,
|
| 22 |
+
) -> None:
|
| 23 |
+
super().__init__()
|
| 24 |
+
|
| 25 |
+
assert power is None
|
| 26 |
+
|
| 27 |
+
window_fn = torch.__dict__[window_fn]
|
| 28 |
+
|
| 29 |
+
self.stft = ta.transforms.Spectrogram(
|
| 30 |
+
n_fft=n_fft,
|
| 31 |
+
win_length=win_length,
|
| 32 |
+
hop_length=hop_length,
|
| 33 |
+
pad_mode=pad_mode,
|
| 34 |
+
pad=0,
|
| 35 |
+
window_fn=window_fn,
|
| 36 |
+
wkwargs=wkwargs,
|
| 37 |
+
power=power,
|
| 38 |
+
normalized=normalized,
|
| 39 |
+
center=center,
|
| 40 |
+
onesided=onesided,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
self.istft = ta.transforms.InverseSpectrogram(
|
| 44 |
+
n_fft=n_fft,
|
| 45 |
+
win_length=win_length,
|
| 46 |
+
hop_length=hop_length,
|
| 47 |
+
pad_mode=pad_mode,
|
| 48 |
+
pad=0,
|
| 49 |
+
window_fn=window_fn,
|
| 50 |
+
wkwargs=wkwargs,
|
| 51 |
+
normalized=normalized,
|
| 52 |
+
center=center,
|
| 53 |
+
onesided=onesided,
|
| 54 |
+
)
|
mvsepless/models/bandit/core/model/bsrnn/__init__.py
CHANGED
|
@@ -1,23 +1,23 @@
|
|
| 1 |
-
from abc import ABC
|
| 2 |
-
from typing import Iterable, Mapping, Union
|
| 3 |
-
|
| 4 |
-
from torch import nn
|
| 5 |
-
|
| 6 |
-
from .bandsplit import BandSplitModule
|
| 7 |
-
from .tfmodel import (
|
| 8 |
-
SeqBandModellingModule,
|
| 9 |
-
TransformerTimeFreqModule,
|
| 10 |
-
)
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class BandsplitCoreBase(nn.Module, ABC):
|
| 14 |
-
band_split: nn.Module
|
| 15 |
-
tf_model: nn.Module
|
| 16 |
-
mask_estim: Union[nn.Module, Mapping[str, nn.Module], Iterable[nn.Module]]
|
| 17 |
-
|
| 18 |
-
def __init__(self) -> None:
|
| 19 |
-
super().__init__()
|
| 20 |
-
|
| 21 |
-
@staticmethod
|
| 22 |
-
def mask(x, m):
|
| 23 |
-
return x * m
|
|
|
|
| 1 |
+
from abc import ABC
|
| 2 |
+
from typing import Iterable, Mapping, Union
|
| 3 |
+
|
| 4 |
+
from torch import nn
|
| 5 |
+
|
| 6 |
+
from .bandsplit import BandSplitModule
|
| 7 |
+
from .tfmodel import (
|
| 8 |
+
SeqBandModellingModule,
|
| 9 |
+
TransformerTimeFreqModule,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class BandsplitCoreBase(nn.Module, ABC):
|
| 14 |
+
band_split: nn.Module
|
| 15 |
+
tf_model: nn.Module
|
| 16 |
+
mask_estim: Union[nn.Module, Mapping[str, nn.Module], Iterable[nn.Module]]
|
| 17 |
+
|
| 18 |
+
def __init__(self) -> None:
|
| 19 |
+
super().__init__()
|
| 20 |
+
|
| 21 |
+
@staticmethod
|
| 22 |
+
def mask(x, m):
|
| 23 |
+
return x * m
|
mvsepless/models/bandit/core/model/bsrnn/bandsplit.py
CHANGED
|
@@ -1,119 +1,119 @@
|
|
| 1 |
-
from typing import List, Tuple
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
from torch import nn
|
| 5 |
-
|
| 6 |
-
from .utils import (
|
| 7 |
-
band_widths_from_specs,
|
| 8 |
-
check_no_gap,
|
| 9 |
-
check_no_overlap,
|
| 10 |
-
check_nonzero_bandwidth,
|
| 11 |
-
)
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
class NormFC(nn.Module):
|
| 15 |
-
def __init__(
|
| 16 |
-
self,
|
| 17 |
-
emb_dim: int,
|
| 18 |
-
bandwidth: int,
|
| 19 |
-
in_channel: int,
|
| 20 |
-
normalize_channel_independently: bool = False,
|
| 21 |
-
treat_channel_as_feature: bool = True,
|
| 22 |
-
) -> None:
|
| 23 |
-
super().__init__()
|
| 24 |
-
|
| 25 |
-
self.treat_channel_as_feature = treat_channel_as_feature
|
| 26 |
-
|
| 27 |
-
if normalize_channel_independently:
|
| 28 |
-
raise NotImplementedError
|
| 29 |
-
|
| 30 |
-
reim = 2
|
| 31 |
-
|
| 32 |
-
self.norm = nn.LayerNorm(in_channel * bandwidth * reim)
|
| 33 |
-
|
| 34 |
-
fc_in = bandwidth * reim
|
| 35 |
-
|
| 36 |
-
if treat_channel_as_feature:
|
| 37 |
-
fc_in *= in_channel
|
| 38 |
-
else:
|
| 39 |
-
assert emb_dim % in_channel == 0
|
| 40 |
-
emb_dim = emb_dim // in_channel
|
| 41 |
-
|
| 42 |
-
self.fc = nn.Linear(fc_in, emb_dim)
|
| 43 |
-
|
| 44 |
-
def forward(self, xb):
|
| 45 |
-
|
| 46 |
-
batch, n_time, in_chan, ribw = xb.shape
|
| 47 |
-
xb = self.norm(xb.reshape(batch, n_time, in_chan * ribw))
|
| 48 |
-
|
| 49 |
-
if not self.treat_channel_as_feature:
|
| 50 |
-
xb = xb.reshape(batch, n_time, in_chan, ribw)
|
| 51 |
-
|
| 52 |
-
zb = self.fc(xb)
|
| 53 |
-
|
| 54 |
-
if not self.treat_channel_as_feature:
|
| 55 |
-
batch, n_time, in_chan, emb_dim_per_chan = zb.shape
|
| 56 |
-
zb = zb.reshape((batch, n_time, in_chan * emb_dim_per_chan))
|
| 57 |
-
|
| 58 |
-
return zb
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
class BandSplitModule(nn.Module):
|
| 62 |
-
def __init__(
|
| 63 |
-
self,
|
| 64 |
-
band_specs: List[Tuple[float, float]],
|
| 65 |
-
emb_dim: int,
|
| 66 |
-
in_channel: int,
|
| 67 |
-
require_no_overlap: bool = False,
|
| 68 |
-
require_no_gap: bool = True,
|
| 69 |
-
normalize_channel_independently: bool = False,
|
| 70 |
-
treat_channel_as_feature: bool = True,
|
| 71 |
-
) -> None:
|
| 72 |
-
super().__init__()
|
| 73 |
-
|
| 74 |
-
check_nonzero_bandwidth(band_specs)
|
| 75 |
-
|
| 76 |
-
if require_no_gap:
|
| 77 |
-
check_no_gap(band_specs)
|
| 78 |
-
|
| 79 |
-
if require_no_overlap:
|
| 80 |
-
check_no_overlap(band_specs)
|
| 81 |
-
|
| 82 |
-
self.band_specs = band_specs
|
| 83 |
-
self.band_widths = band_widths_from_specs(band_specs)
|
| 84 |
-
self.n_bands = len(band_specs)
|
| 85 |
-
self.emb_dim = emb_dim
|
| 86 |
-
|
| 87 |
-
self.norm_fc_modules = nn.ModuleList(
|
| 88 |
-
[ # type: ignore
|
| 89 |
-
(
|
| 90 |
-
NormFC(
|
| 91 |
-
emb_dim=emb_dim,
|
| 92 |
-
bandwidth=bw,
|
| 93 |
-
in_channel=in_channel,
|
| 94 |
-
normalize_channel_independently=normalize_channel_independently,
|
| 95 |
-
treat_channel_as_feature=treat_channel_as_feature,
|
| 96 |
-
)
|
| 97 |
-
)
|
| 98 |
-
for bw in self.band_widths
|
| 99 |
-
]
|
| 100 |
-
)
|
| 101 |
-
|
| 102 |
-
def forward(self, x: torch.Tensor):
|
| 103 |
-
|
| 104 |
-
batch, in_chan, _, n_time = x.shape
|
| 105 |
-
|
| 106 |
-
z = torch.zeros(
|
| 107 |
-
size=(batch, self.n_bands, n_time, self.emb_dim), device=x.device
|
| 108 |
-
)
|
| 109 |
-
|
| 110 |
-
xr = torch.view_as_real(x)
|
| 111 |
-
xr = torch.permute(xr, (0, 3, 1, 4, 2))
|
| 112 |
-
batch, n_time, in_chan, reim, band_width = xr.shape
|
| 113 |
-
for i, nfm in enumerate(self.norm_fc_modules):
|
| 114 |
-
fstart, fend = self.band_specs[i]
|
| 115 |
-
xb = xr[..., fstart:fend]
|
| 116 |
-
xb = torch.reshape(xb, (batch, n_time, in_chan, -1))
|
| 117 |
-
z[:, i, :, :] = nfm(xb.contiguous())
|
| 118 |
-
|
| 119 |
-
return z
|
|
|
|
| 1 |
+
from typing import List, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
|
| 6 |
+
from .utils import (
|
| 7 |
+
band_widths_from_specs,
|
| 8 |
+
check_no_gap,
|
| 9 |
+
check_no_overlap,
|
| 10 |
+
check_nonzero_bandwidth,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class NormFC(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
emb_dim: int,
|
| 18 |
+
bandwidth: int,
|
| 19 |
+
in_channel: int,
|
| 20 |
+
normalize_channel_independently: bool = False,
|
| 21 |
+
treat_channel_as_feature: bool = True,
|
| 22 |
+
) -> None:
|
| 23 |
+
super().__init__()
|
| 24 |
+
|
| 25 |
+
self.treat_channel_as_feature = treat_channel_as_feature
|
| 26 |
+
|
| 27 |
+
if normalize_channel_independently:
|
| 28 |
+
raise NotImplementedError
|
| 29 |
+
|
| 30 |
+
reim = 2
|
| 31 |
+
|
| 32 |
+
self.norm = nn.LayerNorm(in_channel * bandwidth * reim)
|
| 33 |
+
|
| 34 |
+
fc_in = bandwidth * reim
|
| 35 |
+
|
| 36 |
+
if treat_channel_as_feature:
|
| 37 |
+
fc_in *= in_channel
|
| 38 |
+
else:
|
| 39 |
+
assert emb_dim % in_channel == 0
|
| 40 |
+
emb_dim = emb_dim // in_channel
|
| 41 |
+
|
| 42 |
+
self.fc = nn.Linear(fc_in, emb_dim)
|
| 43 |
+
|
| 44 |
+
def forward(self, xb):
|
| 45 |
+
|
| 46 |
+
batch, n_time, in_chan, ribw = xb.shape
|
| 47 |
+
xb = self.norm(xb.reshape(batch, n_time, in_chan * ribw))
|
| 48 |
+
|
| 49 |
+
if not self.treat_channel_as_feature:
|
| 50 |
+
xb = xb.reshape(batch, n_time, in_chan, ribw)
|
| 51 |
+
|
| 52 |
+
zb = self.fc(xb)
|
| 53 |
+
|
| 54 |
+
if not self.treat_channel_as_feature:
|
| 55 |
+
batch, n_time, in_chan, emb_dim_per_chan = zb.shape
|
| 56 |
+
zb = zb.reshape((batch, n_time, in_chan * emb_dim_per_chan))
|
| 57 |
+
|
| 58 |
+
return zb
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class BandSplitModule(nn.Module):
|
| 62 |
+
def __init__(
|
| 63 |
+
self,
|
| 64 |
+
band_specs: List[Tuple[float, float]],
|
| 65 |
+
emb_dim: int,
|
| 66 |
+
in_channel: int,
|
| 67 |
+
require_no_overlap: bool = False,
|
| 68 |
+
require_no_gap: bool = True,
|
| 69 |
+
normalize_channel_independently: bool = False,
|
| 70 |
+
treat_channel_as_feature: bool = True,
|
| 71 |
+
) -> None:
|
| 72 |
+
super().__init__()
|
| 73 |
+
|
| 74 |
+
check_nonzero_bandwidth(band_specs)
|
| 75 |
+
|
| 76 |
+
if require_no_gap:
|
| 77 |
+
check_no_gap(band_specs)
|
| 78 |
+
|
| 79 |
+
if require_no_overlap:
|
| 80 |
+
check_no_overlap(band_specs)
|
| 81 |
+
|
| 82 |
+
self.band_specs = band_specs
|
| 83 |
+
self.band_widths = band_widths_from_specs(band_specs)
|
| 84 |
+
self.n_bands = len(band_specs)
|
| 85 |
+
self.emb_dim = emb_dim
|
| 86 |
+
|
| 87 |
+
self.norm_fc_modules = nn.ModuleList(
|
| 88 |
+
[ # type: ignore
|
| 89 |
+
(
|
| 90 |
+
NormFC(
|
| 91 |
+
emb_dim=emb_dim,
|
| 92 |
+
bandwidth=bw,
|
| 93 |
+
in_channel=in_channel,
|
| 94 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 95 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 96 |
+
)
|
| 97 |
+
)
|
| 98 |
+
for bw in self.band_widths
|
| 99 |
+
]
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
def forward(self, x: torch.Tensor):
|
| 103 |
+
|
| 104 |
+
batch, in_chan, _, n_time = x.shape
|
| 105 |
+
|
| 106 |
+
z = torch.zeros(
|
| 107 |
+
size=(batch, self.n_bands, n_time, self.emb_dim), device=x.device
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
xr = torch.view_as_real(x)
|
| 111 |
+
xr = torch.permute(xr, (0, 3, 1, 4, 2))
|
| 112 |
+
batch, n_time, in_chan, reim, band_width = xr.shape
|
| 113 |
+
for i, nfm in enumerate(self.norm_fc_modules):
|
| 114 |
+
fstart, fend = self.band_specs[i]
|
| 115 |
+
xb = xr[..., fstart:fend]
|
| 116 |
+
xb = torch.reshape(xb, (batch, n_time, in_chan, -1))
|
| 117 |
+
z[:, i, :, :] = nfm(xb.contiguous())
|
| 118 |
+
|
| 119 |
+
return z
|
mvsepless/models/bandit/core/model/bsrnn/core.py
CHANGED
|
@@ -1,619 +1,619 @@
|
|
| 1 |
-
from typing import Dict, List, Optional, Tuple
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
from torch import nn
|
| 5 |
-
from torch.nn import functional as F
|
| 6 |
-
|
| 7 |
-
from . import BandsplitCoreBase
|
| 8 |
-
from .bandsplit import BandSplitModule
|
| 9 |
-
from .maskestim import (
|
| 10 |
-
MaskEstimationModule,
|
| 11 |
-
OverlappingMaskEstimationModule,
|
| 12 |
-
)
|
| 13 |
-
from .tfmodel import (
|
| 14 |
-
ConvolutionalTimeFreqModule,
|
| 15 |
-
SeqBandModellingModule,
|
| 16 |
-
TransformerTimeFreqModule,
|
| 17 |
-
)
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
class MultiMaskBandSplitCoreBase(BandsplitCoreBase):
|
| 21 |
-
def __init__(self) -> None:
|
| 22 |
-
super().__init__()
|
| 23 |
-
|
| 24 |
-
def forward(self, x, cond=None, compute_residual: bool = True):
|
| 25 |
-
batch, in_chan, n_freq, n_time = x.shape
|
| 26 |
-
x = torch.reshape(x, (-1, 1, n_freq, n_time))
|
| 27 |
-
|
| 28 |
-
z = self.band_split(x)
|
| 29 |
-
|
| 30 |
-
q = self.tf_model(z)
|
| 31 |
-
|
| 32 |
-
out = {}
|
| 33 |
-
|
| 34 |
-
for stem, mem in self.mask_estim.items():
|
| 35 |
-
m = mem(q, cond=cond)
|
| 36 |
-
|
| 37 |
-
s = self.mask(x, m)
|
| 38 |
-
s = torch.reshape(s, (batch, in_chan, n_freq, n_time))
|
| 39 |
-
out[stem] = s
|
| 40 |
-
|
| 41 |
-
return {"spectrogram": out}
|
| 42 |
-
|
| 43 |
-
def instantiate_mask_estim(
|
| 44 |
-
self,
|
| 45 |
-
in_channel: int,
|
| 46 |
-
stems: List[str],
|
| 47 |
-
band_specs: List[Tuple[float, float]],
|
| 48 |
-
emb_dim: int,
|
| 49 |
-
mlp_dim: int,
|
| 50 |
-
cond_dim: int,
|
| 51 |
-
hidden_activation: str,
|
| 52 |
-
hidden_activation_kwargs: Optional[Dict] = None,
|
| 53 |
-
complex_mask: bool = True,
|
| 54 |
-
overlapping_band: bool = False,
|
| 55 |
-
freq_weights: Optional[List[torch.Tensor]] = None,
|
| 56 |
-
n_freq: Optional[int] = None,
|
| 57 |
-
use_freq_weights: bool = True,
|
| 58 |
-
mult_add_mask: bool = False,
|
| 59 |
-
):
|
| 60 |
-
if hidden_activation_kwargs is None:
|
| 61 |
-
hidden_activation_kwargs = {}
|
| 62 |
-
|
| 63 |
-
if "mne:+" in stems:
|
| 64 |
-
stems = [s for s in stems if s != "mne:+"]
|
| 65 |
-
|
| 66 |
-
if overlapping_band:
|
| 67 |
-
assert freq_weights is not None
|
| 68 |
-
assert n_freq is not None
|
| 69 |
-
|
| 70 |
-
if mult_add_mask:
|
| 71 |
-
|
| 72 |
-
self.mask_estim = nn.ModuleDict(
|
| 73 |
-
{
|
| 74 |
-
stem: MultAddMaskEstimationModule(
|
| 75 |
-
band_specs=band_specs,
|
| 76 |
-
freq_weights=freq_weights,
|
| 77 |
-
n_freq=n_freq,
|
| 78 |
-
emb_dim=emb_dim,
|
| 79 |
-
mlp_dim=mlp_dim,
|
| 80 |
-
in_channel=in_channel,
|
| 81 |
-
hidden_activation=hidden_activation,
|
| 82 |
-
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 83 |
-
complex_mask=complex_mask,
|
| 84 |
-
use_freq_weights=use_freq_weights,
|
| 85 |
-
)
|
| 86 |
-
for stem in stems
|
| 87 |
-
}
|
| 88 |
-
)
|
| 89 |
-
else:
|
| 90 |
-
self.mask_estim = nn.ModuleDict(
|
| 91 |
-
{
|
| 92 |
-
stem: OverlappingMaskEstimationModule(
|
| 93 |
-
band_specs=band_specs,
|
| 94 |
-
freq_weights=freq_weights,
|
| 95 |
-
n_freq=n_freq,
|
| 96 |
-
emb_dim=emb_dim,
|
| 97 |
-
mlp_dim=mlp_dim,
|
| 98 |
-
in_channel=in_channel,
|
| 99 |
-
hidden_activation=hidden_activation,
|
| 100 |
-
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 101 |
-
complex_mask=complex_mask,
|
| 102 |
-
use_freq_weights=use_freq_weights,
|
| 103 |
-
)
|
| 104 |
-
for stem in stems
|
| 105 |
-
}
|
| 106 |
-
)
|
| 107 |
-
else:
|
| 108 |
-
self.mask_estim = nn.ModuleDict(
|
| 109 |
-
{
|
| 110 |
-
stem: MaskEstimationModule(
|
| 111 |
-
band_specs=band_specs,
|
| 112 |
-
emb_dim=emb_dim,
|
| 113 |
-
mlp_dim=mlp_dim,
|
| 114 |
-
cond_dim=cond_dim,
|
| 115 |
-
in_channel=in_channel,
|
| 116 |
-
hidden_activation=hidden_activation,
|
| 117 |
-
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 118 |
-
complex_mask=complex_mask,
|
| 119 |
-
)
|
| 120 |
-
for stem in stems
|
| 121 |
-
}
|
| 122 |
-
)
|
| 123 |
-
|
| 124 |
-
def instantiate_bandsplit(
|
| 125 |
-
self,
|
| 126 |
-
in_channel: int,
|
| 127 |
-
band_specs: List[Tuple[float, float]],
|
| 128 |
-
require_no_overlap: bool = False,
|
| 129 |
-
require_no_gap: bool = True,
|
| 130 |
-
normalize_channel_independently: bool = False,
|
| 131 |
-
treat_channel_as_feature: bool = True,
|
| 132 |
-
emb_dim: int = 128,
|
| 133 |
-
):
|
| 134 |
-
self.band_split = BandSplitModule(
|
| 135 |
-
in_channel=in_channel,
|
| 136 |
-
band_specs=band_specs,
|
| 137 |
-
require_no_overlap=require_no_overlap,
|
| 138 |
-
require_no_gap=require_no_gap,
|
| 139 |
-
normalize_channel_independently=normalize_channel_independently,
|
| 140 |
-
treat_channel_as_feature=treat_channel_as_feature,
|
| 141 |
-
emb_dim=emb_dim,
|
| 142 |
-
)
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
class SingleMaskBandsplitCoreBase(BandsplitCoreBase):
|
| 146 |
-
def __init__(self, **kwargs) -> None:
|
| 147 |
-
super().__init__()
|
| 148 |
-
|
| 149 |
-
def forward(self, x):
|
| 150 |
-
z = self.band_split(x)
|
| 151 |
-
q = self.tf_model(z)
|
| 152 |
-
m = self.mask_estim(q)
|
| 153 |
-
|
| 154 |
-
s = self.mask(x, m)
|
| 155 |
-
|
| 156 |
-
return s
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
class SingleMaskBandsplitCoreRNN(
|
| 160 |
-
SingleMaskBandsplitCoreBase,
|
| 161 |
-
):
|
| 162 |
-
def __init__(
|
| 163 |
-
self,
|
| 164 |
-
in_channel: int,
|
| 165 |
-
band_specs: List[Tuple[float, float]],
|
| 166 |
-
require_no_overlap: bool = False,
|
| 167 |
-
require_no_gap: bool = True,
|
| 168 |
-
normalize_channel_independently: bool = False,
|
| 169 |
-
treat_channel_as_feature: bool = True,
|
| 170 |
-
n_sqm_modules: int = 12,
|
| 171 |
-
emb_dim: int = 128,
|
| 172 |
-
rnn_dim: int = 256,
|
| 173 |
-
bidirectional: bool = True,
|
| 174 |
-
rnn_type: str = "LSTM",
|
| 175 |
-
mlp_dim: int = 512,
|
| 176 |
-
hidden_activation: str = "Tanh",
|
| 177 |
-
hidden_activation_kwargs: Optional[Dict] = None,
|
| 178 |
-
complex_mask: bool = True,
|
| 179 |
-
) -> None:
|
| 180 |
-
super().__init__()
|
| 181 |
-
self.band_split = BandSplitModule(
|
| 182 |
-
in_channel=in_channel,
|
| 183 |
-
band_specs=band_specs,
|
| 184 |
-
require_no_overlap=require_no_overlap,
|
| 185 |
-
require_no_gap=require_no_gap,
|
| 186 |
-
normalize_channel_independently=normalize_channel_independently,
|
| 187 |
-
treat_channel_as_feature=treat_channel_as_feature,
|
| 188 |
-
emb_dim=emb_dim,
|
| 189 |
-
)
|
| 190 |
-
self.tf_model = SeqBandModellingModule(
|
| 191 |
-
n_modules=n_sqm_modules,
|
| 192 |
-
emb_dim=emb_dim,
|
| 193 |
-
rnn_dim=rnn_dim,
|
| 194 |
-
bidirectional=bidirectional,
|
| 195 |
-
rnn_type=rnn_type,
|
| 196 |
-
)
|
| 197 |
-
self.mask_estim = MaskEstimationModule(
|
| 198 |
-
in_channel=in_channel,
|
| 199 |
-
band_specs=band_specs,
|
| 200 |
-
emb_dim=emb_dim,
|
| 201 |
-
mlp_dim=mlp_dim,
|
| 202 |
-
hidden_activation=hidden_activation,
|
| 203 |
-
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 204 |
-
complex_mask=complex_mask,
|
| 205 |
-
)
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
class SingleMaskBandsplitCoreTransformer(
|
| 209 |
-
SingleMaskBandsplitCoreBase,
|
| 210 |
-
):
|
| 211 |
-
def __init__(
|
| 212 |
-
self,
|
| 213 |
-
in_channel: int,
|
| 214 |
-
band_specs: List[Tuple[float, float]],
|
| 215 |
-
require_no_overlap: bool = False,
|
| 216 |
-
require_no_gap: bool = True,
|
| 217 |
-
normalize_channel_independently: bool = False,
|
| 218 |
-
treat_channel_as_feature: bool = True,
|
| 219 |
-
n_sqm_modules: int = 12,
|
| 220 |
-
emb_dim: int = 128,
|
| 221 |
-
rnn_dim: int = 256,
|
| 222 |
-
bidirectional: bool = True,
|
| 223 |
-
tf_dropout: float = 0.0,
|
| 224 |
-
mlp_dim: int = 512,
|
| 225 |
-
hidden_activation: str = "Tanh",
|
| 226 |
-
hidden_activation_kwargs: Optional[Dict] = None,
|
| 227 |
-
complex_mask: bool = True,
|
| 228 |
-
) -> None:
|
| 229 |
-
super().__init__()
|
| 230 |
-
self.band_split = BandSplitModule(
|
| 231 |
-
in_channel=in_channel,
|
| 232 |
-
band_specs=band_specs,
|
| 233 |
-
require_no_overlap=require_no_overlap,
|
| 234 |
-
require_no_gap=require_no_gap,
|
| 235 |
-
normalize_channel_independently=normalize_channel_independently,
|
| 236 |
-
treat_channel_as_feature=treat_channel_as_feature,
|
| 237 |
-
emb_dim=emb_dim,
|
| 238 |
-
)
|
| 239 |
-
self.tf_model = TransformerTimeFreqModule(
|
| 240 |
-
n_modules=n_sqm_modules,
|
| 241 |
-
emb_dim=emb_dim,
|
| 242 |
-
rnn_dim=rnn_dim,
|
| 243 |
-
bidirectional=bidirectional,
|
| 244 |
-
dropout=tf_dropout,
|
| 245 |
-
)
|
| 246 |
-
self.mask_estim = MaskEstimationModule(
|
| 247 |
-
in_channel=in_channel,
|
| 248 |
-
band_specs=band_specs,
|
| 249 |
-
emb_dim=emb_dim,
|
| 250 |
-
mlp_dim=mlp_dim,
|
| 251 |
-
hidden_activation=hidden_activation,
|
| 252 |
-
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 253 |
-
complex_mask=complex_mask,
|
| 254 |
-
)
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
class MultiSourceMultiMaskBandSplitCoreRNN(MultiMaskBandSplitCoreBase):
|
| 258 |
-
def __init__(
|
| 259 |
-
self,
|
| 260 |
-
in_channel: int,
|
| 261 |
-
stems: List[str],
|
| 262 |
-
band_specs: List[Tuple[float, float]],
|
| 263 |
-
require_no_overlap: bool = False,
|
| 264 |
-
require_no_gap: bool = True,
|
| 265 |
-
normalize_channel_independently: bool = False,
|
| 266 |
-
treat_channel_as_feature: bool = True,
|
| 267 |
-
n_sqm_modules: int = 12,
|
| 268 |
-
emb_dim: int = 128,
|
| 269 |
-
rnn_dim: int = 256,
|
| 270 |
-
bidirectional: bool = True,
|
| 271 |
-
rnn_type: str = "LSTM",
|
| 272 |
-
mlp_dim: int = 512,
|
| 273 |
-
cond_dim: int = 0,
|
| 274 |
-
hidden_activation: str = "Tanh",
|
| 275 |
-
hidden_activation_kwargs: Optional[Dict] = None,
|
| 276 |
-
complex_mask: bool = True,
|
| 277 |
-
overlapping_band: bool = False,
|
| 278 |
-
freq_weights: Optional[List[torch.Tensor]] = None,
|
| 279 |
-
n_freq: Optional[int] = None,
|
| 280 |
-
use_freq_weights: bool = True,
|
| 281 |
-
mult_add_mask: bool = False,
|
| 282 |
-
) -> None:
|
| 283 |
-
|
| 284 |
-
super().__init__()
|
| 285 |
-
self.instantiate_bandsplit(
|
| 286 |
-
in_channel=in_channel,
|
| 287 |
-
band_specs=band_specs,
|
| 288 |
-
require_no_overlap=require_no_overlap,
|
| 289 |
-
require_no_gap=require_no_gap,
|
| 290 |
-
normalize_channel_independently=normalize_channel_independently,
|
| 291 |
-
treat_channel_as_feature=treat_channel_as_feature,
|
| 292 |
-
emb_dim=emb_dim,
|
| 293 |
-
)
|
| 294 |
-
|
| 295 |
-
self.tf_model = SeqBandModellingModule(
|
| 296 |
-
n_modules=n_sqm_modules,
|
| 297 |
-
emb_dim=emb_dim,
|
| 298 |
-
rnn_dim=rnn_dim,
|
| 299 |
-
bidirectional=bidirectional,
|
| 300 |
-
rnn_type=rnn_type,
|
| 301 |
-
)
|
| 302 |
-
|
| 303 |
-
self.mult_add_mask = mult_add_mask
|
| 304 |
-
|
| 305 |
-
self.instantiate_mask_estim(
|
| 306 |
-
in_channel=in_channel,
|
| 307 |
-
stems=stems,
|
| 308 |
-
band_specs=band_specs,
|
| 309 |
-
emb_dim=emb_dim,
|
| 310 |
-
mlp_dim=mlp_dim,
|
| 311 |
-
cond_dim=cond_dim,
|
| 312 |
-
hidden_activation=hidden_activation,
|
| 313 |
-
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 314 |
-
complex_mask=complex_mask,
|
| 315 |
-
overlapping_band=overlapping_band,
|
| 316 |
-
freq_weights=freq_weights,
|
| 317 |
-
n_freq=n_freq,
|
| 318 |
-
use_freq_weights=use_freq_weights,
|
| 319 |
-
mult_add_mask=mult_add_mask,
|
| 320 |
-
)
|
| 321 |
-
|
| 322 |
-
@staticmethod
|
| 323 |
-
def _mult_add_mask(x, m):
|
| 324 |
-
|
| 325 |
-
assert m.ndim == 5
|
| 326 |
-
|
| 327 |
-
mm = m[..., 0]
|
| 328 |
-
am = m[..., 1]
|
| 329 |
-
|
| 330 |
-
return x * mm + am
|
| 331 |
-
|
| 332 |
-
def mask(self, x, m):
|
| 333 |
-
if self.mult_add_mask:
|
| 334 |
-
|
| 335 |
-
return self._mult_add_mask(x, m)
|
| 336 |
-
else:
|
| 337 |
-
return super().mask(x, m)
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
class MultiSourceMultiMaskBandSplitCoreTransformer(
|
| 341 |
-
MultiMaskBandSplitCoreBase,
|
| 342 |
-
):
|
| 343 |
-
def __init__(
|
| 344 |
-
self,
|
| 345 |
-
in_channel: int,
|
| 346 |
-
stems: List[str],
|
| 347 |
-
band_specs: List[Tuple[float, float]],
|
| 348 |
-
require_no_overlap: bool = False,
|
| 349 |
-
require_no_gap: bool = True,
|
| 350 |
-
normalize_channel_independently: bool = False,
|
| 351 |
-
treat_channel_as_feature: bool = True,
|
| 352 |
-
n_sqm_modules: int = 12,
|
| 353 |
-
emb_dim: int = 128,
|
| 354 |
-
rnn_dim: int = 256,
|
| 355 |
-
bidirectional: bool = True,
|
| 356 |
-
tf_dropout: float = 0.0,
|
| 357 |
-
mlp_dim: int = 512,
|
| 358 |
-
hidden_activation: str = "Tanh",
|
| 359 |
-
hidden_activation_kwargs: Optional[Dict] = None,
|
| 360 |
-
complex_mask: bool = True,
|
| 361 |
-
overlapping_band: bool = False,
|
| 362 |
-
freq_weights: Optional[List[torch.Tensor]] = None,
|
| 363 |
-
n_freq: Optional[int] = None,
|
| 364 |
-
use_freq_weights: bool = True,
|
| 365 |
-
rnn_type: str = "LSTM",
|
| 366 |
-
cond_dim: int = 0,
|
| 367 |
-
mult_add_mask: bool = False,
|
| 368 |
-
) -> None:
|
| 369 |
-
super().__init__()
|
| 370 |
-
self.instantiate_bandsplit(
|
| 371 |
-
in_channel=in_channel,
|
| 372 |
-
band_specs=band_specs,
|
| 373 |
-
require_no_overlap=require_no_overlap,
|
| 374 |
-
require_no_gap=require_no_gap,
|
| 375 |
-
normalize_channel_independently=normalize_channel_independently,
|
| 376 |
-
treat_channel_as_feature=treat_channel_as_feature,
|
| 377 |
-
emb_dim=emb_dim,
|
| 378 |
-
)
|
| 379 |
-
self.tf_model = TransformerTimeFreqModule(
|
| 380 |
-
n_modules=n_sqm_modules,
|
| 381 |
-
emb_dim=emb_dim,
|
| 382 |
-
rnn_dim=rnn_dim,
|
| 383 |
-
bidirectional=bidirectional,
|
| 384 |
-
dropout=tf_dropout,
|
| 385 |
-
)
|
| 386 |
-
|
| 387 |
-
self.instantiate_mask_estim(
|
| 388 |
-
in_channel=in_channel,
|
| 389 |
-
stems=stems,
|
| 390 |
-
band_specs=band_specs,
|
| 391 |
-
emb_dim=emb_dim,
|
| 392 |
-
mlp_dim=mlp_dim,
|
| 393 |
-
cond_dim=cond_dim,
|
| 394 |
-
hidden_activation=hidden_activation,
|
| 395 |
-
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 396 |
-
complex_mask=complex_mask,
|
| 397 |
-
overlapping_band=overlapping_band,
|
| 398 |
-
freq_weights=freq_weights,
|
| 399 |
-
n_freq=n_freq,
|
| 400 |
-
use_freq_weights=use_freq_weights,
|
| 401 |
-
mult_add_mask=mult_add_mask,
|
| 402 |
-
)
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
class MultiSourceMultiMaskBandSplitCoreConv(
|
| 406 |
-
MultiMaskBandSplitCoreBase,
|
| 407 |
-
):
|
| 408 |
-
def __init__(
|
| 409 |
-
self,
|
| 410 |
-
in_channel: int,
|
| 411 |
-
stems: List[str],
|
| 412 |
-
band_specs: List[Tuple[float, float]],
|
| 413 |
-
require_no_overlap: bool = False,
|
| 414 |
-
require_no_gap: bool = True,
|
| 415 |
-
normalize_channel_independently: bool = False,
|
| 416 |
-
treat_channel_as_feature: bool = True,
|
| 417 |
-
n_sqm_modules: int = 12,
|
| 418 |
-
emb_dim: int = 128,
|
| 419 |
-
rnn_dim: int = 256,
|
| 420 |
-
bidirectional: bool = True,
|
| 421 |
-
tf_dropout: float = 0.0,
|
| 422 |
-
mlp_dim: int = 512,
|
| 423 |
-
hidden_activation: str = "Tanh",
|
| 424 |
-
hidden_activation_kwargs: Optional[Dict] = None,
|
| 425 |
-
complex_mask: bool = True,
|
| 426 |
-
overlapping_band: bool = False,
|
| 427 |
-
freq_weights: Optional[List[torch.Tensor]] = None,
|
| 428 |
-
n_freq: Optional[int] = None,
|
| 429 |
-
use_freq_weights: bool = True,
|
| 430 |
-
rnn_type: str = "LSTM",
|
| 431 |
-
cond_dim: int = 0,
|
| 432 |
-
mult_add_mask: bool = False,
|
| 433 |
-
) -> None:
|
| 434 |
-
super().__init__()
|
| 435 |
-
self.instantiate_bandsplit(
|
| 436 |
-
in_channel=in_channel,
|
| 437 |
-
band_specs=band_specs,
|
| 438 |
-
require_no_overlap=require_no_overlap,
|
| 439 |
-
require_no_gap=require_no_gap,
|
| 440 |
-
normalize_channel_independently=normalize_channel_independently,
|
| 441 |
-
treat_channel_as_feature=treat_channel_as_feature,
|
| 442 |
-
emb_dim=emb_dim,
|
| 443 |
-
)
|
| 444 |
-
self.tf_model = ConvolutionalTimeFreqModule(
|
| 445 |
-
n_modules=n_sqm_modules,
|
| 446 |
-
emb_dim=emb_dim,
|
| 447 |
-
rnn_dim=rnn_dim,
|
| 448 |
-
bidirectional=bidirectional,
|
| 449 |
-
dropout=tf_dropout,
|
| 450 |
-
)
|
| 451 |
-
|
| 452 |
-
self.instantiate_mask_estim(
|
| 453 |
-
in_channel=in_channel,
|
| 454 |
-
stems=stems,
|
| 455 |
-
band_specs=band_specs,
|
| 456 |
-
emb_dim=emb_dim,
|
| 457 |
-
mlp_dim=mlp_dim,
|
| 458 |
-
cond_dim=cond_dim,
|
| 459 |
-
hidden_activation=hidden_activation,
|
| 460 |
-
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 461 |
-
complex_mask=complex_mask,
|
| 462 |
-
overlapping_band=overlapping_band,
|
| 463 |
-
freq_weights=freq_weights,
|
| 464 |
-
n_freq=n_freq,
|
| 465 |
-
use_freq_weights=use_freq_weights,
|
| 466 |
-
mult_add_mask=mult_add_mask,
|
| 467 |
-
)
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
class PatchingMaskBandsplitCoreBase(MultiMaskBandSplitCoreBase):
|
| 471 |
-
def __init__(self) -> None:
|
| 472 |
-
super().__init__()
|
| 473 |
-
|
| 474 |
-
def mask(self, x, m):
|
| 475 |
-
|
| 476 |
-
_, n_channel, kernel_freq, kernel_time, n_freq, n_time = m.shape
|
| 477 |
-
padding = ((kernel_freq - 1) // 2, (kernel_time - 1) // 2)
|
| 478 |
-
|
| 479 |
-
xf = F.unfold(
|
| 480 |
-
x,
|
| 481 |
-
kernel_size=(kernel_freq, kernel_time),
|
| 482 |
-
padding=padding,
|
| 483 |
-
stride=(1, 1),
|
| 484 |
-
)
|
| 485 |
-
|
| 486 |
-
xf = xf.view(
|
| 487 |
-
-1,
|
| 488 |
-
n_channel,
|
| 489 |
-
kernel_freq,
|
| 490 |
-
kernel_time,
|
| 491 |
-
n_freq,
|
| 492 |
-
n_time,
|
| 493 |
-
)
|
| 494 |
-
|
| 495 |
-
sf = xf * m
|
| 496 |
-
|
| 497 |
-
sf = sf.view(
|
| 498 |
-
-1,
|
| 499 |
-
n_channel * kernel_freq * kernel_time,
|
| 500 |
-
n_freq * n_time,
|
| 501 |
-
)
|
| 502 |
-
|
| 503 |
-
s = F.fold(
|
| 504 |
-
sf,
|
| 505 |
-
output_size=(n_freq, n_time),
|
| 506 |
-
kernel_size=(kernel_freq, kernel_time),
|
| 507 |
-
padding=padding,
|
| 508 |
-
stride=(1, 1),
|
| 509 |
-
).view(
|
| 510 |
-
-1,
|
| 511 |
-
n_channel,
|
| 512 |
-
n_freq,
|
| 513 |
-
n_time,
|
| 514 |
-
)
|
| 515 |
-
|
| 516 |
-
return s
|
| 517 |
-
|
| 518 |
-
def old_mask(self, x, m):
|
| 519 |
-
|
| 520 |
-
s = torch.zeros_like(x)
|
| 521 |
-
|
| 522 |
-
_, n_channel, n_freq, n_time = x.shape
|
| 523 |
-
kernel_freq, kernel_time, _, _, _, _ = m.shape
|
| 524 |
-
|
| 525 |
-
kernel_freq_half = (kernel_freq - 1) // 2
|
| 526 |
-
kernel_time_half = (kernel_time - 1) // 2
|
| 527 |
-
|
| 528 |
-
for ifreq in range(kernel_freq):
|
| 529 |
-
for itime in range(kernel_time):
|
| 530 |
-
df, dt = kernel_freq_half - ifreq, kernel_time_half - itime
|
| 531 |
-
x = x.roll(shifts=(df, dt), dims=(2, 3))
|
| 532 |
-
|
| 533 |
-
fslice = slice(max(0, df), min(n_freq, n_freq + df))
|
| 534 |
-
tslice = slice(max(0, dt), min(n_time, n_time + dt))
|
| 535 |
-
|
| 536 |
-
s[:, :, fslice, tslice] += (
|
| 537 |
-
x[:, :, fslice, tslice] * m[ifreq, itime, :, :, fslice, tslice]
|
| 538 |
-
)
|
| 539 |
-
|
| 540 |
-
return s
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
class MultiSourceMultiPatchingMaskBandSplitCoreRNN(PatchingMaskBandsplitCoreBase):
|
| 544 |
-
def __init__(
|
| 545 |
-
self,
|
| 546 |
-
in_channel: int,
|
| 547 |
-
stems: List[str],
|
| 548 |
-
band_specs: List[Tuple[float, float]],
|
| 549 |
-
mask_kernel_freq: int,
|
| 550 |
-
mask_kernel_time: int,
|
| 551 |
-
conv_kernel_freq: int,
|
| 552 |
-
conv_kernel_time: int,
|
| 553 |
-
kernel_norm_mlp_version: int,
|
| 554 |
-
require_no_overlap: bool = False,
|
| 555 |
-
require_no_gap: bool = True,
|
| 556 |
-
normalize_channel_independently: bool = False,
|
| 557 |
-
treat_channel_as_feature: bool = True,
|
| 558 |
-
n_sqm_modules: int = 12,
|
| 559 |
-
emb_dim: int = 128,
|
| 560 |
-
rnn_dim: int = 256,
|
| 561 |
-
bidirectional: bool = True,
|
| 562 |
-
rnn_type: str = "LSTM",
|
| 563 |
-
mlp_dim: int = 512,
|
| 564 |
-
hidden_activation: str = "Tanh",
|
| 565 |
-
hidden_activation_kwargs: Optional[Dict] = None,
|
| 566 |
-
complex_mask: bool = True,
|
| 567 |
-
overlapping_band: bool = False,
|
| 568 |
-
freq_weights: Optional[List[torch.Tensor]] = None,
|
| 569 |
-
n_freq: Optional[int] = None,
|
| 570 |
-
) -> None:
|
| 571 |
-
|
| 572 |
-
super().__init__()
|
| 573 |
-
self.band_split = BandSplitModule(
|
| 574 |
-
in_channel=in_channel,
|
| 575 |
-
band_specs=band_specs,
|
| 576 |
-
require_no_overlap=require_no_overlap,
|
| 577 |
-
require_no_gap=require_no_gap,
|
| 578 |
-
normalize_channel_independently=normalize_channel_independently,
|
| 579 |
-
treat_channel_as_feature=treat_channel_as_feature,
|
| 580 |
-
emb_dim=emb_dim,
|
| 581 |
-
)
|
| 582 |
-
|
| 583 |
-
self.tf_model = SeqBandModellingModule(
|
| 584 |
-
n_modules=n_sqm_modules,
|
| 585 |
-
emb_dim=emb_dim,
|
| 586 |
-
rnn_dim=rnn_dim,
|
| 587 |
-
bidirectional=bidirectional,
|
| 588 |
-
rnn_type=rnn_type,
|
| 589 |
-
)
|
| 590 |
-
|
| 591 |
-
if hidden_activation_kwargs is None:
|
| 592 |
-
hidden_activation_kwargs = {}
|
| 593 |
-
|
| 594 |
-
if overlapping_band:
|
| 595 |
-
assert freq_weights is not None
|
| 596 |
-
assert n_freq is not None
|
| 597 |
-
self.mask_estim = nn.ModuleDict(
|
| 598 |
-
{
|
| 599 |
-
stem: PatchingMaskEstimationModule(
|
| 600 |
-
band_specs=band_specs,
|
| 601 |
-
freq_weights=freq_weights,
|
| 602 |
-
n_freq=n_freq,
|
| 603 |
-
emb_dim=emb_dim,
|
| 604 |
-
mlp_dim=mlp_dim,
|
| 605 |
-
in_channel=in_channel,
|
| 606 |
-
hidden_activation=hidden_activation,
|
| 607 |
-
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 608 |
-
complex_mask=complex_mask,
|
| 609 |
-
mask_kernel_freq=mask_kernel_freq,
|
| 610 |
-
mask_kernel_time=mask_kernel_time,
|
| 611 |
-
conv_kernel_freq=conv_kernel_freq,
|
| 612 |
-
conv_kernel_time=conv_kernel_time,
|
| 613 |
-
kernel_norm_mlp_version=kernel_norm_mlp_version,
|
| 614 |
-
)
|
| 615 |
-
for stem in stems
|
| 616 |
-
}
|
| 617 |
-
)
|
| 618 |
-
else:
|
| 619 |
-
raise NotImplementedError
|
|
|
|
| 1 |
+
from typing import Dict, List, Optional, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
|
| 7 |
+
from . import BandsplitCoreBase
|
| 8 |
+
from .bandsplit import BandSplitModule
|
| 9 |
+
from .maskestim import (
|
| 10 |
+
MaskEstimationModule,
|
| 11 |
+
OverlappingMaskEstimationModule,
|
| 12 |
+
)
|
| 13 |
+
from .tfmodel import (
|
| 14 |
+
ConvolutionalTimeFreqModule,
|
| 15 |
+
SeqBandModellingModule,
|
| 16 |
+
TransformerTimeFreqModule,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class MultiMaskBandSplitCoreBase(BandsplitCoreBase):
|
| 21 |
+
def __init__(self) -> None:
|
| 22 |
+
super().__init__()
|
| 23 |
+
|
| 24 |
+
def forward(self, x, cond=None, compute_residual: bool = True):
|
| 25 |
+
batch, in_chan, n_freq, n_time = x.shape
|
| 26 |
+
x = torch.reshape(x, (-1, 1, n_freq, n_time))
|
| 27 |
+
|
| 28 |
+
z = self.band_split(x)
|
| 29 |
+
|
| 30 |
+
q = self.tf_model(z)
|
| 31 |
+
|
| 32 |
+
out = {}
|
| 33 |
+
|
| 34 |
+
for stem, mem in self.mask_estim.items():
|
| 35 |
+
m = mem(q, cond=cond)
|
| 36 |
+
|
| 37 |
+
s = self.mask(x, m)
|
| 38 |
+
s = torch.reshape(s, (batch, in_chan, n_freq, n_time))
|
| 39 |
+
out[stem] = s
|
| 40 |
+
|
| 41 |
+
return {"spectrogram": out}
|
| 42 |
+
|
| 43 |
+
def instantiate_mask_estim(
|
| 44 |
+
self,
|
| 45 |
+
in_channel: int,
|
| 46 |
+
stems: List[str],
|
| 47 |
+
band_specs: List[Tuple[float, float]],
|
| 48 |
+
emb_dim: int,
|
| 49 |
+
mlp_dim: int,
|
| 50 |
+
cond_dim: int,
|
| 51 |
+
hidden_activation: str,
|
| 52 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 53 |
+
complex_mask: bool = True,
|
| 54 |
+
overlapping_band: bool = False,
|
| 55 |
+
freq_weights: Optional[List[torch.Tensor]] = None,
|
| 56 |
+
n_freq: Optional[int] = None,
|
| 57 |
+
use_freq_weights: bool = True,
|
| 58 |
+
mult_add_mask: bool = False,
|
| 59 |
+
):
|
| 60 |
+
if hidden_activation_kwargs is None:
|
| 61 |
+
hidden_activation_kwargs = {}
|
| 62 |
+
|
| 63 |
+
if "mne:+" in stems:
|
| 64 |
+
stems = [s for s in stems if s != "mne:+"]
|
| 65 |
+
|
| 66 |
+
if overlapping_band:
|
| 67 |
+
assert freq_weights is not None
|
| 68 |
+
assert n_freq is not None
|
| 69 |
+
|
| 70 |
+
if mult_add_mask:
|
| 71 |
+
|
| 72 |
+
self.mask_estim = nn.ModuleDict(
|
| 73 |
+
{
|
| 74 |
+
stem: MultAddMaskEstimationModule(
|
| 75 |
+
band_specs=band_specs,
|
| 76 |
+
freq_weights=freq_weights,
|
| 77 |
+
n_freq=n_freq,
|
| 78 |
+
emb_dim=emb_dim,
|
| 79 |
+
mlp_dim=mlp_dim,
|
| 80 |
+
in_channel=in_channel,
|
| 81 |
+
hidden_activation=hidden_activation,
|
| 82 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 83 |
+
complex_mask=complex_mask,
|
| 84 |
+
use_freq_weights=use_freq_weights,
|
| 85 |
+
)
|
| 86 |
+
for stem in stems
|
| 87 |
+
}
|
| 88 |
+
)
|
| 89 |
+
else:
|
| 90 |
+
self.mask_estim = nn.ModuleDict(
|
| 91 |
+
{
|
| 92 |
+
stem: OverlappingMaskEstimationModule(
|
| 93 |
+
band_specs=band_specs,
|
| 94 |
+
freq_weights=freq_weights,
|
| 95 |
+
n_freq=n_freq,
|
| 96 |
+
emb_dim=emb_dim,
|
| 97 |
+
mlp_dim=mlp_dim,
|
| 98 |
+
in_channel=in_channel,
|
| 99 |
+
hidden_activation=hidden_activation,
|
| 100 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 101 |
+
complex_mask=complex_mask,
|
| 102 |
+
use_freq_weights=use_freq_weights,
|
| 103 |
+
)
|
| 104 |
+
for stem in stems
|
| 105 |
+
}
|
| 106 |
+
)
|
| 107 |
+
else:
|
| 108 |
+
self.mask_estim = nn.ModuleDict(
|
| 109 |
+
{
|
| 110 |
+
stem: MaskEstimationModule(
|
| 111 |
+
band_specs=band_specs,
|
| 112 |
+
emb_dim=emb_dim,
|
| 113 |
+
mlp_dim=mlp_dim,
|
| 114 |
+
cond_dim=cond_dim,
|
| 115 |
+
in_channel=in_channel,
|
| 116 |
+
hidden_activation=hidden_activation,
|
| 117 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 118 |
+
complex_mask=complex_mask,
|
| 119 |
+
)
|
| 120 |
+
for stem in stems
|
| 121 |
+
}
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
def instantiate_bandsplit(
|
| 125 |
+
self,
|
| 126 |
+
in_channel: int,
|
| 127 |
+
band_specs: List[Tuple[float, float]],
|
| 128 |
+
require_no_overlap: bool = False,
|
| 129 |
+
require_no_gap: bool = True,
|
| 130 |
+
normalize_channel_independently: bool = False,
|
| 131 |
+
treat_channel_as_feature: bool = True,
|
| 132 |
+
emb_dim: int = 128,
|
| 133 |
+
):
|
| 134 |
+
self.band_split = BandSplitModule(
|
| 135 |
+
in_channel=in_channel,
|
| 136 |
+
band_specs=band_specs,
|
| 137 |
+
require_no_overlap=require_no_overlap,
|
| 138 |
+
require_no_gap=require_no_gap,
|
| 139 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 140 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 141 |
+
emb_dim=emb_dim,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class SingleMaskBandsplitCoreBase(BandsplitCoreBase):
|
| 146 |
+
def __init__(self, **kwargs) -> None:
|
| 147 |
+
super().__init__()
|
| 148 |
+
|
| 149 |
+
def forward(self, x):
|
| 150 |
+
z = self.band_split(x)
|
| 151 |
+
q = self.tf_model(z)
|
| 152 |
+
m = self.mask_estim(q)
|
| 153 |
+
|
| 154 |
+
s = self.mask(x, m)
|
| 155 |
+
|
| 156 |
+
return s
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class SingleMaskBandsplitCoreRNN(
|
| 160 |
+
SingleMaskBandsplitCoreBase,
|
| 161 |
+
):
|
| 162 |
+
def __init__(
|
| 163 |
+
self,
|
| 164 |
+
in_channel: int,
|
| 165 |
+
band_specs: List[Tuple[float, float]],
|
| 166 |
+
require_no_overlap: bool = False,
|
| 167 |
+
require_no_gap: bool = True,
|
| 168 |
+
normalize_channel_independently: bool = False,
|
| 169 |
+
treat_channel_as_feature: bool = True,
|
| 170 |
+
n_sqm_modules: int = 12,
|
| 171 |
+
emb_dim: int = 128,
|
| 172 |
+
rnn_dim: int = 256,
|
| 173 |
+
bidirectional: bool = True,
|
| 174 |
+
rnn_type: str = "LSTM",
|
| 175 |
+
mlp_dim: int = 512,
|
| 176 |
+
hidden_activation: str = "Tanh",
|
| 177 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 178 |
+
complex_mask: bool = True,
|
| 179 |
+
) -> None:
|
| 180 |
+
super().__init__()
|
| 181 |
+
self.band_split = BandSplitModule(
|
| 182 |
+
in_channel=in_channel,
|
| 183 |
+
band_specs=band_specs,
|
| 184 |
+
require_no_overlap=require_no_overlap,
|
| 185 |
+
require_no_gap=require_no_gap,
|
| 186 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 187 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 188 |
+
emb_dim=emb_dim,
|
| 189 |
+
)
|
| 190 |
+
self.tf_model = SeqBandModellingModule(
|
| 191 |
+
n_modules=n_sqm_modules,
|
| 192 |
+
emb_dim=emb_dim,
|
| 193 |
+
rnn_dim=rnn_dim,
|
| 194 |
+
bidirectional=bidirectional,
|
| 195 |
+
rnn_type=rnn_type,
|
| 196 |
+
)
|
| 197 |
+
self.mask_estim = MaskEstimationModule(
|
| 198 |
+
in_channel=in_channel,
|
| 199 |
+
band_specs=band_specs,
|
| 200 |
+
emb_dim=emb_dim,
|
| 201 |
+
mlp_dim=mlp_dim,
|
| 202 |
+
hidden_activation=hidden_activation,
|
| 203 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 204 |
+
complex_mask=complex_mask,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class SingleMaskBandsplitCoreTransformer(
|
| 209 |
+
SingleMaskBandsplitCoreBase,
|
| 210 |
+
):
|
| 211 |
+
def __init__(
|
| 212 |
+
self,
|
| 213 |
+
in_channel: int,
|
| 214 |
+
band_specs: List[Tuple[float, float]],
|
| 215 |
+
require_no_overlap: bool = False,
|
| 216 |
+
require_no_gap: bool = True,
|
| 217 |
+
normalize_channel_independently: bool = False,
|
| 218 |
+
treat_channel_as_feature: bool = True,
|
| 219 |
+
n_sqm_modules: int = 12,
|
| 220 |
+
emb_dim: int = 128,
|
| 221 |
+
rnn_dim: int = 256,
|
| 222 |
+
bidirectional: bool = True,
|
| 223 |
+
tf_dropout: float = 0.0,
|
| 224 |
+
mlp_dim: int = 512,
|
| 225 |
+
hidden_activation: str = "Tanh",
|
| 226 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 227 |
+
complex_mask: bool = True,
|
| 228 |
+
) -> None:
|
| 229 |
+
super().__init__()
|
| 230 |
+
self.band_split = BandSplitModule(
|
| 231 |
+
in_channel=in_channel,
|
| 232 |
+
band_specs=band_specs,
|
| 233 |
+
require_no_overlap=require_no_overlap,
|
| 234 |
+
require_no_gap=require_no_gap,
|
| 235 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 236 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 237 |
+
emb_dim=emb_dim,
|
| 238 |
+
)
|
| 239 |
+
self.tf_model = TransformerTimeFreqModule(
|
| 240 |
+
n_modules=n_sqm_modules,
|
| 241 |
+
emb_dim=emb_dim,
|
| 242 |
+
rnn_dim=rnn_dim,
|
| 243 |
+
bidirectional=bidirectional,
|
| 244 |
+
dropout=tf_dropout,
|
| 245 |
+
)
|
| 246 |
+
self.mask_estim = MaskEstimationModule(
|
| 247 |
+
in_channel=in_channel,
|
| 248 |
+
band_specs=band_specs,
|
| 249 |
+
emb_dim=emb_dim,
|
| 250 |
+
mlp_dim=mlp_dim,
|
| 251 |
+
hidden_activation=hidden_activation,
|
| 252 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 253 |
+
complex_mask=complex_mask,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
class MultiSourceMultiMaskBandSplitCoreRNN(MultiMaskBandSplitCoreBase):
|
| 258 |
+
def __init__(
|
| 259 |
+
self,
|
| 260 |
+
in_channel: int,
|
| 261 |
+
stems: List[str],
|
| 262 |
+
band_specs: List[Tuple[float, float]],
|
| 263 |
+
require_no_overlap: bool = False,
|
| 264 |
+
require_no_gap: bool = True,
|
| 265 |
+
normalize_channel_independently: bool = False,
|
| 266 |
+
treat_channel_as_feature: bool = True,
|
| 267 |
+
n_sqm_modules: int = 12,
|
| 268 |
+
emb_dim: int = 128,
|
| 269 |
+
rnn_dim: int = 256,
|
| 270 |
+
bidirectional: bool = True,
|
| 271 |
+
rnn_type: str = "LSTM",
|
| 272 |
+
mlp_dim: int = 512,
|
| 273 |
+
cond_dim: int = 0,
|
| 274 |
+
hidden_activation: str = "Tanh",
|
| 275 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 276 |
+
complex_mask: bool = True,
|
| 277 |
+
overlapping_band: bool = False,
|
| 278 |
+
freq_weights: Optional[List[torch.Tensor]] = None,
|
| 279 |
+
n_freq: Optional[int] = None,
|
| 280 |
+
use_freq_weights: bool = True,
|
| 281 |
+
mult_add_mask: bool = False,
|
| 282 |
+
) -> None:
|
| 283 |
+
|
| 284 |
+
super().__init__()
|
| 285 |
+
self.instantiate_bandsplit(
|
| 286 |
+
in_channel=in_channel,
|
| 287 |
+
band_specs=band_specs,
|
| 288 |
+
require_no_overlap=require_no_overlap,
|
| 289 |
+
require_no_gap=require_no_gap,
|
| 290 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 291 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 292 |
+
emb_dim=emb_dim,
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
self.tf_model = SeqBandModellingModule(
|
| 296 |
+
n_modules=n_sqm_modules,
|
| 297 |
+
emb_dim=emb_dim,
|
| 298 |
+
rnn_dim=rnn_dim,
|
| 299 |
+
bidirectional=bidirectional,
|
| 300 |
+
rnn_type=rnn_type,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
self.mult_add_mask = mult_add_mask
|
| 304 |
+
|
| 305 |
+
self.instantiate_mask_estim(
|
| 306 |
+
in_channel=in_channel,
|
| 307 |
+
stems=stems,
|
| 308 |
+
band_specs=band_specs,
|
| 309 |
+
emb_dim=emb_dim,
|
| 310 |
+
mlp_dim=mlp_dim,
|
| 311 |
+
cond_dim=cond_dim,
|
| 312 |
+
hidden_activation=hidden_activation,
|
| 313 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 314 |
+
complex_mask=complex_mask,
|
| 315 |
+
overlapping_band=overlapping_band,
|
| 316 |
+
freq_weights=freq_weights,
|
| 317 |
+
n_freq=n_freq,
|
| 318 |
+
use_freq_weights=use_freq_weights,
|
| 319 |
+
mult_add_mask=mult_add_mask,
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
@staticmethod
|
| 323 |
+
def _mult_add_mask(x, m):
|
| 324 |
+
|
| 325 |
+
assert m.ndim == 5
|
| 326 |
+
|
| 327 |
+
mm = m[..., 0]
|
| 328 |
+
am = m[..., 1]
|
| 329 |
+
|
| 330 |
+
return x * mm + am
|
| 331 |
+
|
| 332 |
+
def mask(self, x, m):
|
| 333 |
+
if self.mult_add_mask:
|
| 334 |
+
|
| 335 |
+
return self._mult_add_mask(x, m)
|
| 336 |
+
else:
|
| 337 |
+
return super().mask(x, m)
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
class MultiSourceMultiMaskBandSplitCoreTransformer(
|
| 341 |
+
MultiMaskBandSplitCoreBase,
|
| 342 |
+
):
|
| 343 |
+
def __init__(
|
| 344 |
+
self,
|
| 345 |
+
in_channel: int,
|
| 346 |
+
stems: List[str],
|
| 347 |
+
band_specs: List[Tuple[float, float]],
|
| 348 |
+
require_no_overlap: bool = False,
|
| 349 |
+
require_no_gap: bool = True,
|
| 350 |
+
normalize_channel_independently: bool = False,
|
| 351 |
+
treat_channel_as_feature: bool = True,
|
| 352 |
+
n_sqm_modules: int = 12,
|
| 353 |
+
emb_dim: int = 128,
|
| 354 |
+
rnn_dim: int = 256,
|
| 355 |
+
bidirectional: bool = True,
|
| 356 |
+
tf_dropout: float = 0.0,
|
| 357 |
+
mlp_dim: int = 512,
|
| 358 |
+
hidden_activation: str = "Tanh",
|
| 359 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 360 |
+
complex_mask: bool = True,
|
| 361 |
+
overlapping_band: bool = False,
|
| 362 |
+
freq_weights: Optional[List[torch.Tensor]] = None,
|
| 363 |
+
n_freq: Optional[int] = None,
|
| 364 |
+
use_freq_weights: bool = True,
|
| 365 |
+
rnn_type: str = "LSTM",
|
| 366 |
+
cond_dim: int = 0,
|
| 367 |
+
mult_add_mask: bool = False,
|
| 368 |
+
) -> None:
|
| 369 |
+
super().__init__()
|
| 370 |
+
self.instantiate_bandsplit(
|
| 371 |
+
in_channel=in_channel,
|
| 372 |
+
band_specs=band_specs,
|
| 373 |
+
require_no_overlap=require_no_overlap,
|
| 374 |
+
require_no_gap=require_no_gap,
|
| 375 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 376 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 377 |
+
emb_dim=emb_dim,
|
| 378 |
+
)
|
| 379 |
+
self.tf_model = TransformerTimeFreqModule(
|
| 380 |
+
n_modules=n_sqm_modules,
|
| 381 |
+
emb_dim=emb_dim,
|
| 382 |
+
rnn_dim=rnn_dim,
|
| 383 |
+
bidirectional=bidirectional,
|
| 384 |
+
dropout=tf_dropout,
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
self.instantiate_mask_estim(
|
| 388 |
+
in_channel=in_channel,
|
| 389 |
+
stems=stems,
|
| 390 |
+
band_specs=band_specs,
|
| 391 |
+
emb_dim=emb_dim,
|
| 392 |
+
mlp_dim=mlp_dim,
|
| 393 |
+
cond_dim=cond_dim,
|
| 394 |
+
hidden_activation=hidden_activation,
|
| 395 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 396 |
+
complex_mask=complex_mask,
|
| 397 |
+
overlapping_band=overlapping_band,
|
| 398 |
+
freq_weights=freq_weights,
|
| 399 |
+
n_freq=n_freq,
|
| 400 |
+
use_freq_weights=use_freq_weights,
|
| 401 |
+
mult_add_mask=mult_add_mask,
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
class MultiSourceMultiMaskBandSplitCoreConv(
|
| 406 |
+
MultiMaskBandSplitCoreBase,
|
| 407 |
+
):
|
| 408 |
+
def __init__(
|
| 409 |
+
self,
|
| 410 |
+
in_channel: int,
|
| 411 |
+
stems: List[str],
|
| 412 |
+
band_specs: List[Tuple[float, float]],
|
| 413 |
+
require_no_overlap: bool = False,
|
| 414 |
+
require_no_gap: bool = True,
|
| 415 |
+
normalize_channel_independently: bool = False,
|
| 416 |
+
treat_channel_as_feature: bool = True,
|
| 417 |
+
n_sqm_modules: int = 12,
|
| 418 |
+
emb_dim: int = 128,
|
| 419 |
+
rnn_dim: int = 256,
|
| 420 |
+
bidirectional: bool = True,
|
| 421 |
+
tf_dropout: float = 0.0,
|
| 422 |
+
mlp_dim: int = 512,
|
| 423 |
+
hidden_activation: str = "Tanh",
|
| 424 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 425 |
+
complex_mask: bool = True,
|
| 426 |
+
overlapping_band: bool = False,
|
| 427 |
+
freq_weights: Optional[List[torch.Tensor]] = None,
|
| 428 |
+
n_freq: Optional[int] = None,
|
| 429 |
+
use_freq_weights: bool = True,
|
| 430 |
+
rnn_type: str = "LSTM",
|
| 431 |
+
cond_dim: int = 0,
|
| 432 |
+
mult_add_mask: bool = False,
|
| 433 |
+
) -> None:
|
| 434 |
+
super().__init__()
|
| 435 |
+
self.instantiate_bandsplit(
|
| 436 |
+
in_channel=in_channel,
|
| 437 |
+
band_specs=band_specs,
|
| 438 |
+
require_no_overlap=require_no_overlap,
|
| 439 |
+
require_no_gap=require_no_gap,
|
| 440 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 441 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 442 |
+
emb_dim=emb_dim,
|
| 443 |
+
)
|
| 444 |
+
self.tf_model = ConvolutionalTimeFreqModule(
|
| 445 |
+
n_modules=n_sqm_modules,
|
| 446 |
+
emb_dim=emb_dim,
|
| 447 |
+
rnn_dim=rnn_dim,
|
| 448 |
+
bidirectional=bidirectional,
|
| 449 |
+
dropout=tf_dropout,
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
self.instantiate_mask_estim(
|
| 453 |
+
in_channel=in_channel,
|
| 454 |
+
stems=stems,
|
| 455 |
+
band_specs=band_specs,
|
| 456 |
+
emb_dim=emb_dim,
|
| 457 |
+
mlp_dim=mlp_dim,
|
| 458 |
+
cond_dim=cond_dim,
|
| 459 |
+
hidden_activation=hidden_activation,
|
| 460 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 461 |
+
complex_mask=complex_mask,
|
| 462 |
+
overlapping_band=overlapping_band,
|
| 463 |
+
freq_weights=freq_weights,
|
| 464 |
+
n_freq=n_freq,
|
| 465 |
+
use_freq_weights=use_freq_weights,
|
| 466 |
+
mult_add_mask=mult_add_mask,
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
class PatchingMaskBandsplitCoreBase(MultiMaskBandSplitCoreBase):
|
| 471 |
+
def __init__(self) -> None:
|
| 472 |
+
super().__init__()
|
| 473 |
+
|
| 474 |
+
def mask(self, x, m):
|
| 475 |
+
|
| 476 |
+
_, n_channel, kernel_freq, kernel_time, n_freq, n_time = m.shape
|
| 477 |
+
padding = ((kernel_freq - 1) // 2, (kernel_time - 1) // 2)
|
| 478 |
+
|
| 479 |
+
xf = F.unfold(
|
| 480 |
+
x,
|
| 481 |
+
kernel_size=(kernel_freq, kernel_time),
|
| 482 |
+
padding=padding,
|
| 483 |
+
stride=(1, 1),
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
xf = xf.view(
|
| 487 |
+
-1,
|
| 488 |
+
n_channel,
|
| 489 |
+
kernel_freq,
|
| 490 |
+
kernel_time,
|
| 491 |
+
n_freq,
|
| 492 |
+
n_time,
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
sf = xf * m
|
| 496 |
+
|
| 497 |
+
sf = sf.view(
|
| 498 |
+
-1,
|
| 499 |
+
n_channel * kernel_freq * kernel_time,
|
| 500 |
+
n_freq * n_time,
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
s = F.fold(
|
| 504 |
+
sf,
|
| 505 |
+
output_size=(n_freq, n_time),
|
| 506 |
+
kernel_size=(kernel_freq, kernel_time),
|
| 507 |
+
padding=padding,
|
| 508 |
+
stride=(1, 1),
|
| 509 |
+
).view(
|
| 510 |
+
-1,
|
| 511 |
+
n_channel,
|
| 512 |
+
n_freq,
|
| 513 |
+
n_time,
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
return s
|
| 517 |
+
|
| 518 |
+
def old_mask(self, x, m):
|
| 519 |
+
|
| 520 |
+
s = torch.zeros_like(x)
|
| 521 |
+
|
| 522 |
+
_, n_channel, n_freq, n_time = x.shape
|
| 523 |
+
kernel_freq, kernel_time, _, _, _, _ = m.shape
|
| 524 |
+
|
| 525 |
+
kernel_freq_half = (kernel_freq - 1) // 2
|
| 526 |
+
kernel_time_half = (kernel_time - 1) // 2
|
| 527 |
+
|
| 528 |
+
for ifreq in range(kernel_freq):
|
| 529 |
+
for itime in range(kernel_time):
|
| 530 |
+
df, dt = kernel_freq_half - ifreq, kernel_time_half - itime
|
| 531 |
+
x = x.roll(shifts=(df, dt), dims=(2, 3))
|
| 532 |
+
|
| 533 |
+
fslice = slice(max(0, df), min(n_freq, n_freq + df))
|
| 534 |
+
tslice = slice(max(0, dt), min(n_time, n_time + dt))
|
| 535 |
+
|
| 536 |
+
s[:, :, fslice, tslice] += (
|
| 537 |
+
x[:, :, fslice, tslice] * m[ifreq, itime, :, :, fslice, tslice]
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
return s
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
class MultiSourceMultiPatchingMaskBandSplitCoreRNN(PatchingMaskBandsplitCoreBase):
|
| 544 |
+
def __init__(
|
| 545 |
+
self,
|
| 546 |
+
in_channel: int,
|
| 547 |
+
stems: List[str],
|
| 548 |
+
band_specs: List[Tuple[float, float]],
|
| 549 |
+
mask_kernel_freq: int,
|
| 550 |
+
mask_kernel_time: int,
|
| 551 |
+
conv_kernel_freq: int,
|
| 552 |
+
conv_kernel_time: int,
|
| 553 |
+
kernel_norm_mlp_version: int,
|
| 554 |
+
require_no_overlap: bool = False,
|
| 555 |
+
require_no_gap: bool = True,
|
| 556 |
+
normalize_channel_independently: bool = False,
|
| 557 |
+
treat_channel_as_feature: bool = True,
|
| 558 |
+
n_sqm_modules: int = 12,
|
| 559 |
+
emb_dim: int = 128,
|
| 560 |
+
rnn_dim: int = 256,
|
| 561 |
+
bidirectional: bool = True,
|
| 562 |
+
rnn_type: str = "LSTM",
|
| 563 |
+
mlp_dim: int = 512,
|
| 564 |
+
hidden_activation: str = "Tanh",
|
| 565 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 566 |
+
complex_mask: bool = True,
|
| 567 |
+
overlapping_band: bool = False,
|
| 568 |
+
freq_weights: Optional[List[torch.Tensor]] = None,
|
| 569 |
+
n_freq: Optional[int] = None,
|
| 570 |
+
) -> None:
|
| 571 |
+
|
| 572 |
+
super().__init__()
|
| 573 |
+
self.band_split = BandSplitModule(
|
| 574 |
+
in_channel=in_channel,
|
| 575 |
+
band_specs=band_specs,
|
| 576 |
+
require_no_overlap=require_no_overlap,
|
| 577 |
+
require_no_gap=require_no_gap,
|
| 578 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 579 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 580 |
+
emb_dim=emb_dim,
|
| 581 |
+
)
|
| 582 |
+
|
| 583 |
+
self.tf_model = SeqBandModellingModule(
|
| 584 |
+
n_modules=n_sqm_modules,
|
| 585 |
+
emb_dim=emb_dim,
|
| 586 |
+
rnn_dim=rnn_dim,
|
| 587 |
+
bidirectional=bidirectional,
|
| 588 |
+
rnn_type=rnn_type,
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
if hidden_activation_kwargs is None:
|
| 592 |
+
hidden_activation_kwargs = {}
|
| 593 |
+
|
| 594 |
+
if overlapping_band:
|
| 595 |
+
assert freq_weights is not None
|
| 596 |
+
assert n_freq is not None
|
| 597 |
+
self.mask_estim = nn.ModuleDict(
|
| 598 |
+
{
|
| 599 |
+
stem: PatchingMaskEstimationModule(
|
| 600 |
+
band_specs=band_specs,
|
| 601 |
+
freq_weights=freq_weights,
|
| 602 |
+
n_freq=n_freq,
|
| 603 |
+
emb_dim=emb_dim,
|
| 604 |
+
mlp_dim=mlp_dim,
|
| 605 |
+
in_channel=in_channel,
|
| 606 |
+
hidden_activation=hidden_activation,
|
| 607 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 608 |
+
complex_mask=complex_mask,
|
| 609 |
+
mask_kernel_freq=mask_kernel_freq,
|
| 610 |
+
mask_kernel_time=mask_kernel_time,
|
| 611 |
+
conv_kernel_freq=conv_kernel_freq,
|
| 612 |
+
conv_kernel_time=conv_kernel_time,
|
| 613 |
+
kernel_norm_mlp_version=kernel_norm_mlp_version,
|
| 614 |
+
)
|
| 615 |
+
for stem in stems
|
| 616 |
+
}
|
| 617 |
+
)
|
| 618 |
+
else:
|
| 619 |
+
raise NotImplementedError
|
mvsepless/models/bandit/core/model/bsrnn/maskestim.py
CHANGED
|
@@ -1,327 +1,327 @@
|
|
| 1 |
-
import warnings
|
| 2 |
-
from typing import Dict, List, Optional, Tuple, Type
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
from torch import nn
|
| 6 |
-
from torch.nn.modules import activation
|
| 7 |
-
|
| 8 |
-
from .utils import (
|
| 9 |
-
band_widths_from_specs,
|
| 10 |
-
check_no_gap,
|
| 11 |
-
check_no_overlap,
|
| 12 |
-
check_nonzero_bandwidth,
|
| 13 |
-
)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
class BaseNormMLP(nn.Module):
|
| 17 |
-
def __init__(
|
| 18 |
-
self,
|
| 19 |
-
emb_dim: int,
|
| 20 |
-
mlp_dim: int,
|
| 21 |
-
bandwidth: int,
|
| 22 |
-
in_channel: Optional[int],
|
| 23 |
-
hidden_activation: str = "Tanh",
|
| 24 |
-
hidden_activation_kwargs=None,
|
| 25 |
-
complex_mask: bool = True,
|
| 26 |
-
):
|
| 27 |
-
|
| 28 |
-
super().__init__()
|
| 29 |
-
if hidden_activation_kwargs is None:
|
| 30 |
-
hidden_activation_kwargs = {}
|
| 31 |
-
self.hidden_activation_kwargs = hidden_activation_kwargs
|
| 32 |
-
self.norm = nn.LayerNorm(emb_dim)
|
| 33 |
-
self.hidden = torch.jit.script(
|
| 34 |
-
nn.Sequential(
|
| 35 |
-
nn.Linear(in_features=emb_dim, out_features=mlp_dim),
|
| 36 |
-
activation.__dict__[hidden_activation](**self.hidden_activation_kwargs),
|
| 37 |
-
)
|
| 38 |
-
)
|
| 39 |
-
|
| 40 |
-
self.bandwidth = bandwidth
|
| 41 |
-
self.in_channel = in_channel
|
| 42 |
-
|
| 43 |
-
self.complex_mask = complex_mask
|
| 44 |
-
self.reim = 2 if complex_mask else 1
|
| 45 |
-
self.glu_mult = 2
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
class NormMLP(BaseNormMLP):
|
| 49 |
-
def __init__(
|
| 50 |
-
self,
|
| 51 |
-
emb_dim: int,
|
| 52 |
-
mlp_dim: int,
|
| 53 |
-
bandwidth: int,
|
| 54 |
-
in_channel: Optional[int],
|
| 55 |
-
hidden_activation: str = "Tanh",
|
| 56 |
-
hidden_activation_kwargs=None,
|
| 57 |
-
complex_mask: bool = True,
|
| 58 |
-
) -> None:
|
| 59 |
-
super().__init__(
|
| 60 |
-
emb_dim=emb_dim,
|
| 61 |
-
mlp_dim=mlp_dim,
|
| 62 |
-
bandwidth=bandwidth,
|
| 63 |
-
in_channel=in_channel,
|
| 64 |
-
hidden_activation=hidden_activation,
|
| 65 |
-
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 66 |
-
complex_mask=complex_mask,
|
| 67 |
-
)
|
| 68 |
-
|
| 69 |
-
self.output = torch.jit.script(
|
| 70 |
-
nn.Sequential(
|
| 71 |
-
nn.Linear(
|
| 72 |
-
in_features=mlp_dim,
|
| 73 |
-
out_features=bandwidth * in_channel * self.reim * 2,
|
| 74 |
-
),
|
| 75 |
-
nn.GLU(dim=-1),
|
| 76 |
-
)
|
| 77 |
-
)
|
| 78 |
-
|
| 79 |
-
def reshape_output(self, mb):
|
| 80 |
-
batch, n_time, _ = mb.shape
|
| 81 |
-
if self.complex_mask:
|
| 82 |
-
mb = mb.reshape(
|
| 83 |
-
batch, n_time, self.in_channel, self.bandwidth, self.reim
|
| 84 |
-
).contiguous()
|
| 85 |
-
mb = torch.view_as_complex(mb)
|
| 86 |
-
else:
|
| 87 |
-
mb = mb.reshape(batch, n_time, self.in_channel, self.bandwidth)
|
| 88 |
-
|
| 89 |
-
mb = torch.permute(mb, (0, 2, 3, 1))
|
| 90 |
-
|
| 91 |
-
return mb
|
| 92 |
-
|
| 93 |
-
def forward(self, qb):
|
| 94 |
-
|
| 95 |
-
qb = self.norm(qb)
|
| 96 |
-
|
| 97 |
-
qb = self.hidden(qb)
|
| 98 |
-
mb = self.output(qb)
|
| 99 |
-
mb = self.reshape_output(mb)
|
| 100 |
-
|
| 101 |
-
return mb
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
class MultAddNormMLP(NormMLP):
|
| 105 |
-
def __init__(
|
| 106 |
-
self,
|
| 107 |
-
emb_dim: int,
|
| 108 |
-
mlp_dim: int,
|
| 109 |
-
bandwidth: int,
|
| 110 |
-
in_channel: "int | None",
|
| 111 |
-
hidden_activation: str = "Tanh",
|
| 112 |
-
hidden_activation_kwargs=None,
|
| 113 |
-
complex_mask: bool = True,
|
| 114 |
-
) -> None:
|
| 115 |
-
super().__init__(
|
| 116 |
-
emb_dim,
|
| 117 |
-
mlp_dim,
|
| 118 |
-
bandwidth,
|
| 119 |
-
in_channel,
|
| 120 |
-
hidden_activation,
|
| 121 |
-
hidden_activation_kwargs,
|
| 122 |
-
complex_mask,
|
| 123 |
-
)
|
| 124 |
-
|
| 125 |
-
self.output2 = torch.jit.script(
|
| 126 |
-
nn.Sequential(
|
| 127 |
-
nn.Linear(
|
| 128 |
-
in_features=mlp_dim,
|
| 129 |
-
out_features=bandwidth * in_channel * self.reim * 2,
|
| 130 |
-
),
|
| 131 |
-
nn.GLU(dim=-1),
|
| 132 |
-
)
|
| 133 |
-
)
|
| 134 |
-
|
| 135 |
-
def forward(self, qb):
|
| 136 |
-
|
| 137 |
-
qb = self.norm(qb)
|
| 138 |
-
qb = self.hidden(qb)
|
| 139 |
-
mmb = self.output(qb)
|
| 140 |
-
mmb = self.reshape_output(mmb)
|
| 141 |
-
amb = self.output2(qb)
|
| 142 |
-
amb = self.reshape_output(amb)
|
| 143 |
-
|
| 144 |
-
return mmb, amb
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
class MaskEstimationModuleSuperBase(nn.Module):
|
| 148 |
-
pass
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
class MaskEstimationModuleBase(MaskEstimationModuleSuperBase):
|
| 152 |
-
def __init__(
|
| 153 |
-
self,
|
| 154 |
-
band_specs: List[Tuple[float, float]],
|
| 155 |
-
emb_dim: int,
|
| 156 |
-
mlp_dim: int,
|
| 157 |
-
in_channel: Optional[int],
|
| 158 |
-
hidden_activation: str = "Tanh",
|
| 159 |
-
hidden_activation_kwargs: Dict = None,
|
| 160 |
-
complex_mask: bool = True,
|
| 161 |
-
norm_mlp_cls: Type[nn.Module] = NormMLP,
|
| 162 |
-
norm_mlp_kwargs: Dict = None,
|
| 163 |
-
) -> None:
|
| 164 |
-
super().__init__()
|
| 165 |
-
|
| 166 |
-
self.band_widths = band_widths_from_specs(band_specs)
|
| 167 |
-
self.n_bands = len(band_specs)
|
| 168 |
-
|
| 169 |
-
if hidden_activation_kwargs is None:
|
| 170 |
-
hidden_activation_kwargs = {}
|
| 171 |
-
|
| 172 |
-
if norm_mlp_kwargs is None:
|
| 173 |
-
norm_mlp_kwargs = {}
|
| 174 |
-
|
| 175 |
-
self.norm_mlp = nn.ModuleList(
|
| 176 |
-
[
|
| 177 |
-
(
|
| 178 |
-
norm_mlp_cls(
|
| 179 |
-
bandwidth=self.band_widths[b],
|
| 180 |
-
emb_dim=emb_dim,
|
| 181 |
-
mlp_dim=mlp_dim,
|
| 182 |
-
in_channel=in_channel,
|
| 183 |
-
hidden_activation=hidden_activation,
|
| 184 |
-
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 185 |
-
complex_mask=complex_mask,
|
| 186 |
-
**norm_mlp_kwargs,
|
| 187 |
-
)
|
| 188 |
-
)
|
| 189 |
-
for b in range(self.n_bands)
|
| 190 |
-
]
|
| 191 |
-
)
|
| 192 |
-
|
| 193 |
-
def compute_masks(self, q):
|
| 194 |
-
batch, n_bands, n_time, emb_dim = q.shape
|
| 195 |
-
|
| 196 |
-
masks = []
|
| 197 |
-
|
| 198 |
-
for b, nmlp in enumerate(self.norm_mlp):
|
| 199 |
-
qb = q[:, b, :, :]
|
| 200 |
-
mb = nmlp(qb)
|
| 201 |
-
masks.append(mb)
|
| 202 |
-
|
| 203 |
-
return masks
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
class OverlappingMaskEstimationModule(MaskEstimationModuleBase):
|
| 207 |
-
def __init__(
|
| 208 |
-
self,
|
| 209 |
-
in_channel: int,
|
| 210 |
-
band_specs: List[Tuple[float, float]],
|
| 211 |
-
freq_weights: List[torch.Tensor],
|
| 212 |
-
n_freq: int,
|
| 213 |
-
emb_dim: int,
|
| 214 |
-
mlp_dim: int,
|
| 215 |
-
cond_dim: int = 0,
|
| 216 |
-
hidden_activation: str = "Tanh",
|
| 217 |
-
hidden_activation_kwargs: Dict = None,
|
| 218 |
-
complex_mask: bool = True,
|
| 219 |
-
norm_mlp_cls: Type[nn.Module] = NormMLP,
|
| 220 |
-
norm_mlp_kwargs: Dict = None,
|
| 221 |
-
use_freq_weights: bool = True,
|
| 222 |
-
) -> None:
|
| 223 |
-
check_nonzero_bandwidth(band_specs)
|
| 224 |
-
check_no_gap(band_specs)
|
| 225 |
-
|
| 226 |
-
super().__init__(
|
| 227 |
-
band_specs=band_specs,
|
| 228 |
-
emb_dim=emb_dim + cond_dim,
|
| 229 |
-
mlp_dim=mlp_dim,
|
| 230 |
-
in_channel=in_channel,
|
| 231 |
-
hidden_activation=hidden_activation,
|
| 232 |
-
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 233 |
-
complex_mask=complex_mask,
|
| 234 |
-
norm_mlp_cls=norm_mlp_cls,
|
| 235 |
-
norm_mlp_kwargs=norm_mlp_kwargs,
|
| 236 |
-
)
|
| 237 |
-
|
| 238 |
-
self.n_freq = n_freq
|
| 239 |
-
self.band_specs = band_specs
|
| 240 |
-
self.in_channel = in_channel
|
| 241 |
-
|
| 242 |
-
if freq_weights is not None:
|
| 243 |
-
for i, fw in enumerate(freq_weights):
|
| 244 |
-
self.register_buffer(f"freq_weights/{i}", fw)
|
| 245 |
-
|
| 246 |
-
self.use_freq_weights = use_freq_weights
|
| 247 |
-
else:
|
| 248 |
-
self.use_freq_weights = False
|
| 249 |
-
|
| 250 |
-
self.cond_dim = cond_dim
|
| 251 |
-
|
| 252 |
-
def forward(self, q, cond=None):
|
| 253 |
-
|
| 254 |
-
batch, n_bands, n_time, emb_dim = q.shape
|
| 255 |
-
|
| 256 |
-
if cond is not None:
|
| 257 |
-
print(cond)
|
| 258 |
-
if cond.ndim == 2:
|
| 259 |
-
cond = cond[:, None, None, :].expand(-1, n_bands, n_time, -1)
|
| 260 |
-
elif cond.ndim == 3:
|
| 261 |
-
assert cond.shape[1] == n_time
|
| 262 |
-
else:
|
| 263 |
-
raise ValueError(f"Invalid cond shape: {cond.shape}")
|
| 264 |
-
|
| 265 |
-
q = torch.cat([q, cond], dim=-1)
|
| 266 |
-
elif self.cond_dim > 0:
|
| 267 |
-
cond = torch.ones(
|
| 268 |
-
(batch, n_bands, n_time, self.cond_dim),
|
| 269 |
-
device=q.device,
|
| 270 |
-
dtype=q.dtype,
|
| 271 |
-
)
|
| 272 |
-
q = torch.cat([q, cond], dim=-1)
|
| 273 |
-
else:
|
| 274 |
-
pass
|
| 275 |
-
|
| 276 |
-
mask_list = self.compute_masks(q)
|
| 277 |
-
|
| 278 |
-
masks = torch.zeros(
|
| 279 |
-
(batch, self.in_channel, self.n_freq, n_time),
|
| 280 |
-
device=q.device,
|
| 281 |
-
dtype=mask_list[0].dtype,
|
| 282 |
-
)
|
| 283 |
-
|
| 284 |
-
for im, mask in enumerate(mask_list):
|
| 285 |
-
fstart, fend = self.band_specs[im]
|
| 286 |
-
if self.use_freq_weights:
|
| 287 |
-
fw = self.get_buffer(f"freq_weights/{im}")[:, None]
|
| 288 |
-
mask = mask * fw
|
| 289 |
-
masks[:, :, fstart:fend, :] += mask
|
| 290 |
-
|
| 291 |
-
return masks
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
class MaskEstimationModule(OverlappingMaskEstimationModule):
|
| 295 |
-
def __init__(
|
| 296 |
-
self,
|
| 297 |
-
band_specs: List[Tuple[float, float]],
|
| 298 |
-
emb_dim: int,
|
| 299 |
-
mlp_dim: int,
|
| 300 |
-
in_channel: Optional[int],
|
| 301 |
-
hidden_activation: str = "Tanh",
|
| 302 |
-
hidden_activation_kwargs: Dict = None,
|
| 303 |
-
complex_mask: bool = True,
|
| 304 |
-
**kwargs,
|
| 305 |
-
) -> None:
|
| 306 |
-
check_nonzero_bandwidth(band_specs)
|
| 307 |
-
check_no_gap(band_specs)
|
| 308 |
-
check_no_overlap(band_specs)
|
| 309 |
-
super().__init__(
|
| 310 |
-
in_channel=in_channel,
|
| 311 |
-
band_specs=band_specs,
|
| 312 |
-
freq_weights=None,
|
| 313 |
-
n_freq=None,
|
| 314 |
-
emb_dim=emb_dim,
|
| 315 |
-
mlp_dim=mlp_dim,
|
| 316 |
-
hidden_activation=hidden_activation,
|
| 317 |
-
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 318 |
-
complex_mask=complex_mask,
|
| 319 |
-
)
|
| 320 |
-
|
| 321 |
-
def forward(self, q, cond=None):
|
| 322 |
-
|
| 323 |
-
masks = self.compute_masks(q)
|
| 324 |
-
|
| 325 |
-
masks = torch.concat(masks, dim=2)
|
| 326 |
-
|
| 327 |
-
return masks
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
from typing import Dict, List, Optional, Tuple, Type
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.nn.modules import activation
|
| 7 |
+
|
| 8 |
+
from .utils import (
|
| 9 |
+
band_widths_from_specs,
|
| 10 |
+
check_no_gap,
|
| 11 |
+
check_no_overlap,
|
| 12 |
+
check_nonzero_bandwidth,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class BaseNormMLP(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
emb_dim: int,
|
| 20 |
+
mlp_dim: int,
|
| 21 |
+
bandwidth: int,
|
| 22 |
+
in_channel: Optional[int],
|
| 23 |
+
hidden_activation: str = "Tanh",
|
| 24 |
+
hidden_activation_kwargs=None,
|
| 25 |
+
complex_mask: bool = True,
|
| 26 |
+
):
|
| 27 |
+
|
| 28 |
+
super().__init__()
|
| 29 |
+
if hidden_activation_kwargs is None:
|
| 30 |
+
hidden_activation_kwargs = {}
|
| 31 |
+
self.hidden_activation_kwargs = hidden_activation_kwargs
|
| 32 |
+
self.norm = nn.LayerNorm(emb_dim)
|
| 33 |
+
self.hidden = torch.jit.script(
|
| 34 |
+
nn.Sequential(
|
| 35 |
+
nn.Linear(in_features=emb_dim, out_features=mlp_dim),
|
| 36 |
+
activation.__dict__[hidden_activation](**self.hidden_activation_kwargs),
|
| 37 |
+
)
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
self.bandwidth = bandwidth
|
| 41 |
+
self.in_channel = in_channel
|
| 42 |
+
|
| 43 |
+
self.complex_mask = complex_mask
|
| 44 |
+
self.reim = 2 if complex_mask else 1
|
| 45 |
+
self.glu_mult = 2
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class NormMLP(BaseNormMLP):
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
emb_dim: int,
|
| 52 |
+
mlp_dim: int,
|
| 53 |
+
bandwidth: int,
|
| 54 |
+
in_channel: Optional[int],
|
| 55 |
+
hidden_activation: str = "Tanh",
|
| 56 |
+
hidden_activation_kwargs=None,
|
| 57 |
+
complex_mask: bool = True,
|
| 58 |
+
) -> None:
|
| 59 |
+
super().__init__(
|
| 60 |
+
emb_dim=emb_dim,
|
| 61 |
+
mlp_dim=mlp_dim,
|
| 62 |
+
bandwidth=bandwidth,
|
| 63 |
+
in_channel=in_channel,
|
| 64 |
+
hidden_activation=hidden_activation,
|
| 65 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 66 |
+
complex_mask=complex_mask,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
self.output = torch.jit.script(
|
| 70 |
+
nn.Sequential(
|
| 71 |
+
nn.Linear(
|
| 72 |
+
in_features=mlp_dim,
|
| 73 |
+
out_features=bandwidth * in_channel * self.reim * 2,
|
| 74 |
+
),
|
| 75 |
+
nn.GLU(dim=-1),
|
| 76 |
+
)
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
def reshape_output(self, mb):
|
| 80 |
+
batch, n_time, _ = mb.shape
|
| 81 |
+
if self.complex_mask:
|
| 82 |
+
mb = mb.reshape(
|
| 83 |
+
batch, n_time, self.in_channel, self.bandwidth, self.reim
|
| 84 |
+
).contiguous()
|
| 85 |
+
mb = torch.view_as_complex(mb)
|
| 86 |
+
else:
|
| 87 |
+
mb = mb.reshape(batch, n_time, self.in_channel, self.bandwidth)
|
| 88 |
+
|
| 89 |
+
mb = torch.permute(mb, (0, 2, 3, 1))
|
| 90 |
+
|
| 91 |
+
return mb
|
| 92 |
+
|
| 93 |
+
def forward(self, qb):
|
| 94 |
+
|
| 95 |
+
qb = self.norm(qb)
|
| 96 |
+
|
| 97 |
+
qb = self.hidden(qb)
|
| 98 |
+
mb = self.output(qb)
|
| 99 |
+
mb = self.reshape_output(mb)
|
| 100 |
+
|
| 101 |
+
return mb
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class MultAddNormMLP(NormMLP):
|
| 105 |
+
def __init__(
|
| 106 |
+
self,
|
| 107 |
+
emb_dim: int,
|
| 108 |
+
mlp_dim: int,
|
| 109 |
+
bandwidth: int,
|
| 110 |
+
in_channel: "int | None",
|
| 111 |
+
hidden_activation: str = "Tanh",
|
| 112 |
+
hidden_activation_kwargs=None,
|
| 113 |
+
complex_mask: bool = True,
|
| 114 |
+
) -> None:
|
| 115 |
+
super().__init__(
|
| 116 |
+
emb_dim,
|
| 117 |
+
mlp_dim,
|
| 118 |
+
bandwidth,
|
| 119 |
+
in_channel,
|
| 120 |
+
hidden_activation,
|
| 121 |
+
hidden_activation_kwargs,
|
| 122 |
+
complex_mask,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
self.output2 = torch.jit.script(
|
| 126 |
+
nn.Sequential(
|
| 127 |
+
nn.Linear(
|
| 128 |
+
in_features=mlp_dim,
|
| 129 |
+
out_features=bandwidth * in_channel * self.reim * 2,
|
| 130 |
+
),
|
| 131 |
+
nn.GLU(dim=-1),
|
| 132 |
+
)
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
def forward(self, qb):
|
| 136 |
+
|
| 137 |
+
qb = self.norm(qb)
|
| 138 |
+
qb = self.hidden(qb)
|
| 139 |
+
mmb = self.output(qb)
|
| 140 |
+
mmb = self.reshape_output(mmb)
|
| 141 |
+
amb = self.output2(qb)
|
| 142 |
+
amb = self.reshape_output(amb)
|
| 143 |
+
|
| 144 |
+
return mmb, amb
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class MaskEstimationModuleSuperBase(nn.Module):
|
| 148 |
+
pass
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class MaskEstimationModuleBase(MaskEstimationModuleSuperBase):
|
| 152 |
+
def __init__(
|
| 153 |
+
self,
|
| 154 |
+
band_specs: List[Tuple[float, float]],
|
| 155 |
+
emb_dim: int,
|
| 156 |
+
mlp_dim: int,
|
| 157 |
+
in_channel: Optional[int],
|
| 158 |
+
hidden_activation: str = "Tanh",
|
| 159 |
+
hidden_activation_kwargs: Dict = None,
|
| 160 |
+
complex_mask: bool = True,
|
| 161 |
+
norm_mlp_cls: Type[nn.Module] = NormMLP,
|
| 162 |
+
norm_mlp_kwargs: Dict = None,
|
| 163 |
+
) -> None:
|
| 164 |
+
super().__init__()
|
| 165 |
+
|
| 166 |
+
self.band_widths = band_widths_from_specs(band_specs)
|
| 167 |
+
self.n_bands = len(band_specs)
|
| 168 |
+
|
| 169 |
+
if hidden_activation_kwargs is None:
|
| 170 |
+
hidden_activation_kwargs = {}
|
| 171 |
+
|
| 172 |
+
if norm_mlp_kwargs is None:
|
| 173 |
+
norm_mlp_kwargs = {}
|
| 174 |
+
|
| 175 |
+
self.norm_mlp = nn.ModuleList(
|
| 176 |
+
[
|
| 177 |
+
(
|
| 178 |
+
norm_mlp_cls(
|
| 179 |
+
bandwidth=self.band_widths[b],
|
| 180 |
+
emb_dim=emb_dim,
|
| 181 |
+
mlp_dim=mlp_dim,
|
| 182 |
+
in_channel=in_channel,
|
| 183 |
+
hidden_activation=hidden_activation,
|
| 184 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 185 |
+
complex_mask=complex_mask,
|
| 186 |
+
**norm_mlp_kwargs,
|
| 187 |
+
)
|
| 188 |
+
)
|
| 189 |
+
for b in range(self.n_bands)
|
| 190 |
+
]
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
def compute_masks(self, q):
|
| 194 |
+
batch, n_bands, n_time, emb_dim = q.shape
|
| 195 |
+
|
| 196 |
+
masks = []
|
| 197 |
+
|
| 198 |
+
for b, nmlp in enumerate(self.norm_mlp):
|
| 199 |
+
qb = q[:, b, :, :]
|
| 200 |
+
mb = nmlp(qb)
|
| 201 |
+
masks.append(mb)
|
| 202 |
+
|
| 203 |
+
return masks
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class OverlappingMaskEstimationModule(MaskEstimationModuleBase):
|
| 207 |
+
def __init__(
|
| 208 |
+
self,
|
| 209 |
+
in_channel: int,
|
| 210 |
+
band_specs: List[Tuple[float, float]],
|
| 211 |
+
freq_weights: List[torch.Tensor],
|
| 212 |
+
n_freq: int,
|
| 213 |
+
emb_dim: int,
|
| 214 |
+
mlp_dim: int,
|
| 215 |
+
cond_dim: int = 0,
|
| 216 |
+
hidden_activation: str = "Tanh",
|
| 217 |
+
hidden_activation_kwargs: Dict = None,
|
| 218 |
+
complex_mask: bool = True,
|
| 219 |
+
norm_mlp_cls: Type[nn.Module] = NormMLP,
|
| 220 |
+
norm_mlp_kwargs: Dict = None,
|
| 221 |
+
use_freq_weights: bool = True,
|
| 222 |
+
) -> None:
|
| 223 |
+
check_nonzero_bandwidth(band_specs)
|
| 224 |
+
check_no_gap(band_specs)
|
| 225 |
+
|
| 226 |
+
super().__init__(
|
| 227 |
+
band_specs=band_specs,
|
| 228 |
+
emb_dim=emb_dim + cond_dim,
|
| 229 |
+
mlp_dim=mlp_dim,
|
| 230 |
+
in_channel=in_channel,
|
| 231 |
+
hidden_activation=hidden_activation,
|
| 232 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 233 |
+
complex_mask=complex_mask,
|
| 234 |
+
norm_mlp_cls=norm_mlp_cls,
|
| 235 |
+
norm_mlp_kwargs=norm_mlp_kwargs,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
self.n_freq = n_freq
|
| 239 |
+
self.band_specs = band_specs
|
| 240 |
+
self.in_channel = in_channel
|
| 241 |
+
|
| 242 |
+
if freq_weights is not None:
|
| 243 |
+
for i, fw in enumerate(freq_weights):
|
| 244 |
+
self.register_buffer(f"freq_weights/{i}", fw)
|
| 245 |
+
|
| 246 |
+
self.use_freq_weights = use_freq_weights
|
| 247 |
+
else:
|
| 248 |
+
self.use_freq_weights = False
|
| 249 |
+
|
| 250 |
+
self.cond_dim = cond_dim
|
| 251 |
+
|
| 252 |
+
def forward(self, q, cond=None):
|
| 253 |
+
|
| 254 |
+
batch, n_bands, n_time, emb_dim = q.shape
|
| 255 |
+
|
| 256 |
+
if cond is not None:
|
| 257 |
+
print(cond)
|
| 258 |
+
if cond.ndim == 2:
|
| 259 |
+
cond = cond[:, None, None, :].expand(-1, n_bands, n_time, -1)
|
| 260 |
+
elif cond.ndim == 3:
|
| 261 |
+
assert cond.shape[1] == n_time
|
| 262 |
+
else:
|
| 263 |
+
raise ValueError(f"Invalid cond shape: {cond.shape}")
|
| 264 |
+
|
| 265 |
+
q = torch.cat([q, cond], dim=-1)
|
| 266 |
+
elif self.cond_dim > 0:
|
| 267 |
+
cond = torch.ones(
|
| 268 |
+
(batch, n_bands, n_time, self.cond_dim),
|
| 269 |
+
device=q.device,
|
| 270 |
+
dtype=q.dtype,
|
| 271 |
+
)
|
| 272 |
+
q = torch.cat([q, cond], dim=-1)
|
| 273 |
+
else:
|
| 274 |
+
pass
|
| 275 |
+
|
| 276 |
+
mask_list = self.compute_masks(q)
|
| 277 |
+
|
| 278 |
+
masks = torch.zeros(
|
| 279 |
+
(batch, self.in_channel, self.n_freq, n_time),
|
| 280 |
+
device=q.device,
|
| 281 |
+
dtype=mask_list[0].dtype,
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
for im, mask in enumerate(mask_list):
|
| 285 |
+
fstart, fend = self.band_specs[im]
|
| 286 |
+
if self.use_freq_weights:
|
| 287 |
+
fw = self.get_buffer(f"freq_weights/{im}")[:, None]
|
| 288 |
+
mask = mask * fw
|
| 289 |
+
masks[:, :, fstart:fend, :] += mask
|
| 290 |
+
|
| 291 |
+
return masks
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class MaskEstimationModule(OverlappingMaskEstimationModule):
|
| 295 |
+
def __init__(
|
| 296 |
+
self,
|
| 297 |
+
band_specs: List[Tuple[float, float]],
|
| 298 |
+
emb_dim: int,
|
| 299 |
+
mlp_dim: int,
|
| 300 |
+
in_channel: Optional[int],
|
| 301 |
+
hidden_activation: str = "Tanh",
|
| 302 |
+
hidden_activation_kwargs: Dict = None,
|
| 303 |
+
complex_mask: bool = True,
|
| 304 |
+
**kwargs,
|
| 305 |
+
) -> None:
|
| 306 |
+
check_nonzero_bandwidth(band_specs)
|
| 307 |
+
check_no_gap(band_specs)
|
| 308 |
+
check_no_overlap(band_specs)
|
| 309 |
+
super().__init__(
|
| 310 |
+
in_channel=in_channel,
|
| 311 |
+
band_specs=band_specs,
|
| 312 |
+
freq_weights=None,
|
| 313 |
+
n_freq=None,
|
| 314 |
+
emb_dim=emb_dim,
|
| 315 |
+
mlp_dim=mlp_dim,
|
| 316 |
+
hidden_activation=hidden_activation,
|
| 317 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 318 |
+
complex_mask=complex_mask,
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
def forward(self, q, cond=None):
|
| 322 |
+
|
| 323 |
+
masks = self.compute_masks(q)
|
| 324 |
+
|
| 325 |
+
masks = torch.concat(masks, dim=2)
|
| 326 |
+
|
| 327 |
+
return masks
|
mvsepless/models/bandit/core/model/bsrnn/tfmodel.py
CHANGED
|
@@ -1,287 +1,287 @@
|
|
| 1 |
-
import warnings
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
from torch import nn
|
| 5 |
-
from torch.nn import functional as F
|
| 6 |
-
from torch.nn.modules import rnn
|
| 7 |
-
|
| 8 |
-
import torch.backends.cuda
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class TimeFrequencyModellingModule(nn.Module):
|
| 12 |
-
def __init__(self) -> None:
|
| 13 |
-
super().__init__()
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
class ResidualRNN(nn.Module):
|
| 17 |
-
def __init__(
|
| 18 |
-
self,
|
| 19 |
-
emb_dim: int,
|
| 20 |
-
rnn_dim: int,
|
| 21 |
-
bidirectional: bool = True,
|
| 22 |
-
rnn_type: str = "LSTM",
|
| 23 |
-
use_batch_trick: bool = True,
|
| 24 |
-
use_layer_norm: bool = True,
|
| 25 |
-
) -> None:
|
| 26 |
-
super().__init__()
|
| 27 |
-
|
| 28 |
-
self.use_layer_norm = use_layer_norm
|
| 29 |
-
if use_layer_norm:
|
| 30 |
-
self.norm = nn.LayerNorm(emb_dim)
|
| 31 |
-
else:
|
| 32 |
-
self.norm = nn.GroupNorm(num_groups=emb_dim, num_channels=emb_dim)
|
| 33 |
-
|
| 34 |
-
self.rnn = rnn.__dict__[rnn_type](
|
| 35 |
-
input_size=emb_dim,
|
| 36 |
-
hidden_size=rnn_dim,
|
| 37 |
-
num_layers=1,
|
| 38 |
-
batch_first=True,
|
| 39 |
-
bidirectional=bidirectional,
|
| 40 |
-
)
|
| 41 |
-
|
| 42 |
-
self.fc = nn.Linear(
|
| 43 |
-
in_features=rnn_dim * (2 if bidirectional else 1), out_features=emb_dim
|
| 44 |
-
)
|
| 45 |
-
|
| 46 |
-
self.use_batch_trick = use_batch_trick
|
| 47 |
-
if not self.use_batch_trick:
|
| 48 |
-
warnings.warn("NOT USING BATCH TRICK IS EXTREMELY SLOW!!")
|
| 49 |
-
|
| 50 |
-
def forward(self, z):
|
| 51 |
-
|
| 52 |
-
z0 = torch.clone(z)
|
| 53 |
-
|
| 54 |
-
if self.use_layer_norm:
|
| 55 |
-
z = self.norm(z)
|
| 56 |
-
else:
|
| 57 |
-
z = torch.permute(z, (0, 3, 1, 2))
|
| 58 |
-
|
| 59 |
-
z = self.norm(z)
|
| 60 |
-
|
| 61 |
-
z = torch.permute(z, (0, 2, 3, 1))
|
| 62 |
-
|
| 63 |
-
batch, n_uncrossed, n_across, emb_dim = z.shape
|
| 64 |
-
|
| 65 |
-
if self.use_batch_trick:
|
| 66 |
-
z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
|
| 67 |
-
|
| 68 |
-
z = self.rnn(z.contiguous())[0]
|
| 69 |
-
|
| 70 |
-
z = torch.reshape(z, (batch, n_uncrossed, n_across, -1))
|
| 71 |
-
else:
|
| 72 |
-
zlist = []
|
| 73 |
-
for i in range(n_uncrossed):
|
| 74 |
-
zi = self.rnn(z[:, i, :, :])[0]
|
| 75 |
-
zlist.append(zi)
|
| 76 |
-
|
| 77 |
-
z = torch.stack(zlist, dim=1)
|
| 78 |
-
|
| 79 |
-
z = self.fc(z)
|
| 80 |
-
|
| 81 |
-
z = z + z0
|
| 82 |
-
|
| 83 |
-
return z
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
class SeqBandModellingModule(TimeFrequencyModellingModule):
|
| 87 |
-
def __init__(
|
| 88 |
-
self,
|
| 89 |
-
n_modules: int = 12,
|
| 90 |
-
emb_dim: int = 128,
|
| 91 |
-
rnn_dim: int = 256,
|
| 92 |
-
bidirectional: bool = True,
|
| 93 |
-
rnn_type: str = "LSTM",
|
| 94 |
-
parallel_mode=False,
|
| 95 |
-
) -> None:
|
| 96 |
-
super().__init__()
|
| 97 |
-
self.seqband = nn.ModuleList([])
|
| 98 |
-
|
| 99 |
-
if parallel_mode:
|
| 100 |
-
for _ in range(n_modules):
|
| 101 |
-
self.seqband.append(
|
| 102 |
-
nn.ModuleList(
|
| 103 |
-
[
|
| 104 |
-
ResidualRNN(
|
| 105 |
-
emb_dim=emb_dim,
|
| 106 |
-
rnn_dim=rnn_dim,
|
| 107 |
-
bidirectional=bidirectional,
|
| 108 |
-
rnn_type=rnn_type,
|
| 109 |
-
),
|
| 110 |
-
ResidualRNN(
|
| 111 |
-
emb_dim=emb_dim,
|
| 112 |
-
rnn_dim=rnn_dim,
|
| 113 |
-
bidirectional=bidirectional,
|
| 114 |
-
rnn_type=rnn_type,
|
| 115 |
-
),
|
| 116 |
-
]
|
| 117 |
-
)
|
| 118 |
-
)
|
| 119 |
-
else:
|
| 120 |
-
|
| 121 |
-
for _ in range(2 * n_modules):
|
| 122 |
-
self.seqband.append(
|
| 123 |
-
ResidualRNN(
|
| 124 |
-
emb_dim=emb_dim,
|
| 125 |
-
rnn_dim=rnn_dim,
|
| 126 |
-
bidirectional=bidirectional,
|
| 127 |
-
rnn_type=rnn_type,
|
| 128 |
-
)
|
| 129 |
-
)
|
| 130 |
-
|
| 131 |
-
self.parallel_mode = parallel_mode
|
| 132 |
-
|
| 133 |
-
def forward(self, z):
|
| 134 |
-
|
| 135 |
-
if self.parallel_mode:
|
| 136 |
-
for sbm_pair in self.seqband:
|
| 137 |
-
sbm_t, sbm_f = sbm_pair[0], sbm_pair[1]
|
| 138 |
-
zt = sbm_t(z)
|
| 139 |
-
zf = sbm_f(z.transpose(1, 2))
|
| 140 |
-
z = zt + zf.transpose(1, 2)
|
| 141 |
-
else:
|
| 142 |
-
for sbm in self.seqband:
|
| 143 |
-
z = sbm(z)
|
| 144 |
-
z = z.transpose(1, 2)
|
| 145 |
-
|
| 146 |
-
q = z
|
| 147 |
-
return q
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
class ResidualTransformer(nn.Module):
|
| 151 |
-
def __init__(
|
| 152 |
-
self,
|
| 153 |
-
emb_dim: int = 128,
|
| 154 |
-
rnn_dim: int = 256,
|
| 155 |
-
bidirectional: bool = True,
|
| 156 |
-
dropout: float = 0.0,
|
| 157 |
-
) -> None:
|
| 158 |
-
super().__init__()
|
| 159 |
-
|
| 160 |
-
self.tf = nn.TransformerEncoderLayer(
|
| 161 |
-
d_model=emb_dim, nhead=4, dim_feedforward=rnn_dim, batch_first=True
|
| 162 |
-
)
|
| 163 |
-
|
| 164 |
-
self.is_causal = not bidirectional
|
| 165 |
-
self.dropout = dropout
|
| 166 |
-
|
| 167 |
-
def forward(self, z):
|
| 168 |
-
batch, n_uncrossed, n_across, emb_dim = z.shape
|
| 169 |
-
z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
|
| 170 |
-
z = self.tf(z, is_causal=self.is_causal)
|
| 171 |
-
z = torch.reshape(z, (batch, n_uncrossed, n_across, emb_dim))
|
| 172 |
-
|
| 173 |
-
return z
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
class TransformerTimeFreqModule(TimeFrequencyModellingModule):
|
| 177 |
-
def __init__(
|
| 178 |
-
self,
|
| 179 |
-
n_modules: int = 12,
|
| 180 |
-
emb_dim: int = 128,
|
| 181 |
-
rnn_dim: int = 256,
|
| 182 |
-
bidirectional: bool = True,
|
| 183 |
-
dropout: float = 0.0,
|
| 184 |
-
) -> None:
|
| 185 |
-
super().__init__()
|
| 186 |
-
self.norm = nn.LayerNorm(emb_dim)
|
| 187 |
-
self.seqband = nn.ModuleList([])
|
| 188 |
-
|
| 189 |
-
for _ in range(2 * n_modules):
|
| 190 |
-
self.seqband.append(
|
| 191 |
-
ResidualTransformer(
|
| 192 |
-
emb_dim=emb_dim,
|
| 193 |
-
rnn_dim=rnn_dim,
|
| 194 |
-
bidirectional=bidirectional,
|
| 195 |
-
dropout=dropout,
|
| 196 |
-
)
|
| 197 |
-
)
|
| 198 |
-
|
| 199 |
-
def forward(self, z):
|
| 200 |
-
z = self.norm(z)
|
| 201 |
-
|
| 202 |
-
for sbm in self.seqband:
|
| 203 |
-
z = sbm(z)
|
| 204 |
-
z = z.transpose(1, 2)
|
| 205 |
-
|
| 206 |
-
q = z
|
| 207 |
-
return q
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
class ResidualConvolution(nn.Module):
|
| 211 |
-
def __init__(
|
| 212 |
-
self,
|
| 213 |
-
emb_dim: int = 128,
|
| 214 |
-
rnn_dim: int = 256,
|
| 215 |
-
bidirectional: bool = True,
|
| 216 |
-
dropout: float = 0.0,
|
| 217 |
-
) -> None:
|
| 218 |
-
super().__init__()
|
| 219 |
-
self.norm = nn.InstanceNorm2d(emb_dim, affine=True)
|
| 220 |
-
|
| 221 |
-
self.conv = nn.Sequential(
|
| 222 |
-
nn.Conv2d(
|
| 223 |
-
in_channels=emb_dim,
|
| 224 |
-
out_channels=rnn_dim,
|
| 225 |
-
kernel_size=(3, 3),
|
| 226 |
-
padding="same",
|
| 227 |
-
stride=(1, 1),
|
| 228 |
-
),
|
| 229 |
-
nn.Tanhshrink(),
|
| 230 |
-
)
|
| 231 |
-
|
| 232 |
-
self.is_causal = not bidirectional
|
| 233 |
-
self.dropout = dropout
|
| 234 |
-
|
| 235 |
-
self.fc = nn.Conv2d(
|
| 236 |
-
in_channels=rnn_dim,
|
| 237 |
-
out_channels=emb_dim,
|
| 238 |
-
kernel_size=(1, 1),
|
| 239 |
-
padding="same",
|
| 240 |
-
stride=(1, 1),
|
| 241 |
-
)
|
| 242 |
-
|
| 243 |
-
def forward(self, z):
|
| 244 |
-
|
| 245 |
-
z0 = torch.clone(z)
|
| 246 |
-
|
| 247 |
-
z = self.norm(z)
|
| 248 |
-
z = self.conv(z)
|
| 249 |
-
z = self.fc(z)
|
| 250 |
-
z = z + z0
|
| 251 |
-
|
| 252 |
-
return z
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
class ConvolutionalTimeFreqModule(TimeFrequencyModellingModule):
|
| 256 |
-
def __init__(
|
| 257 |
-
self,
|
| 258 |
-
n_modules: int = 12,
|
| 259 |
-
emb_dim: int = 128,
|
| 260 |
-
rnn_dim: int = 256,
|
| 261 |
-
bidirectional: bool = True,
|
| 262 |
-
dropout: float = 0.0,
|
| 263 |
-
) -> None:
|
| 264 |
-
super().__init__()
|
| 265 |
-
self.seqband = torch.jit.script(
|
| 266 |
-
nn.Sequential(
|
| 267 |
-
*[
|
| 268 |
-
ResidualConvolution(
|
| 269 |
-
emb_dim=emb_dim,
|
| 270 |
-
rnn_dim=rnn_dim,
|
| 271 |
-
bidirectional=bidirectional,
|
| 272 |
-
dropout=dropout,
|
| 273 |
-
)
|
| 274 |
-
for _ in range(2 * n_modules)
|
| 275 |
-
]
|
| 276 |
-
)
|
| 277 |
-
)
|
| 278 |
-
|
| 279 |
-
def forward(self, z):
|
| 280 |
-
|
| 281 |
-
z = torch.permute(z, (0, 3, 1, 2))
|
| 282 |
-
|
| 283 |
-
z = self.seqband(z)
|
| 284 |
-
|
| 285 |
-
z = torch.permute(z, (0, 2, 3, 1))
|
| 286 |
-
|
| 287 |
-
return z
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
from torch.nn.modules import rnn
|
| 7 |
+
|
| 8 |
+
import torch.backends.cuda
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TimeFrequencyModellingModule(nn.Module):
|
| 12 |
+
def __init__(self) -> None:
|
| 13 |
+
super().__init__()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ResidualRNN(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
emb_dim: int,
|
| 20 |
+
rnn_dim: int,
|
| 21 |
+
bidirectional: bool = True,
|
| 22 |
+
rnn_type: str = "LSTM",
|
| 23 |
+
use_batch_trick: bool = True,
|
| 24 |
+
use_layer_norm: bool = True,
|
| 25 |
+
) -> None:
|
| 26 |
+
super().__init__()
|
| 27 |
+
|
| 28 |
+
self.use_layer_norm = use_layer_norm
|
| 29 |
+
if use_layer_norm:
|
| 30 |
+
self.norm = nn.LayerNorm(emb_dim)
|
| 31 |
+
else:
|
| 32 |
+
self.norm = nn.GroupNorm(num_groups=emb_dim, num_channels=emb_dim)
|
| 33 |
+
|
| 34 |
+
self.rnn = rnn.__dict__[rnn_type](
|
| 35 |
+
input_size=emb_dim,
|
| 36 |
+
hidden_size=rnn_dim,
|
| 37 |
+
num_layers=1,
|
| 38 |
+
batch_first=True,
|
| 39 |
+
bidirectional=bidirectional,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
self.fc = nn.Linear(
|
| 43 |
+
in_features=rnn_dim * (2 if bidirectional else 1), out_features=emb_dim
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
self.use_batch_trick = use_batch_trick
|
| 47 |
+
if not self.use_batch_trick:
|
| 48 |
+
warnings.warn("NOT USING BATCH TRICK IS EXTREMELY SLOW!!")
|
| 49 |
+
|
| 50 |
+
def forward(self, z):
|
| 51 |
+
|
| 52 |
+
z0 = torch.clone(z)
|
| 53 |
+
|
| 54 |
+
if self.use_layer_norm:
|
| 55 |
+
z = self.norm(z)
|
| 56 |
+
else:
|
| 57 |
+
z = torch.permute(z, (0, 3, 1, 2))
|
| 58 |
+
|
| 59 |
+
z = self.norm(z)
|
| 60 |
+
|
| 61 |
+
z = torch.permute(z, (0, 2, 3, 1))
|
| 62 |
+
|
| 63 |
+
batch, n_uncrossed, n_across, emb_dim = z.shape
|
| 64 |
+
|
| 65 |
+
if self.use_batch_trick:
|
| 66 |
+
z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
|
| 67 |
+
|
| 68 |
+
z = self.rnn(z.contiguous())[0]
|
| 69 |
+
|
| 70 |
+
z = torch.reshape(z, (batch, n_uncrossed, n_across, -1))
|
| 71 |
+
else:
|
| 72 |
+
zlist = []
|
| 73 |
+
for i in range(n_uncrossed):
|
| 74 |
+
zi = self.rnn(z[:, i, :, :])[0]
|
| 75 |
+
zlist.append(zi)
|
| 76 |
+
|
| 77 |
+
z = torch.stack(zlist, dim=1)
|
| 78 |
+
|
| 79 |
+
z = self.fc(z)
|
| 80 |
+
|
| 81 |
+
z = z + z0
|
| 82 |
+
|
| 83 |
+
return z
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class SeqBandModellingModule(TimeFrequencyModellingModule):
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
n_modules: int = 12,
|
| 90 |
+
emb_dim: int = 128,
|
| 91 |
+
rnn_dim: int = 256,
|
| 92 |
+
bidirectional: bool = True,
|
| 93 |
+
rnn_type: str = "LSTM",
|
| 94 |
+
parallel_mode=False,
|
| 95 |
+
) -> None:
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.seqband = nn.ModuleList([])
|
| 98 |
+
|
| 99 |
+
if parallel_mode:
|
| 100 |
+
for _ in range(n_modules):
|
| 101 |
+
self.seqband.append(
|
| 102 |
+
nn.ModuleList(
|
| 103 |
+
[
|
| 104 |
+
ResidualRNN(
|
| 105 |
+
emb_dim=emb_dim,
|
| 106 |
+
rnn_dim=rnn_dim,
|
| 107 |
+
bidirectional=bidirectional,
|
| 108 |
+
rnn_type=rnn_type,
|
| 109 |
+
),
|
| 110 |
+
ResidualRNN(
|
| 111 |
+
emb_dim=emb_dim,
|
| 112 |
+
rnn_dim=rnn_dim,
|
| 113 |
+
bidirectional=bidirectional,
|
| 114 |
+
rnn_type=rnn_type,
|
| 115 |
+
),
|
| 116 |
+
]
|
| 117 |
+
)
|
| 118 |
+
)
|
| 119 |
+
else:
|
| 120 |
+
|
| 121 |
+
for _ in range(2 * n_modules):
|
| 122 |
+
self.seqband.append(
|
| 123 |
+
ResidualRNN(
|
| 124 |
+
emb_dim=emb_dim,
|
| 125 |
+
rnn_dim=rnn_dim,
|
| 126 |
+
bidirectional=bidirectional,
|
| 127 |
+
rnn_type=rnn_type,
|
| 128 |
+
)
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
self.parallel_mode = parallel_mode
|
| 132 |
+
|
| 133 |
+
def forward(self, z):
|
| 134 |
+
|
| 135 |
+
if self.parallel_mode:
|
| 136 |
+
for sbm_pair in self.seqband:
|
| 137 |
+
sbm_t, sbm_f = sbm_pair[0], sbm_pair[1]
|
| 138 |
+
zt = sbm_t(z)
|
| 139 |
+
zf = sbm_f(z.transpose(1, 2))
|
| 140 |
+
z = zt + zf.transpose(1, 2)
|
| 141 |
+
else:
|
| 142 |
+
for sbm in self.seqband:
|
| 143 |
+
z = sbm(z)
|
| 144 |
+
z = z.transpose(1, 2)
|
| 145 |
+
|
| 146 |
+
q = z
|
| 147 |
+
return q
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class ResidualTransformer(nn.Module):
|
| 151 |
+
def __init__(
|
| 152 |
+
self,
|
| 153 |
+
emb_dim: int = 128,
|
| 154 |
+
rnn_dim: int = 256,
|
| 155 |
+
bidirectional: bool = True,
|
| 156 |
+
dropout: float = 0.0,
|
| 157 |
+
) -> None:
|
| 158 |
+
super().__init__()
|
| 159 |
+
|
| 160 |
+
self.tf = nn.TransformerEncoderLayer(
|
| 161 |
+
d_model=emb_dim, nhead=4, dim_feedforward=rnn_dim, batch_first=True
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
self.is_causal = not bidirectional
|
| 165 |
+
self.dropout = dropout
|
| 166 |
+
|
| 167 |
+
def forward(self, z):
|
| 168 |
+
batch, n_uncrossed, n_across, emb_dim = z.shape
|
| 169 |
+
z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
|
| 170 |
+
z = self.tf(z, is_causal=self.is_causal)
|
| 171 |
+
z = torch.reshape(z, (batch, n_uncrossed, n_across, emb_dim))
|
| 172 |
+
|
| 173 |
+
return z
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class TransformerTimeFreqModule(TimeFrequencyModellingModule):
|
| 177 |
+
def __init__(
|
| 178 |
+
self,
|
| 179 |
+
n_modules: int = 12,
|
| 180 |
+
emb_dim: int = 128,
|
| 181 |
+
rnn_dim: int = 256,
|
| 182 |
+
bidirectional: bool = True,
|
| 183 |
+
dropout: float = 0.0,
|
| 184 |
+
) -> None:
|
| 185 |
+
super().__init__()
|
| 186 |
+
self.norm = nn.LayerNorm(emb_dim)
|
| 187 |
+
self.seqband = nn.ModuleList([])
|
| 188 |
+
|
| 189 |
+
for _ in range(2 * n_modules):
|
| 190 |
+
self.seqband.append(
|
| 191 |
+
ResidualTransformer(
|
| 192 |
+
emb_dim=emb_dim,
|
| 193 |
+
rnn_dim=rnn_dim,
|
| 194 |
+
bidirectional=bidirectional,
|
| 195 |
+
dropout=dropout,
|
| 196 |
+
)
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
def forward(self, z):
|
| 200 |
+
z = self.norm(z)
|
| 201 |
+
|
| 202 |
+
for sbm in self.seqband:
|
| 203 |
+
z = sbm(z)
|
| 204 |
+
z = z.transpose(1, 2)
|
| 205 |
+
|
| 206 |
+
q = z
|
| 207 |
+
return q
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class ResidualConvolution(nn.Module):
|
| 211 |
+
def __init__(
|
| 212 |
+
self,
|
| 213 |
+
emb_dim: int = 128,
|
| 214 |
+
rnn_dim: int = 256,
|
| 215 |
+
bidirectional: bool = True,
|
| 216 |
+
dropout: float = 0.0,
|
| 217 |
+
) -> None:
|
| 218 |
+
super().__init__()
|
| 219 |
+
self.norm = nn.InstanceNorm2d(emb_dim, affine=True)
|
| 220 |
+
|
| 221 |
+
self.conv = nn.Sequential(
|
| 222 |
+
nn.Conv2d(
|
| 223 |
+
in_channels=emb_dim,
|
| 224 |
+
out_channels=rnn_dim,
|
| 225 |
+
kernel_size=(3, 3),
|
| 226 |
+
padding="same",
|
| 227 |
+
stride=(1, 1),
|
| 228 |
+
),
|
| 229 |
+
nn.Tanhshrink(),
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
self.is_causal = not bidirectional
|
| 233 |
+
self.dropout = dropout
|
| 234 |
+
|
| 235 |
+
self.fc = nn.Conv2d(
|
| 236 |
+
in_channels=rnn_dim,
|
| 237 |
+
out_channels=emb_dim,
|
| 238 |
+
kernel_size=(1, 1),
|
| 239 |
+
padding="same",
|
| 240 |
+
stride=(1, 1),
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
def forward(self, z):
|
| 244 |
+
|
| 245 |
+
z0 = torch.clone(z)
|
| 246 |
+
|
| 247 |
+
z = self.norm(z)
|
| 248 |
+
z = self.conv(z)
|
| 249 |
+
z = self.fc(z)
|
| 250 |
+
z = z + z0
|
| 251 |
+
|
| 252 |
+
return z
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
class ConvolutionalTimeFreqModule(TimeFrequencyModellingModule):
|
| 256 |
+
def __init__(
|
| 257 |
+
self,
|
| 258 |
+
n_modules: int = 12,
|
| 259 |
+
emb_dim: int = 128,
|
| 260 |
+
rnn_dim: int = 256,
|
| 261 |
+
bidirectional: bool = True,
|
| 262 |
+
dropout: float = 0.0,
|
| 263 |
+
) -> None:
|
| 264 |
+
super().__init__()
|
| 265 |
+
self.seqband = torch.jit.script(
|
| 266 |
+
nn.Sequential(
|
| 267 |
+
*[
|
| 268 |
+
ResidualConvolution(
|
| 269 |
+
emb_dim=emb_dim,
|
| 270 |
+
rnn_dim=rnn_dim,
|
| 271 |
+
bidirectional=bidirectional,
|
| 272 |
+
dropout=dropout,
|
| 273 |
+
)
|
| 274 |
+
for _ in range(2 * n_modules)
|
| 275 |
+
]
|
| 276 |
+
)
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
def forward(self, z):
|
| 280 |
+
|
| 281 |
+
z = torch.permute(z, (0, 3, 1, 2))
|
| 282 |
+
|
| 283 |
+
z = self.seqband(z)
|
| 284 |
+
|
| 285 |
+
z = torch.permute(z, (0, 2, 3, 1))
|
| 286 |
+
|
| 287 |
+
return z
|
mvsepless/models/bandit/core/model/bsrnn/utils.py
CHANGED
|
@@ -1,518 +1,518 @@
|
|
| 1 |
-
import os
|
| 2 |
-
from abc import abstractmethod
|
| 3 |
-
from typing import Any, Callable
|
| 4 |
-
|
| 5 |
-
import numpy as np
|
| 6 |
-
import torch
|
| 7 |
-
from librosa import hz_to_midi, midi_to_hz
|
| 8 |
-
from torch import Tensor
|
| 9 |
-
from torchaudio import functional as taF
|
| 10 |
-
from spafe.fbanks import bark_fbanks
|
| 11 |
-
from spafe.utils.converters import erb2hz, hz2bark, hz2erb
|
| 12 |
-
from torchaudio.functional.functional import _create_triangular_filterbank
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def band_widths_from_specs(band_specs):
|
| 16 |
-
return [e - i for i, e in band_specs]
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
def check_nonzero_bandwidth(band_specs):
|
| 20 |
-
for fstart, fend in band_specs:
|
| 21 |
-
if fend - fstart <= 0:
|
| 22 |
-
raise ValueError("Bands cannot be zero-width")
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def check_no_overlap(band_specs):
|
| 26 |
-
fend_prev = -1
|
| 27 |
-
for fstart_curr, fend_curr in band_specs:
|
| 28 |
-
if fstart_curr <= fend_prev:
|
| 29 |
-
raise ValueError("Bands cannot overlap")
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
def check_no_gap(band_specs):
|
| 33 |
-
fstart, _ = band_specs[0]
|
| 34 |
-
assert fstart == 0
|
| 35 |
-
|
| 36 |
-
fend_prev = -1
|
| 37 |
-
for fstart_curr, fend_curr in band_specs:
|
| 38 |
-
if fstart_curr - fend_prev > 1:
|
| 39 |
-
raise ValueError("Bands cannot leave gap")
|
| 40 |
-
fend_prev = fend_curr
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
class BandsplitSpecification:
|
| 44 |
-
def __init__(self, nfft: int, fs: int) -> None:
|
| 45 |
-
self.fs = fs
|
| 46 |
-
self.nfft = nfft
|
| 47 |
-
self.nyquist = fs / 2
|
| 48 |
-
self.max_index = nfft // 2 + 1
|
| 49 |
-
|
| 50 |
-
self.split500 = self.hertz_to_index(500)
|
| 51 |
-
self.split1k = self.hertz_to_index(1000)
|
| 52 |
-
self.split2k = self.hertz_to_index(2000)
|
| 53 |
-
self.split4k = self.hertz_to_index(4000)
|
| 54 |
-
self.split8k = self.hertz_to_index(8000)
|
| 55 |
-
self.split16k = self.hertz_to_index(16000)
|
| 56 |
-
self.split20k = self.hertz_to_index(20000)
|
| 57 |
-
|
| 58 |
-
self.above20k = [(self.split20k, self.max_index)]
|
| 59 |
-
self.above16k = [(self.split16k, self.split20k)] + self.above20k
|
| 60 |
-
|
| 61 |
-
def index_to_hertz(self, index: int):
|
| 62 |
-
return index * self.fs / self.nfft
|
| 63 |
-
|
| 64 |
-
def hertz_to_index(self, hz: float, round: bool = True):
|
| 65 |
-
index = hz * self.nfft / self.fs
|
| 66 |
-
|
| 67 |
-
if round:
|
| 68 |
-
index = int(np.round(index))
|
| 69 |
-
|
| 70 |
-
return index
|
| 71 |
-
|
| 72 |
-
def get_band_specs_with_bandwidth(self, start_index, end_index, bandwidth_hz):
|
| 73 |
-
band_specs = []
|
| 74 |
-
lower = start_index
|
| 75 |
-
|
| 76 |
-
while lower < end_index:
|
| 77 |
-
upper = int(np.floor(lower + self.hertz_to_index(bandwidth_hz)))
|
| 78 |
-
upper = min(upper, end_index)
|
| 79 |
-
|
| 80 |
-
band_specs.append((lower, upper))
|
| 81 |
-
lower = upper
|
| 82 |
-
|
| 83 |
-
return band_specs
|
| 84 |
-
|
| 85 |
-
@abstractmethod
|
| 86 |
-
def get_band_specs(self):
|
| 87 |
-
raise NotImplementedError
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
class VocalBandsplitSpecification(BandsplitSpecification):
|
| 91 |
-
def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
|
| 92 |
-
super().__init__(nfft=nfft, fs=fs)
|
| 93 |
-
|
| 94 |
-
self.version = version
|
| 95 |
-
|
| 96 |
-
def get_band_specs(self):
|
| 97 |
-
return getattr(self, f"version{self.version}")()
|
| 98 |
-
|
| 99 |
-
@property
|
| 100 |
-
def version1(self):
|
| 101 |
-
return self.get_band_specs_with_bandwidth(
|
| 102 |
-
start_index=0, end_index=self.max_index, bandwidth_hz=1000
|
| 103 |
-
)
|
| 104 |
-
|
| 105 |
-
def version2(self):
|
| 106 |
-
below16k = self.get_band_specs_with_bandwidth(
|
| 107 |
-
start_index=0, end_index=self.split16k, bandwidth_hz=1000
|
| 108 |
-
)
|
| 109 |
-
below20k = self.get_band_specs_with_bandwidth(
|
| 110 |
-
start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
|
| 111 |
-
)
|
| 112 |
-
|
| 113 |
-
return below16k + below20k + self.above20k
|
| 114 |
-
|
| 115 |
-
def version3(self):
|
| 116 |
-
below8k = self.get_band_specs_with_bandwidth(
|
| 117 |
-
start_index=0, end_index=self.split8k, bandwidth_hz=1000
|
| 118 |
-
)
|
| 119 |
-
below16k = self.get_band_specs_with_bandwidth(
|
| 120 |
-
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
|
| 121 |
-
)
|
| 122 |
-
|
| 123 |
-
return below8k + below16k + self.above16k
|
| 124 |
-
|
| 125 |
-
def version4(self):
|
| 126 |
-
below1k = self.get_band_specs_with_bandwidth(
|
| 127 |
-
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 128 |
-
)
|
| 129 |
-
below8k = self.get_band_specs_with_bandwidth(
|
| 130 |
-
start_index=self.split1k, end_index=self.split8k, bandwidth_hz=1000
|
| 131 |
-
)
|
| 132 |
-
below16k = self.get_band_specs_with_bandwidth(
|
| 133 |
-
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
|
| 134 |
-
)
|
| 135 |
-
|
| 136 |
-
return below1k + below8k + below16k + self.above16k
|
| 137 |
-
|
| 138 |
-
def version5(self):
|
| 139 |
-
below1k = self.get_band_specs_with_bandwidth(
|
| 140 |
-
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 141 |
-
)
|
| 142 |
-
below16k = self.get_band_specs_with_bandwidth(
|
| 143 |
-
start_index=self.split1k, end_index=self.split16k, bandwidth_hz=1000
|
| 144 |
-
)
|
| 145 |
-
below20k = self.get_band_specs_with_bandwidth(
|
| 146 |
-
start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
|
| 147 |
-
)
|
| 148 |
-
return below1k + below16k + below20k + self.above20k
|
| 149 |
-
|
| 150 |
-
def version6(self):
|
| 151 |
-
below1k = self.get_band_specs_with_bandwidth(
|
| 152 |
-
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 153 |
-
)
|
| 154 |
-
below4k = self.get_band_specs_with_bandwidth(
|
| 155 |
-
start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
|
| 156 |
-
)
|
| 157 |
-
below8k = self.get_band_specs_with_bandwidth(
|
| 158 |
-
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
|
| 159 |
-
)
|
| 160 |
-
below16k = self.get_band_specs_with_bandwidth(
|
| 161 |
-
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
|
| 162 |
-
)
|
| 163 |
-
return below1k + below4k + below8k + below16k + self.above16k
|
| 164 |
-
|
| 165 |
-
def version7(self):
|
| 166 |
-
below1k = self.get_band_specs_with_bandwidth(
|
| 167 |
-
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 168 |
-
)
|
| 169 |
-
below4k = self.get_band_specs_with_bandwidth(
|
| 170 |
-
start_index=self.split1k, end_index=self.split4k, bandwidth_hz=250
|
| 171 |
-
)
|
| 172 |
-
below8k = self.get_band_specs_with_bandwidth(
|
| 173 |
-
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
|
| 174 |
-
)
|
| 175 |
-
below16k = self.get_band_specs_with_bandwidth(
|
| 176 |
-
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
|
| 177 |
-
)
|
| 178 |
-
below20k = self.get_band_specs_with_bandwidth(
|
| 179 |
-
start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
|
| 180 |
-
)
|
| 181 |
-
return below1k + below4k + below8k + below16k + below20k + self.above20k
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
class OtherBandsplitSpecification(VocalBandsplitSpecification):
|
| 185 |
-
def __init__(self, nfft: int, fs: int) -> None:
|
| 186 |
-
super().__init__(nfft=nfft, fs=fs, version="7")
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
class BassBandsplitSpecification(BandsplitSpecification):
|
| 190 |
-
def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
|
| 191 |
-
super().__init__(nfft=nfft, fs=fs)
|
| 192 |
-
|
| 193 |
-
def get_band_specs(self):
|
| 194 |
-
below500 = self.get_band_specs_with_bandwidth(
|
| 195 |
-
start_index=0, end_index=self.split500, bandwidth_hz=50
|
| 196 |
-
)
|
| 197 |
-
below1k = self.get_band_specs_with_bandwidth(
|
| 198 |
-
start_index=self.split500, end_index=self.split1k, bandwidth_hz=100
|
| 199 |
-
)
|
| 200 |
-
below4k = self.get_band_specs_with_bandwidth(
|
| 201 |
-
start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
|
| 202 |
-
)
|
| 203 |
-
below8k = self.get_band_specs_with_bandwidth(
|
| 204 |
-
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
|
| 205 |
-
)
|
| 206 |
-
below16k = self.get_band_specs_with_bandwidth(
|
| 207 |
-
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
|
| 208 |
-
)
|
| 209 |
-
above16k = [(self.split16k, self.max_index)]
|
| 210 |
-
|
| 211 |
-
return below500 + below1k + below4k + below8k + below16k + above16k
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
class DrumBandsplitSpecification(BandsplitSpecification):
|
| 215 |
-
def __init__(self, nfft: int, fs: int) -> None:
|
| 216 |
-
super().__init__(nfft=nfft, fs=fs)
|
| 217 |
-
|
| 218 |
-
def get_band_specs(self):
|
| 219 |
-
below1k = self.get_band_specs_with_bandwidth(
|
| 220 |
-
start_index=0, end_index=self.split1k, bandwidth_hz=50
|
| 221 |
-
)
|
| 222 |
-
below2k = self.get_band_specs_with_bandwidth(
|
| 223 |
-
start_index=self.split1k, end_index=self.split2k, bandwidth_hz=100
|
| 224 |
-
)
|
| 225 |
-
below4k = self.get_band_specs_with_bandwidth(
|
| 226 |
-
start_index=self.split2k, end_index=self.split4k, bandwidth_hz=250
|
| 227 |
-
)
|
| 228 |
-
below8k = self.get_band_specs_with_bandwidth(
|
| 229 |
-
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
|
| 230 |
-
)
|
| 231 |
-
below16k = self.get_band_specs_with_bandwidth(
|
| 232 |
-
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
|
| 233 |
-
)
|
| 234 |
-
above16k = [(self.split16k, self.max_index)]
|
| 235 |
-
|
| 236 |
-
return below1k + below2k + below4k + below8k + below16k + above16k
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
class PerceptualBandsplitSpecification(BandsplitSpecification):
|
| 240 |
-
def __init__(
|
| 241 |
-
self,
|
| 242 |
-
nfft: int,
|
| 243 |
-
fs: int,
|
| 244 |
-
fbank_fn: Callable[[int, int, float, float, int], torch.Tensor],
|
| 245 |
-
n_bands: int,
|
| 246 |
-
f_min: float = 0.0,
|
| 247 |
-
f_max: float = None,
|
| 248 |
-
) -> None:
|
| 249 |
-
super().__init__(nfft=nfft, fs=fs)
|
| 250 |
-
self.n_bands = n_bands
|
| 251 |
-
if f_max is None:
|
| 252 |
-
f_max = fs / 2
|
| 253 |
-
|
| 254 |
-
self.filterbank = fbank_fn(n_bands, fs, f_min, f_max, self.max_index)
|
| 255 |
-
|
| 256 |
-
weight_per_bin = torch.sum(self.filterbank, dim=0, keepdim=True)
|
| 257 |
-
normalized_mel_fb = self.filterbank / weight_per_bin
|
| 258 |
-
|
| 259 |
-
freq_weights = []
|
| 260 |
-
band_specs = []
|
| 261 |
-
for i in range(self.n_bands):
|
| 262 |
-
active_bins = torch.nonzero(self.filterbank[i, :]).squeeze().tolist()
|
| 263 |
-
if isinstance(active_bins, int):
|
| 264 |
-
active_bins = (active_bins, active_bins)
|
| 265 |
-
if len(active_bins) == 0:
|
| 266 |
-
continue
|
| 267 |
-
start_index = active_bins[0]
|
| 268 |
-
end_index = active_bins[-1] + 1
|
| 269 |
-
band_specs.append((start_index, end_index))
|
| 270 |
-
freq_weights.append(normalized_mel_fb[i, start_index:end_index])
|
| 271 |
-
|
| 272 |
-
self.freq_weights = freq_weights
|
| 273 |
-
self.band_specs = band_specs
|
| 274 |
-
|
| 275 |
-
def get_band_specs(self):
|
| 276 |
-
return self.band_specs
|
| 277 |
-
|
| 278 |
-
def get_freq_weights(self):
|
| 279 |
-
return self.freq_weights
|
| 280 |
-
|
| 281 |
-
def save_to_file(self, dir_path: str) -> None:
|
| 282 |
-
|
| 283 |
-
os.makedirs(dir_path, exist_ok=True)
|
| 284 |
-
|
| 285 |
-
import pickle
|
| 286 |
-
|
| 287 |
-
with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f:
|
| 288 |
-
pickle.dump(
|
| 289 |
-
{
|
| 290 |
-
"band_specs": self.band_specs,
|
| 291 |
-
"freq_weights": self.freq_weights,
|
| 292 |
-
"filterbank": self.filterbank,
|
| 293 |
-
},
|
| 294 |
-
f,
|
| 295 |
-
)
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs):
|
| 299 |
-
fb = taF.melscale_fbanks(
|
| 300 |
-
n_mels=n_bands,
|
| 301 |
-
sample_rate=fs,
|
| 302 |
-
f_min=f_min,
|
| 303 |
-
f_max=f_max,
|
| 304 |
-
n_freqs=n_freqs,
|
| 305 |
-
).T
|
| 306 |
-
|
| 307 |
-
fb[0, 0] = 1.0
|
| 308 |
-
|
| 309 |
-
return fb
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
class MelBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 313 |
-
def __init__(
|
| 314 |
-
self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
|
| 315 |
-
) -> None:
|
| 316 |
-
super().__init__(
|
| 317 |
-
fbank_fn=mel_filterbank,
|
| 318 |
-
nfft=nfft,
|
| 319 |
-
fs=fs,
|
| 320 |
-
n_bands=n_bands,
|
| 321 |
-
f_min=f_min,
|
| 322 |
-
f_max=f_max,
|
| 323 |
-
)
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs, scale="constant"):
|
| 327 |
-
|
| 328 |
-
nfft = 2 * (n_freqs - 1)
|
| 329 |
-
df = fs / nfft
|
| 330 |
-
f_max = f_max or fs / 2
|
| 331 |
-
f_min = f_min or 0
|
| 332 |
-
f_min = fs / nfft
|
| 333 |
-
|
| 334 |
-
n_octaves = np.log2(f_max / f_min)
|
| 335 |
-
n_octaves_per_band = n_octaves / n_bands
|
| 336 |
-
bandwidth_mult = np.power(2.0, n_octaves_per_band)
|
| 337 |
-
|
| 338 |
-
low_midi = max(0, hz_to_midi(f_min))
|
| 339 |
-
high_midi = hz_to_midi(f_max)
|
| 340 |
-
midi_points = np.linspace(low_midi, high_midi, n_bands)
|
| 341 |
-
hz_pts = midi_to_hz(midi_points)
|
| 342 |
-
|
| 343 |
-
low_pts = hz_pts / bandwidth_mult
|
| 344 |
-
high_pts = hz_pts * bandwidth_mult
|
| 345 |
-
|
| 346 |
-
low_bins = np.floor(low_pts / df).astype(int)
|
| 347 |
-
high_bins = np.ceil(high_pts / df).astype(int)
|
| 348 |
-
|
| 349 |
-
fb = np.zeros((n_bands, n_freqs))
|
| 350 |
-
|
| 351 |
-
for i in range(n_bands):
|
| 352 |
-
fb[i, low_bins[i] : high_bins[i] + 1] = 1.0
|
| 353 |
-
|
| 354 |
-
fb[0, : low_bins[0]] = 1.0
|
| 355 |
-
fb[-1, high_bins[-1] + 1 :] = 1.0
|
| 356 |
-
|
| 357 |
-
return torch.as_tensor(fb)
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
class MusicalBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 361 |
-
def __init__(
|
| 362 |
-
self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
|
| 363 |
-
) -> None:
|
| 364 |
-
super().__init__(
|
| 365 |
-
fbank_fn=musical_filterbank,
|
| 366 |
-
nfft=nfft,
|
| 367 |
-
fs=fs,
|
| 368 |
-
n_bands=n_bands,
|
| 369 |
-
f_min=f_min,
|
| 370 |
-
f_max=f_max,
|
| 371 |
-
)
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
def bark_filterbank(n_bands, fs, f_min, f_max, n_freqs):
|
| 375 |
-
nfft = 2 * (n_freqs - 1)
|
| 376 |
-
fb, _ = bark_fbanks.bark_filter_banks(
|
| 377 |
-
nfilts=n_bands,
|
| 378 |
-
nfft=nfft,
|
| 379 |
-
fs=fs,
|
| 380 |
-
low_freq=f_min,
|
| 381 |
-
high_freq=f_max,
|
| 382 |
-
scale="constant",
|
| 383 |
-
)
|
| 384 |
-
|
| 385 |
-
return torch.as_tensor(fb)
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
class BarkBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 389 |
-
def __init__(
|
| 390 |
-
self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
|
| 391 |
-
) -> None:
|
| 392 |
-
super().__init__(
|
| 393 |
-
fbank_fn=bark_filterbank,
|
| 394 |
-
nfft=nfft,
|
| 395 |
-
fs=fs,
|
| 396 |
-
n_bands=n_bands,
|
| 397 |
-
f_min=f_min,
|
| 398 |
-
f_max=f_max,
|
| 399 |
-
)
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
def triangular_bark_filterbank(n_bands, fs, f_min, f_max, n_freqs):
|
| 403 |
-
|
| 404 |
-
all_freqs = torch.linspace(0, fs // 2, n_freqs)
|
| 405 |
-
|
| 406 |
-
m_min = hz2bark(f_min)
|
| 407 |
-
m_max = hz2bark(f_max)
|
| 408 |
-
|
| 409 |
-
m_pts = torch.linspace(m_min, m_max, n_bands + 2)
|
| 410 |
-
f_pts = 600 * torch.sinh(m_pts / 6)
|
| 411 |
-
|
| 412 |
-
fb = _create_triangular_filterbank(all_freqs, f_pts)
|
| 413 |
-
|
| 414 |
-
fb = fb.T
|
| 415 |
-
|
| 416 |
-
first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
|
| 417 |
-
first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
|
| 418 |
-
|
| 419 |
-
fb[first_active_band, :first_active_bin] = 1.0
|
| 420 |
-
|
| 421 |
-
return fb
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
class TriangularBarkBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 425 |
-
def __init__(
|
| 426 |
-
self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
|
| 427 |
-
) -> None:
|
| 428 |
-
super().__init__(
|
| 429 |
-
fbank_fn=triangular_bark_filterbank,
|
| 430 |
-
nfft=nfft,
|
| 431 |
-
fs=fs,
|
| 432 |
-
n_bands=n_bands,
|
| 433 |
-
f_min=f_min,
|
| 434 |
-
f_max=f_max,
|
| 435 |
-
)
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
def minibark_filterbank(n_bands, fs, f_min, f_max, n_freqs):
|
| 439 |
-
fb = bark_filterbank(n_bands, fs, f_min, f_max, n_freqs)
|
| 440 |
-
|
| 441 |
-
fb[fb < np.sqrt(0.5)] = 0.0
|
| 442 |
-
|
| 443 |
-
return fb
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
class MiniBarkBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 447 |
-
def __init__(
|
| 448 |
-
self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
|
| 449 |
-
) -> None:
|
| 450 |
-
super().__init__(
|
| 451 |
-
fbank_fn=minibark_filterbank,
|
| 452 |
-
nfft=nfft,
|
| 453 |
-
fs=fs,
|
| 454 |
-
n_bands=n_bands,
|
| 455 |
-
f_min=f_min,
|
| 456 |
-
f_max=f_max,
|
| 457 |
-
)
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
def erb_filterbank(
|
| 461 |
-
n_bands: int,
|
| 462 |
-
fs: int,
|
| 463 |
-
f_min: float,
|
| 464 |
-
f_max: float,
|
| 465 |
-
n_freqs: int,
|
| 466 |
-
) -> Tensor:
|
| 467 |
-
A = (1000 * np.log(10)) / (24.7 * 4.37)
|
| 468 |
-
all_freqs = torch.linspace(0, fs // 2, n_freqs)
|
| 469 |
-
|
| 470 |
-
m_min = hz2erb(f_min)
|
| 471 |
-
m_max = hz2erb(f_max)
|
| 472 |
-
|
| 473 |
-
m_pts = torch.linspace(m_min, m_max, n_bands + 2)
|
| 474 |
-
f_pts = (torch.pow(10, (m_pts / A)) - 1) / 0.00437
|
| 475 |
-
|
| 476 |
-
fb = _create_triangular_filterbank(all_freqs, f_pts)
|
| 477 |
-
|
| 478 |
-
fb = fb.T
|
| 479 |
-
|
| 480 |
-
first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
|
| 481 |
-
first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
|
| 482 |
-
|
| 483 |
-
fb[first_active_band, :first_active_bin] = 1.0
|
| 484 |
-
|
| 485 |
-
return fb
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
class EquivalentRectangularBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 489 |
-
def __init__(
|
| 490 |
-
self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
|
| 491 |
-
) -> None:
|
| 492 |
-
super().__init__(
|
| 493 |
-
fbank_fn=erb_filterbank,
|
| 494 |
-
nfft=nfft,
|
| 495 |
-
fs=fs,
|
| 496 |
-
n_bands=n_bands,
|
| 497 |
-
f_min=f_min,
|
| 498 |
-
f_max=f_max,
|
| 499 |
-
)
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
if __name__ == "__main__":
|
| 503 |
-
import pandas as pd
|
| 504 |
-
|
| 505 |
-
band_defs = []
|
| 506 |
-
|
| 507 |
-
for bands in [VocalBandsplitSpecification]:
|
| 508 |
-
band_name = bands.__name__.replace("BandsplitSpecification", "")
|
| 509 |
-
|
| 510 |
-
mbs = bands(nfft=2048, fs=44100).get_band_specs()
|
| 511 |
-
|
| 512 |
-
for i, (f_min, f_max) in enumerate(mbs):
|
| 513 |
-
band_defs.append(
|
| 514 |
-
{"band": band_name, "band_index": i, "f_min": f_min, "f_max": f_max}
|
| 515 |
-
)
|
| 516 |
-
|
| 517 |
-
df = pd.DataFrame(band_defs)
|
| 518 |
-
df.to_csv("vox7bands.csv", index=False)
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from abc import abstractmethod
|
| 3 |
+
from typing import Any, Callable
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from librosa import hz_to_midi, midi_to_hz
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
from torchaudio import functional as taF
|
| 10 |
+
from spafe.fbanks import bark_fbanks
|
| 11 |
+
from spafe.utils.converters import erb2hz, hz2bark, hz2erb
|
| 12 |
+
from torchaudio.functional.functional import _create_triangular_filterbank
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def band_widths_from_specs(band_specs):
|
| 16 |
+
return [e - i for i, e in band_specs]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def check_nonzero_bandwidth(band_specs):
|
| 20 |
+
for fstart, fend in band_specs:
|
| 21 |
+
if fend - fstart <= 0:
|
| 22 |
+
raise ValueError("Bands cannot be zero-width")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def check_no_overlap(band_specs):
|
| 26 |
+
fend_prev = -1
|
| 27 |
+
for fstart_curr, fend_curr in band_specs:
|
| 28 |
+
if fstart_curr <= fend_prev:
|
| 29 |
+
raise ValueError("Bands cannot overlap")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def check_no_gap(band_specs):
|
| 33 |
+
fstart, _ = band_specs[0]
|
| 34 |
+
assert fstart == 0
|
| 35 |
+
|
| 36 |
+
fend_prev = -1
|
| 37 |
+
for fstart_curr, fend_curr in band_specs:
|
| 38 |
+
if fstart_curr - fend_prev > 1:
|
| 39 |
+
raise ValueError("Bands cannot leave gap")
|
| 40 |
+
fend_prev = fend_curr
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class BandsplitSpecification:
|
| 44 |
+
def __init__(self, nfft: int, fs: int) -> None:
|
| 45 |
+
self.fs = fs
|
| 46 |
+
self.nfft = nfft
|
| 47 |
+
self.nyquist = fs / 2
|
| 48 |
+
self.max_index = nfft // 2 + 1
|
| 49 |
+
|
| 50 |
+
self.split500 = self.hertz_to_index(500)
|
| 51 |
+
self.split1k = self.hertz_to_index(1000)
|
| 52 |
+
self.split2k = self.hertz_to_index(2000)
|
| 53 |
+
self.split4k = self.hertz_to_index(4000)
|
| 54 |
+
self.split8k = self.hertz_to_index(8000)
|
| 55 |
+
self.split16k = self.hertz_to_index(16000)
|
| 56 |
+
self.split20k = self.hertz_to_index(20000)
|
| 57 |
+
|
| 58 |
+
self.above20k = [(self.split20k, self.max_index)]
|
| 59 |
+
self.above16k = [(self.split16k, self.split20k)] + self.above20k
|
| 60 |
+
|
| 61 |
+
def index_to_hertz(self, index: int):
|
| 62 |
+
return index * self.fs / self.nfft
|
| 63 |
+
|
| 64 |
+
def hertz_to_index(self, hz: float, round: bool = True):
|
| 65 |
+
index = hz * self.nfft / self.fs
|
| 66 |
+
|
| 67 |
+
if round:
|
| 68 |
+
index = int(np.round(index))
|
| 69 |
+
|
| 70 |
+
return index
|
| 71 |
+
|
| 72 |
+
def get_band_specs_with_bandwidth(self, start_index, end_index, bandwidth_hz):
|
| 73 |
+
band_specs = []
|
| 74 |
+
lower = start_index
|
| 75 |
+
|
| 76 |
+
while lower < end_index:
|
| 77 |
+
upper = int(np.floor(lower + self.hertz_to_index(bandwidth_hz)))
|
| 78 |
+
upper = min(upper, end_index)
|
| 79 |
+
|
| 80 |
+
band_specs.append((lower, upper))
|
| 81 |
+
lower = upper
|
| 82 |
+
|
| 83 |
+
return band_specs
|
| 84 |
+
|
| 85 |
+
@abstractmethod
|
| 86 |
+
def get_band_specs(self):
|
| 87 |
+
raise NotImplementedError
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class VocalBandsplitSpecification(BandsplitSpecification):
|
| 91 |
+
def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
|
| 92 |
+
super().__init__(nfft=nfft, fs=fs)
|
| 93 |
+
|
| 94 |
+
self.version = version
|
| 95 |
+
|
| 96 |
+
def get_band_specs(self):
|
| 97 |
+
return getattr(self, f"version{self.version}")()
|
| 98 |
+
|
| 99 |
+
@property
|
| 100 |
+
def version1(self):
|
| 101 |
+
return self.get_band_specs_with_bandwidth(
|
| 102 |
+
start_index=0, end_index=self.max_index, bandwidth_hz=1000
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
def version2(self):
|
| 106 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 107 |
+
start_index=0, end_index=self.split16k, bandwidth_hz=1000
|
| 108 |
+
)
|
| 109 |
+
below20k = self.get_band_specs_with_bandwidth(
|
| 110 |
+
start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
return below16k + below20k + self.above20k
|
| 114 |
+
|
| 115 |
+
def version3(self):
|
| 116 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 117 |
+
start_index=0, end_index=self.split8k, bandwidth_hz=1000
|
| 118 |
+
)
|
| 119 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 120 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
return below8k + below16k + self.above16k
|
| 124 |
+
|
| 125 |
+
def version4(self):
|
| 126 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 127 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 128 |
+
)
|
| 129 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 130 |
+
start_index=self.split1k, end_index=self.split8k, bandwidth_hz=1000
|
| 131 |
+
)
|
| 132 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 133 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
return below1k + below8k + below16k + self.above16k
|
| 137 |
+
|
| 138 |
+
def version5(self):
|
| 139 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 140 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 141 |
+
)
|
| 142 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 143 |
+
start_index=self.split1k, end_index=self.split16k, bandwidth_hz=1000
|
| 144 |
+
)
|
| 145 |
+
below20k = self.get_band_specs_with_bandwidth(
|
| 146 |
+
start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
|
| 147 |
+
)
|
| 148 |
+
return below1k + below16k + below20k + self.above20k
|
| 149 |
+
|
| 150 |
+
def version6(self):
|
| 151 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 152 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 153 |
+
)
|
| 154 |
+
below4k = self.get_band_specs_with_bandwidth(
|
| 155 |
+
start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
|
| 156 |
+
)
|
| 157 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 158 |
+
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
|
| 159 |
+
)
|
| 160 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 161 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
|
| 162 |
+
)
|
| 163 |
+
return below1k + below4k + below8k + below16k + self.above16k
|
| 164 |
+
|
| 165 |
+
def version7(self):
|
| 166 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 167 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 168 |
+
)
|
| 169 |
+
below4k = self.get_band_specs_with_bandwidth(
|
| 170 |
+
start_index=self.split1k, end_index=self.split4k, bandwidth_hz=250
|
| 171 |
+
)
|
| 172 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 173 |
+
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
|
| 174 |
+
)
|
| 175 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 176 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
|
| 177 |
+
)
|
| 178 |
+
below20k = self.get_band_specs_with_bandwidth(
|
| 179 |
+
start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
|
| 180 |
+
)
|
| 181 |
+
return below1k + below4k + below8k + below16k + below20k + self.above20k
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class OtherBandsplitSpecification(VocalBandsplitSpecification):
|
| 185 |
+
def __init__(self, nfft: int, fs: int) -> None:
|
| 186 |
+
super().__init__(nfft=nfft, fs=fs, version="7")
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class BassBandsplitSpecification(BandsplitSpecification):
|
| 190 |
+
def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
|
| 191 |
+
super().__init__(nfft=nfft, fs=fs)
|
| 192 |
+
|
| 193 |
+
def get_band_specs(self):
|
| 194 |
+
below500 = self.get_band_specs_with_bandwidth(
|
| 195 |
+
start_index=0, end_index=self.split500, bandwidth_hz=50
|
| 196 |
+
)
|
| 197 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 198 |
+
start_index=self.split500, end_index=self.split1k, bandwidth_hz=100
|
| 199 |
+
)
|
| 200 |
+
below4k = self.get_band_specs_with_bandwidth(
|
| 201 |
+
start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
|
| 202 |
+
)
|
| 203 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 204 |
+
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
|
| 205 |
+
)
|
| 206 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 207 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
|
| 208 |
+
)
|
| 209 |
+
above16k = [(self.split16k, self.max_index)]
|
| 210 |
+
|
| 211 |
+
return below500 + below1k + below4k + below8k + below16k + above16k
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class DrumBandsplitSpecification(BandsplitSpecification):
|
| 215 |
+
def __init__(self, nfft: int, fs: int) -> None:
|
| 216 |
+
super().__init__(nfft=nfft, fs=fs)
|
| 217 |
+
|
| 218 |
+
def get_band_specs(self):
|
| 219 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 220 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=50
|
| 221 |
+
)
|
| 222 |
+
below2k = self.get_band_specs_with_bandwidth(
|
| 223 |
+
start_index=self.split1k, end_index=self.split2k, bandwidth_hz=100
|
| 224 |
+
)
|
| 225 |
+
below4k = self.get_band_specs_with_bandwidth(
|
| 226 |
+
start_index=self.split2k, end_index=self.split4k, bandwidth_hz=250
|
| 227 |
+
)
|
| 228 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 229 |
+
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
|
| 230 |
+
)
|
| 231 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 232 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
|
| 233 |
+
)
|
| 234 |
+
above16k = [(self.split16k, self.max_index)]
|
| 235 |
+
|
| 236 |
+
return below1k + below2k + below4k + below8k + below16k + above16k
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class PerceptualBandsplitSpecification(BandsplitSpecification):
|
| 240 |
+
def __init__(
|
| 241 |
+
self,
|
| 242 |
+
nfft: int,
|
| 243 |
+
fs: int,
|
| 244 |
+
fbank_fn: Callable[[int, int, float, float, int], torch.Tensor],
|
| 245 |
+
n_bands: int,
|
| 246 |
+
f_min: float = 0.0,
|
| 247 |
+
f_max: float = None,
|
| 248 |
+
) -> None:
|
| 249 |
+
super().__init__(nfft=nfft, fs=fs)
|
| 250 |
+
self.n_bands = n_bands
|
| 251 |
+
if f_max is None:
|
| 252 |
+
f_max = fs / 2
|
| 253 |
+
|
| 254 |
+
self.filterbank = fbank_fn(n_bands, fs, f_min, f_max, self.max_index)
|
| 255 |
+
|
| 256 |
+
weight_per_bin = torch.sum(self.filterbank, dim=0, keepdim=True)
|
| 257 |
+
normalized_mel_fb = self.filterbank / weight_per_bin
|
| 258 |
+
|
| 259 |
+
freq_weights = []
|
| 260 |
+
band_specs = []
|
| 261 |
+
for i in range(self.n_bands):
|
| 262 |
+
active_bins = torch.nonzero(self.filterbank[i, :]).squeeze().tolist()
|
| 263 |
+
if isinstance(active_bins, int):
|
| 264 |
+
active_bins = (active_bins, active_bins)
|
| 265 |
+
if len(active_bins) == 0:
|
| 266 |
+
continue
|
| 267 |
+
start_index = active_bins[0]
|
| 268 |
+
end_index = active_bins[-1] + 1
|
| 269 |
+
band_specs.append((start_index, end_index))
|
| 270 |
+
freq_weights.append(normalized_mel_fb[i, start_index:end_index])
|
| 271 |
+
|
| 272 |
+
self.freq_weights = freq_weights
|
| 273 |
+
self.band_specs = band_specs
|
| 274 |
+
|
| 275 |
+
def get_band_specs(self):
|
| 276 |
+
return self.band_specs
|
| 277 |
+
|
| 278 |
+
def get_freq_weights(self):
|
| 279 |
+
return self.freq_weights
|
| 280 |
+
|
| 281 |
+
def save_to_file(self, dir_path: str) -> None:
|
| 282 |
+
|
| 283 |
+
os.makedirs(dir_path, exist_ok=True)
|
| 284 |
+
|
| 285 |
+
import pickle
|
| 286 |
+
|
| 287 |
+
with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f:
|
| 288 |
+
pickle.dump(
|
| 289 |
+
{
|
| 290 |
+
"band_specs": self.band_specs,
|
| 291 |
+
"freq_weights": self.freq_weights,
|
| 292 |
+
"filterbank": self.filterbank,
|
| 293 |
+
},
|
| 294 |
+
f,
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs):
|
| 299 |
+
fb = taF.melscale_fbanks(
|
| 300 |
+
n_mels=n_bands,
|
| 301 |
+
sample_rate=fs,
|
| 302 |
+
f_min=f_min,
|
| 303 |
+
f_max=f_max,
|
| 304 |
+
n_freqs=n_freqs,
|
| 305 |
+
).T
|
| 306 |
+
|
| 307 |
+
fb[0, 0] = 1.0
|
| 308 |
+
|
| 309 |
+
return fb
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
class MelBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 313 |
+
def __init__(
|
| 314 |
+
self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
|
| 315 |
+
) -> None:
|
| 316 |
+
super().__init__(
|
| 317 |
+
fbank_fn=mel_filterbank,
|
| 318 |
+
nfft=nfft,
|
| 319 |
+
fs=fs,
|
| 320 |
+
n_bands=n_bands,
|
| 321 |
+
f_min=f_min,
|
| 322 |
+
f_max=f_max,
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs, scale="constant"):
|
| 327 |
+
|
| 328 |
+
nfft = 2 * (n_freqs - 1)
|
| 329 |
+
df = fs / nfft
|
| 330 |
+
f_max = f_max or fs / 2
|
| 331 |
+
f_min = f_min or 0
|
| 332 |
+
f_min = fs / nfft
|
| 333 |
+
|
| 334 |
+
n_octaves = np.log2(f_max / f_min)
|
| 335 |
+
n_octaves_per_band = n_octaves / n_bands
|
| 336 |
+
bandwidth_mult = np.power(2.0, n_octaves_per_band)
|
| 337 |
+
|
| 338 |
+
low_midi = max(0, hz_to_midi(f_min))
|
| 339 |
+
high_midi = hz_to_midi(f_max)
|
| 340 |
+
midi_points = np.linspace(low_midi, high_midi, n_bands)
|
| 341 |
+
hz_pts = midi_to_hz(midi_points)
|
| 342 |
+
|
| 343 |
+
low_pts = hz_pts / bandwidth_mult
|
| 344 |
+
high_pts = hz_pts * bandwidth_mult
|
| 345 |
+
|
| 346 |
+
low_bins = np.floor(low_pts / df).astype(int)
|
| 347 |
+
high_bins = np.ceil(high_pts / df).astype(int)
|
| 348 |
+
|
| 349 |
+
fb = np.zeros((n_bands, n_freqs))
|
| 350 |
+
|
| 351 |
+
for i in range(n_bands):
|
| 352 |
+
fb[i, low_bins[i] : high_bins[i] + 1] = 1.0
|
| 353 |
+
|
| 354 |
+
fb[0, : low_bins[0]] = 1.0
|
| 355 |
+
fb[-1, high_bins[-1] + 1 :] = 1.0
|
| 356 |
+
|
| 357 |
+
return torch.as_tensor(fb)
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
class MusicalBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 361 |
+
def __init__(
|
| 362 |
+
self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
|
| 363 |
+
) -> None:
|
| 364 |
+
super().__init__(
|
| 365 |
+
fbank_fn=musical_filterbank,
|
| 366 |
+
nfft=nfft,
|
| 367 |
+
fs=fs,
|
| 368 |
+
n_bands=n_bands,
|
| 369 |
+
f_min=f_min,
|
| 370 |
+
f_max=f_max,
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def bark_filterbank(n_bands, fs, f_min, f_max, n_freqs):
|
| 375 |
+
nfft = 2 * (n_freqs - 1)
|
| 376 |
+
fb, _ = bark_fbanks.bark_filter_banks(
|
| 377 |
+
nfilts=n_bands,
|
| 378 |
+
nfft=nfft,
|
| 379 |
+
fs=fs,
|
| 380 |
+
low_freq=f_min,
|
| 381 |
+
high_freq=f_max,
|
| 382 |
+
scale="constant",
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
return torch.as_tensor(fb)
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
class BarkBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 389 |
+
def __init__(
|
| 390 |
+
self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
|
| 391 |
+
) -> None:
|
| 392 |
+
super().__init__(
|
| 393 |
+
fbank_fn=bark_filterbank,
|
| 394 |
+
nfft=nfft,
|
| 395 |
+
fs=fs,
|
| 396 |
+
n_bands=n_bands,
|
| 397 |
+
f_min=f_min,
|
| 398 |
+
f_max=f_max,
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def triangular_bark_filterbank(n_bands, fs, f_min, f_max, n_freqs):
|
| 403 |
+
|
| 404 |
+
all_freqs = torch.linspace(0, fs // 2, n_freqs)
|
| 405 |
+
|
| 406 |
+
m_min = hz2bark(f_min)
|
| 407 |
+
m_max = hz2bark(f_max)
|
| 408 |
+
|
| 409 |
+
m_pts = torch.linspace(m_min, m_max, n_bands + 2)
|
| 410 |
+
f_pts = 600 * torch.sinh(m_pts / 6)
|
| 411 |
+
|
| 412 |
+
fb = _create_triangular_filterbank(all_freqs, f_pts)
|
| 413 |
+
|
| 414 |
+
fb = fb.T
|
| 415 |
+
|
| 416 |
+
first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
|
| 417 |
+
first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
|
| 418 |
+
|
| 419 |
+
fb[first_active_band, :first_active_bin] = 1.0
|
| 420 |
+
|
| 421 |
+
return fb
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
class TriangularBarkBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 425 |
+
def __init__(
|
| 426 |
+
self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
|
| 427 |
+
) -> None:
|
| 428 |
+
super().__init__(
|
| 429 |
+
fbank_fn=triangular_bark_filterbank,
|
| 430 |
+
nfft=nfft,
|
| 431 |
+
fs=fs,
|
| 432 |
+
n_bands=n_bands,
|
| 433 |
+
f_min=f_min,
|
| 434 |
+
f_max=f_max,
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
def minibark_filterbank(n_bands, fs, f_min, f_max, n_freqs):
|
| 439 |
+
fb = bark_filterbank(n_bands, fs, f_min, f_max, n_freqs)
|
| 440 |
+
|
| 441 |
+
fb[fb < np.sqrt(0.5)] = 0.0
|
| 442 |
+
|
| 443 |
+
return fb
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
class MiniBarkBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 447 |
+
def __init__(
|
| 448 |
+
self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
|
| 449 |
+
) -> None:
|
| 450 |
+
super().__init__(
|
| 451 |
+
fbank_fn=minibark_filterbank,
|
| 452 |
+
nfft=nfft,
|
| 453 |
+
fs=fs,
|
| 454 |
+
n_bands=n_bands,
|
| 455 |
+
f_min=f_min,
|
| 456 |
+
f_max=f_max,
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
def erb_filterbank(
|
| 461 |
+
n_bands: int,
|
| 462 |
+
fs: int,
|
| 463 |
+
f_min: float,
|
| 464 |
+
f_max: float,
|
| 465 |
+
n_freqs: int,
|
| 466 |
+
) -> Tensor:
|
| 467 |
+
A = (1000 * np.log(10)) / (24.7 * 4.37)
|
| 468 |
+
all_freqs = torch.linspace(0, fs // 2, n_freqs)
|
| 469 |
+
|
| 470 |
+
m_min = hz2erb(f_min)
|
| 471 |
+
m_max = hz2erb(f_max)
|
| 472 |
+
|
| 473 |
+
m_pts = torch.linspace(m_min, m_max, n_bands + 2)
|
| 474 |
+
f_pts = (torch.pow(10, (m_pts / A)) - 1) / 0.00437
|
| 475 |
+
|
| 476 |
+
fb = _create_triangular_filterbank(all_freqs, f_pts)
|
| 477 |
+
|
| 478 |
+
fb = fb.T
|
| 479 |
+
|
| 480 |
+
first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
|
| 481 |
+
first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
|
| 482 |
+
|
| 483 |
+
fb[first_active_band, :first_active_bin] = 1.0
|
| 484 |
+
|
| 485 |
+
return fb
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
class EquivalentRectangularBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 489 |
+
def __init__(
|
| 490 |
+
self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
|
| 491 |
+
) -> None:
|
| 492 |
+
super().__init__(
|
| 493 |
+
fbank_fn=erb_filterbank,
|
| 494 |
+
nfft=nfft,
|
| 495 |
+
fs=fs,
|
| 496 |
+
n_bands=n_bands,
|
| 497 |
+
f_min=f_min,
|
| 498 |
+
f_max=f_max,
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
if __name__ == "__main__":
|
| 503 |
+
import pandas as pd
|
| 504 |
+
|
| 505 |
+
band_defs = []
|
| 506 |
+
|
| 507 |
+
for bands in [VocalBandsplitSpecification]:
|
| 508 |
+
band_name = bands.__name__.replace("BandsplitSpecification", "")
|
| 509 |
+
|
| 510 |
+
mbs = bands(nfft=2048, fs=44100).get_band_specs()
|
| 511 |
+
|
| 512 |
+
for i, (f_min, f_max) in enumerate(mbs):
|
| 513 |
+
band_defs.append(
|
| 514 |
+
{"band": band_name, "band_index": i, "f_min": f_min, "f_max": f_max}
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
df = pd.DataFrame(band_defs)
|
| 518 |
+
df.to_csv("vox7bands.csv", index=False)
|
mvsepless/models/bandit/core/model/bsrnn/wrapper.py
CHANGED
|
@@ -1,828 +1,828 @@
|
|
| 1 |
-
from pprint import pprint
|
| 2 |
-
from typing import Dict, List, Optional, Tuple, Union
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
from torch import nn
|
| 6 |
-
|
| 7 |
-
from .._spectral import _SpectralComponent
|
| 8 |
-
from .utils import (
|
| 9 |
-
BarkBandsplitSpecification,
|
| 10 |
-
BassBandsplitSpecification,
|
| 11 |
-
DrumBandsplitSpecification,
|
| 12 |
-
EquivalentRectangularBandsplitSpecification,
|
| 13 |
-
MelBandsplitSpecification,
|
| 14 |
-
MusicalBandsplitSpecification,
|
| 15 |
-
OtherBandsplitSpecification,
|
| 16 |
-
TriangularBarkBandsplitSpecification,
|
| 17 |
-
VocalBandsplitSpecification,
|
| 18 |
-
)
|
| 19 |
-
from .core import (
|
| 20 |
-
MultiSourceMultiMaskBandSplitCoreConv,
|
| 21 |
-
MultiSourceMultiMaskBandSplitCoreRNN,
|
| 22 |
-
MultiSourceMultiMaskBandSplitCoreTransformer,
|
| 23 |
-
MultiSourceMultiPatchingMaskBandSplitCoreRNN,
|
| 24 |
-
SingleMaskBandsplitCoreRNN,
|
| 25 |
-
SingleMaskBandsplitCoreTransformer,
|
| 26 |
-
)
|
| 27 |
-
|
| 28 |
-
import pytorch_lightning as pl
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
def get_band_specs(band_specs, n_fft, fs, n_bands=None):
|
| 32 |
-
if band_specs in ["dnr:speech", "dnr:vox7", "musdb:vocals", "musdb:vox7"]:
|
| 33 |
-
bsm = VocalBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs()
|
| 34 |
-
freq_weights = None
|
| 35 |
-
overlapping_band = False
|
| 36 |
-
elif "tribark" in band_specs:
|
| 37 |
-
assert n_bands is not None
|
| 38 |
-
specs = TriangularBarkBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
|
| 39 |
-
bsm = specs.get_band_specs()
|
| 40 |
-
freq_weights = specs.get_freq_weights()
|
| 41 |
-
overlapping_band = True
|
| 42 |
-
elif "bark" in band_specs:
|
| 43 |
-
assert n_bands is not None
|
| 44 |
-
specs = BarkBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
|
| 45 |
-
bsm = specs.get_band_specs()
|
| 46 |
-
freq_weights = specs.get_freq_weights()
|
| 47 |
-
overlapping_band = True
|
| 48 |
-
elif "erb" in band_specs:
|
| 49 |
-
assert n_bands is not None
|
| 50 |
-
specs = EquivalentRectangularBandsplitSpecification(
|
| 51 |
-
nfft=n_fft, fs=fs, n_bands=n_bands
|
| 52 |
-
)
|
| 53 |
-
bsm = specs.get_band_specs()
|
| 54 |
-
freq_weights = specs.get_freq_weights()
|
| 55 |
-
overlapping_band = True
|
| 56 |
-
elif "musical" in band_specs:
|
| 57 |
-
assert n_bands is not None
|
| 58 |
-
specs = MusicalBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
|
| 59 |
-
bsm = specs.get_band_specs()
|
| 60 |
-
freq_weights = specs.get_freq_weights()
|
| 61 |
-
overlapping_band = True
|
| 62 |
-
elif band_specs == "dnr:mel" or "mel" in band_specs:
|
| 63 |
-
assert n_bands is not None
|
| 64 |
-
specs = MelBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
|
| 65 |
-
bsm = specs.get_band_specs()
|
| 66 |
-
freq_weights = specs.get_freq_weights()
|
| 67 |
-
overlapping_band = True
|
| 68 |
-
else:
|
| 69 |
-
raise NameError
|
| 70 |
-
|
| 71 |
-
return bsm, freq_weights, overlapping_band
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
def get_band_specs_map(band_specs_map, n_fft, fs, n_bands=None):
|
| 75 |
-
if band_specs_map == "musdb:all":
|
| 76 |
-
bsm = {
|
| 77 |
-
"vocals": VocalBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
|
| 78 |
-
"drums": DrumBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
|
| 79 |
-
"bass": BassBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
|
| 80 |
-
"other": OtherBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
|
| 81 |
-
}
|
| 82 |
-
freq_weights = None
|
| 83 |
-
overlapping_band = False
|
| 84 |
-
elif band_specs_map == "dnr:vox7":
|
| 85 |
-
bsm_, freq_weights, overlapping_band = get_band_specs(
|
| 86 |
-
"dnr:speech", n_fft, fs, n_bands
|
| 87 |
-
)
|
| 88 |
-
bsm = {"speech": bsm_, "music": bsm_, "effects": bsm_}
|
| 89 |
-
elif "dnr:vox7:" in band_specs_map:
|
| 90 |
-
stem = band_specs_map.split(":")[-1]
|
| 91 |
-
bsm_, freq_weights, overlapping_band = get_band_specs(
|
| 92 |
-
"dnr:speech", n_fft, fs, n_bands
|
| 93 |
-
)
|
| 94 |
-
bsm = {stem: bsm_}
|
| 95 |
-
else:
|
| 96 |
-
raise NameError
|
| 97 |
-
|
| 98 |
-
return bsm, freq_weights, overlapping_band
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
class BandSplitWrapperBase(pl.LightningModule):
|
| 102 |
-
bsrnn: nn.Module
|
| 103 |
-
|
| 104 |
-
def __init__(self, **kwargs):
|
| 105 |
-
super().__init__()
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
class SingleMaskMultiSourceBandSplitBase(BandSplitWrapperBase, _SpectralComponent):
|
| 109 |
-
def __init__(
|
| 110 |
-
self,
|
| 111 |
-
band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
|
| 112 |
-
fs: int = 44100,
|
| 113 |
-
n_fft: int = 2048,
|
| 114 |
-
win_length: Optional[int] = 2048,
|
| 115 |
-
hop_length: int = 512,
|
| 116 |
-
window_fn: str = "hann_window",
|
| 117 |
-
wkwargs: Optional[Dict] = None,
|
| 118 |
-
power: Optional[int] = None,
|
| 119 |
-
center: bool = True,
|
| 120 |
-
normalized: bool = True,
|
| 121 |
-
pad_mode: str = "constant",
|
| 122 |
-
onesided: bool = True,
|
| 123 |
-
n_bands: int = None,
|
| 124 |
-
) -> None:
|
| 125 |
-
super().__init__(
|
| 126 |
-
n_fft=n_fft,
|
| 127 |
-
win_length=win_length,
|
| 128 |
-
hop_length=hop_length,
|
| 129 |
-
window_fn=window_fn,
|
| 130 |
-
wkwargs=wkwargs,
|
| 131 |
-
power=power,
|
| 132 |
-
center=center,
|
| 133 |
-
normalized=normalized,
|
| 134 |
-
pad_mode=pad_mode,
|
| 135 |
-
onesided=onesided,
|
| 136 |
-
)
|
| 137 |
-
|
| 138 |
-
if isinstance(band_specs_map, str):
|
| 139 |
-
self.band_specs_map, self.freq_weights, self.overlapping_band = (
|
| 140 |
-
get_band_specs_map(band_specs_map, n_fft, fs, n_bands=n_bands)
|
| 141 |
-
)
|
| 142 |
-
|
| 143 |
-
self.stems = list(self.band_specs_map.keys())
|
| 144 |
-
|
| 145 |
-
def forward(self, batch):
|
| 146 |
-
audio = batch["audio"]
|
| 147 |
-
|
| 148 |
-
with torch.no_grad():
|
| 149 |
-
batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in audio}
|
| 150 |
-
|
| 151 |
-
X = batch["spectrogram"]["mixture"]
|
| 152 |
-
length = batch["audio"]["mixture"].shape[-1]
|
| 153 |
-
|
| 154 |
-
output = {"spectrogram": {}, "audio": {}}
|
| 155 |
-
|
| 156 |
-
for stem, bsrnn in self.bsrnn.items():
|
| 157 |
-
S = bsrnn(X)
|
| 158 |
-
s = self.istft(S, length)
|
| 159 |
-
output["spectrogram"][stem] = S
|
| 160 |
-
output["audio"][stem] = s
|
| 161 |
-
|
| 162 |
-
return batch, output
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
class MultiMaskMultiSourceBandSplitBase(BandSplitWrapperBase, _SpectralComponent):
|
| 166 |
-
def __init__(
|
| 167 |
-
self,
|
| 168 |
-
stems: List[str],
|
| 169 |
-
band_specs: Union[str, List[Tuple[float, float]]],
|
| 170 |
-
fs: int = 44100,
|
| 171 |
-
n_fft: int = 2048,
|
| 172 |
-
win_length: Optional[int] = 2048,
|
| 173 |
-
hop_length: int = 512,
|
| 174 |
-
window_fn: str = "hann_window",
|
| 175 |
-
wkwargs: Optional[Dict] = None,
|
| 176 |
-
power: Optional[int] = None,
|
| 177 |
-
center: bool = True,
|
| 178 |
-
normalized: bool = True,
|
| 179 |
-
pad_mode: str = "constant",
|
| 180 |
-
onesided: bool = True,
|
| 181 |
-
n_bands: int = None,
|
| 182 |
-
) -> None:
|
| 183 |
-
super().__init__(
|
| 184 |
-
n_fft=n_fft,
|
| 185 |
-
win_length=win_length,
|
| 186 |
-
hop_length=hop_length,
|
| 187 |
-
window_fn=window_fn,
|
| 188 |
-
wkwargs=wkwargs,
|
| 189 |
-
power=power,
|
| 190 |
-
center=center,
|
| 191 |
-
normalized=normalized,
|
| 192 |
-
pad_mode=pad_mode,
|
| 193 |
-
onesided=onesided,
|
| 194 |
-
)
|
| 195 |
-
|
| 196 |
-
if isinstance(band_specs, str):
|
| 197 |
-
self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs(
|
| 198 |
-
band_specs, n_fft, fs, n_bands
|
| 199 |
-
)
|
| 200 |
-
|
| 201 |
-
self.stems = stems
|
| 202 |
-
|
| 203 |
-
def forward(self, batch):
|
| 204 |
-
audio = batch["audio"]
|
| 205 |
-
cond = batch.get("condition", None)
|
| 206 |
-
with torch.no_grad():
|
| 207 |
-
batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in audio}
|
| 208 |
-
|
| 209 |
-
X = batch["spectrogram"]["mixture"]
|
| 210 |
-
length = batch["audio"]["mixture"].shape[-1]
|
| 211 |
-
|
| 212 |
-
output = self.bsrnn(X, cond=cond)
|
| 213 |
-
output["audio"] = {}
|
| 214 |
-
|
| 215 |
-
for stem, S in output["spectrogram"].items():
|
| 216 |
-
s = self.istft(S, length)
|
| 217 |
-
output["audio"][stem] = s
|
| 218 |
-
|
| 219 |
-
return batch, output
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
class MultiMaskMultiSourceBandSplitBaseSimple(BandSplitWrapperBase, _SpectralComponent):
|
| 223 |
-
def __init__(
|
| 224 |
-
self,
|
| 225 |
-
stems: List[str],
|
| 226 |
-
band_specs: Union[str, List[Tuple[float, float]]],
|
| 227 |
-
fs: int = 44100,
|
| 228 |
-
n_fft: int = 2048,
|
| 229 |
-
win_length: Optional[int] = 2048,
|
| 230 |
-
hop_length: int = 512,
|
| 231 |
-
window_fn: str = "hann_window",
|
| 232 |
-
wkwargs: Optional[Dict] = None,
|
| 233 |
-
power: Optional[int] = None,
|
| 234 |
-
center: bool = True,
|
| 235 |
-
normalized: bool = True,
|
| 236 |
-
pad_mode: str = "constant",
|
| 237 |
-
onesided: bool = True,
|
| 238 |
-
n_bands: int = None,
|
| 239 |
-
) -> None:
|
| 240 |
-
super().__init__(
|
| 241 |
-
n_fft=n_fft,
|
| 242 |
-
win_length=win_length,
|
| 243 |
-
hop_length=hop_length,
|
| 244 |
-
window_fn=window_fn,
|
| 245 |
-
wkwargs=wkwargs,
|
| 246 |
-
power=power,
|
| 247 |
-
center=center,
|
| 248 |
-
normalized=normalized,
|
| 249 |
-
pad_mode=pad_mode,
|
| 250 |
-
onesided=onesided,
|
| 251 |
-
)
|
| 252 |
-
|
| 253 |
-
if isinstance(band_specs, str):
|
| 254 |
-
self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs(
|
| 255 |
-
band_specs, n_fft, fs, n_bands
|
| 256 |
-
)
|
| 257 |
-
|
| 258 |
-
self.stems = stems
|
| 259 |
-
|
| 260 |
-
def forward(self, batch):
|
| 261 |
-
with torch.no_grad():
|
| 262 |
-
X = self.stft(batch)
|
| 263 |
-
length = batch.shape[-1]
|
| 264 |
-
output = self.bsrnn(X, cond=None)
|
| 265 |
-
res = []
|
| 266 |
-
for stem, S in output["spectrogram"].items():
|
| 267 |
-
s = self.istft(S, length)
|
| 268 |
-
res.append(s)
|
| 269 |
-
res = torch.stack(res, dim=1)
|
| 270 |
-
return res
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
class SingleMaskMultiSourceBandSplitRNN(SingleMaskMultiSourceBandSplitBase):
|
| 274 |
-
def __init__(
|
| 275 |
-
self,
|
| 276 |
-
in_channel: int,
|
| 277 |
-
band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
|
| 278 |
-
fs: int = 44100,
|
| 279 |
-
require_no_overlap: bool = False,
|
| 280 |
-
require_no_gap: bool = True,
|
| 281 |
-
normalize_channel_independently: bool = False,
|
| 282 |
-
treat_channel_as_feature: bool = True,
|
| 283 |
-
n_sqm_modules: int = 12,
|
| 284 |
-
emb_dim: int = 128,
|
| 285 |
-
rnn_dim: int = 256,
|
| 286 |
-
bidirectional: bool = True,
|
| 287 |
-
rnn_type: str = "LSTM",
|
| 288 |
-
mlp_dim: int = 512,
|
| 289 |
-
hidden_activation: str = "Tanh",
|
| 290 |
-
hidden_activation_kwargs: Optional[Dict] = None,
|
| 291 |
-
complex_mask: bool = True,
|
| 292 |
-
n_fft: int = 2048,
|
| 293 |
-
win_length: Optional[int] = 2048,
|
| 294 |
-
hop_length: int = 512,
|
| 295 |
-
window_fn: str = "hann_window",
|
| 296 |
-
wkwargs: Optional[Dict] = None,
|
| 297 |
-
power: Optional[int] = None,
|
| 298 |
-
center: bool = True,
|
| 299 |
-
normalized: bool = True,
|
| 300 |
-
pad_mode: str = "constant",
|
| 301 |
-
onesided: bool = True,
|
| 302 |
-
) -> None:
|
| 303 |
-
super().__init__(
|
| 304 |
-
band_specs_map=band_specs_map,
|
| 305 |
-
fs=fs,
|
| 306 |
-
n_fft=n_fft,
|
| 307 |
-
win_length=win_length,
|
| 308 |
-
hop_length=hop_length,
|
| 309 |
-
window_fn=window_fn,
|
| 310 |
-
wkwargs=wkwargs,
|
| 311 |
-
power=power,
|
| 312 |
-
center=center,
|
| 313 |
-
normalized=normalized,
|
| 314 |
-
pad_mode=pad_mode,
|
| 315 |
-
onesided=onesided,
|
| 316 |
-
)
|
| 317 |
-
|
| 318 |
-
self.bsrnn = nn.ModuleDict(
|
| 319 |
-
{
|
| 320 |
-
src: SingleMaskBandsplitCoreRNN(
|
| 321 |
-
band_specs=specs,
|
| 322 |
-
in_channel=in_channel,
|
| 323 |
-
require_no_overlap=require_no_overlap,
|
| 324 |
-
require_no_gap=require_no_gap,
|
| 325 |
-
normalize_channel_independently=normalize_channel_independently,
|
| 326 |
-
treat_channel_as_feature=treat_channel_as_feature,
|
| 327 |
-
n_sqm_modules=n_sqm_modules,
|
| 328 |
-
emb_dim=emb_dim,
|
| 329 |
-
rnn_dim=rnn_dim,
|
| 330 |
-
bidirectional=bidirectional,
|
| 331 |
-
rnn_type=rnn_type,
|
| 332 |
-
mlp_dim=mlp_dim,
|
| 333 |
-
hidden_activation=hidden_activation,
|
| 334 |
-
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 335 |
-
complex_mask=complex_mask,
|
| 336 |
-
)
|
| 337 |
-
for src, specs in self.band_specs_map.items()
|
| 338 |
-
}
|
| 339 |
-
)
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
class SingleMaskMultiSourceBandSplitTransformer(SingleMaskMultiSourceBandSplitBase):
|
| 343 |
-
def __init__(
|
| 344 |
-
self,
|
| 345 |
-
in_channel: int,
|
| 346 |
-
band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
|
| 347 |
-
fs: int = 44100,
|
| 348 |
-
require_no_overlap: bool = False,
|
| 349 |
-
require_no_gap: bool = True,
|
| 350 |
-
normalize_channel_independently: bool = False,
|
| 351 |
-
treat_channel_as_feature: bool = True,
|
| 352 |
-
n_sqm_modules: int = 12,
|
| 353 |
-
emb_dim: int = 128,
|
| 354 |
-
rnn_dim: int = 256,
|
| 355 |
-
bidirectional: bool = True,
|
| 356 |
-
tf_dropout: float = 0.0,
|
| 357 |
-
mlp_dim: int = 512,
|
| 358 |
-
hidden_activation: str = "Tanh",
|
| 359 |
-
hidden_activation_kwargs: Optional[Dict] = None,
|
| 360 |
-
complex_mask: bool = True,
|
| 361 |
-
n_fft: int = 2048,
|
| 362 |
-
win_length: Optional[int] = 2048,
|
| 363 |
-
hop_length: int = 512,
|
| 364 |
-
window_fn: str = "hann_window",
|
| 365 |
-
wkwargs: Optional[Dict] = None,
|
| 366 |
-
power: Optional[int] = None,
|
| 367 |
-
center: bool = True,
|
| 368 |
-
normalized: bool = True,
|
| 369 |
-
pad_mode: str = "constant",
|
| 370 |
-
onesided: bool = True,
|
| 371 |
-
) -> None:
|
| 372 |
-
super().__init__(
|
| 373 |
-
band_specs_map=band_specs_map,
|
| 374 |
-
fs=fs,
|
| 375 |
-
n_fft=n_fft,
|
| 376 |
-
win_length=win_length,
|
| 377 |
-
hop_length=hop_length,
|
| 378 |
-
window_fn=window_fn,
|
| 379 |
-
wkwargs=wkwargs,
|
| 380 |
-
power=power,
|
| 381 |
-
center=center,
|
| 382 |
-
normalized=normalized,
|
| 383 |
-
pad_mode=pad_mode,
|
| 384 |
-
onesided=onesided,
|
| 385 |
-
)
|
| 386 |
-
|
| 387 |
-
self.bsrnn = nn.ModuleDict(
|
| 388 |
-
{
|
| 389 |
-
src: SingleMaskBandsplitCoreTransformer(
|
| 390 |
-
band_specs=specs,
|
| 391 |
-
in_channel=in_channel,
|
| 392 |
-
require_no_overlap=require_no_overlap,
|
| 393 |
-
require_no_gap=require_no_gap,
|
| 394 |
-
normalize_channel_independently=normalize_channel_independently,
|
| 395 |
-
treat_channel_as_feature=treat_channel_as_feature,
|
| 396 |
-
n_sqm_modules=n_sqm_modules,
|
| 397 |
-
emb_dim=emb_dim,
|
| 398 |
-
rnn_dim=rnn_dim,
|
| 399 |
-
bidirectional=bidirectional,
|
| 400 |
-
tf_dropout=tf_dropout,
|
| 401 |
-
mlp_dim=mlp_dim,
|
| 402 |
-
hidden_activation=hidden_activation,
|
| 403 |
-
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 404 |
-
complex_mask=complex_mask,
|
| 405 |
-
)
|
| 406 |
-
for src, specs in self.band_specs_map.items()
|
| 407 |
-
}
|
| 408 |
-
)
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
class MultiMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase):
|
| 412 |
-
def __init__(
|
| 413 |
-
self,
|
| 414 |
-
in_channel: int,
|
| 415 |
-
stems: List[str],
|
| 416 |
-
band_specs: Union[str, List[Tuple[float, float]]],
|
| 417 |
-
fs: int = 44100,
|
| 418 |
-
require_no_overlap: bool = False,
|
| 419 |
-
require_no_gap: bool = True,
|
| 420 |
-
normalize_channel_independently: bool = False,
|
| 421 |
-
treat_channel_as_feature: bool = True,
|
| 422 |
-
n_sqm_modules: int = 12,
|
| 423 |
-
emb_dim: int = 128,
|
| 424 |
-
rnn_dim: int = 256,
|
| 425 |
-
cond_dim: int = 0,
|
| 426 |
-
bidirectional: bool = True,
|
| 427 |
-
rnn_type: str = "LSTM",
|
| 428 |
-
mlp_dim: int = 512,
|
| 429 |
-
hidden_activation: str = "Tanh",
|
| 430 |
-
hidden_activation_kwargs: Optional[Dict] = None,
|
| 431 |
-
complex_mask: bool = True,
|
| 432 |
-
n_fft: int = 2048,
|
| 433 |
-
win_length: Optional[int] = 2048,
|
| 434 |
-
hop_length: int = 512,
|
| 435 |
-
window_fn: str = "hann_window",
|
| 436 |
-
wkwargs: Optional[Dict] = None,
|
| 437 |
-
power: Optional[int] = None,
|
| 438 |
-
center: bool = True,
|
| 439 |
-
normalized: bool = True,
|
| 440 |
-
pad_mode: str = "constant",
|
| 441 |
-
onesided: bool = True,
|
| 442 |
-
n_bands: int = None,
|
| 443 |
-
use_freq_weights: bool = True,
|
| 444 |
-
normalize_input: bool = False,
|
| 445 |
-
mult_add_mask: bool = False,
|
| 446 |
-
freeze_encoder: bool = False,
|
| 447 |
-
) -> None:
|
| 448 |
-
super().__init__(
|
| 449 |
-
stems=stems,
|
| 450 |
-
band_specs=band_specs,
|
| 451 |
-
fs=fs,
|
| 452 |
-
n_fft=n_fft,
|
| 453 |
-
win_length=win_length,
|
| 454 |
-
hop_length=hop_length,
|
| 455 |
-
window_fn=window_fn,
|
| 456 |
-
wkwargs=wkwargs,
|
| 457 |
-
power=power,
|
| 458 |
-
center=center,
|
| 459 |
-
normalized=normalized,
|
| 460 |
-
pad_mode=pad_mode,
|
| 461 |
-
onesided=onesided,
|
| 462 |
-
n_bands=n_bands,
|
| 463 |
-
)
|
| 464 |
-
|
| 465 |
-
self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN(
|
| 466 |
-
stems=stems,
|
| 467 |
-
band_specs=self.band_specs,
|
| 468 |
-
in_channel=in_channel,
|
| 469 |
-
require_no_overlap=require_no_overlap,
|
| 470 |
-
require_no_gap=require_no_gap,
|
| 471 |
-
normalize_channel_independently=normalize_channel_independently,
|
| 472 |
-
treat_channel_as_feature=treat_channel_as_feature,
|
| 473 |
-
n_sqm_modules=n_sqm_modules,
|
| 474 |
-
emb_dim=emb_dim,
|
| 475 |
-
rnn_dim=rnn_dim,
|
| 476 |
-
bidirectional=bidirectional,
|
| 477 |
-
rnn_type=rnn_type,
|
| 478 |
-
mlp_dim=mlp_dim,
|
| 479 |
-
cond_dim=cond_dim,
|
| 480 |
-
hidden_activation=hidden_activation,
|
| 481 |
-
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 482 |
-
complex_mask=complex_mask,
|
| 483 |
-
overlapping_band=self.overlapping_band,
|
| 484 |
-
freq_weights=self.freq_weights,
|
| 485 |
-
n_freq=n_fft // 2 + 1,
|
| 486 |
-
use_freq_weights=use_freq_weights,
|
| 487 |
-
mult_add_mask=mult_add_mask,
|
| 488 |
-
)
|
| 489 |
-
|
| 490 |
-
self.normalize_input = normalize_input
|
| 491 |
-
self.cond_dim = cond_dim
|
| 492 |
-
|
| 493 |
-
if freeze_encoder:
|
| 494 |
-
for param in self.bsrnn.band_split.parameters():
|
| 495 |
-
param.requires_grad = False
|
| 496 |
-
|
| 497 |
-
for param in self.bsrnn.tf_model.parameters():
|
| 498 |
-
param.requires_grad = False
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
class MultiMaskMultiSourceBandSplitRNNSimple(MultiMaskMultiSourceBandSplitBaseSimple):
|
| 502 |
-
def __init__(
|
| 503 |
-
self,
|
| 504 |
-
in_channel: int,
|
| 505 |
-
stems: List[str],
|
| 506 |
-
band_specs: Union[str, List[Tuple[float, float]]],
|
| 507 |
-
fs: int = 44100,
|
| 508 |
-
require_no_overlap: bool = False,
|
| 509 |
-
require_no_gap: bool = True,
|
| 510 |
-
normalize_channel_independently: bool = False,
|
| 511 |
-
treat_channel_as_feature: bool = True,
|
| 512 |
-
n_sqm_modules: int = 12,
|
| 513 |
-
emb_dim: int = 128,
|
| 514 |
-
rnn_dim: int = 256,
|
| 515 |
-
cond_dim: int = 0,
|
| 516 |
-
bidirectional: bool = True,
|
| 517 |
-
rnn_type: str = "LSTM",
|
| 518 |
-
mlp_dim: int = 512,
|
| 519 |
-
hidden_activation: str = "Tanh",
|
| 520 |
-
hidden_activation_kwargs: Optional[Dict] = None,
|
| 521 |
-
complex_mask: bool = True,
|
| 522 |
-
n_fft: int = 2048,
|
| 523 |
-
win_length: Optional[int] = 2048,
|
| 524 |
-
hop_length: int = 512,
|
| 525 |
-
window_fn: str = "hann_window",
|
| 526 |
-
wkwargs: Optional[Dict] = None,
|
| 527 |
-
power: Optional[int] = None,
|
| 528 |
-
center: bool = True,
|
| 529 |
-
normalized: bool = True,
|
| 530 |
-
pad_mode: str = "constant",
|
| 531 |
-
onesided: bool = True,
|
| 532 |
-
n_bands: int = None,
|
| 533 |
-
use_freq_weights: bool = True,
|
| 534 |
-
normalize_input: bool = False,
|
| 535 |
-
mult_add_mask: bool = False,
|
| 536 |
-
freeze_encoder: bool = False,
|
| 537 |
-
) -> None:
|
| 538 |
-
super().__init__(
|
| 539 |
-
stems=stems,
|
| 540 |
-
band_specs=band_specs,
|
| 541 |
-
fs=fs,
|
| 542 |
-
n_fft=n_fft,
|
| 543 |
-
win_length=win_length,
|
| 544 |
-
hop_length=hop_length,
|
| 545 |
-
window_fn=window_fn,
|
| 546 |
-
wkwargs=wkwargs,
|
| 547 |
-
power=power,
|
| 548 |
-
center=center,
|
| 549 |
-
normalized=normalized,
|
| 550 |
-
pad_mode=pad_mode,
|
| 551 |
-
onesided=onesided,
|
| 552 |
-
n_bands=n_bands,
|
| 553 |
-
)
|
| 554 |
-
|
| 555 |
-
self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN(
|
| 556 |
-
stems=stems,
|
| 557 |
-
band_specs=self.band_specs,
|
| 558 |
-
in_channel=in_channel,
|
| 559 |
-
require_no_overlap=require_no_overlap,
|
| 560 |
-
require_no_gap=require_no_gap,
|
| 561 |
-
normalize_channel_independently=normalize_channel_independently,
|
| 562 |
-
treat_channel_as_feature=treat_channel_as_feature,
|
| 563 |
-
n_sqm_modules=n_sqm_modules,
|
| 564 |
-
emb_dim=emb_dim,
|
| 565 |
-
rnn_dim=rnn_dim,
|
| 566 |
-
bidirectional=bidirectional,
|
| 567 |
-
rnn_type=rnn_type,
|
| 568 |
-
mlp_dim=mlp_dim,
|
| 569 |
-
cond_dim=cond_dim,
|
| 570 |
-
hidden_activation=hidden_activation,
|
| 571 |
-
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 572 |
-
complex_mask=complex_mask,
|
| 573 |
-
overlapping_band=self.overlapping_band,
|
| 574 |
-
freq_weights=self.freq_weights,
|
| 575 |
-
n_freq=n_fft // 2 + 1,
|
| 576 |
-
use_freq_weights=use_freq_weights,
|
| 577 |
-
mult_add_mask=mult_add_mask,
|
| 578 |
-
)
|
| 579 |
-
|
| 580 |
-
self.normalize_input = normalize_input
|
| 581 |
-
self.cond_dim = cond_dim
|
| 582 |
-
|
| 583 |
-
if freeze_encoder:
|
| 584 |
-
for param in self.bsrnn.band_split.parameters():
|
| 585 |
-
param.requires_grad = False
|
| 586 |
-
|
| 587 |
-
for param in self.bsrnn.tf_model.parameters():
|
| 588 |
-
param.requires_grad = False
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
class MultiMaskMultiSourceBandSplitTransformer(MultiMaskMultiSourceBandSplitBase):
|
| 592 |
-
def __init__(
|
| 593 |
-
self,
|
| 594 |
-
in_channel: int,
|
| 595 |
-
stems: List[str],
|
| 596 |
-
band_specs: Union[str, List[Tuple[float, float]]],
|
| 597 |
-
fs: int = 44100,
|
| 598 |
-
require_no_overlap: bool = False,
|
| 599 |
-
require_no_gap: bool = True,
|
| 600 |
-
normalize_channel_independently: bool = False,
|
| 601 |
-
treat_channel_as_feature: bool = True,
|
| 602 |
-
n_sqm_modules: int = 12,
|
| 603 |
-
emb_dim: int = 128,
|
| 604 |
-
rnn_dim: int = 256,
|
| 605 |
-
cond_dim: int = 0,
|
| 606 |
-
bidirectional: bool = True,
|
| 607 |
-
rnn_type: str = "LSTM",
|
| 608 |
-
mlp_dim: int = 512,
|
| 609 |
-
hidden_activation: str = "Tanh",
|
| 610 |
-
hidden_activation_kwargs: Optional[Dict] = None,
|
| 611 |
-
complex_mask: bool = True,
|
| 612 |
-
n_fft: int = 2048,
|
| 613 |
-
win_length: Optional[int] = 2048,
|
| 614 |
-
hop_length: int = 512,
|
| 615 |
-
window_fn: str = "hann_window",
|
| 616 |
-
wkwargs: Optional[Dict] = None,
|
| 617 |
-
power: Optional[int] = None,
|
| 618 |
-
center: bool = True,
|
| 619 |
-
normalized: bool = True,
|
| 620 |
-
pad_mode: str = "constant",
|
| 621 |
-
onesided: bool = True,
|
| 622 |
-
n_bands: int = None,
|
| 623 |
-
use_freq_weights: bool = True,
|
| 624 |
-
normalize_input: bool = False,
|
| 625 |
-
mult_add_mask: bool = False,
|
| 626 |
-
) -> None:
|
| 627 |
-
super().__init__(
|
| 628 |
-
stems=stems,
|
| 629 |
-
band_specs=band_specs,
|
| 630 |
-
fs=fs,
|
| 631 |
-
n_fft=n_fft,
|
| 632 |
-
win_length=win_length,
|
| 633 |
-
hop_length=hop_length,
|
| 634 |
-
window_fn=window_fn,
|
| 635 |
-
wkwargs=wkwargs,
|
| 636 |
-
power=power,
|
| 637 |
-
center=center,
|
| 638 |
-
normalized=normalized,
|
| 639 |
-
pad_mode=pad_mode,
|
| 640 |
-
onesided=onesided,
|
| 641 |
-
n_bands=n_bands,
|
| 642 |
-
)
|
| 643 |
-
|
| 644 |
-
self.bsrnn = MultiSourceMultiMaskBandSplitCoreTransformer(
|
| 645 |
-
stems=stems,
|
| 646 |
-
band_specs=self.band_specs,
|
| 647 |
-
in_channel=in_channel,
|
| 648 |
-
require_no_overlap=require_no_overlap,
|
| 649 |
-
require_no_gap=require_no_gap,
|
| 650 |
-
normalize_channel_independently=normalize_channel_independently,
|
| 651 |
-
treat_channel_as_feature=treat_channel_as_feature,
|
| 652 |
-
n_sqm_modules=n_sqm_modules,
|
| 653 |
-
emb_dim=emb_dim,
|
| 654 |
-
rnn_dim=rnn_dim,
|
| 655 |
-
bidirectional=bidirectional,
|
| 656 |
-
rnn_type=rnn_type,
|
| 657 |
-
mlp_dim=mlp_dim,
|
| 658 |
-
cond_dim=cond_dim,
|
| 659 |
-
hidden_activation=hidden_activation,
|
| 660 |
-
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 661 |
-
complex_mask=complex_mask,
|
| 662 |
-
overlapping_band=self.overlapping_band,
|
| 663 |
-
freq_weights=self.freq_weights,
|
| 664 |
-
n_freq=n_fft // 2 + 1,
|
| 665 |
-
use_freq_weights=use_freq_weights,
|
| 666 |
-
mult_add_mask=mult_add_mask,
|
| 667 |
-
)
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
class MultiMaskMultiSourceBandSplitConv(MultiMaskMultiSourceBandSplitBase):
|
| 671 |
-
def __init__(
|
| 672 |
-
self,
|
| 673 |
-
in_channel: int,
|
| 674 |
-
stems: List[str],
|
| 675 |
-
band_specs: Union[str, List[Tuple[float, float]]],
|
| 676 |
-
fs: int = 44100,
|
| 677 |
-
require_no_overlap: bool = False,
|
| 678 |
-
require_no_gap: bool = True,
|
| 679 |
-
normalize_channel_independently: bool = False,
|
| 680 |
-
treat_channel_as_feature: bool = True,
|
| 681 |
-
n_sqm_modules: int = 12,
|
| 682 |
-
emb_dim: int = 128,
|
| 683 |
-
rnn_dim: int = 256,
|
| 684 |
-
cond_dim: int = 0,
|
| 685 |
-
bidirectional: bool = True,
|
| 686 |
-
rnn_type: str = "LSTM",
|
| 687 |
-
mlp_dim: int = 512,
|
| 688 |
-
hidden_activation: str = "Tanh",
|
| 689 |
-
hidden_activation_kwargs: Optional[Dict] = None,
|
| 690 |
-
complex_mask: bool = True,
|
| 691 |
-
n_fft: int = 2048,
|
| 692 |
-
win_length: Optional[int] = 2048,
|
| 693 |
-
hop_length: int = 512,
|
| 694 |
-
window_fn: str = "hann_window",
|
| 695 |
-
wkwargs: Optional[Dict] = None,
|
| 696 |
-
power: Optional[int] = None,
|
| 697 |
-
center: bool = True,
|
| 698 |
-
normalized: bool = True,
|
| 699 |
-
pad_mode: str = "constant",
|
| 700 |
-
onesided: bool = True,
|
| 701 |
-
n_bands: int = None,
|
| 702 |
-
use_freq_weights: bool = True,
|
| 703 |
-
normalize_input: bool = False,
|
| 704 |
-
mult_add_mask: bool = False,
|
| 705 |
-
) -> None:
|
| 706 |
-
super().__init__(
|
| 707 |
-
stems=stems,
|
| 708 |
-
band_specs=band_specs,
|
| 709 |
-
fs=fs,
|
| 710 |
-
n_fft=n_fft,
|
| 711 |
-
win_length=win_length,
|
| 712 |
-
hop_length=hop_length,
|
| 713 |
-
window_fn=window_fn,
|
| 714 |
-
wkwargs=wkwargs,
|
| 715 |
-
power=power,
|
| 716 |
-
center=center,
|
| 717 |
-
normalized=normalized,
|
| 718 |
-
pad_mode=pad_mode,
|
| 719 |
-
onesided=onesided,
|
| 720 |
-
n_bands=n_bands,
|
| 721 |
-
)
|
| 722 |
-
|
| 723 |
-
self.bsrnn = MultiSourceMultiMaskBandSplitCoreConv(
|
| 724 |
-
stems=stems,
|
| 725 |
-
band_specs=self.band_specs,
|
| 726 |
-
in_channel=in_channel,
|
| 727 |
-
require_no_overlap=require_no_overlap,
|
| 728 |
-
require_no_gap=require_no_gap,
|
| 729 |
-
normalize_channel_independently=normalize_channel_independently,
|
| 730 |
-
treat_channel_as_feature=treat_channel_as_feature,
|
| 731 |
-
n_sqm_modules=n_sqm_modules,
|
| 732 |
-
emb_dim=emb_dim,
|
| 733 |
-
rnn_dim=rnn_dim,
|
| 734 |
-
bidirectional=bidirectional,
|
| 735 |
-
rnn_type=rnn_type,
|
| 736 |
-
mlp_dim=mlp_dim,
|
| 737 |
-
cond_dim=cond_dim,
|
| 738 |
-
hidden_activation=hidden_activation,
|
| 739 |
-
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 740 |
-
complex_mask=complex_mask,
|
| 741 |
-
overlapping_band=self.overlapping_band,
|
| 742 |
-
freq_weights=self.freq_weights,
|
| 743 |
-
n_freq=n_fft // 2 + 1,
|
| 744 |
-
use_freq_weights=use_freq_weights,
|
| 745 |
-
mult_add_mask=mult_add_mask,
|
| 746 |
-
)
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
class PatchingMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase):
|
| 750 |
-
def __init__(
|
| 751 |
-
self,
|
| 752 |
-
in_channel: int,
|
| 753 |
-
stems: List[str],
|
| 754 |
-
band_specs: Union[str, List[Tuple[float, float]]],
|
| 755 |
-
kernel_norm_mlp_version: int = 1,
|
| 756 |
-
mask_kernel_freq: int = 3,
|
| 757 |
-
mask_kernel_time: int = 3,
|
| 758 |
-
conv_kernel_freq: int = 1,
|
| 759 |
-
conv_kernel_time: int = 1,
|
| 760 |
-
fs: int = 44100,
|
| 761 |
-
require_no_overlap: bool = False,
|
| 762 |
-
require_no_gap: bool = True,
|
| 763 |
-
normalize_channel_independently: bool = False,
|
| 764 |
-
treat_channel_as_feature: bool = True,
|
| 765 |
-
n_sqm_modules: int = 12,
|
| 766 |
-
emb_dim: int = 128,
|
| 767 |
-
rnn_dim: int = 256,
|
| 768 |
-
bidirectional: bool = True,
|
| 769 |
-
rnn_type: str = "LSTM",
|
| 770 |
-
mlp_dim: int = 512,
|
| 771 |
-
hidden_activation: str = "Tanh",
|
| 772 |
-
hidden_activation_kwargs: Optional[Dict] = None,
|
| 773 |
-
complex_mask: bool = True,
|
| 774 |
-
n_fft: int = 2048,
|
| 775 |
-
win_length: Optional[int] = 2048,
|
| 776 |
-
hop_length: int = 512,
|
| 777 |
-
window_fn: str = "hann_window",
|
| 778 |
-
wkwargs: Optional[Dict] = None,
|
| 779 |
-
power: Optional[int] = None,
|
| 780 |
-
center: bool = True,
|
| 781 |
-
normalized: bool = True,
|
| 782 |
-
pad_mode: str = "constant",
|
| 783 |
-
onesided: bool = True,
|
| 784 |
-
n_bands: int = None,
|
| 785 |
-
) -> None:
|
| 786 |
-
super().__init__(
|
| 787 |
-
stems=stems,
|
| 788 |
-
band_specs=band_specs,
|
| 789 |
-
fs=fs,
|
| 790 |
-
n_fft=n_fft,
|
| 791 |
-
win_length=win_length,
|
| 792 |
-
hop_length=hop_length,
|
| 793 |
-
window_fn=window_fn,
|
| 794 |
-
wkwargs=wkwargs,
|
| 795 |
-
power=power,
|
| 796 |
-
center=center,
|
| 797 |
-
normalized=normalized,
|
| 798 |
-
pad_mode=pad_mode,
|
| 799 |
-
onesided=onesided,
|
| 800 |
-
n_bands=n_bands,
|
| 801 |
-
)
|
| 802 |
-
|
| 803 |
-
self.bsrnn = MultiSourceMultiPatchingMaskBandSplitCoreRNN(
|
| 804 |
-
stems=stems,
|
| 805 |
-
band_specs=self.band_specs,
|
| 806 |
-
in_channel=in_channel,
|
| 807 |
-
require_no_overlap=require_no_overlap,
|
| 808 |
-
require_no_gap=require_no_gap,
|
| 809 |
-
normalize_channel_independently=normalize_channel_independently,
|
| 810 |
-
treat_channel_as_feature=treat_channel_as_feature,
|
| 811 |
-
n_sqm_modules=n_sqm_modules,
|
| 812 |
-
emb_dim=emb_dim,
|
| 813 |
-
rnn_dim=rnn_dim,
|
| 814 |
-
bidirectional=bidirectional,
|
| 815 |
-
rnn_type=rnn_type,
|
| 816 |
-
mlp_dim=mlp_dim,
|
| 817 |
-
hidden_activation=hidden_activation,
|
| 818 |
-
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 819 |
-
complex_mask=complex_mask,
|
| 820 |
-
overlapping_band=self.overlapping_band,
|
| 821 |
-
freq_weights=self.freq_weights,
|
| 822 |
-
n_freq=n_fft // 2 + 1,
|
| 823 |
-
mask_kernel_freq=mask_kernel_freq,
|
| 824 |
-
mask_kernel_time=mask_kernel_time,
|
| 825 |
-
conv_kernel_freq=conv_kernel_freq,
|
| 826 |
-
conv_kernel_time=conv_kernel_time,
|
| 827 |
-
kernel_norm_mlp_version=kernel_norm_mlp_version,
|
| 828 |
-
)
|
|
|
|
| 1 |
+
from pprint import pprint
|
| 2 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
from .._spectral import _SpectralComponent
|
| 8 |
+
from .utils import (
|
| 9 |
+
BarkBandsplitSpecification,
|
| 10 |
+
BassBandsplitSpecification,
|
| 11 |
+
DrumBandsplitSpecification,
|
| 12 |
+
EquivalentRectangularBandsplitSpecification,
|
| 13 |
+
MelBandsplitSpecification,
|
| 14 |
+
MusicalBandsplitSpecification,
|
| 15 |
+
OtherBandsplitSpecification,
|
| 16 |
+
TriangularBarkBandsplitSpecification,
|
| 17 |
+
VocalBandsplitSpecification,
|
| 18 |
+
)
|
| 19 |
+
from .core import (
|
| 20 |
+
MultiSourceMultiMaskBandSplitCoreConv,
|
| 21 |
+
MultiSourceMultiMaskBandSplitCoreRNN,
|
| 22 |
+
MultiSourceMultiMaskBandSplitCoreTransformer,
|
| 23 |
+
MultiSourceMultiPatchingMaskBandSplitCoreRNN,
|
| 24 |
+
SingleMaskBandsplitCoreRNN,
|
| 25 |
+
SingleMaskBandsplitCoreTransformer,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
import pytorch_lightning as pl
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_band_specs(band_specs, n_fft, fs, n_bands=None):
|
| 32 |
+
if band_specs in ["dnr:speech", "dnr:vox7", "musdb:vocals", "musdb:vox7"]:
|
| 33 |
+
bsm = VocalBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs()
|
| 34 |
+
freq_weights = None
|
| 35 |
+
overlapping_band = False
|
| 36 |
+
elif "tribark" in band_specs:
|
| 37 |
+
assert n_bands is not None
|
| 38 |
+
specs = TriangularBarkBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
|
| 39 |
+
bsm = specs.get_band_specs()
|
| 40 |
+
freq_weights = specs.get_freq_weights()
|
| 41 |
+
overlapping_band = True
|
| 42 |
+
elif "bark" in band_specs:
|
| 43 |
+
assert n_bands is not None
|
| 44 |
+
specs = BarkBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
|
| 45 |
+
bsm = specs.get_band_specs()
|
| 46 |
+
freq_weights = specs.get_freq_weights()
|
| 47 |
+
overlapping_band = True
|
| 48 |
+
elif "erb" in band_specs:
|
| 49 |
+
assert n_bands is not None
|
| 50 |
+
specs = EquivalentRectangularBandsplitSpecification(
|
| 51 |
+
nfft=n_fft, fs=fs, n_bands=n_bands
|
| 52 |
+
)
|
| 53 |
+
bsm = specs.get_band_specs()
|
| 54 |
+
freq_weights = specs.get_freq_weights()
|
| 55 |
+
overlapping_band = True
|
| 56 |
+
elif "musical" in band_specs:
|
| 57 |
+
assert n_bands is not None
|
| 58 |
+
specs = MusicalBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
|
| 59 |
+
bsm = specs.get_band_specs()
|
| 60 |
+
freq_weights = specs.get_freq_weights()
|
| 61 |
+
overlapping_band = True
|
| 62 |
+
elif band_specs == "dnr:mel" or "mel" in band_specs:
|
| 63 |
+
assert n_bands is not None
|
| 64 |
+
specs = MelBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
|
| 65 |
+
bsm = specs.get_band_specs()
|
| 66 |
+
freq_weights = specs.get_freq_weights()
|
| 67 |
+
overlapping_band = True
|
| 68 |
+
else:
|
| 69 |
+
raise NameError
|
| 70 |
+
|
| 71 |
+
return bsm, freq_weights, overlapping_band
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def get_band_specs_map(band_specs_map, n_fft, fs, n_bands=None):
|
| 75 |
+
if band_specs_map == "musdb:all":
|
| 76 |
+
bsm = {
|
| 77 |
+
"vocals": VocalBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
|
| 78 |
+
"drums": DrumBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
|
| 79 |
+
"bass": BassBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
|
| 80 |
+
"other": OtherBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
|
| 81 |
+
}
|
| 82 |
+
freq_weights = None
|
| 83 |
+
overlapping_band = False
|
| 84 |
+
elif band_specs_map == "dnr:vox7":
|
| 85 |
+
bsm_, freq_weights, overlapping_band = get_band_specs(
|
| 86 |
+
"dnr:speech", n_fft, fs, n_bands
|
| 87 |
+
)
|
| 88 |
+
bsm = {"speech": bsm_, "music": bsm_, "effects": bsm_}
|
| 89 |
+
elif "dnr:vox7:" in band_specs_map:
|
| 90 |
+
stem = band_specs_map.split(":")[-1]
|
| 91 |
+
bsm_, freq_weights, overlapping_band = get_band_specs(
|
| 92 |
+
"dnr:speech", n_fft, fs, n_bands
|
| 93 |
+
)
|
| 94 |
+
bsm = {stem: bsm_}
|
| 95 |
+
else:
|
| 96 |
+
raise NameError
|
| 97 |
+
|
| 98 |
+
return bsm, freq_weights, overlapping_band
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class BandSplitWrapperBase(pl.LightningModule):
|
| 102 |
+
bsrnn: nn.Module
|
| 103 |
+
|
| 104 |
+
def __init__(self, **kwargs):
|
| 105 |
+
super().__init__()
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class SingleMaskMultiSourceBandSplitBase(BandSplitWrapperBase, _SpectralComponent):
|
| 109 |
+
def __init__(
|
| 110 |
+
self,
|
| 111 |
+
band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
|
| 112 |
+
fs: int = 44100,
|
| 113 |
+
n_fft: int = 2048,
|
| 114 |
+
win_length: Optional[int] = 2048,
|
| 115 |
+
hop_length: int = 512,
|
| 116 |
+
window_fn: str = "hann_window",
|
| 117 |
+
wkwargs: Optional[Dict] = None,
|
| 118 |
+
power: Optional[int] = None,
|
| 119 |
+
center: bool = True,
|
| 120 |
+
normalized: bool = True,
|
| 121 |
+
pad_mode: str = "constant",
|
| 122 |
+
onesided: bool = True,
|
| 123 |
+
n_bands: int = None,
|
| 124 |
+
) -> None:
|
| 125 |
+
super().__init__(
|
| 126 |
+
n_fft=n_fft,
|
| 127 |
+
win_length=win_length,
|
| 128 |
+
hop_length=hop_length,
|
| 129 |
+
window_fn=window_fn,
|
| 130 |
+
wkwargs=wkwargs,
|
| 131 |
+
power=power,
|
| 132 |
+
center=center,
|
| 133 |
+
normalized=normalized,
|
| 134 |
+
pad_mode=pad_mode,
|
| 135 |
+
onesided=onesided,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
if isinstance(band_specs_map, str):
|
| 139 |
+
self.band_specs_map, self.freq_weights, self.overlapping_band = (
|
| 140 |
+
get_band_specs_map(band_specs_map, n_fft, fs, n_bands=n_bands)
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
self.stems = list(self.band_specs_map.keys())
|
| 144 |
+
|
| 145 |
+
def forward(self, batch):
|
| 146 |
+
audio = batch["audio"]
|
| 147 |
+
|
| 148 |
+
with torch.no_grad():
|
| 149 |
+
batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in audio}
|
| 150 |
+
|
| 151 |
+
X = batch["spectrogram"]["mixture"]
|
| 152 |
+
length = batch["audio"]["mixture"].shape[-1]
|
| 153 |
+
|
| 154 |
+
output = {"spectrogram": {}, "audio": {}}
|
| 155 |
+
|
| 156 |
+
for stem, bsrnn in self.bsrnn.items():
|
| 157 |
+
S = bsrnn(X)
|
| 158 |
+
s = self.istft(S, length)
|
| 159 |
+
output["spectrogram"][stem] = S
|
| 160 |
+
output["audio"][stem] = s
|
| 161 |
+
|
| 162 |
+
return batch, output
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class MultiMaskMultiSourceBandSplitBase(BandSplitWrapperBase, _SpectralComponent):
|
| 166 |
+
def __init__(
|
| 167 |
+
self,
|
| 168 |
+
stems: List[str],
|
| 169 |
+
band_specs: Union[str, List[Tuple[float, float]]],
|
| 170 |
+
fs: int = 44100,
|
| 171 |
+
n_fft: int = 2048,
|
| 172 |
+
win_length: Optional[int] = 2048,
|
| 173 |
+
hop_length: int = 512,
|
| 174 |
+
window_fn: str = "hann_window",
|
| 175 |
+
wkwargs: Optional[Dict] = None,
|
| 176 |
+
power: Optional[int] = None,
|
| 177 |
+
center: bool = True,
|
| 178 |
+
normalized: bool = True,
|
| 179 |
+
pad_mode: str = "constant",
|
| 180 |
+
onesided: bool = True,
|
| 181 |
+
n_bands: int = None,
|
| 182 |
+
) -> None:
|
| 183 |
+
super().__init__(
|
| 184 |
+
n_fft=n_fft,
|
| 185 |
+
win_length=win_length,
|
| 186 |
+
hop_length=hop_length,
|
| 187 |
+
window_fn=window_fn,
|
| 188 |
+
wkwargs=wkwargs,
|
| 189 |
+
power=power,
|
| 190 |
+
center=center,
|
| 191 |
+
normalized=normalized,
|
| 192 |
+
pad_mode=pad_mode,
|
| 193 |
+
onesided=onesided,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
if isinstance(band_specs, str):
|
| 197 |
+
self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs(
|
| 198 |
+
band_specs, n_fft, fs, n_bands
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
self.stems = stems
|
| 202 |
+
|
| 203 |
+
def forward(self, batch):
|
| 204 |
+
audio = batch["audio"]
|
| 205 |
+
cond = batch.get("condition", None)
|
| 206 |
+
with torch.no_grad():
|
| 207 |
+
batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in audio}
|
| 208 |
+
|
| 209 |
+
X = batch["spectrogram"]["mixture"]
|
| 210 |
+
length = batch["audio"]["mixture"].shape[-1]
|
| 211 |
+
|
| 212 |
+
output = self.bsrnn(X, cond=cond)
|
| 213 |
+
output["audio"] = {}
|
| 214 |
+
|
| 215 |
+
for stem, S in output["spectrogram"].items():
|
| 216 |
+
s = self.istft(S, length)
|
| 217 |
+
output["audio"][stem] = s
|
| 218 |
+
|
| 219 |
+
return batch, output
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class MultiMaskMultiSourceBandSplitBaseSimple(BandSplitWrapperBase, _SpectralComponent):
|
| 223 |
+
def __init__(
|
| 224 |
+
self,
|
| 225 |
+
stems: List[str],
|
| 226 |
+
band_specs: Union[str, List[Tuple[float, float]]],
|
| 227 |
+
fs: int = 44100,
|
| 228 |
+
n_fft: int = 2048,
|
| 229 |
+
win_length: Optional[int] = 2048,
|
| 230 |
+
hop_length: int = 512,
|
| 231 |
+
window_fn: str = "hann_window",
|
| 232 |
+
wkwargs: Optional[Dict] = None,
|
| 233 |
+
power: Optional[int] = None,
|
| 234 |
+
center: bool = True,
|
| 235 |
+
normalized: bool = True,
|
| 236 |
+
pad_mode: str = "constant",
|
| 237 |
+
onesided: bool = True,
|
| 238 |
+
n_bands: int = None,
|
| 239 |
+
) -> None:
|
| 240 |
+
super().__init__(
|
| 241 |
+
n_fft=n_fft,
|
| 242 |
+
win_length=win_length,
|
| 243 |
+
hop_length=hop_length,
|
| 244 |
+
window_fn=window_fn,
|
| 245 |
+
wkwargs=wkwargs,
|
| 246 |
+
power=power,
|
| 247 |
+
center=center,
|
| 248 |
+
normalized=normalized,
|
| 249 |
+
pad_mode=pad_mode,
|
| 250 |
+
onesided=onesided,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
if isinstance(band_specs, str):
|
| 254 |
+
self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs(
|
| 255 |
+
band_specs, n_fft, fs, n_bands
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
self.stems = stems
|
| 259 |
+
|
| 260 |
+
def forward(self, batch):
|
| 261 |
+
with torch.no_grad():
|
| 262 |
+
X = self.stft(batch)
|
| 263 |
+
length = batch.shape[-1]
|
| 264 |
+
output = self.bsrnn(X, cond=None)
|
| 265 |
+
res = []
|
| 266 |
+
for stem, S in output["spectrogram"].items():
|
| 267 |
+
s = self.istft(S, length)
|
| 268 |
+
res.append(s)
|
| 269 |
+
res = torch.stack(res, dim=1)
|
| 270 |
+
return res
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class SingleMaskMultiSourceBandSplitRNN(SingleMaskMultiSourceBandSplitBase):
|
| 274 |
+
def __init__(
|
| 275 |
+
self,
|
| 276 |
+
in_channel: int,
|
| 277 |
+
band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
|
| 278 |
+
fs: int = 44100,
|
| 279 |
+
require_no_overlap: bool = False,
|
| 280 |
+
require_no_gap: bool = True,
|
| 281 |
+
normalize_channel_independently: bool = False,
|
| 282 |
+
treat_channel_as_feature: bool = True,
|
| 283 |
+
n_sqm_modules: int = 12,
|
| 284 |
+
emb_dim: int = 128,
|
| 285 |
+
rnn_dim: int = 256,
|
| 286 |
+
bidirectional: bool = True,
|
| 287 |
+
rnn_type: str = "LSTM",
|
| 288 |
+
mlp_dim: int = 512,
|
| 289 |
+
hidden_activation: str = "Tanh",
|
| 290 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 291 |
+
complex_mask: bool = True,
|
| 292 |
+
n_fft: int = 2048,
|
| 293 |
+
win_length: Optional[int] = 2048,
|
| 294 |
+
hop_length: int = 512,
|
| 295 |
+
window_fn: str = "hann_window",
|
| 296 |
+
wkwargs: Optional[Dict] = None,
|
| 297 |
+
power: Optional[int] = None,
|
| 298 |
+
center: bool = True,
|
| 299 |
+
normalized: bool = True,
|
| 300 |
+
pad_mode: str = "constant",
|
| 301 |
+
onesided: bool = True,
|
| 302 |
+
) -> None:
|
| 303 |
+
super().__init__(
|
| 304 |
+
band_specs_map=band_specs_map,
|
| 305 |
+
fs=fs,
|
| 306 |
+
n_fft=n_fft,
|
| 307 |
+
win_length=win_length,
|
| 308 |
+
hop_length=hop_length,
|
| 309 |
+
window_fn=window_fn,
|
| 310 |
+
wkwargs=wkwargs,
|
| 311 |
+
power=power,
|
| 312 |
+
center=center,
|
| 313 |
+
normalized=normalized,
|
| 314 |
+
pad_mode=pad_mode,
|
| 315 |
+
onesided=onesided,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
self.bsrnn = nn.ModuleDict(
|
| 319 |
+
{
|
| 320 |
+
src: SingleMaskBandsplitCoreRNN(
|
| 321 |
+
band_specs=specs,
|
| 322 |
+
in_channel=in_channel,
|
| 323 |
+
require_no_overlap=require_no_overlap,
|
| 324 |
+
require_no_gap=require_no_gap,
|
| 325 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 326 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 327 |
+
n_sqm_modules=n_sqm_modules,
|
| 328 |
+
emb_dim=emb_dim,
|
| 329 |
+
rnn_dim=rnn_dim,
|
| 330 |
+
bidirectional=bidirectional,
|
| 331 |
+
rnn_type=rnn_type,
|
| 332 |
+
mlp_dim=mlp_dim,
|
| 333 |
+
hidden_activation=hidden_activation,
|
| 334 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 335 |
+
complex_mask=complex_mask,
|
| 336 |
+
)
|
| 337 |
+
for src, specs in self.band_specs_map.items()
|
| 338 |
+
}
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
class SingleMaskMultiSourceBandSplitTransformer(SingleMaskMultiSourceBandSplitBase):
|
| 343 |
+
def __init__(
|
| 344 |
+
self,
|
| 345 |
+
in_channel: int,
|
| 346 |
+
band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
|
| 347 |
+
fs: int = 44100,
|
| 348 |
+
require_no_overlap: bool = False,
|
| 349 |
+
require_no_gap: bool = True,
|
| 350 |
+
normalize_channel_independently: bool = False,
|
| 351 |
+
treat_channel_as_feature: bool = True,
|
| 352 |
+
n_sqm_modules: int = 12,
|
| 353 |
+
emb_dim: int = 128,
|
| 354 |
+
rnn_dim: int = 256,
|
| 355 |
+
bidirectional: bool = True,
|
| 356 |
+
tf_dropout: float = 0.0,
|
| 357 |
+
mlp_dim: int = 512,
|
| 358 |
+
hidden_activation: str = "Tanh",
|
| 359 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 360 |
+
complex_mask: bool = True,
|
| 361 |
+
n_fft: int = 2048,
|
| 362 |
+
win_length: Optional[int] = 2048,
|
| 363 |
+
hop_length: int = 512,
|
| 364 |
+
window_fn: str = "hann_window",
|
| 365 |
+
wkwargs: Optional[Dict] = None,
|
| 366 |
+
power: Optional[int] = None,
|
| 367 |
+
center: bool = True,
|
| 368 |
+
normalized: bool = True,
|
| 369 |
+
pad_mode: str = "constant",
|
| 370 |
+
onesided: bool = True,
|
| 371 |
+
) -> None:
|
| 372 |
+
super().__init__(
|
| 373 |
+
band_specs_map=band_specs_map,
|
| 374 |
+
fs=fs,
|
| 375 |
+
n_fft=n_fft,
|
| 376 |
+
win_length=win_length,
|
| 377 |
+
hop_length=hop_length,
|
| 378 |
+
window_fn=window_fn,
|
| 379 |
+
wkwargs=wkwargs,
|
| 380 |
+
power=power,
|
| 381 |
+
center=center,
|
| 382 |
+
normalized=normalized,
|
| 383 |
+
pad_mode=pad_mode,
|
| 384 |
+
onesided=onesided,
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
self.bsrnn = nn.ModuleDict(
|
| 388 |
+
{
|
| 389 |
+
src: SingleMaskBandsplitCoreTransformer(
|
| 390 |
+
band_specs=specs,
|
| 391 |
+
in_channel=in_channel,
|
| 392 |
+
require_no_overlap=require_no_overlap,
|
| 393 |
+
require_no_gap=require_no_gap,
|
| 394 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 395 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 396 |
+
n_sqm_modules=n_sqm_modules,
|
| 397 |
+
emb_dim=emb_dim,
|
| 398 |
+
rnn_dim=rnn_dim,
|
| 399 |
+
bidirectional=bidirectional,
|
| 400 |
+
tf_dropout=tf_dropout,
|
| 401 |
+
mlp_dim=mlp_dim,
|
| 402 |
+
hidden_activation=hidden_activation,
|
| 403 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 404 |
+
complex_mask=complex_mask,
|
| 405 |
+
)
|
| 406 |
+
for src, specs in self.band_specs_map.items()
|
| 407 |
+
}
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
class MultiMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase):
|
| 412 |
+
def __init__(
|
| 413 |
+
self,
|
| 414 |
+
in_channel: int,
|
| 415 |
+
stems: List[str],
|
| 416 |
+
band_specs: Union[str, List[Tuple[float, float]]],
|
| 417 |
+
fs: int = 44100,
|
| 418 |
+
require_no_overlap: bool = False,
|
| 419 |
+
require_no_gap: bool = True,
|
| 420 |
+
normalize_channel_independently: bool = False,
|
| 421 |
+
treat_channel_as_feature: bool = True,
|
| 422 |
+
n_sqm_modules: int = 12,
|
| 423 |
+
emb_dim: int = 128,
|
| 424 |
+
rnn_dim: int = 256,
|
| 425 |
+
cond_dim: int = 0,
|
| 426 |
+
bidirectional: bool = True,
|
| 427 |
+
rnn_type: str = "LSTM",
|
| 428 |
+
mlp_dim: int = 512,
|
| 429 |
+
hidden_activation: str = "Tanh",
|
| 430 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 431 |
+
complex_mask: bool = True,
|
| 432 |
+
n_fft: int = 2048,
|
| 433 |
+
win_length: Optional[int] = 2048,
|
| 434 |
+
hop_length: int = 512,
|
| 435 |
+
window_fn: str = "hann_window",
|
| 436 |
+
wkwargs: Optional[Dict] = None,
|
| 437 |
+
power: Optional[int] = None,
|
| 438 |
+
center: bool = True,
|
| 439 |
+
normalized: bool = True,
|
| 440 |
+
pad_mode: str = "constant",
|
| 441 |
+
onesided: bool = True,
|
| 442 |
+
n_bands: int = None,
|
| 443 |
+
use_freq_weights: bool = True,
|
| 444 |
+
normalize_input: bool = False,
|
| 445 |
+
mult_add_mask: bool = False,
|
| 446 |
+
freeze_encoder: bool = False,
|
| 447 |
+
) -> None:
|
| 448 |
+
super().__init__(
|
| 449 |
+
stems=stems,
|
| 450 |
+
band_specs=band_specs,
|
| 451 |
+
fs=fs,
|
| 452 |
+
n_fft=n_fft,
|
| 453 |
+
win_length=win_length,
|
| 454 |
+
hop_length=hop_length,
|
| 455 |
+
window_fn=window_fn,
|
| 456 |
+
wkwargs=wkwargs,
|
| 457 |
+
power=power,
|
| 458 |
+
center=center,
|
| 459 |
+
normalized=normalized,
|
| 460 |
+
pad_mode=pad_mode,
|
| 461 |
+
onesided=onesided,
|
| 462 |
+
n_bands=n_bands,
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN(
|
| 466 |
+
stems=stems,
|
| 467 |
+
band_specs=self.band_specs,
|
| 468 |
+
in_channel=in_channel,
|
| 469 |
+
require_no_overlap=require_no_overlap,
|
| 470 |
+
require_no_gap=require_no_gap,
|
| 471 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 472 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 473 |
+
n_sqm_modules=n_sqm_modules,
|
| 474 |
+
emb_dim=emb_dim,
|
| 475 |
+
rnn_dim=rnn_dim,
|
| 476 |
+
bidirectional=bidirectional,
|
| 477 |
+
rnn_type=rnn_type,
|
| 478 |
+
mlp_dim=mlp_dim,
|
| 479 |
+
cond_dim=cond_dim,
|
| 480 |
+
hidden_activation=hidden_activation,
|
| 481 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 482 |
+
complex_mask=complex_mask,
|
| 483 |
+
overlapping_band=self.overlapping_band,
|
| 484 |
+
freq_weights=self.freq_weights,
|
| 485 |
+
n_freq=n_fft // 2 + 1,
|
| 486 |
+
use_freq_weights=use_freq_weights,
|
| 487 |
+
mult_add_mask=mult_add_mask,
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
self.normalize_input = normalize_input
|
| 491 |
+
self.cond_dim = cond_dim
|
| 492 |
+
|
| 493 |
+
if freeze_encoder:
|
| 494 |
+
for param in self.bsrnn.band_split.parameters():
|
| 495 |
+
param.requires_grad = False
|
| 496 |
+
|
| 497 |
+
for param in self.bsrnn.tf_model.parameters():
|
| 498 |
+
param.requires_grad = False
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
class MultiMaskMultiSourceBandSplitRNNSimple(MultiMaskMultiSourceBandSplitBaseSimple):
|
| 502 |
+
def __init__(
|
| 503 |
+
self,
|
| 504 |
+
in_channel: int,
|
| 505 |
+
stems: List[str],
|
| 506 |
+
band_specs: Union[str, List[Tuple[float, float]]],
|
| 507 |
+
fs: int = 44100,
|
| 508 |
+
require_no_overlap: bool = False,
|
| 509 |
+
require_no_gap: bool = True,
|
| 510 |
+
normalize_channel_independently: bool = False,
|
| 511 |
+
treat_channel_as_feature: bool = True,
|
| 512 |
+
n_sqm_modules: int = 12,
|
| 513 |
+
emb_dim: int = 128,
|
| 514 |
+
rnn_dim: int = 256,
|
| 515 |
+
cond_dim: int = 0,
|
| 516 |
+
bidirectional: bool = True,
|
| 517 |
+
rnn_type: str = "LSTM",
|
| 518 |
+
mlp_dim: int = 512,
|
| 519 |
+
hidden_activation: str = "Tanh",
|
| 520 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 521 |
+
complex_mask: bool = True,
|
| 522 |
+
n_fft: int = 2048,
|
| 523 |
+
win_length: Optional[int] = 2048,
|
| 524 |
+
hop_length: int = 512,
|
| 525 |
+
window_fn: str = "hann_window",
|
| 526 |
+
wkwargs: Optional[Dict] = None,
|
| 527 |
+
power: Optional[int] = None,
|
| 528 |
+
center: bool = True,
|
| 529 |
+
normalized: bool = True,
|
| 530 |
+
pad_mode: str = "constant",
|
| 531 |
+
onesided: bool = True,
|
| 532 |
+
n_bands: int = None,
|
| 533 |
+
use_freq_weights: bool = True,
|
| 534 |
+
normalize_input: bool = False,
|
| 535 |
+
mult_add_mask: bool = False,
|
| 536 |
+
freeze_encoder: bool = False,
|
| 537 |
+
) -> None:
|
| 538 |
+
super().__init__(
|
| 539 |
+
stems=stems,
|
| 540 |
+
band_specs=band_specs,
|
| 541 |
+
fs=fs,
|
| 542 |
+
n_fft=n_fft,
|
| 543 |
+
win_length=win_length,
|
| 544 |
+
hop_length=hop_length,
|
| 545 |
+
window_fn=window_fn,
|
| 546 |
+
wkwargs=wkwargs,
|
| 547 |
+
power=power,
|
| 548 |
+
center=center,
|
| 549 |
+
normalized=normalized,
|
| 550 |
+
pad_mode=pad_mode,
|
| 551 |
+
onesided=onesided,
|
| 552 |
+
n_bands=n_bands,
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN(
|
| 556 |
+
stems=stems,
|
| 557 |
+
band_specs=self.band_specs,
|
| 558 |
+
in_channel=in_channel,
|
| 559 |
+
require_no_overlap=require_no_overlap,
|
| 560 |
+
require_no_gap=require_no_gap,
|
| 561 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 562 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 563 |
+
n_sqm_modules=n_sqm_modules,
|
| 564 |
+
emb_dim=emb_dim,
|
| 565 |
+
rnn_dim=rnn_dim,
|
| 566 |
+
bidirectional=bidirectional,
|
| 567 |
+
rnn_type=rnn_type,
|
| 568 |
+
mlp_dim=mlp_dim,
|
| 569 |
+
cond_dim=cond_dim,
|
| 570 |
+
hidden_activation=hidden_activation,
|
| 571 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 572 |
+
complex_mask=complex_mask,
|
| 573 |
+
overlapping_band=self.overlapping_band,
|
| 574 |
+
freq_weights=self.freq_weights,
|
| 575 |
+
n_freq=n_fft // 2 + 1,
|
| 576 |
+
use_freq_weights=use_freq_weights,
|
| 577 |
+
mult_add_mask=mult_add_mask,
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
self.normalize_input = normalize_input
|
| 581 |
+
self.cond_dim = cond_dim
|
| 582 |
+
|
| 583 |
+
if freeze_encoder:
|
| 584 |
+
for param in self.bsrnn.band_split.parameters():
|
| 585 |
+
param.requires_grad = False
|
| 586 |
+
|
| 587 |
+
for param in self.bsrnn.tf_model.parameters():
|
| 588 |
+
param.requires_grad = False
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
class MultiMaskMultiSourceBandSplitTransformer(MultiMaskMultiSourceBandSplitBase):
|
| 592 |
+
def __init__(
|
| 593 |
+
self,
|
| 594 |
+
in_channel: int,
|
| 595 |
+
stems: List[str],
|
| 596 |
+
band_specs: Union[str, List[Tuple[float, float]]],
|
| 597 |
+
fs: int = 44100,
|
| 598 |
+
require_no_overlap: bool = False,
|
| 599 |
+
require_no_gap: bool = True,
|
| 600 |
+
normalize_channel_independently: bool = False,
|
| 601 |
+
treat_channel_as_feature: bool = True,
|
| 602 |
+
n_sqm_modules: int = 12,
|
| 603 |
+
emb_dim: int = 128,
|
| 604 |
+
rnn_dim: int = 256,
|
| 605 |
+
cond_dim: int = 0,
|
| 606 |
+
bidirectional: bool = True,
|
| 607 |
+
rnn_type: str = "LSTM",
|
| 608 |
+
mlp_dim: int = 512,
|
| 609 |
+
hidden_activation: str = "Tanh",
|
| 610 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 611 |
+
complex_mask: bool = True,
|
| 612 |
+
n_fft: int = 2048,
|
| 613 |
+
win_length: Optional[int] = 2048,
|
| 614 |
+
hop_length: int = 512,
|
| 615 |
+
window_fn: str = "hann_window",
|
| 616 |
+
wkwargs: Optional[Dict] = None,
|
| 617 |
+
power: Optional[int] = None,
|
| 618 |
+
center: bool = True,
|
| 619 |
+
normalized: bool = True,
|
| 620 |
+
pad_mode: str = "constant",
|
| 621 |
+
onesided: bool = True,
|
| 622 |
+
n_bands: int = None,
|
| 623 |
+
use_freq_weights: bool = True,
|
| 624 |
+
normalize_input: bool = False,
|
| 625 |
+
mult_add_mask: bool = False,
|
| 626 |
+
) -> None:
|
| 627 |
+
super().__init__(
|
| 628 |
+
stems=stems,
|
| 629 |
+
band_specs=band_specs,
|
| 630 |
+
fs=fs,
|
| 631 |
+
n_fft=n_fft,
|
| 632 |
+
win_length=win_length,
|
| 633 |
+
hop_length=hop_length,
|
| 634 |
+
window_fn=window_fn,
|
| 635 |
+
wkwargs=wkwargs,
|
| 636 |
+
power=power,
|
| 637 |
+
center=center,
|
| 638 |
+
normalized=normalized,
|
| 639 |
+
pad_mode=pad_mode,
|
| 640 |
+
onesided=onesided,
|
| 641 |
+
n_bands=n_bands,
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
self.bsrnn = MultiSourceMultiMaskBandSplitCoreTransformer(
|
| 645 |
+
stems=stems,
|
| 646 |
+
band_specs=self.band_specs,
|
| 647 |
+
in_channel=in_channel,
|
| 648 |
+
require_no_overlap=require_no_overlap,
|
| 649 |
+
require_no_gap=require_no_gap,
|
| 650 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 651 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 652 |
+
n_sqm_modules=n_sqm_modules,
|
| 653 |
+
emb_dim=emb_dim,
|
| 654 |
+
rnn_dim=rnn_dim,
|
| 655 |
+
bidirectional=bidirectional,
|
| 656 |
+
rnn_type=rnn_type,
|
| 657 |
+
mlp_dim=mlp_dim,
|
| 658 |
+
cond_dim=cond_dim,
|
| 659 |
+
hidden_activation=hidden_activation,
|
| 660 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 661 |
+
complex_mask=complex_mask,
|
| 662 |
+
overlapping_band=self.overlapping_band,
|
| 663 |
+
freq_weights=self.freq_weights,
|
| 664 |
+
n_freq=n_fft // 2 + 1,
|
| 665 |
+
use_freq_weights=use_freq_weights,
|
| 666 |
+
mult_add_mask=mult_add_mask,
|
| 667 |
+
)
|
| 668 |
+
|
| 669 |
+
|
| 670 |
+
class MultiMaskMultiSourceBandSplitConv(MultiMaskMultiSourceBandSplitBase):
|
| 671 |
+
def __init__(
|
| 672 |
+
self,
|
| 673 |
+
in_channel: int,
|
| 674 |
+
stems: List[str],
|
| 675 |
+
band_specs: Union[str, List[Tuple[float, float]]],
|
| 676 |
+
fs: int = 44100,
|
| 677 |
+
require_no_overlap: bool = False,
|
| 678 |
+
require_no_gap: bool = True,
|
| 679 |
+
normalize_channel_independently: bool = False,
|
| 680 |
+
treat_channel_as_feature: bool = True,
|
| 681 |
+
n_sqm_modules: int = 12,
|
| 682 |
+
emb_dim: int = 128,
|
| 683 |
+
rnn_dim: int = 256,
|
| 684 |
+
cond_dim: int = 0,
|
| 685 |
+
bidirectional: bool = True,
|
| 686 |
+
rnn_type: str = "LSTM",
|
| 687 |
+
mlp_dim: int = 512,
|
| 688 |
+
hidden_activation: str = "Tanh",
|
| 689 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 690 |
+
complex_mask: bool = True,
|
| 691 |
+
n_fft: int = 2048,
|
| 692 |
+
win_length: Optional[int] = 2048,
|
| 693 |
+
hop_length: int = 512,
|
| 694 |
+
window_fn: str = "hann_window",
|
| 695 |
+
wkwargs: Optional[Dict] = None,
|
| 696 |
+
power: Optional[int] = None,
|
| 697 |
+
center: bool = True,
|
| 698 |
+
normalized: bool = True,
|
| 699 |
+
pad_mode: str = "constant",
|
| 700 |
+
onesided: bool = True,
|
| 701 |
+
n_bands: int = None,
|
| 702 |
+
use_freq_weights: bool = True,
|
| 703 |
+
normalize_input: bool = False,
|
| 704 |
+
mult_add_mask: bool = False,
|
| 705 |
+
) -> None:
|
| 706 |
+
super().__init__(
|
| 707 |
+
stems=stems,
|
| 708 |
+
band_specs=band_specs,
|
| 709 |
+
fs=fs,
|
| 710 |
+
n_fft=n_fft,
|
| 711 |
+
win_length=win_length,
|
| 712 |
+
hop_length=hop_length,
|
| 713 |
+
window_fn=window_fn,
|
| 714 |
+
wkwargs=wkwargs,
|
| 715 |
+
power=power,
|
| 716 |
+
center=center,
|
| 717 |
+
normalized=normalized,
|
| 718 |
+
pad_mode=pad_mode,
|
| 719 |
+
onesided=onesided,
|
| 720 |
+
n_bands=n_bands,
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
+
self.bsrnn = MultiSourceMultiMaskBandSplitCoreConv(
|
| 724 |
+
stems=stems,
|
| 725 |
+
band_specs=self.band_specs,
|
| 726 |
+
in_channel=in_channel,
|
| 727 |
+
require_no_overlap=require_no_overlap,
|
| 728 |
+
require_no_gap=require_no_gap,
|
| 729 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 730 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 731 |
+
n_sqm_modules=n_sqm_modules,
|
| 732 |
+
emb_dim=emb_dim,
|
| 733 |
+
rnn_dim=rnn_dim,
|
| 734 |
+
bidirectional=bidirectional,
|
| 735 |
+
rnn_type=rnn_type,
|
| 736 |
+
mlp_dim=mlp_dim,
|
| 737 |
+
cond_dim=cond_dim,
|
| 738 |
+
hidden_activation=hidden_activation,
|
| 739 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 740 |
+
complex_mask=complex_mask,
|
| 741 |
+
overlapping_band=self.overlapping_band,
|
| 742 |
+
freq_weights=self.freq_weights,
|
| 743 |
+
n_freq=n_fft // 2 + 1,
|
| 744 |
+
use_freq_weights=use_freq_weights,
|
| 745 |
+
mult_add_mask=mult_add_mask,
|
| 746 |
+
)
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
class PatchingMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase):
|
| 750 |
+
def __init__(
|
| 751 |
+
self,
|
| 752 |
+
in_channel: int,
|
| 753 |
+
stems: List[str],
|
| 754 |
+
band_specs: Union[str, List[Tuple[float, float]]],
|
| 755 |
+
kernel_norm_mlp_version: int = 1,
|
| 756 |
+
mask_kernel_freq: int = 3,
|
| 757 |
+
mask_kernel_time: int = 3,
|
| 758 |
+
conv_kernel_freq: int = 1,
|
| 759 |
+
conv_kernel_time: int = 1,
|
| 760 |
+
fs: int = 44100,
|
| 761 |
+
require_no_overlap: bool = False,
|
| 762 |
+
require_no_gap: bool = True,
|
| 763 |
+
normalize_channel_independently: bool = False,
|
| 764 |
+
treat_channel_as_feature: bool = True,
|
| 765 |
+
n_sqm_modules: int = 12,
|
| 766 |
+
emb_dim: int = 128,
|
| 767 |
+
rnn_dim: int = 256,
|
| 768 |
+
bidirectional: bool = True,
|
| 769 |
+
rnn_type: str = "LSTM",
|
| 770 |
+
mlp_dim: int = 512,
|
| 771 |
+
hidden_activation: str = "Tanh",
|
| 772 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 773 |
+
complex_mask: bool = True,
|
| 774 |
+
n_fft: int = 2048,
|
| 775 |
+
win_length: Optional[int] = 2048,
|
| 776 |
+
hop_length: int = 512,
|
| 777 |
+
window_fn: str = "hann_window",
|
| 778 |
+
wkwargs: Optional[Dict] = None,
|
| 779 |
+
power: Optional[int] = None,
|
| 780 |
+
center: bool = True,
|
| 781 |
+
normalized: bool = True,
|
| 782 |
+
pad_mode: str = "constant",
|
| 783 |
+
onesided: bool = True,
|
| 784 |
+
n_bands: int = None,
|
| 785 |
+
) -> None:
|
| 786 |
+
super().__init__(
|
| 787 |
+
stems=stems,
|
| 788 |
+
band_specs=band_specs,
|
| 789 |
+
fs=fs,
|
| 790 |
+
n_fft=n_fft,
|
| 791 |
+
win_length=win_length,
|
| 792 |
+
hop_length=hop_length,
|
| 793 |
+
window_fn=window_fn,
|
| 794 |
+
wkwargs=wkwargs,
|
| 795 |
+
power=power,
|
| 796 |
+
center=center,
|
| 797 |
+
normalized=normalized,
|
| 798 |
+
pad_mode=pad_mode,
|
| 799 |
+
onesided=onesided,
|
| 800 |
+
n_bands=n_bands,
|
| 801 |
+
)
|
| 802 |
+
|
| 803 |
+
self.bsrnn = MultiSourceMultiPatchingMaskBandSplitCoreRNN(
|
| 804 |
+
stems=stems,
|
| 805 |
+
band_specs=self.band_specs,
|
| 806 |
+
in_channel=in_channel,
|
| 807 |
+
require_no_overlap=require_no_overlap,
|
| 808 |
+
require_no_gap=require_no_gap,
|
| 809 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 810 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 811 |
+
n_sqm_modules=n_sqm_modules,
|
| 812 |
+
emb_dim=emb_dim,
|
| 813 |
+
rnn_dim=rnn_dim,
|
| 814 |
+
bidirectional=bidirectional,
|
| 815 |
+
rnn_type=rnn_type,
|
| 816 |
+
mlp_dim=mlp_dim,
|
| 817 |
+
hidden_activation=hidden_activation,
|
| 818 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 819 |
+
complex_mask=complex_mask,
|
| 820 |
+
overlapping_band=self.overlapping_band,
|
| 821 |
+
freq_weights=self.freq_weights,
|
| 822 |
+
n_freq=n_fft // 2 + 1,
|
| 823 |
+
mask_kernel_freq=mask_kernel_freq,
|
| 824 |
+
mask_kernel_time=mask_kernel_time,
|
| 825 |
+
conv_kernel_freq=conv_kernel_freq,
|
| 826 |
+
conv_kernel_time=conv_kernel_time,
|
| 827 |
+
kernel_norm_mlp_version=kernel_norm_mlp_version,
|
| 828 |
+
)
|
mvsepless/models/bandit/core/utils/audio.py
CHANGED
|
@@ -1,324 +1,324 @@
|
|
| 1 |
-
from collections import defaultdict
|
| 2 |
-
|
| 3 |
-
from tqdm.auto import tqdm
|
| 4 |
-
from typing import Callable, Dict, List, Optional, Tuple
|
| 5 |
-
|
| 6 |
-
import numpy as np
|
| 7 |
-
import torch
|
| 8 |
-
from torch import nn
|
| 9 |
-
from torch.nn import functional as F
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
@torch.jit.script
|
| 13 |
-
def merge(
|
| 14 |
-
combined: torch.Tensor,
|
| 15 |
-
original_batch_size: int,
|
| 16 |
-
n_channel: int,
|
| 17 |
-
n_chunks: int,
|
| 18 |
-
chunk_size: int,
|
| 19 |
-
):
|
| 20 |
-
combined = torch.reshape(
|
| 21 |
-
combined, (original_batch_size, n_chunks, n_channel, chunk_size)
|
| 22 |
-
)
|
| 23 |
-
combined = torch.permute(combined, (0, 2, 3, 1)).reshape(
|
| 24 |
-
original_batch_size * n_channel, chunk_size, n_chunks
|
| 25 |
-
)
|
| 26 |
-
|
| 27 |
-
return combined
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
@torch.jit.script
|
| 31 |
-
def unfold(
|
| 32 |
-
padded_audio: torch.Tensor,
|
| 33 |
-
original_batch_size: int,
|
| 34 |
-
n_channel: int,
|
| 35 |
-
chunk_size: int,
|
| 36 |
-
hop_size: int,
|
| 37 |
-
) -> torch.Tensor:
|
| 38 |
-
|
| 39 |
-
unfolded_input = F.unfold(
|
| 40 |
-
padded_audio[:, :, None, :], kernel_size=(1, chunk_size), stride=(1, hop_size)
|
| 41 |
-
)
|
| 42 |
-
|
| 43 |
-
_, _, n_chunks = unfolded_input.shape
|
| 44 |
-
unfolded_input = unfolded_input.view(
|
| 45 |
-
original_batch_size, n_channel, chunk_size, n_chunks
|
| 46 |
-
)
|
| 47 |
-
unfolded_input = torch.permute(unfolded_input, (0, 3, 1, 2)).reshape(
|
| 48 |
-
original_batch_size * n_chunks, n_channel, chunk_size
|
| 49 |
-
)
|
| 50 |
-
|
| 51 |
-
return unfolded_input
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
@torch.jit.script
|
| 55 |
-
def merge_chunks_all(
|
| 56 |
-
combined: torch.Tensor,
|
| 57 |
-
original_batch_size: int,
|
| 58 |
-
n_channel: int,
|
| 59 |
-
n_samples: int,
|
| 60 |
-
n_padded_samples: int,
|
| 61 |
-
n_chunks: int,
|
| 62 |
-
chunk_size: int,
|
| 63 |
-
hop_size: int,
|
| 64 |
-
edge_frame_pad_sizes: Tuple[int, int],
|
| 65 |
-
standard_window: torch.Tensor,
|
| 66 |
-
first_window: torch.Tensor,
|
| 67 |
-
last_window: torch.Tensor,
|
| 68 |
-
):
|
| 69 |
-
combined = merge(combined, original_batch_size, n_channel, n_chunks, chunk_size)
|
| 70 |
-
|
| 71 |
-
combined = combined * standard_window[:, None].to(combined.device)
|
| 72 |
-
|
| 73 |
-
combined = F.fold(
|
| 74 |
-
combined.to(torch.float32),
|
| 75 |
-
output_size=(1, n_padded_samples),
|
| 76 |
-
kernel_size=(1, chunk_size),
|
| 77 |
-
stride=(1, hop_size),
|
| 78 |
-
)
|
| 79 |
-
|
| 80 |
-
combined = combined.view(original_batch_size, n_channel, n_padded_samples)
|
| 81 |
-
|
| 82 |
-
pad_front, pad_back = edge_frame_pad_sizes
|
| 83 |
-
combined = combined[..., pad_front:-pad_back]
|
| 84 |
-
|
| 85 |
-
combined = combined[..., :n_samples]
|
| 86 |
-
|
| 87 |
-
return combined
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
def merge_chunks_edge(
|
| 91 |
-
combined: torch.Tensor,
|
| 92 |
-
original_batch_size: int,
|
| 93 |
-
n_channel: int,
|
| 94 |
-
n_samples: int,
|
| 95 |
-
n_padded_samples: int,
|
| 96 |
-
n_chunks: int,
|
| 97 |
-
chunk_size: int,
|
| 98 |
-
hop_size: int,
|
| 99 |
-
edge_frame_pad_sizes: Tuple[int, int],
|
| 100 |
-
standard_window: torch.Tensor,
|
| 101 |
-
first_window: torch.Tensor,
|
| 102 |
-
last_window: torch.Tensor,
|
| 103 |
-
):
|
| 104 |
-
combined = merge(combined, original_batch_size, n_channel, n_chunks, chunk_size)
|
| 105 |
-
|
| 106 |
-
combined[..., 0] = combined[..., 0] * first_window
|
| 107 |
-
combined[..., -1] = combined[..., -1] * last_window
|
| 108 |
-
combined[..., 1:-1] = combined[..., 1:-1] * standard_window[:, None]
|
| 109 |
-
|
| 110 |
-
combined = F.fold(
|
| 111 |
-
combined,
|
| 112 |
-
output_size=(1, n_padded_samples),
|
| 113 |
-
kernel_size=(1, chunk_size),
|
| 114 |
-
stride=(1, hop_size),
|
| 115 |
-
)
|
| 116 |
-
|
| 117 |
-
combined = combined.view(original_batch_size, n_channel, n_padded_samples)
|
| 118 |
-
|
| 119 |
-
combined = combined[..., :n_samples]
|
| 120 |
-
|
| 121 |
-
return combined
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
class BaseFader(nn.Module):
|
| 125 |
-
def __init__(
|
| 126 |
-
self,
|
| 127 |
-
chunk_size_second: float,
|
| 128 |
-
hop_size_second: float,
|
| 129 |
-
fs: int,
|
| 130 |
-
fade_edge_frames: bool,
|
| 131 |
-
batch_size: int,
|
| 132 |
-
) -> None:
|
| 133 |
-
super().__init__()
|
| 134 |
-
|
| 135 |
-
self.chunk_size = int(chunk_size_second * fs)
|
| 136 |
-
self.hop_size = int(hop_size_second * fs)
|
| 137 |
-
self.overlap_size = self.chunk_size - self.hop_size
|
| 138 |
-
self.fade_edge_frames = fade_edge_frames
|
| 139 |
-
self.batch_size = batch_size
|
| 140 |
-
|
| 141 |
-
def prepare(self, audio):
|
| 142 |
-
|
| 143 |
-
if self.fade_edge_frames:
|
| 144 |
-
audio = F.pad(audio, self.edge_frame_pad_sizes, mode="reflect")
|
| 145 |
-
|
| 146 |
-
n_samples = audio.shape[-1]
|
| 147 |
-
n_chunks = int(np.ceil((n_samples - self.chunk_size) / self.hop_size) + 1)
|
| 148 |
-
|
| 149 |
-
padded_size = (n_chunks - 1) * self.hop_size + self.chunk_size
|
| 150 |
-
pad_size = padded_size - n_samples
|
| 151 |
-
|
| 152 |
-
padded_audio = F.pad(audio, (0, pad_size))
|
| 153 |
-
|
| 154 |
-
return padded_audio, n_chunks
|
| 155 |
-
|
| 156 |
-
def forward(
|
| 157 |
-
self,
|
| 158 |
-
audio: torch.Tensor,
|
| 159 |
-
model_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]],
|
| 160 |
-
):
|
| 161 |
-
|
| 162 |
-
original_dtype = audio.dtype
|
| 163 |
-
original_device = audio.device
|
| 164 |
-
|
| 165 |
-
audio = audio.to("cpu")
|
| 166 |
-
|
| 167 |
-
original_batch_size, n_channel, n_samples = audio.shape
|
| 168 |
-
padded_audio, n_chunks = self.prepare(audio)
|
| 169 |
-
del audio
|
| 170 |
-
n_padded_samples = padded_audio.shape[-1]
|
| 171 |
-
|
| 172 |
-
if n_channel > 1:
|
| 173 |
-
padded_audio = padded_audio.view(
|
| 174 |
-
original_batch_size * n_channel, 1, n_padded_samples
|
| 175 |
-
)
|
| 176 |
-
|
| 177 |
-
unfolded_input = unfold(
|
| 178 |
-
padded_audio, original_batch_size, n_channel, self.chunk_size, self.hop_size
|
| 179 |
-
)
|
| 180 |
-
|
| 181 |
-
n_total_chunks, n_channel, chunk_size = unfolded_input.shape
|
| 182 |
-
|
| 183 |
-
n_batch = np.ceil(n_total_chunks / self.batch_size).astype(int)
|
| 184 |
-
|
| 185 |
-
chunks_in = [
|
| 186 |
-
unfolded_input[b * self.batch_size : (b + 1) * self.batch_size, ...].clone()
|
| 187 |
-
for b in range(n_batch)
|
| 188 |
-
]
|
| 189 |
-
|
| 190 |
-
all_chunks_out = defaultdict(
|
| 191 |
-
lambda: torch.zeros_like(unfolded_input, device="cpu")
|
| 192 |
-
)
|
| 193 |
-
|
| 194 |
-
for b, cin in enumerate(chunks_in):
|
| 195 |
-
if torch.allclose(cin, torch.tensor(0.0)):
|
| 196 |
-
del cin
|
| 197 |
-
continue
|
| 198 |
-
|
| 199 |
-
chunks_out = model_fn(cin.to(original_device))
|
| 200 |
-
del cin
|
| 201 |
-
for s, c in chunks_out.items():
|
| 202 |
-
all_chunks_out[s][
|
| 203 |
-
b * self.batch_size : (b + 1) * self.batch_size, ...
|
| 204 |
-
] = c.cpu()
|
| 205 |
-
del chunks_out
|
| 206 |
-
|
| 207 |
-
del unfolded_input
|
| 208 |
-
del padded_audio
|
| 209 |
-
|
| 210 |
-
if self.fade_edge_frames:
|
| 211 |
-
fn = merge_chunks_all
|
| 212 |
-
else:
|
| 213 |
-
fn = merge_chunks_edge
|
| 214 |
-
outputs = {}
|
| 215 |
-
|
| 216 |
-
torch.cuda.empty_cache()
|
| 217 |
-
|
| 218 |
-
for s, c in all_chunks_out.items():
|
| 219 |
-
combined: torch.Tensor = fn(
|
| 220 |
-
c,
|
| 221 |
-
original_batch_size,
|
| 222 |
-
n_channel,
|
| 223 |
-
n_samples,
|
| 224 |
-
n_padded_samples,
|
| 225 |
-
n_chunks,
|
| 226 |
-
self.chunk_size,
|
| 227 |
-
self.hop_size,
|
| 228 |
-
self.edge_frame_pad_sizes,
|
| 229 |
-
self.standard_window,
|
| 230 |
-
self.__dict__.get("first_window", self.standard_window),
|
| 231 |
-
self.__dict__.get("last_window", self.standard_window),
|
| 232 |
-
)
|
| 233 |
-
|
| 234 |
-
outputs[s] = combined.to(dtype=original_dtype, device=original_device)
|
| 235 |
-
|
| 236 |
-
return {"audio": outputs}
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
class LinearFader(BaseFader):
|
| 240 |
-
def __init__(
|
| 241 |
-
self,
|
| 242 |
-
chunk_size_second: float,
|
| 243 |
-
hop_size_second: float,
|
| 244 |
-
fs: int,
|
| 245 |
-
fade_edge_frames: bool = False,
|
| 246 |
-
batch_size: int = 1,
|
| 247 |
-
) -> None:
|
| 248 |
-
|
| 249 |
-
assert hop_size_second >= chunk_size_second / 2
|
| 250 |
-
|
| 251 |
-
super().__init__(
|
| 252 |
-
chunk_size_second=chunk_size_second,
|
| 253 |
-
hop_size_second=hop_size_second,
|
| 254 |
-
fs=fs,
|
| 255 |
-
fade_edge_frames=fade_edge_frames,
|
| 256 |
-
batch_size=batch_size,
|
| 257 |
-
)
|
| 258 |
-
|
| 259 |
-
in_fade = torch.linspace(0.0, 1.0, self.overlap_size + 1)[:-1]
|
| 260 |
-
out_fade = torch.linspace(1.0, 0.0, self.overlap_size + 1)[1:]
|
| 261 |
-
center_ones = torch.ones(self.chunk_size - 2 * self.overlap_size)
|
| 262 |
-
inout_ones = torch.ones(self.overlap_size)
|
| 263 |
-
|
| 264 |
-
self.register_buffer(
|
| 265 |
-
"standard_window", torch.concat([in_fade, center_ones, out_fade])
|
| 266 |
-
)
|
| 267 |
-
|
| 268 |
-
self.fade_edge_frames = fade_edge_frames
|
| 269 |
-
self.edge_frame_pad_size = (self.overlap_size, self.overlap_size)
|
| 270 |
-
|
| 271 |
-
if not self.fade_edge_frames:
|
| 272 |
-
self.first_window = nn.Parameter(
|
| 273 |
-
torch.concat([inout_ones, center_ones, out_fade]), requires_grad=False
|
| 274 |
-
)
|
| 275 |
-
self.last_window = nn.Parameter(
|
| 276 |
-
torch.concat([in_fade, center_ones, inout_ones]), requires_grad=False
|
| 277 |
-
)
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
class OverlapAddFader(BaseFader):
|
| 281 |
-
def __init__(
|
| 282 |
-
self,
|
| 283 |
-
window_type: str,
|
| 284 |
-
chunk_size_second: float,
|
| 285 |
-
hop_size_second: float,
|
| 286 |
-
fs: int,
|
| 287 |
-
batch_size: int = 1,
|
| 288 |
-
) -> None:
|
| 289 |
-
assert (chunk_size_second / hop_size_second) % 2 == 0
|
| 290 |
-
assert int(chunk_size_second * fs) % 2 == 0
|
| 291 |
-
|
| 292 |
-
super().__init__(
|
| 293 |
-
chunk_size_second=chunk_size_second,
|
| 294 |
-
hop_size_second=hop_size_second,
|
| 295 |
-
fs=fs,
|
| 296 |
-
fade_edge_frames=True,
|
| 297 |
-
batch_size=batch_size,
|
| 298 |
-
)
|
| 299 |
-
|
| 300 |
-
self.hop_multiplier = self.chunk_size / (2 * self.hop_size)
|
| 301 |
-
|
| 302 |
-
self.edge_frame_pad_sizes = (2 * self.overlap_size, 2 * self.overlap_size)
|
| 303 |
-
|
| 304 |
-
self.register_buffer(
|
| 305 |
-
"standard_window",
|
| 306 |
-
torch.windows.__dict__[window_type](
|
| 307 |
-
self.chunk_size,
|
| 308 |
-
sym=False,
|
| 309 |
-
)
|
| 310 |
-
/ self.hop_multiplier,
|
| 311 |
-
)
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
if __name__ == "__main__":
|
| 315 |
-
import torchaudio as ta
|
| 316 |
-
|
| 317 |
-
fs = 44100
|
| 318 |
-
ola = OverlapAddFader("hann", 6.0, 1.0, fs, batch_size=16)
|
| 319 |
-
audio_, _ = ta.load(
|
| 320 |
-
"$DATA_ROOT/MUSDB18/HQ/canonical/test/BKS - Too " "Much/vocals.wav"
|
| 321 |
-
)
|
| 322 |
-
audio_ = audio_[None, ...]
|
| 323 |
-
out = ola(audio_, lambda x: {"stem": x})["audio"]["stem"]
|
| 324 |
-
print(torch.allclose(out, audio_))
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
|
| 3 |
+
from tqdm.auto import tqdm
|
| 4 |
+
from typing import Callable, Dict, List, Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@torch.jit.script
|
| 13 |
+
def merge(
|
| 14 |
+
combined: torch.Tensor,
|
| 15 |
+
original_batch_size: int,
|
| 16 |
+
n_channel: int,
|
| 17 |
+
n_chunks: int,
|
| 18 |
+
chunk_size: int,
|
| 19 |
+
):
|
| 20 |
+
combined = torch.reshape(
|
| 21 |
+
combined, (original_batch_size, n_chunks, n_channel, chunk_size)
|
| 22 |
+
)
|
| 23 |
+
combined = torch.permute(combined, (0, 2, 3, 1)).reshape(
|
| 24 |
+
original_batch_size * n_channel, chunk_size, n_chunks
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
return combined
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@torch.jit.script
|
| 31 |
+
def unfold(
|
| 32 |
+
padded_audio: torch.Tensor,
|
| 33 |
+
original_batch_size: int,
|
| 34 |
+
n_channel: int,
|
| 35 |
+
chunk_size: int,
|
| 36 |
+
hop_size: int,
|
| 37 |
+
) -> torch.Tensor:
|
| 38 |
+
|
| 39 |
+
unfolded_input = F.unfold(
|
| 40 |
+
padded_audio[:, :, None, :], kernel_size=(1, chunk_size), stride=(1, hop_size)
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
_, _, n_chunks = unfolded_input.shape
|
| 44 |
+
unfolded_input = unfolded_input.view(
|
| 45 |
+
original_batch_size, n_channel, chunk_size, n_chunks
|
| 46 |
+
)
|
| 47 |
+
unfolded_input = torch.permute(unfolded_input, (0, 3, 1, 2)).reshape(
|
| 48 |
+
original_batch_size * n_chunks, n_channel, chunk_size
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
return unfolded_input
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@torch.jit.script
|
| 55 |
+
def merge_chunks_all(
|
| 56 |
+
combined: torch.Tensor,
|
| 57 |
+
original_batch_size: int,
|
| 58 |
+
n_channel: int,
|
| 59 |
+
n_samples: int,
|
| 60 |
+
n_padded_samples: int,
|
| 61 |
+
n_chunks: int,
|
| 62 |
+
chunk_size: int,
|
| 63 |
+
hop_size: int,
|
| 64 |
+
edge_frame_pad_sizes: Tuple[int, int],
|
| 65 |
+
standard_window: torch.Tensor,
|
| 66 |
+
first_window: torch.Tensor,
|
| 67 |
+
last_window: torch.Tensor,
|
| 68 |
+
):
|
| 69 |
+
combined = merge(combined, original_batch_size, n_channel, n_chunks, chunk_size)
|
| 70 |
+
|
| 71 |
+
combined = combined * standard_window[:, None].to(combined.device)
|
| 72 |
+
|
| 73 |
+
combined = F.fold(
|
| 74 |
+
combined.to(torch.float32),
|
| 75 |
+
output_size=(1, n_padded_samples),
|
| 76 |
+
kernel_size=(1, chunk_size),
|
| 77 |
+
stride=(1, hop_size),
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
combined = combined.view(original_batch_size, n_channel, n_padded_samples)
|
| 81 |
+
|
| 82 |
+
pad_front, pad_back = edge_frame_pad_sizes
|
| 83 |
+
combined = combined[..., pad_front:-pad_back]
|
| 84 |
+
|
| 85 |
+
combined = combined[..., :n_samples]
|
| 86 |
+
|
| 87 |
+
return combined
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def merge_chunks_edge(
|
| 91 |
+
combined: torch.Tensor,
|
| 92 |
+
original_batch_size: int,
|
| 93 |
+
n_channel: int,
|
| 94 |
+
n_samples: int,
|
| 95 |
+
n_padded_samples: int,
|
| 96 |
+
n_chunks: int,
|
| 97 |
+
chunk_size: int,
|
| 98 |
+
hop_size: int,
|
| 99 |
+
edge_frame_pad_sizes: Tuple[int, int],
|
| 100 |
+
standard_window: torch.Tensor,
|
| 101 |
+
first_window: torch.Tensor,
|
| 102 |
+
last_window: torch.Tensor,
|
| 103 |
+
):
|
| 104 |
+
combined = merge(combined, original_batch_size, n_channel, n_chunks, chunk_size)
|
| 105 |
+
|
| 106 |
+
combined[..., 0] = combined[..., 0] * first_window
|
| 107 |
+
combined[..., -1] = combined[..., -1] * last_window
|
| 108 |
+
combined[..., 1:-1] = combined[..., 1:-1] * standard_window[:, None]
|
| 109 |
+
|
| 110 |
+
combined = F.fold(
|
| 111 |
+
combined,
|
| 112 |
+
output_size=(1, n_padded_samples),
|
| 113 |
+
kernel_size=(1, chunk_size),
|
| 114 |
+
stride=(1, hop_size),
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
combined = combined.view(original_batch_size, n_channel, n_padded_samples)
|
| 118 |
+
|
| 119 |
+
combined = combined[..., :n_samples]
|
| 120 |
+
|
| 121 |
+
return combined
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class BaseFader(nn.Module):
|
| 125 |
+
def __init__(
|
| 126 |
+
self,
|
| 127 |
+
chunk_size_second: float,
|
| 128 |
+
hop_size_second: float,
|
| 129 |
+
fs: int,
|
| 130 |
+
fade_edge_frames: bool,
|
| 131 |
+
batch_size: int,
|
| 132 |
+
) -> None:
|
| 133 |
+
super().__init__()
|
| 134 |
+
|
| 135 |
+
self.chunk_size = int(chunk_size_second * fs)
|
| 136 |
+
self.hop_size = int(hop_size_second * fs)
|
| 137 |
+
self.overlap_size = self.chunk_size - self.hop_size
|
| 138 |
+
self.fade_edge_frames = fade_edge_frames
|
| 139 |
+
self.batch_size = batch_size
|
| 140 |
+
|
| 141 |
+
def prepare(self, audio):
|
| 142 |
+
|
| 143 |
+
if self.fade_edge_frames:
|
| 144 |
+
audio = F.pad(audio, self.edge_frame_pad_sizes, mode="reflect")
|
| 145 |
+
|
| 146 |
+
n_samples = audio.shape[-1]
|
| 147 |
+
n_chunks = int(np.ceil((n_samples - self.chunk_size) / self.hop_size) + 1)
|
| 148 |
+
|
| 149 |
+
padded_size = (n_chunks - 1) * self.hop_size + self.chunk_size
|
| 150 |
+
pad_size = padded_size - n_samples
|
| 151 |
+
|
| 152 |
+
padded_audio = F.pad(audio, (0, pad_size))
|
| 153 |
+
|
| 154 |
+
return padded_audio, n_chunks
|
| 155 |
+
|
| 156 |
+
def forward(
|
| 157 |
+
self,
|
| 158 |
+
audio: torch.Tensor,
|
| 159 |
+
model_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]],
|
| 160 |
+
):
|
| 161 |
+
|
| 162 |
+
original_dtype = audio.dtype
|
| 163 |
+
original_device = audio.device
|
| 164 |
+
|
| 165 |
+
audio = audio.to("cpu")
|
| 166 |
+
|
| 167 |
+
original_batch_size, n_channel, n_samples = audio.shape
|
| 168 |
+
padded_audio, n_chunks = self.prepare(audio)
|
| 169 |
+
del audio
|
| 170 |
+
n_padded_samples = padded_audio.shape[-1]
|
| 171 |
+
|
| 172 |
+
if n_channel > 1:
|
| 173 |
+
padded_audio = padded_audio.view(
|
| 174 |
+
original_batch_size * n_channel, 1, n_padded_samples
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
unfolded_input = unfold(
|
| 178 |
+
padded_audio, original_batch_size, n_channel, self.chunk_size, self.hop_size
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
n_total_chunks, n_channel, chunk_size = unfolded_input.shape
|
| 182 |
+
|
| 183 |
+
n_batch = np.ceil(n_total_chunks / self.batch_size).astype(int)
|
| 184 |
+
|
| 185 |
+
chunks_in = [
|
| 186 |
+
unfolded_input[b * self.batch_size : (b + 1) * self.batch_size, ...].clone()
|
| 187 |
+
for b in range(n_batch)
|
| 188 |
+
]
|
| 189 |
+
|
| 190 |
+
all_chunks_out = defaultdict(
|
| 191 |
+
lambda: torch.zeros_like(unfolded_input, device="cpu")
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
for b, cin in enumerate(chunks_in):
|
| 195 |
+
if torch.allclose(cin, torch.tensor(0.0)):
|
| 196 |
+
del cin
|
| 197 |
+
continue
|
| 198 |
+
|
| 199 |
+
chunks_out = model_fn(cin.to(original_device))
|
| 200 |
+
del cin
|
| 201 |
+
for s, c in chunks_out.items():
|
| 202 |
+
all_chunks_out[s][
|
| 203 |
+
b * self.batch_size : (b + 1) * self.batch_size, ...
|
| 204 |
+
] = c.cpu()
|
| 205 |
+
del chunks_out
|
| 206 |
+
|
| 207 |
+
del unfolded_input
|
| 208 |
+
del padded_audio
|
| 209 |
+
|
| 210 |
+
if self.fade_edge_frames:
|
| 211 |
+
fn = merge_chunks_all
|
| 212 |
+
else:
|
| 213 |
+
fn = merge_chunks_edge
|
| 214 |
+
outputs = {}
|
| 215 |
+
|
| 216 |
+
torch.cuda.empty_cache()
|
| 217 |
+
|
| 218 |
+
for s, c in all_chunks_out.items():
|
| 219 |
+
combined: torch.Tensor = fn(
|
| 220 |
+
c,
|
| 221 |
+
original_batch_size,
|
| 222 |
+
n_channel,
|
| 223 |
+
n_samples,
|
| 224 |
+
n_padded_samples,
|
| 225 |
+
n_chunks,
|
| 226 |
+
self.chunk_size,
|
| 227 |
+
self.hop_size,
|
| 228 |
+
self.edge_frame_pad_sizes,
|
| 229 |
+
self.standard_window,
|
| 230 |
+
self.__dict__.get("first_window", self.standard_window),
|
| 231 |
+
self.__dict__.get("last_window", self.standard_window),
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
outputs[s] = combined.to(dtype=original_dtype, device=original_device)
|
| 235 |
+
|
| 236 |
+
return {"audio": outputs}
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class LinearFader(BaseFader):
|
| 240 |
+
def __init__(
|
| 241 |
+
self,
|
| 242 |
+
chunk_size_second: float,
|
| 243 |
+
hop_size_second: float,
|
| 244 |
+
fs: int,
|
| 245 |
+
fade_edge_frames: bool = False,
|
| 246 |
+
batch_size: int = 1,
|
| 247 |
+
) -> None:
|
| 248 |
+
|
| 249 |
+
assert hop_size_second >= chunk_size_second / 2
|
| 250 |
+
|
| 251 |
+
super().__init__(
|
| 252 |
+
chunk_size_second=chunk_size_second,
|
| 253 |
+
hop_size_second=hop_size_second,
|
| 254 |
+
fs=fs,
|
| 255 |
+
fade_edge_frames=fade_edge_frames,
|
| 256 |
+
batch_size=batch_size,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
in_fade = torch.linspace(0.0, 1.0, self.overlap_size + 1)[:-1]
|
| 260 |
+
out_fade = torch.linspace(1.0, 0.0, self.overlap_size + 1)[1:]
|
| 261 |
+
center_ones = torch.ones(self.chunk_size - 2 * self.overlap_size)
|
| 262 |
+
inout_ones = torch.ones(self.overlap_size)
|
| 263 |
+
|
| 264 |
+
self.register_buffer(
|
| 265 |
+
"standard_window", torch.concat([in_fade, center_ones, out_fade])
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
self.fade_edge_frames = fade_edge_frames
|
| 269 |
+
self.edge_frame_pad_size = (self.overlap_size, self.overlap_size)
|
| 270 |
+
|
| 271 |
+
if not self.fade_edge_frames:
|
| 272 |
+
self.first_window = nn.Parameter(
|
| 273 |
+
torch.concat([inout_ones, center_ones, out_fade]), requires_grad=False
|
| 274 |
+
)
|
| 275 |
+
self.last_window = nn.Parameter(
|
| 276 |
+
torch.concat([in_fade, center_ones, inout_ones]), requires_grad=False
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
class OverlapAddFader(BaseFader):
|
| 281 |
+
def __init__(
|
| 282 |
+
self,
|
| 283 |
+
window_type: str,
|
| 284 |
+
chunk_size_second: float,
|
| 285 |
+
hop_size_second: float,
|
| 286 |
+
fs: int,
|
| 287 |
+
batch_size: int = 1,
|
| 288 |
+
) -> None:
|
| 289 |
+
assert (chunk_size_second / hop_size_second) % 2 == 0
|
| 290 |
+
assert int(chunk_size_second * fs) % 2 == 0
|
| 291 |
+
|
| 292 |
+
super().__init__(
|
| 293 |
+
chunk_size_second=chunk_size_second,
|
| 294 |
+
hop_size_second=hop_size_second,
|
| 295 |
+
fs=fs,
|
| 296 |
+
fade_edge_frames=True,
|
| 297 |
+
batch_size=batch_size,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
self.hop_multiplier = self.chunk_size / (2 * self.hop_size)
|
| 301 |
+
|
| 302 |
+
self.edge_frame_pad_sizes = (2 * self.overlap_size, 2 * self.overlap_size)
|
| 303 |
+
|
| 304 |
+
self.register_buffer(
|
| 305 |
+
"standard_window",
|
| 306 |
+
torch.windows.__dict__[window_type](
|
| 307 |
+
self.chunk_size,
|
| 308 |
+
sym=False,
|
| 309 |
+
)
|
| 310 |
+
/ self.hop_multiplier,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
if __name__ == "__main__":
|
| 315 |
+
import torchaudio as ta
|
| 316 |
+
|
| 317 |
+
fs = 44100
|
| 318 |
+
ola = OverlapAddFader("hann", 6.0, 1.0, fs, batch_size=16)
|
| 319 |
+
audio_, _ = ta.load(
|
| 320 |
+
"$DATA_ROOT/MUSDB18/HQ/canonical/test/BKS - Too " "Much/vocals.wav"
|
| 321 |
+
)
|
| 322 |
+
audio_ = audio_[None, ...]
|
| 323 |
+
out = ola(audio_, lambda x: {"stem": x})["audio"]["stem"]
|
| 324 |
+
print(torch.allclose(out, audio_))
|
mvsepless/models/bandit/model_from_config.py
CHANGED
|
@@ -1,26 +1,26 @@
|
|
| 1 |
-
import sys
|
| 2 |
-
import os.path
|
| 3 |
-
import torch
|
| 4 |
-
|
| 5 |
-
import yaml
|
| 6 |
-
from ml_collections import ConfigDict
|
| 7 |
-
|
| 8 |
-
torch.set_float32_matmul_precision("medium")
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def get_model(
|
| 12 |
-
config_path,
|
| 13 |
-
weights_path,
|
| 14 |
-
device,
|
| 15 |
-
):
|
| 16 |
-
from .core.model import MultiMaskMultiSourceBandSplitRNNSimple
|
| 17 |
-
|
| 18 |
-
f = open(config_path)
|
| 19 |
-
config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
|
| 20 |
-
f.close()
|
| 21 |
-
|
| 22 |
-
model = MultiMaskMultiSourceBandSplitRNNSimple(**config.model)
|
| 23 |
-
d = torch.load(code_path + "model_bandit_plus_dnr_sdr_11.47.chpt")
|
| 24 |
-
model.load_state_dict(d)
|
| 25 |
-
model.to(device)
|
| 26 |
-
return model, config
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os.path
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
import yaml
|
| 6 |
+
from ml_collections import ConfigDict
|
| 7 |
+
|
| 8 |
+
torch.set_float32_matmul_precision("medium")
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_model(
|
| 12 |
+
config_path,
|
| 13 |
+
weights_path,
|
| 14 |
+
device,
|
| 15 |
+
):
|
| 16 |
+
from .core.model import MultiMaskMultiSourceBandSplitRNNSimple
|
| 17 |
+
|
| 18 |
+
f = open(config_path)
|
| 19 |
+
config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
|
| 20 |
+
f.close()
|
| 21 |
+
|
| 22 |
+
model = MultiMaskMultiSourceBandSplitRNNSimple(**config.model)
|
| 23 |
+
d = torch.load(code_path + "model_bandit_plus_dnr_sdr_11.47.chpt")
|
| 24 |
+
model.load_state_dict(d)
|
| 25 |
+
model.to(device)
|
| 26 |
+
return model, config
|
mvsepless/models/bandit_v2/bandit.py
CHANGED
|
@@ -1,360 +1,360 @@
|
|
| 1 |
-
from typing import Dict, List, Optional
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
import torchaudio as ta
|
| 5 |
-
from torch import nn
|
| 6 |
-
import pytorch_lightning as pl
|
| 7 |
-
|
| 8 |
-
from .bandsplit import BandSplitModule
|
| 9 |
-
from .maskestim import OverlappingMaskEstimationModule
|
| 10 |
-
from .tfmodel import SeqBandModellingModule
|
| 11 |
-
from .utils import MusicalBandsplitSpecification
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
class BaseEndToEndModule(pl.LightningModule):
|
| 15 |
-
def __init__(
|
| 16 |
-
self,
|
| 17 |
-
) -> None:
|
| 18 |
-
super().__init__()
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
class BaseBandit(BaseEndToEndModule):
|
| 22 |
-
def __init__(
|
| 23 |
-
self,
|
| 24 |
-
in_channels: int,
|
| 25 |
-
fs: int,
|
| 26 |
-
band_type: str = "musical",
|
| 27 |
-
n_bands: int = 64,
|
| 28 |
-
require_no_overlap: bool = False,
|
| 29 |
-
require_no_gap: bool = True,
|
| 30 |
-
normalize_channel_independently: bool = False,
|
| 31 |
-
treat_channel_as_feature: bool = True,
|
| 32 |
-
n_sqm_modules: int = 12,
|
| 33 |
-
emb_dim: int = 128,
|
| 34 |
-
rnn_dim: int = 256,
|
| 35 |
-
bidirectional: bool = True,
|
| 36 |
-
rnn_type: str = "LSTM",
|
| 37 |
-
n_fft: int = 2048,
|
| 38 |
-
win_length: Optional[int] = 2048,
|
| 39 |
-
hop_length: int = 512,
|
| 40 |
-
window_fn: str = "hann_window",
|
| 41 |
-
wkwargs: Optional[Dict] = None,
|
| 42 |
-
power: Optional[int] = None,
|
| 43 |
-
center: bool = True,
|
| 44 |
-
normalized: bool = True,
|
| 45 |
-
pad_mode: str = "constant",
|
| 46 |
-
onesided: bool = True,
|
| 47 |
-
):
|
| 48 |
-
super().__init__()
|
| 49 |
-
|
| 50 |
-
self.in_channels = in_channels
|
| 51 |
-
|
| 52 |
-
self.instantitate_spectral(
|
| 53 |
-
n_fft=n_fft,
|
| 54 |
-
win_length=win_length,
|
| 55 |
-
hop_length=hop_length,
|
| 56 |
-
window_fn=window_fn,
|
| 57 |
-
wkwargs=wkwargs,
|
| 58 |
-
power=power,
|
| 59 |
-
normalized=normalized,
|
| 60 |
-
center=center,
|
| 61 |
-
pad_mode=pad_mode,
|
| 62 |
-
onesided=onesided,
|
| 63 |
-
)
|
| 64 |
-
|
| 65 |
-
self.instantiate_bandsplit(
|
| 66 |
-
in_channels=in_channels,
|
| 67 |
-
band_type=band_type,
|
| 68 |
-
n_bands=n_bands,
|
| 69 |
-
require_no_overlap=require_no_overlap,
|
| 70 |
-
require_no_gap=require_no_gap,
|
| 71 |
-
normalize_channel_independently=normalize_channel_independently,
|
| 72 |
-
treat_channel_as_feature=treat_channel_as_feature,
|
| 73 |
-
emb_dim=emb_dim,
|
| 74 |
-
n_fft=n_fft,
|
| 75 |
-
fs=fs,
|
| 76 |
-
)
|
| 77 |
-
|
| 78 |
-
self.instantiate_tf_modelling(
|
| 79 |
-
n_sqm_modules=n_sqm_modules,
|
| 80 |
-
emb_dim=emb_dim,
|
| 81 |
-
rnn_dim=rnn_dim,
|
| 82 |
-
bidirectional=bidirectional,
|
| 83 |
-
rnn_type=rnn_type,
|
| 84 |
-
)
|
| 85 |
-
|
| 86 |
-
def instantitate_spectral(
|
| 87 |
-
self,
|
| 88 |
-
n_fft: int = 2048,
|
| 89 |
-
win_length: Optional[int] = 2048,
|
| 90 |
-
hop_length: int = 512,
|
| 91 |
-
window_fn: str = "hann_window",
|
| 92 |
-
wkwargs: Optional[Dict] = None,
|
| 93 |
-
power: Optional[int] = None,
|
| 94 |
-
normalized: bool = True,
|
| 95 |
-
center: bool = True,
|
| 96 |
-
pad_mode: str = "constant",
|
| 97 |
-
onesided: bool = True,
|
| 98 |
-
):
|
| 99 |
-
assert power is None
|
| 100 |
-
|
| 101 |
-
window_fn = torch.__dict__[window_fn]
|
| 102 |
-
|
| 103 |
-
self.stft = ta.transforms.Spectrogram(
|
| 104 |
-
n_fft=n_fft,
|
| 105 |
-
win_length=win_length,
|
| 106 |
-
hop_length=hop_length,
|
| 107 |
-
pad_mode=pad_mode,
|
| 108 |
-
pad=0,
|
| 109 |
-
window_fn=window_fn,
|
| 110 |
-
wkwargs=wkwargs,
|
| 111 |
-
power=power,
|
| 112 |
-
normalized=normalized,
|
| 113 |
-
center=center,
|
| 114 |
-
onesided=onesided,
|
| 115 |
-
)
|
| 116 |
-
|
| 117 |
-
self.istft = ta.transforms.InverseSpectrogram(
|
| 118 |
-
n_fft=n_fft,
|
| 119 |
-
win_length=win_length,
|
| 120 |
-
hop_length=hop_length,
|
| 121 |
-
pad_mode=pad_mode,
|
| 122 |
-
pad=0,
|
| 123 |
-
window_fn=window_fn,
|
| 124 |
-
wkwargs=wkwargs,
|
| 125 |
-
normalized=normalized,
|
| 126 |
-
center=center,
|
| 127 |
-
onesided=onesided,
|
| 128 |
-
)
|
| 129 |
-
|
| 130 |
-
def instantiate_bandsplit(
|
| 131 |
-
self,
|
| 132 |
-
in_channels: int,
|
| 133 |
-
band_type: str = "musical",
|
| 134 |
-
n_bands: int = 64,
|
| 135 |
-
require_no_overlap: bool = False,
|
| 136 |
-
require_no_gap: bool = True,
|
| 137 |
-
normalize_channel_independently: bool = False,
|
| 138 |
-
treat_channel_as_feature: bool = True,
|
| 139 |
-
emb_dim: int = 128,
|
| 140 |
-
n_fft: int = 2048,
|
| 141 |
-
fs: int = 44100,
|
| 142 |
-
):
|
| 143 |
-
assert band_type == "musical"
|
| 144 |
-
|
| 145 |
-
self.band_specs = MusicalBandsplitSpecification(
|
| 146 |
-
nfft=n_fft, fs=fs, n_bands=n_bands
|
| 147 |
-
)
|
| 148 |
-
|
| 149 |
-
self.band_split = BandSplitModule(
|
| 150 |
-
in_channels=in_channels,
|
| 151 |
-
band_specs=self.band_specs.get_band_specs(),
|
| 152 |
-
require_no_overlap=require_no_overlap,
|
| 153 |
-
require_no_gap=require_no_gap,
|
| 154 |
-
normalize_channel_independently=normalize_channel_independently,
|
| 155 |
-
treat_channel_as_feature=treat_channel_as_feature,
|
| 156 |
-
emb_dim=emb_dim,
|
| 157 |
-
)
|
| 158 |
-
|
| 159 |
-
def instantiate_tf_modelling(
|
| 160 |
-
self,
|
| 161 |
-
n_sqm_modules: int = 12,
|
| 162 |
-
emb_dim: int = 128,
|
| 163 |
-
rnn_dim: int = 256,
|
| 164 |
-
bidirectional: bool = True,
|
| 165 |
-
rnn_type: str = "LSTM",
|
| 166 |
-
):
|
| 167 |
-
try:
|
| 168 |
-
self.tf_model = torch.compile(
|
| 169 |
-
SeqBandModellingModule(
|
| 170 |
-
n_modules=n_sqm_modules,
|
| 171 |
-
emb_dim=emb_dim,
|
| 172 |
-
rnn_dim=rnn_dim,
|
| 173 |
-
bidirectional=bidirectional,
|
| 174 |
-
rnn_type=rnn_type,
|
| 175 |
-
),
|
| 176 |
-
disable=True,
|
| 177 |
-
)
|
| 178 |
-
except Exception as e:
|
| 179 |
-
self.tf_model = SeqBandModellingModule(
|
| 180 |
-
n_modules=n_sqm_modules,
|
| 181 |
-
emb_dim=emb_dim,
|
| 182 |
-
rnn_dim=rnn_dim,
|
| 183 |
-
bidirectional=bidirectional,
|
| 184 |
-
rnn_type=rnn_type,
|
| 185 |
-
)
|
| 186 |
-
|
| 187 |
-
def mask(self, x, m):
|
| 188 |
-
return x * m
|
| 189 |
-
|
| 190 |
-
def forward(self, batch, mode="train"):
|
| 191 |
-
init_shape = batch.shape
|
| 192 |
-
if not isinstance(batch, dict):
|
| 193 |
-
mono = batch.view(-1, 1, batch.shape[-1])
|
| 194 |
-
batch = {"mixture": {"audio": mono}}
|
| 195 |
-
|
| 196 |
-
with torch.no_grad():
|
| 197 |
-
mixture = batch["mixture"]["audio"]
|
| 198 |
-
|
| 199 |
-
x = self.stft(mixture)
|
| 200 |
-
batch["mixture"]["spectrogram"] = x
|
| 201 |
-
|
| 202 |
-
if "sources" in batch.keys():
|
| 203 |
-
for stem in batch["sources"].keys():
|
| 204 |
-
s = batch["sources"][stem]["audio"]
|
| 205 |
-
s = self.stft(s)
|
| 206 |
-
batch["sources"][stem]["spectrogram"] = s
|
| 207 |
-
|
| 208 |
-
batch = self.separate(batch)
|
| 209 |
-
|
| 210 |
-
if 1:
|
| 211 |
-
b = []
|
| 212 |
-
for s in self.stems:
|
| 213 |
-
r = batch["estimates"][s]["audio"].view(
|
| 214 |
-
-1, init_shape[1], init_shape[2]
|
| 215 |
-
)
|
| 216 |
-
b.append(r)
|
| 217 |
-
batch = torch.stack(b, dim=1)
|
| 218 |
-
return batch
|
| 219 |
-
|
| 220 |
-
def encode(self, batch):
|
| 221 |
-
x = batch["mixture"]["spectrogram"]
|
| 222 |
-
length = batch["mixture"]["audio"].shape[-1]
|
| 223 |
-
|
| 224 |
-
z = self.band_split(x)
|
| 225 |
-
q = self.tf_model(z)
|
| 226 |
-
|
| 227 |
-
return x, q, length
|
| 228 |
-
|
| 229 |
-
def separate(self, batch):
|
| 230 |
-
raise NotImplementedError
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
class Bandit(BaseBandit):
|
| 234 |
-
def __init__(
|
| 235 |
-
self,
|
| 236 |
-
in_channels: int,
|
| 237 |
-
stems: List[str],
|
| 238 |
-
band_type: str = "musical",
|
| 239 |
-
n_bands: int = 64,
|
| 240 |
-
require_no_overlap: bool = False,
|
| 241 |
-
require_no_gap: bool = True,
|
| 242 |
-
normalize_channel_independently: bool = False,
|
| 243 |
-
treat_channel_as_feature: bool = True,
|
| 244 |
-
n_sqm_modules: int = 12,
|
| 245 |
-
emb_dim: int = 128,
|
| 246 |
-
rnn_dim: int = 256,
|
| 247 |
-
bidirectional: bool = True,
|
| 248 |
-
rnn_type: str = "LSTM",
|
| 249 |
-
mlp_dim: int = 512,
|
| 250 |
-
hidden_activation: str = "Tanh",
|
| 251 |
-
hidden_activation_kwargs: Dict | None = None,
|
| 252 |
-
complex_mask: bool = True,
|
| 253 |
-
use_freq_weights: bool = True,
|
| 254 |
-
n_fft: int = 2048,
|
| 255 |
-
win_length: int | None = 2048,
|
| 256 |
-
hop_length: int = 512,
|
| 257 |
-
window_fn: str = "hann_window",
|
| 258 |
-
wkwargs: Dict | None = None,
|
| 259 |
-
power: int | None = None,
|
| 260 |
-
center: bool = True,
|
| 261 |
-
normalized: bool = True,
|
| 262 |
-
pad_mode: str = "constant",
|
| 263 |
-
onesided: bool = True,
|
| 264 |
-
fs: int = 44100,
|
| 265 |
-
stft_precisions="32",
|
| 266 |
-
bandsplit_precisions="bf16",
|
| 267 |
-
tf_model_precisions="bf16",
|
| 268 |
-
mask_estim_precisions="bf16",
|
| 269 |
-
):
|
| 270 |
-
super().__init__(
|
| 271 |
-
in_channels=in_channels,
|
| 272 |
-
band_type=band_type,
|
| 273 |
-
n_bands=n_bands,
|
| 274 |
-
require_no_overlap=require_no_overlap,
|
| 275 |
-
require_no_gap=require_no_gap,
|
| 276 |
-
normalize_channel_independently=normalize_channel_independently,
|
| 277 |
-
treat_channel_as_feature=treat_channel_as_feature,
|
| 278 |
-
n_sqm_modules=n_sqm_modules,
|
| 279 |
-
emb_dim=emb_dim,
|
| 280 |
-
rnn_dim=rnn_dim,
|
| 281 |
-
bidirectional=bidirectional,
|
| 282 |
-
rnn_type=rnn_type,
|
| 283 |
-
n_fft=n_fft,
|
| 284 |
-
win_length=win_length,
|
| 285 |
-
hop_length=hop_length,
|
| 286 |
-
window_fn=window_fn,
|
| 287 |
-
wkwargs=wkwargs,
|
| 288 |
-
power=power,
|
| 289 |
-
center=center,
|
| 290 |
-
normalized=normalized,
|
| 291 |
-
pad_mode=pad_mode,
|
| 292 |
-
onesided=onesided,
|
| 293 |
-
fs=fs,
|
| 294 |
-
)
|
| 295 |
-
|
| 296 |
-
self.stems = stems
|
| 297 |
-
|
| 298 |
-
self.instantiate_mask_estim(
|
| 299 |
-
in_channels=in_channels,
|
| 300 |
-
stems=stems,
|
| 301 |
-
emb_dim=emb_dim,
|
| 302 |
-
mlp_dim=mlp_dim,
|
| 303 |
-
hidden_activation=hidden_activation,
|
| 304 |
-
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 305 |
-
complex_mask=complex_mask,
|
| 306 |
-
n_freq=n_fft // 2 + 1,
|
| 307 |
-
use_freq_weights=use_freq_weights,
|
| 308 |
-
)
|
| 309 |
-
|
| 310 |
-
def instantiate_mask_estim(
|
| 311 |
-
self,
|
| 312 |
-
in_channels: int,
|
| 313 |
-
stems: List[str],
|
| 314 |
-
emb_dim: int,
|
| 315 |
-
mlp_dim: int,
|
| 316 |
-
hidden_activation: str,
|
| 317 |
-
hidden_activation_kwargs: Optional[Dict] = None,
|
| 318 |
-
complex_mask: bool = True,
|
| 319 |
-
n_freq: Optional[int] = None,
|
| 320 |
-
use_freq_weights: bool = False,
|
| 321 |
-
):
|
| 322 |
-
if hidden_activation_kwargs is None:
|
| 323 |
-
hidden_activation_kwargs = {}
|
| 324 |
-
|
| 325 |
-
assert n_freq is not None
|
| 326 |
-
|
| 327 |
-
self.mask_estim = nn.ModuleDict(
|
| 328 |
-
{
|
| 329 |
-
stem: OverlappingMaskEstimationModule(
|
| 330 |
-
band_specs=self.band_specs.get_band_specs(),
|
| 331 |
-
freq_weights=self.band_specs.get_freq_weights(),
|
| 332 |
-
n_freq=n_freq,
|
| 333 |
-
emb_dim=emb_dim,
|
| 334 |
-
mlp_dim=mlp_dim,
|
| 335 |
-
in_channels=in_channels,
|
| 336 |
-
hidden_activation=hidden_activation,
|
| 337 |
-
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 338 |
-
complex_mask=complex_mask,
|
| 339 |
-
use_freq_weights=use_freq_weights,
|
| 340 |
-
)
|
| 341 |
-
for stem in stems
|
| 342 |
-
}
|
| 343 |
-
)
|
| 344 |
-
|
| 345 |
-
def separate(self, batch):
|
| 346 |
-
batch["estimates"] = {}
|
| 347 |
-
|
| 348 |
-
x, q, length = self.encode(batch)
|
| 349 |
-
|
| 350 |
-
for stem, mem in self.mask_estim.items():
|
| 351 |
-
m = mem(q)
|
| 352 |
-
|
| 353 |
-
s = self.mask(x, m.to(x.dtype))
|
| 354 |
-
s = torch.reshape(s, x.shape)
|
| 355 |
-
batch["estimates"][stem] = {
|
| 356 |
-
"audio": self.istft(s, length),
|
| 357 |
-
"spectrogram": s,
|
| 358 |
-
}
|
| 359 |
-
|
| 360 |
-
return batch
|
|
|
|
| 1 |
+
from typing import Dict, List, Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torchaudio as ta
|
| 5 |
+
from torch import nn
|
| 6 |
+
import pytorch_lightning as pl
|
| 7 |
+
|
| 8 |
+
from .bandsplit import BandSplitModule
|
| 9 |
+
from .maskestim import OverlappingMaskEstimationModule
|
| 10 |
+
from .tfmodel import SeqBandModellingModule
|
| 11 |
+
from .utils import MusicalBandsplitSpecification
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class BaseEndToEndModule(pl.LightningModule):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
) -> None:
|
| 18 |
+
super().__init__()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class BaseBandit(BaseEndToEndModule):
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
in_channels: int,
|
| 25 |
+
fs: int,
|
| 26 |
+
band_type: str = "musical",
|
| 27 |
+
n_bands: int = 64,
|
| 28 |
+
require_no_overlap: bool = False,
|
| 29 |
+
require_no_gap: bool = True,
|
| 30 |
+
normalize_channel_independently: bool = False,
|
| 31 |
+
treat_channel_as_feature: bool = True,
|
| 32 |
+
n_sqm_modules: int = 12,
|
| 33 |
+
emb_dim: int = 128,
|
| 34 |
+
rnn_dim: int = 256,
|
| 35 |
+
bidirectional: bool = True,
|
| 36 |
+
rnn_type: str = "LSTM",
|
| 37 |
+
n_fft: int = 2048,
|
| 38 |
+
win_length: Optional[int] = 2048,
|
| 39 |
+
hop_length: int = 512,
|
| 40 |
+
window_fn: str = "hann_window",
|
| 41 |
+
wkwargs: Optional[Dict] = None,
|
| 42 |
+
power: Optional[int] = None,
|
| 43 |
+
center: bool = True,
|
| 44 |
+
normalized: bool = True,
|
| 45 |
+
pad_mode: str = "constant",
|
| 46 |
+
onesided: bool = True,
|
| 47 |
+
):
|
| 48 |
+
super().__init__()
|
| 49 |
+
|
| 50 |
+
self.in_channels = in_channels
|
| 51 |
+
|
| 52 |
+
self.instantitate_spectral(
|
| 53 |
+
n_fft=n_fft,
|
| 54 |
+
win_length=win_length,
|
| 55 |
+
hop_length=hop_length,
|
| 56 |
+
window_fn=window_fn,
|
| 57 |
+
wkwargs=wkwargs,
|
| 58 |
+
power=power,
|
| 59 |
+
normalized=normalized,
|
| 60 |
+
center=center,
|
| 61 |
+
pad_mode=pad_mode,
|
| 62 |
+
onesided=onesided,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
self.instantiate_bandsplit(
|
| 66 |
+
in_channels=in_channels,
|
| 67 |
+
band_type=band_type,
|
| 68 |
+
n_bands=n_bands,
|
| 69 |
+
require_no_overlap=require_no_overlap,
|
| 70 |
+
require_no_gap=require_no_gap,
|
| 71 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 72 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 73 |
+
emb_dim=emb_dim,
|
| 74 |
+
n_fft=n_fft,
|
| 75 |
+
fs=fs,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
self.instantiate_tf_modelling(
|
| 79 |
+
n_sqm_modules=n_sqm_modules,
|
| 80 |
+
emb_dim=emb_dim,
|
| 81 |
+
rnn_dim=rnn_dim,
|
| 82 |
+
bidirectional=bidirectional,
|
| 83 |
+
rnn_type=rnn_type,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
def instantitate_spectral(
|
| 87 |
+
self,
|
| 88 |
+
n_fft: int = 2048,
|
| 89 |
+
win_length: Optional[int] = 2048,
|
| 90 |
+
hop_length: int = 512,
|
| 91 |
+
window_fn: str = "hann_window",
|
| 92 |
+
wkwargs: Optional[Dict] = None,
|
| 93 |
+
power: Optional[int] = None,
|
| 94 |
+
normalized: bool = True,
|
| 95 |
+
center: bool = True,
|
| 96 |
+
pad_mode: str = "constant",
|
| 97 |
+
onesided: bool = True,
|
| 98 |
+
):
|
| 99 |
+
assert power is None
|
| 100 |
+
|
| 101 |
+
window_fn = torch.__dict__[window_fn]
|
| 102 |
+
|
| 103 |
+
self.stft = ta.transforms.Spectrogram(
|
| 104 |
+
n_fft=n_fft,
|
| 105 |
+
win_length=win_length,
|
| 106 |
+
hop_length=hop_length,
|
| 107 |
+
pad_mode=pad_mode,
|
| 108 |
+
pad=0,
|
| 109 |
+
window_fn=window_fn,
|
| 110 |
+
wkwargs=wkwargs,
|
| 111 |
+
power=power,
|
| 112 |
+
normalized=normalized,
|
| 113 |
+
center=center,
|
| 114 |
+
onesided=onesided,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
self.istft = ta.transforms.InverseSpectrogram(
|
| 118 |
+
n_fft=n_fft,
|
| 119 |
+
win_length=win_length,
|
| 120 |
+
hop_length=hop_length,
|
| 121 |
+
pad_mode=pad_mode,
|
| 122 |
+
pad=0,
|
| 123 |
+
window_fn=window_fn,
|
| 124 |
+
wkwargs=wkwargs,
|
| 125 |
+
normalized=normalized,
|
| 126 |
+
center=center,
|
| 127 |
+
onesided=onesided,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
def instantiate_bandsplit(
|
| 131 |
+
self,
|
| 132 |
+
in_channels: int,
|
| 133 |
+
band_type: str = "musical",
|
| 134 |
+
n_bands: int = 64,
|
| 135 |
+
require_no_overlap: bool = False,
|
| 136 |
+
require_no_gap: bool = True,
|
| 137 |
+
normalize_channel_independently: bool = False,
|
| 138 |
+
treat_channel_as_feature: bool = True,
|
| 139 |
+
emb_dim: int = 128,
|
| 140 |
+
n_fft: int = 2048,
|
| 141 |
+
fs: int = 44100,
|
| 142 |
+
):
|
| 143 |
+
assert band_type == "musical"
|
| 144 |
+
|
| 145 |
+
self.band_specs = MusicalBandsplitSpecification(
|
| 146 |
+
nfft=n_fft, fs=fs, n_bands=n_bands
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
self.band_split = BandSplitModule(
|
| 150 |
+
in_channels=in_channels,
|
| 151 |
+
band_specs=self.band_specs.get_band_specs(),
|
| 152 |
+
require_no_overlap=require_no_overlap,
|
| 153 |
+
require_no_gap=require_no_gap,
|
| 154 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 155 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 156 |
+
emb_dim=emb_dim,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
def instantiate_tf_modelling(
|
| 160 |
+
self,
|
| 161 |
+
n_sqm_modules: int = 12,
|
| 162 |
+
emb_dim: int = 128,
|
| 163 |
+
rnn_dim: int = 256,
|
| 164 |
+
bidirectional: bool = True,
|
| 165 |
+
rnn_type: str = "LSTM",
|
| 166 |
+
):
|
| 167 |
+
try:
|
| 168 |
+
self.tf_model = torch.compile(
|
| 169 |
+
SeqBandModellingModule(
|
| 170 |
+
n_modules=n_sqm_modules,
|
| 171 |
+
emb_dim=emb_dim,
|
| 172 |
+
rnn_dim=rnn_dim,
|
| 173 |
+
bidirectional=bidirectional,
|
| 174 |
+
rnn_type=rnn_type,
|
| 175 |
+
),
|
| 176 |
+
disable=True,
|
| 177 |
+
)
|
| 178 |
+
except Exception as e:
|
| 179 |
+
self.tf_model = SeqBandModellingModule(
|
| 180 |
+
n_modules=n_sqm_modules,
|
| 181 |
+
emb_dim=emb_dim,
|
| 182 |
+
rnn_dim=rnn_dim,
|
| 183 |
+
bidirectional=bidirectional,
|
| 184 |
+
rnn_type=rnn_type,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
def mask(self, x, m):
|
| 188 |
+
return x * m
|
| 189 |
+
|
| 190 |
+
def forward(self, batch, mode="train"):
|
| 191 |
+
init_shape = batch.shape
|
| 192 |
+
if not isinstance(batch, dict):
|
| 193 |
+
mono = batch.view(-1, 1, batch.shape[-1])
|
| 194 |
+
batch = {"mixture": {"audio": mono}}
|
| 195 |
+
|
| 196 |
+
with torch.no_grad():
|
| 197 |
+
mixture = batch["mixture"]["audio"]
|
| 198 |
+
|
| 199 |
+
x = self.stft(mixture)
|
| 200 |
+
batch["mixture"]["spectrogram"] = x
|
| 201 |
+
|
| 202 |
+
if "sources" in batch.keys():
|
| 203 |
+
for stem in batch["sources"].keys():
|
| 204 |
+
s = batch["sources"][stem]["audio"]
|
| 205 |
+
s = self.stft(s)
|
| 206 |
+
batch["sources"][stem]["spectrogram"] = s
|
| 207 |
+
|
| 208 |
+
batch = self.separate(batch)
|
| 209 |
+
|
| 210 |
+
if 1:
|
| 211 |
+
b = []
|
| 212 |
+
for s in self.stems:
|
| 213 |
+
r = batch["estimates"][s]["audio"].view(
|
| 214 |
+
-1, init_shape[1], init_shape[2]
|
| 215 |
+
)
|
| 216 |
+
b.append(r)
|
| 217 |
+
batch = torch.stack(b, dim=1)
|
| 218 |
+
return batch
|
| 219 |
+
|
| 220 |
+
def encode(self, batch):
|
| 221 |
+
x = batch["mixture"]["spectrogram"]
|
| 222 |
+
length = batch["mixture"]["audio"].shape[-1]
|
| 223 |
+
|
| 224 |
+
z = self.band_split(x)
|
| 225 |
+
q = self.tf_model(z)
|
| 226 |
+
|
| 227 |
+
return x, q, length
|
| 228 |
+
|
| 229 |
+
def separate(self, batch):
|
| 230 |
+
raise NotImplementedError
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class Bandit(BaseBandit):
|
| 234 |
+
def __init__(
|
| 235 |
+
self,
|
| 236 |
+
in_channels: int,
|
| 237 |
+
stems: List[str],
|
| 238 |
+
band_type: str = "musical",
|
| 239 |
+
n_bands: int = 64,
|
| 240 |
+
require_no_overlap: bool = False,
|
| 241 |
+
require_no_gap: bool = True,
|
| 242 |
+
normalize_channel_independently: bool = False,
|
| 243 |
+
treat_channel_as_feature: bool = True,
|
| 244 |
+
n_sqm_modules: int = 12,
|
| 245 |
+
emb_dim: int = 128,
|
| 246 |
+
rnn_dim: int = 256,
|
| 247 |
+
bidirectional: bool = True,
|
| 248 |
+
rnn_type: str = "LSTM",
|
| 249 |
+
mlp_dim: int = 512,
|
| 250 |
+
hidden_activation: str = "Tanh",
|
| 251 |
+
hidden_activation_kwargs: Dict | None = None,
|
| 252 |
+
complex_mask: bool = True,
|
| 253 |
+
use_freq_weights: bool = True,
|
| 254 |
+
n_fft: int = 2048,
|
| 255 |
+
win_length: int | None = 2048,
|
| 256 |
+
hop_length: int = 512,
|
| 257 |
+
window_fn: str = "hann_window",
|
| 258 |
+
wkwargs: Dict | None = None,
|
| 259 |
+
power: int | None = None,
|
| 260 |
+
center: bool = True,
|
| 261 |
+
normalized: bool = True,
|
| 262 |
+
pad_mode: str = "constant",
|
| 263 |
+
onesided: bool = True,
|
| 264 |
+
fs: int = 44100,
|
| 265 |
+
stft_precisions="32",
|
| 266 |
+
bandsplit_precisions="bf16",
|
| 267 |
+
tf_model_precisions="bf16",
|
| 268 |
+
mask_estim_precisions="bf16",
|
| 269 |
+
):
|
| 270 |
+
super().__init__(
|
| 271 |
+
in_channels=in_channels,
|
| 272 |
+
band_type=band_type,
|
| 273 |
+
n_bands=n_bands,
|
| 274 |
+
require_no_overlap=require_no_overlap,
|
| 275 |
+
require_no_gap=require_no_gap,
|
| 276 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 277 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 278 |
+
n_sqm_modules=n_sqm_modules,
|
| 279 |
+
emb_dim=emb_dim,
|
| 280 |
+
rnn_dim=rnn_dim,
|
| 281 |
+
bidirectional=bidirectional,
|
| 282 |
+
rnn_type=rnn_type,
|
| 283 |
+
n_fft=n_fft,
|
| 284 |
+
win_length=win_length,
|
| 285 |
+
hop_length=hop_length,
|
| 286 |
+
window_fn=window_fn,
|
| 287 |
+
wkwargs=wkwargs,
|
| 288 |
+
power=power,
|
| 289 |
+
center=center,
|
| 290 |
+
normalized=normalized,
|
| 291 |
+
pad_mode=pad_mode,
|
| 292 |
+
onesided=onesided,
|
| 293 |
+
fs=fs,
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
self.stems = stems
|
| 297 |
+
|
| 298 |
+
self.instantiate_mask_estim(
|
| 299 |
+
in_channels=in_channels,
|
| 300 |
+
stems=stems,
|
| 301 |
+
emb_dim=emb_dim,
|
| 302 |
+
mlp_dim=mlp_dim,
|
| 303 |
+
hidden_activation=hidden_activation,
|
| 304 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 305 |
+
complex_mask=complex_mask,
|
| 306 |
+
n_freq=n_fft // 2 + 1,
|
| 307 |
+
use_freq_weights=use_freq_weights,
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
def instantiate_mask_estim(
|
| 311 |
+
self,
|
| 312 |
+
in_channels: int,
|
| 313 |
+
stems: List[str],
|
| 314 |
+
emb_dim: int,
|
| 315 |
+
mlp_dim: int,
|
| 316 |
+
hidden_activation: str,
|
| 317 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 318 |
+
complex_mask: bool = True,
|
| 319 |
+
n_freq: Optional[int] = None,
|
| 320 |
+
use_freq_weights: bool = False,
|
| 321 |
+
):
|
| 322 |
+
if hidden_activation_kwargs is None:
|
| 323 |
+
hidden_activation_kwargs = {}
|
| 324 |
+
|
| 325 |
+
assert n_freq is not None
|
| 326 |
+
|
| 327 |
+
self.mask_estim = nn.ModuleDict(
|
| 328 |
+
{
|
| 329 |
+
stem: OverlappingMaskEstimationModule(
|
| 330 |
+
band_specs=self.band_specs.get_band_specs(),
|
| 331 |
+
freq_weights=self.band_specs.get_freq_weights(),
|
| 332 |
+
n_freq=n_freq,
|
| 333 |
+
emb_dim=emb_dim,
|
| 334 |
+
mlp_dim=mlp_dim,
|
| 335 |
+
in_channels=in_channels,
|
| 336 |
+
hidden_activation=hidden_activation,
|
| 337 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 338 |
+
complex_mask=complex_mask,
|
| 339 |
+
use_freq_weights=use_freq_weights,
|
| 340 |
+
)
|
| 341 |
+
for stem in stems
|
| 342 |
+
}
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
def separate(self, batch):
|
| 346 |
+
batch["estimates"] = {}
|
| 347 |
+
|
| 348 |
+
x, q, length = self.encode(batch)
|
| 349 |
+
|
| 350 |
+
for stem, mem in self.mask_estim.items():
|
| 351 |
+
m = mem(q)
|
| 352 |
+
|
| 353 |
+
s = self.mask(x, m.to(x.dtype))
|
| 354 |
+
s = torch.reshape(s, x.shape)
|
| 355 |
+
batch["estimates"][stem] = {
|
| 356 |
+
"audio": self.istft(s, length),
|
| 357 |
+
"spectrogram": s,
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
return batch
|
mvsepless/models/bandit_v2/bandsplit.py
CHANGED
|
@@ -1,127 +1,127 @@
|
|
| 1 |
-
from typing import List, Tuple
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
from torch import nn
|
| 5 |
-
from torch.utils.checkpoint import checkpoint_sequential
|
| 6 |
-
|
| 7 |
-
from .utils import (
|
| 8 |
-
band_widths_from_specs,
|
| 9 |
-
check_no_gap,
|
| 10 |
-
check_no_overlap,
|
| 11 |
-
check_nonzero_bandwidth,
|
| 12 |
-
)
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
class NormFC(nn.Module):
|
| 16 |
-
def __init__(
|
| 17 |
-
self,
|
| 18 |
-
emb_dim: int,
|
| 19 |
-
bandwidth: int,
|
| 20 |
-
in_channels: int,
|
| 21 |
-
normalize_channel_independently: bool = False,
|
| 22 |
-
treat_channel_as_feature: bool = True,
|
| 23 |
-
) -> None:
|
| 24 |
-
super().__init__()
|
| 25 |
-
|
| 26 |
-
if not treat_channel_as_feature:
|
| 27 |
-
raise NotImplementedError
|
| 28 |
-
|
| 29 |
-
self.treat_channel_as_feature = treat_channel_as_feature
|
| 30 |
-
|
| 31 |
-
if normalize_channel_independently:
|
| 32 |
-
raise NotImplementedError
|
| 33 |
-
|
| 34 |
-
reim = 2
|
| 35 |
-
|
| 36 |
-
norm = nn.LayerNorm(in_channels * bandwidth * reim)
|
| 37 |
-
|
| 38 |
-
fc_in = bandwidth * reim
|
| 39 |
-
|
| 40 |
-
if treat_channel_as_feature:
|
| 41 |
-
fc_in *= in_channels
|
| 42 |
-
else:
|
| 43 |
-
assert emb_dim % in_channels == 0
|
| 44 |
-
emb_dim = emb_dim // in_channels
|
| 45 |
-
|
| 46 |
-
fc = nn.Linear(fc_in, emb_dim)
|
| 47 |
-
|
| 48 |
-
self.combined = nn.Sequential(norm, fc)
|
| 49 |
-
|
| 50 |
-
def forward(self, xb):
|
| 51 |
-
return checkpoint_sequential(self.combined, 1, xb, use_reentrant=False)
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
class BandSplitModule(nn.Module):
|
| 55 |
-
def __init__(
|
| 56 |
-
self,
|
| 57 |
-
band_specs: List[Tuple[float, float]],
|
| 58 |
-
emb_dim: int,
|
| 59 |
-
in_channels: int,
|
| 60 |
-
require_no_overlap: bool = False,
|
| 61 |
-
require_no_gap: bool = True,
|
| 62 |
-
normalize_channel_independently: bool = False,
|
| 63 |
-
treat_channel_as_feature: bool = True,
|
| 64 |
-
) -> None:
|
| 65 |
-
super().__init__()
|
| 66 |
-
|
| 67 |
-
check_nonzero_bandwidth(band_specs)
|
| 68 |
-
|
| 69 |
-
if require_no_gap:
|
| 70 |
-
check_no_gap(band_specs)
|
| 71 |
-
|
| 72 |
-
if require_no_overlap:
|
| 73 |
-
check_no_overlap(band_specs)
|
| 74 |
-
|
| 75 |
-
self.band_specs = band_specs
|
| 76 |
-
self.band_widths = band_widths_from_specs(band_specs)
|
| 77 |
-
self.n_bands = len(band_specs)
|
| 78 |
-
self.emb_dim = emb_dim
|
| 79 |
-
|
| 80 |
-
try:
|
| 81 |
-
self.norm_fc_modules = nn.ModuleList(
|
| 82 |
-
[ # type: ignore
|
| 83 |
-
torch.compile(
|
| 84 |
-
NormFC(
|
| 85 |
-
emb_dim=emb_dim,
|
| 86 |
-
bandwidth=bw,
|
| 87 |
-
in_channels=in_channels,
|
| 88 |
-
normalize_channel_independently=normalize_channel_independently,
|
| 89 |
-
treat_channel_as_feature=treat_channel_as_feature,
|
| 90 |
-
),
|
| 91 |
-
disable=True,
|
| 92 |
-
)
|
| 93 |
-
for bw in self.band_widths
|
| 94 |
-
]
|
| 95 |
-
)
|
| 96 |
-
except Exception as e:
|
| 97 |
-
self.norm_fc_modules = nn.ModuleList(
|
| 98 |
-
[ # type: ignore
|
| 99 |
-
NormFC(
|
| 100 |
-
emb_dim=emb_dim,
|
| 101 |
-
bandwidth=bw,
|
| 102 |
-
in_channels=in_channels,
|
| 103 |
-
normalize_channel_independently=normalize_channel_independently,
|
| 104 |
-
treat_channel_as_feature=treat_channel_as_feature,
|
| 105 |
-
)
|
| 106 |
-
for bw in self.band_widths
|
| 107 |
-
]
|
| 108 |
-
)
|
| 109 |
-
|
| 110 |
-
def forward(self, x: torch.Tensor):
|
| 111 |
-
|
| 112 |
-
batch, in_chan, band_width, n_time = x.shape
|
| 113 |
-
|
| 114 |
-
z = torch.zeros(
|
| 115 |
-
size=(batch, self.n_bands, n_time, self.emb_dim), device=x.device
|
| 116 |
-
)
|
| 117 |
-
|
| 118 |
-
x = torch.permute(x, (0, 3, 1, 2)).contiguous()
|
| 119 |
-
|
| 120 |
-
for i, nfm in enumerate(self.norm_fc_modules):
|
| 121 |
-
fstart, fend = self.band_specs[i]
|
| 122 |
-
xb = x[:, :, :, fstart:fend]
|
| 123 |
-
xb = torch.view_as_real(xb)
|
| 124 |
-
xb = torch.reshape(xb, (batch, n_time, -1))
|
| 125 |
-
z[:, i, :, :] = nfm(xb)
|
| 126 |
-
|
| 127 |
-
return z
|
|
|
|
| 1 |
+
from typing import List, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.utils.checkpoint import checkpoint_sequential
|
| 6 |
+
|
| 7 |
+
from .utils import (
|
| 8 |
+
band_widths_from_specs,
|
| 9 |
+
check_no_gap,
|
| 10 |
+
check_no_overlap,
|
| 11 |
+
check_nonzero_bandwidth,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class NormFC(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
emb_dim: int,
|
| 19 |
+
bandwidth: int,
|
| 20 |
+
in_channels: int,
|
| 21 |
+
normalize_channel_independently: bool = False,
|
| 22 |
+
treat_channel_as_feature: bool = True,
|
| 23 |
+
) -> None:
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
if not treat_channel_as_feature:
|
| 27 |
+
raise NotImplementedError
|
| 28 |
+
|
| 29 |
+
self.treat_channel_as_feature = treat_channel_as_feature
|
| 30 |
+
|
| 31 |
+
if normalize_channel_independently:
|
| 32 |
+
raise NotImplementedError
|
| 33 |
+
|
| 34 |
+
reim = 2
|
| 35 |
+
|
| 36 |
+
norm = nn.LayerNorm(in_channels * bandwidth * reim)
|
| 37 |
+
|
| 38 |
+
fc_in = bandwidth * reim
|
| 39 |
+
|
| 40 |
+
if treat_channel_as_feature:
|
| 41 |
+
fc_in *= in_channels
|
| 42 |
+
else:
|
| 43 |
+
assert emb_dim % in_channels == 0
|
| 44 |
+
emb_dim = emb_dim // in_channels
|
| 45 |
+
|
| 46 |
+
fc = nn.Linear(fc_in, emb_dim)
|
| 47 |
+
|
| 48 |
+
self.combined = nn.Sequential(norm, fc)
|
| 49 |
+
|
| 50 |
+
def forward(self, xb):
|
| 51 |
+
return checkpoint_sequential(self.combined, 1, xb, use_reentrant=False)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class BandSplitModule(nn.Module):
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
band_specs: List[Tuple[float, float]],
|
| 58 |
+
emb_dim: int,
|
| 59 |
+
in_channels: int,
|
| 60 |
+
require_no_overlap: bool = False,
|
| 61 |
+
require_no_gap: bool = True,
|
| 62 |
+
normalize_channel_independently: bool = False,
|
| 63 |
+
treat_channel_as_feature: bool = True,
|
| 64 |
+
) -> None:
|
| 65 |
+
super().__init__()
|
| 66 |
+
|
| 67 |
+
check_nonzero_bandwidth(band_specs)
|
| 68 |
+
|
| 69 |
+
if require_no_gap:
|
| 70 |
+
check_no_gap(band_specs)
|
| 71 |
+
|
| 72 |
+
if require_no_overlap:
|
| 73 |
+
check_no_overlap(band_specs)
|
| 74 |
+
|
| 75 |
+
self.band_specs = band_specs
|
| 76 |
+
self.band_widths = band_widths_from_specs(band_specs)
|
| 77 |
+
self.n_bands = len(band_specs)
|
| 78 |
+
self.emb_dim = emb_dim
|
| 79 |
+
|
| 80 |
+
try:
|
| 81 |
+
self.norm_fc_modules = nn.ModuleList(
|
| 82 |
+
[ # type: ignore
|
| 83 |
+
torch.compile(
|
| 84 |
+
NormFC(
|
| 85 |
+
emb_dim=emb_dim,
|
| 86 |
+
bandwidth=bw,
|
| 87 |
+
in_channels=in_channels,
|
| 88 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 89 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 90 |
+
),
|
| 91 |
+
disable=True,
|
| 92 |
+
)
|
| 93 |
+
for bw in self.band_widths
|
| 94 |
+
]
|
| 95 |
+
)
|
| 96 |
+
except Exception as e:
|
| 97 |
+
self.norm_fc_modules = nn.ModuleList(
|
| 98 |
+
[ # type: ignore
|
| 99 |
+
NormFC(
|
| 100 |
+
emb_dim=emb_dim,
|
| 101 |
+
bandwidth=bw,
|
| 102 |
+
in_channels=in_channels,
|
| 103 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 104 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 105 |
+
)
|
| 106 |
+
for bw in self.band_widths
|
| 107 |
+
]
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
def forward(self, x: torch.Tensor):
|
| 111 |
+
|
| 112 |
+
batch, in_chan, band_width, n_time = x.shape
|
| 113 |
+
|
| 114 |
+
z = torch.zeros(
|
| 115 |
+
size=(batch, self.n_bands, n_time, self.emb_dim), device=x.device
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
x = torch.permute(x, (0, 3, 1, 2)).contiguous()
|
| 119 |
+
|
| 120 |
+
for i, nfm in enumerate(self.norm_fc_modules):
|
| 121 |
+
fstart, fend = self.band_specs[i]
|
| 122 |
+
xb = x[:, :, :, fstart:fend]
|
| 123 |
+
xb = torch.view_as_real(xb)
|
| 124 |
+
xb = torch.reshape(xb, (batch, n_time, -1))
|
| 125 |
+
z[:, i, :, :] = nfm(xb)
|
| 126 |
+
|
| 127 |
+
return z
|
mvsepless/models/bandit_v2/film.py
CHANGED
|
@@ -1,23 +1,23 @@
|
|
| 1 |
-
from torch import nn
|
| 2 |
-
import torch
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
class FiLM(nn.Module):
|
| 6 |
-
def __init__(self):
|
| 7 |
-
super().__init__()
|
| 8 |
-
|
| 9 |
-
def forward(self, x, gamma, beta):
|
| 10 |
-
return gamma * x + beta
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class BTFBroadcastedFiLM(nn.Module):
|
| 14 |
-
def __init__(self):
|
| 15 |
-
super().__init__()
|
| 16 |
-
self.film = FiLM()
|
| 17 |
-
|
| 18 |
-
def forward(self, x, gamma, beta):
|
| 19 |
-
|
| 20 |
-
gamma = gamma[None, None, None, :]
|
| 21 |
-
beta = beta[None, None, None, :]
|
| 22 |
-
|
| 23 |
-
return self.film(x, gamma, beta)
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class FiLM(nn.Module):
|
| 6 |
+
def __init__(self):
|
| 7 |
+
super().__init__()
|
| 8 |
+
|
| 9 |
+
def forward(self, x, gamma, beta):
|
| 10 |
+
return gamma * x + beta
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class BTFBroadcastedFiLM(nn.Module):
|
| 14 |
+
def __init__(self):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.film = FiLM()
|
| 17 |
+
|
| 18 |
+
def forward(self, x, gamma, beta):
|
| 19 |
+
|
| 20 |
+
gamma = gamma[None, None, None, :]
|
| 21 |
+
beta = beta[None, None, None, :]
|
| 22 |
+
|
| 23 |
+
return self.film(x, gamma, beta)
|
mvsepless/models/bandit_v2/maskestim.py
CHANGED
|
@@ -1,269 +1,269 @@
|
|
| 1 |
-
from typing import Dict, List, Optional, Tuple, Type
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
from torch import nn
|
| 5 |
-
from torch.nn.modules import activation
|
| 6 |
-
from torch.utils.checkpoint import checkpoint_sequential
|
| 7 |
-
|
| 8 |
-
from .utils import (
|
| 9 |
-
band_widths_from_specs,
|
| 10 |
-
check_no_gap,
|
| 11 |
-
check_no_overlap,
|
| 12 |
-
check_nonzero_bandwidth,
|
| 13 |
-
)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
class BaseNormMLP(nn.Module):
|
| 17 |
-
def __init__(
|
| 18 |
-
self,
|
| 19 |
-
emb_dim: int,
|
| 20 |
-
mlp_dim: int,
|
| 21 |
-
bandwidth: int,
|
| 22 |
-
in_channels: Optional[int],
|
| 23 |
-
hidden_activation: str = "Tanh",
|
| 24 |
-
hidden_activation_kwargs=None,
|
| 25 |
-
complex_mask: bool = True,
|
| 26 |
-
):
|
| 27 |
-
super().__init__()
|
| 28 |
-
if hidden_activation_kwargs is None:
|
| 29 |
-
hidden_activation_kwargs = {}
|
| 30 |
-
self.hidden_activation_kwargs = hidden_activation_kwargs
|
| 31 |
-
self.norm = nn.LayerNorm(emb_dim)
|
| 32 |
-
self.hidden = nn.Sequential(
|
| 33 |
-
nn.Linear(in_features=emb_dim, out_features=mlp_dim),
|
| 34 |
-
activation.__dict__[hidden_activation](**self.hidden_activation_kwargs),
|
| 35 |
-
)
|
| 36 |
-
|
| 37 |
-
self.bandwidth = bandwidth
|
| 38 |
-
self.in_channels = in_channels
|
| 39 |
-
|
| 40 |
-
self.complex_mask = complex_mask
|
| 41 |
-
self.reim = 2 if complex_mask else 1
|
| 42 |
-
self.glu_mult = 2
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
class NormMLP(BaseNormMLP):
|
| 46 |
-
def __init__(
|
| 47 |
-
self,
|
| 48 |
-
emb_dim: int,
|
| 49 |
-
mlp_dim: int,
|
| 50 |
-
bandwidth: int,
|
| 51 |
-
in_channels: Optional[int],
|
| 52 |
-
hidden_activation: str = "Tanh",
|
| 53 |
-
hidden_activation_kwargs=None,
|
| 54 |
-
complex_mask: bool = True,
|
| 55 |
-
) -> None:
|
| 56 |
-
super().__init__(
|
| 57 |
-
emb_dim=emb_dim,
|
| 58 |
-
mlp_dim=mlp_dim,
|
| 59 |
-
bandwidth=bandwidth,
|
| 60 |
-
in_channels=in_channels,
|
| 61 |
-
hidden_activation=hidden_activation,
|
| 62 |
-
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 63 |
-
complex_mask=complex_mask,
|
| 64 |
-
)
|
| 65 |
-
|
| 66 |
-
self.output = nn.Sequential(
|
| 67 |
-
nn.Linear(
|
| 68 |
-
in_features=mlp_dim,
|
| 69 |
-
out_features=bandwidth * in_channels * self.reim * 2,
|
| 70 |
-
),
|
| 71 |
-
nn.GLU(dim=-1),
|
| 72 |
-
)
|
| 73 |
-
|
| 74 |
-
try:
|
| 75 |
-
self.combined = torch.compile(
|
| 76 |
-
nn.Sequential(self.norm, self.hidden, self.output), disable=True
|
| 77 |
-
)
|
| 78 |
-
except Exception as e:
|
| 79 |
-
self.combined = nn.Sequential(self.norm, self.hidden, self.output)
|
| 80 |
-
|
| 81 |
-
def reshape_output(self, mb):
|
| 82 |
-
batch, n_time, _ = mb.shape
|
| 83 |
-
if self.complex_mask:
|
| 84 |
-
mb = mb.reshape(
|
| 85 |
-
batch, n_time, self.in_channels, self.bandwidth, self.reim
|
| 86 |
-
).contiguous()
|
| 87 |
-
mb = torch.view_as_complex(mb)
|
| 88 |
-
else:
|
| 89 |
-
mb = mb.reshape(batch, n_time, self.in_channels, self.bandwidth)
|
| 90 |
-
|
| 91 |
-
mb = torch.permute(mb, (0, 2, 3, 1))
|
| 92 |
-
|
| 93 |
-
return mb
|
| 94 |
-
|
| 95 |
-
def forward(self, qb):
|
| 96 |
-
|
| 97 |
-
mb = checkpoint_sequential(self.combined, 2, qb, use_reentrant=False)
|
| 98 |
-
mb = self.reshape_output(mb)
|
| 99 |
-
|
| 100 |
-
return mb
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
class MaskEstimationModuleSuperBase(nn.Module):
|
| 104 |
-
pass
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
class MaskEstimationModuleBase(MaskEstimationModuleSuperBase):
|
| 108 |
-
def __init__(
|
| 109 |
-
self,
|
| 110 |
-
band_specs: List[Tuple[float, float]],
|
| 111 |
-
emb_dim: int,
|
| 112 |
-
mlp_dim: int,
|
| 113 |
-
in_channels: Optional[int],
|
| 114 |
-
hidden_activation: str = "Tanh",
|
| 115 |
-
hidden_activation_kwargs: Dict = None,
|
| 116 |
-
complex_mask: bool = True,
|
| 117 |
-
norm_mlp_cls: Type[nn.Module] = NormMLP,
|
| 118 |
-
norm_mlp_kwargs: Dict = None,
|
| 119 |
-
) -> None:
|
| 120 |
-
super().__init__()
|
| 121 |
-
|
| 122 |
-
self.band_widths = band_widths_from_specs(band_specs)
|
| 123 |
-
self.n_bands = len(band_specs)
|
| 124 |
-
|
| 125 |
-
if hidden_activation_kwargs is None:
|
| 126 |
-
hidden_activation_kwargs = {}
|
| 127 |
-
|
| 128 |
-
if norm_mlp_kwargs is None:
|
| 129 |
-
norm_mlp_kwargs = {}
|
| 130 |
-
|
| 131 |
-
self.norm_mlp = nn.ModuleList(
|
| 132 |
-
[
|
| 133 |
-
norm_mlp_cls(
|
| 134 |
-
bandwidth=self.band_widths[b],
|
| 135 |
-
emb_dim=emb_dim,
|
| 136 |
-
mlp_dim=mlp_dim,
|
| 137 |
-
in_channels=in_channels,
|
| 138 |
-
hidden_activation=hidden_activation,
|
| 139 |
-
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 140 |
-
complex_mask=complex_mask,
|
| 141 |
-
**norm_mlp_kwargs,
|
| 142 |
-
)
|
| 143 |
-
for b in range(self.n_bands)
|
| 144 |
-
]
|
| 145 |
-
)
|
| 146 |
-
|
| 147 |
-
def compute_masks(self, q):
|
| 148 |
-
batch, n_bands, n_time, emb_dim = q.shape
|
| 149 |
-
|
| 150 |
-
masks = []
|
| 151 |
-
|
| 152 |
-
for b, nmlp in enumerate(self.norm_mlp):
|
| 153 |
-
qb = q[:, b, :, :]
|
| 154 |
-
mb = nmlp(qb)
|
| 155 |
-
masks.append(mb)
|
| 156 |
-
|
| 157 |
-
return masks
|
| 158 |
-
|
| 159 |
-
def compute_mask(self, q, b):
|
| 160 |
-
batch, n_bands, n_time, emb_dim = q.shape
|
| 161 |
-
qb = q[:, b, :, :]
|
| 162 |
-
mb = self.norm_mlp[b](qb)
|
| 163 |
-
return mb
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
class OverlappingMaskEstimationModule(MaskEstimationModuleBase):
|
| 167 |
-
def __init__(
|
| 168 |
-
self,
|
| 169 |
-
in_channels: int,
|
| 170 |
-
band_specs: List[Tuple[float, float]],
|
| 171 |
-
freq_weights: List[torch.Tensor],
|
| 172 |
-
n_freq: int,
|
| 173 |
-
emb_dim: int,
|
| 174 |
-
mlp_dim: int,
|
| 175 |
-
cond_dim: int = 0,
|
| 176 |
-
hidden_activation: str = "Tanh",
|
| 177 |
-
hidden_activation_kwargs: Dict = None,
|
| 178 |
-
complex_mask: bool = True,
|
| 179 |
-
norm_mlp_cls: Type[nn.Module] = NormMLP,
|
| 180 |
-
norm_mlp_kwargs: Dict = None,
|
| 181 |
-
use_freq_weights: bool = False,
|
| 182 |
-
) -> None:
|
| 183 |
-
check_nonzero_bandwidth(band_specs)
|
| 184 |
-
check_no_gap(band_specs)
|
| 185 |
-
|
| 186 |
-
if cond_dim > 0:
|
| 187 |
-
raise NotImplementedError
|
| 188 |
-
|
| 189 |
-
super().__init__(
|
| 190 |
-
band_specs=band_specs,
|
| 191 |
-
emb_dim=emb_dim + cond_dim,
|
| 192 |
-
mlp_dim=mlp_dim,
|
| 193 |
-
in_channels=in_channels,
|
| 194 |
-
hidden_activation=hidden_activation,
|
| 195 |
-
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 196 |
-
complex_mask=complex_mask,
|
| 197 |
-
norm_mlp_cls=norm_mlp_cls,
|
| 198 |
-
norm_mlp_kwargs=norm_mlp_kwargs,
|
| 199 |
-
)
|
| 200 |
-
|
| 201 |
-
self.n_freq = n_freq
|
| 202 |
-
self.band_specs = band_specs
|
| 203 |
-
self.in_channels = in_channels
|
| 204 |
-
|
| 205 |
-
if freq_weights is not None and use_freq_weights:
|
| 206 |
-
for i, fw in enumerate(freq_weights):
|
| 207 |
-
self.register_buffer(f"freq_weights/{i}", fw)
|
| 208 |
-
|
| 209 |
-
self.use_freq_weights = use_freq_weights
|
| 210 |
-
else:
|
| 211 |
-
self.use_freq_weights = False
|
| 212 |
-
|
| 213 |
-
def forward(self, q):
|
| 214 |
-
|
| 215 |
-
batch, n_bands, n_time, emb_dim = q.shape
|
| 216 |
-
|
| 217 |
-
masks = torch.zeros(
|
| 218 |
-
(batch, self.in_channels, self.n_freq, n_time),
|
| 219 |
-
device=q.device,
|
| 220 |
-
dtype=torch.complex64,
|
| 221 |
-
)
|
| 222 |
-
|
| 223 |
-
for im in range(n_bands):
|
| 224 |
-
fstart, fend = self.band_specs[im]
|
| 225 |
-
|
| 226 |
-
mask = self.compute_mask(q, im)
|
| 227 |
-
|
| 228 |
-
if self.use_freq_weights:
|
| 229 |
-
fw = self.get_buffer(f"freq_weights/{im}")[:, None]
|
| 230 |
-
mask = mask * fw
|
| 231 |
-
masks[:, :, fstart:fend, :] += mask
|
| 232 |
-
|
| 233 |
-
return masks
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
class MaskEstimationModule(OverlappingMaskEstimationModule):
|
| 237 |
-
def __init__(
|
| 238 |
-
self,
|
| 239 |
-
band_specs: List[Tuple[float, float]],
|
| 240 |
-
emb_dim: int,
|
| 241 |
-
mlp_dim: int,
|
| 242 |
-
in_channels: Optional[int],
|
| 243 |
-
hidden_activation: str = "Tanh",
|
| 244 |
-
hidden_activation_kwargs: Dict = None,
|
| 245 |
-
complex_mask: bool = True,
|
| 246 |
-
**kwargs,
|
| 247 |
-
) -> None:
|
| 248 |
-
check_nonzero_bandwidth(band_specs)
|
| 249 |
-
check_no_gap(band_specs)
|
| 250 |
-
check_no_overlap(band_specs)
|
| 251 |
-
super().__init__(
|
| 252 |
-
in_channels=in_channels,
|
| 253 |
-
band_specs=band_specs,
|
| 254 |
-
freq_weights=None,
|
| 255 |
-
n_freq=None,
|
| 256 |
-
emb_dim=emb_dim,
|
| 257 |
-
mlp_dim=mlp_dim,
|
| 258 |
-
hidden_activation=hidden_activation,
|
| 259 |
-
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 260 |
-
complex_mask=complex_mask,
|
| 261 |
-
)
|
| 262 |
-
|
| 263 |
-
def forward(self, q, cond=None):
|
| 264 |
-
|
| 265 |
-
masks = self.compute_masks(q)
|
| 266 |
-
|
| 267 |
-
masks = torch.concat(masks, dim=2)
|
| 268 |
-
|
| 269 |
-
return masks
|
|
|
|
| 1 |
+
from typing import Dict, List, Optional, Tuple, Type
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn.modules import activation
|
| 6 |
+
from torch.utils.checkpoint import checkpoint_sequential
|
| 7 |
+
|
| 8 |
+
from .utils import (
|
| 9 |
+
band_widths_from_specs,
|
| 10 |
+
check_no_gap,
|
| 11 |
+
check_no_overlap,
|
| 12 |
+
check_nonzero_bandwidth,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class BaseNormMLP(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
emb_dim: int,
|
| 20 |
+
mlp_dim: int,
|
| 21 |
+
bandwidth: int,
|
| 22 |
+
in_channels: Optional[int],
|
| 23 |
+
hidden_activation: str = "Tanh",
|
| 24 |
+
hidden_activation_kwargs=None,
|
| 25 |
+
complex_mask: bool = True,
|
| 26 |
+
):
|
| 27 |
+
super().__init__()
|
| 28 |
+
if hidden_activation_kwargs is None:
|
| 29 |
+
hidden_activation_kwargs = {}
|
| 30 |
+
self.hidden_activation_kwargs = hidden_activation_kwargs
|
| 31 |
+
self.norm = nn.LayerNorm(emb_dim)
|
| 32 |
+
self.hidden = nn.Sequential(
|
| 33 |
+
nn.Linear(in_features=emb_dim, out_features=mlp_dim),
|
| 34 |
+
activation.__dict__[hidden_activation](**self.hidden_activation_kwargs),
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
self.bandwidth = bandwidth
|
| 38 |
+
self.in_channels = in_channels
|
| 39 |
+
|
| 40 |
+
self.complex_mask = complex_mask
|
| 41 |
+
self.reim = 2 if complex_mask else 1
|
| 42 |
+
self.glu_mult = 2
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class NormMLP(BaseNormMLP):
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
emb_dim: int,
|
| 49 |
+
mlp_dim: int,
|
| 50 |
+
bandwidth: int,
|
| 51 |
+
in_channels: Optional[int],
|
| 52 |
+
hidden_activation: str = "Tanh",
|
| 53 |
+
hidden_activation_kwargs=None,
|
| 54 |
+
complex_mask: bool = True,
|
| 55 |
+
) -> None:
|
| 56 |
+
super().__init__(
|
| 57 |
+
emb_dim=emb_dim,
|
| 58 |
+
mlp_dim=mlp_dim,
|
| 59 |
+
bandwidth=bandwidth,
|
| 60 |
+
in_channels=in_channels,
|
| 61 |
+
hidden_activation=hidden_activation,
|
| 62 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 63 |
+
complex_mask=complex_mask,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
self.output = nn.Sequential(
|
| 67 |
+
nn.Linear(
|
| 68 |
+
in_features=mlp_dim,
|
| 69 |
+
out_features=bandwidth * in_channels * self.reim * 2,
|
| 70 |
+
),
|
| 71 |
+
nn.GLU(dim=-1),
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
try:
|
| 75 |
+
self.combined = torch.compile(
|
| 76 |
+
nn.Sequential(self.norm, self.hidden, self.output), disable=True
|
| 77 |
+
)
|
| 78 |
+
except Exception as e:
|
| 79 |
+
self.combined = nn.Sequential(self.norm, self.hidden, self.output)
|
| 80 |
+
|
| 81 |
+
def reshape_output(self, mb):
|
| 82 |
+
batch, n_time, _ = mb.shape
|
| 83 |
+
if self.complex_mask:
|
| 84 |
+
mb = mb.reshape(
|
| 85 |
+
batch, n_time, self.in_channels, self.bandwidth, self.reim
|
| 86 |
+
).contiguous()
|
| 87 |
+
mb = torch.view_as_complex(mb)
|
| 88 |
+
else:
|
| 89 |
+
mb = mb.reshape(batch, n_time, self.in_channels, self.bandwidth)
|
| 90 |
+
|
| 91 |
+
mb = torch.permute(mb, (0, 2, 3, 1))
|
| 92 |
+
|
| 93 |
+
return mb
|
| 94 |
+
|
| 95 |
+
def forward(self, qb):
|
| 96 |
+
|
| 97 |
+
mb = checkpoint_sequential(self.combined, 2, qb, use_reentrant=False)
|
| 98 |
+
mb = self.reshape_output(mb)
|
| 99 |
+
|
| 100 |
+
return mb
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class MaskEstimationModuleSuperBase(nn.Module):
|
| 104 |
+
pass
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class MaskEstimationModuleBase(MaskEstimationModuleSuperBase):
|
| 108 |
+
def __init__(
|
| 109 |
+
self,
|
| 110 |
+
band_specs: List[Tuple[float, float]],
|
| 111 |
+
emb_dim: int,
|
| 112 |
+
mlp_dim: int,
|
| 113 |
+
in_channels: Optional[int],
|
| 114 |
+
hidden_activation: str = "Tanh",
|
| 115 |
+
hidden_activation_kwargs: Dict = None,
|
| 116 |
+
complex_mask: bool = True,
|
| 117 |
+
norm_mlp_cls: Type[nn.Module] = NormMLP,
|
| 118 |
+
norm_mlp_kwargs: Dict = None,
|
| 119 |
+
) -> None:
|
| 120 |
+
super().__init__()
|
| 121 |
+
|
| 122 |
+
self.band_widths = band_widths_from_specs(band_specs)
|
| 123 |
+
self.n_bands = len(band_specs)
|
| 124 |
+
|
| 125 |
+
if hidden_activation_kwargs is None:
|
| 126 |
+
hidden_activation_kwargs = {}
|
| 127 |
+
|
| 128 |
+
if norm_mlp_kwargs is None:
|
| 129 |
+
norm_mlp_kwargs = {}
|
| 130 |
+
|
| 131 |
+
self.norm_mlp = nn.ModuleList(
|
| 132 |
+
[
|
| 133 |
+
norm_mlp_cls(
|
| 134 |
+
bandwidth=self.band_widths[b],
|
| 135 |
+
emb_dim=emb_dim,
|
| 136 |
+
mlp_dim=mlp_dim,
|
| 137 |
+
in_channels=in_channels,
|
| 138 |
+
hidden_activation=hidden_activation,
|
| 139 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 140 |
+
complex_mask=complex_mask,
|
| 141 |
+
**norm_mlp_kwargs,
|
| 142 |
+
)
|
| 143 |
+
for b in range(self.n_bands)
|
| 144 |
+
]
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
def compute_masks(self, q):
|
| 148 |
+
batch, n_bands, n_time, emb_dim = q.shape
|
| 149 |
+
|
| 150 |
+
masks = []
|
| 151 |
+
|
| 152 |
+
for b, nmlp in enumerate(self.norm_mlp):
|
| 153 |
+
qb = q[:, b, :, :]
|
| 154 |
+
mb = nmlp(qb)
|
| 155 |
+
masks.append(mb)
|
| 156 |
+
|
| 157 |
+
return masks
|
| 158 |
+
|
| 159 |
+
def compute_mask(self, q, b):
|
| 160 |
+
batch, n_bands, n_time, emb_dim = q.shape
|
| 161 |
+
qb = q[:, b, :, :]
|
| 162 |
+
mb = self.norm_mlp[b](qb)
|
| 163 |
+
return mb
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class OverlappingMaskEstimationModule(MaskEstimationModuleBase):
|
| 167 |
+
def __init__(
|
| 168 |
+
self,
|
| 169 |
+
in_channels: int,
|
| 170 |
+
band_specs: List[Tuple[float, float]],
|
| 171 |
+
freq_weights: List[torch.Tensor],
|
| 172 |
+
n_freq: int,
|
| 173 |
+
emb_dim: int,
|
| 174 |
+
mlp_dim: int,
|
| 175 |
+
cond_dim: int = 0,
|
| 176 |
+
hidden_activation: str = "Tanh",
|
| 177 |
+
hidden_activation_kwargs: Dict = None,
|
| 178 |
+
complex_mask: bool = True,
|
| 179 |
+
norm_mlp_cls: Type[nn.Module] = NormMLP,
|
| 180 |
+
norm_mlp_kwargs: Dict = None,
|
| 181 |
+
use_freq_weights: bool = False,
|
| 182 |
+
) -> None:
|
| 183 |
+
check_nonzero_bandwidth(band_specs)
|
| 184 |
+
check_no_gap(band_specs)
|
| 185 |
+
|
| 186 |
+
if cond_dim > 0:
|
| 187 |
+
raise NotImplementedError
|
| 188 |
+
|
| 189 |
+
super().__init__(
|
| 190 |
+
band_specs=band_specs,
|
| 191 |
+
emb_dim=emb_dim + cond_dim,
|
| 192 |
+
mlp_dim=mlp_dim,
|
| 193 |
+
in_channels=in_channels,
|
| 194 |
+
hidden_activation=hidden_activation,
|
| 195 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 196 |
+
complex_mask=complex_mask,
|
| 197 |
+
norm_mlp_cls=norm_mlp_cls,
|
| 198 |
+
norm_mlp_kwargs=norm_mlp_kwargs,
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
self.n_freq = n_freq
|
| 202 |
+
self.band_specs = band_specs
|
| 203 |
+
self.in_channels = in_channels
|
| 204 |
+
|
| 205 |
+
if freq_weights is not None and use_freq_weights:
|
| 206 |
+
for i, fw in enumerate(freq_weights):
|
| 207 |
+
self.register_buffer(f"freq_weights/{i}", fw)
|
| 208 |
+
|
| 209 |
+
self.use_freq_weights = use_freq_weights
|
| 210 |
+
else:
|
| 211 |
+
self.use_freq_weights = False
|
| 212 |
+
|
| 213 |
+
def forward(self, q):
|
| 214 |
+
|
| 215 |
+
batch, n_bands, n_time, emb_dim = q.shape
|
| 216 |
+
|
| 217 |
+
masks = torch.zeros(
|
| 218 |
+
(batch, self.in_channels, self.n_freq, n_time),
|
| 219 |
+
device=q.device,
|
| 220 |
+
dtype=torch.complex64,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
for im in range(n_bands):
|
| 224 |
+
fstart, fend = self.band_specs[im]
|
| 225 |
+
|
| 226 |
+
mask = self.compute_mask(q, im)
|
| 227 |
+
|
| 228 |
+
if self.use_freq_weights:
|
| 229 |
+
fw = self.get_buffer(f"freq_weights/{im}")[:, None]
|
| 230 |
+
mask = mask * fw
|
| 231 |
+
masks[:, :, fstart:fend, :] += mask
|
| 232 |
+
|
| 233 |
+
return masks
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class MaskEstimationModule(OverlappingMaskEstimationModule):
|
| 237 |
+
def __init__(
|
| 238 |
+
self,
|
| 239 |
+
band_specs: List[Tuple[float, float]],
|
| 240 |
+
emb_dim: int,
|
| 241 |
+
mlp_dim: int,
|
| 242 |
+
in_channels: Optional[int],
|
| 243 |
+
hidden_activation: str = "Tanh",
|
| 244 |
+
hidden_activation_kwargs: Dict = None,
|
| 245 |
+
complex_mask: bool = True,
|
| 246 |
+
**kwargs,
|
| 247 |
+
) -> None:
|
| 248 |
+
check_nonzero_bandwidth(band_specs)
|
| 249 |
+
check_no_gap(band_specs)
|
| 250 |
+
check_no_overlap(band_specs)
|
| 251 |
+
super().__init__(
|
| 252 |
+
in_channels=in_channels,
|
| 253 |
+
band_specs=band_specs,
|
| 254 |
+
freq_weights=None,
|
| 255 |
+
n_freq=None,
|
| 256 |
+
emb_dim=emb_dim,
|
| 257 |
+
mlp_dim=mlp_dim,
|
| 258 |
+
hidden_activation=hidden_activation,
|
| 259 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 260 |
+
complex_mask=complex_mask,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
def forward(self, q, cond=None):
|
| 264 |
+
|
| 265 |
+
masks = self.compute_masks(q)
|
| 266 |
+
|
| 267 |
+
masks = torch.concat(masks, dim=2)
|
| 268 |
+
|
| 269 |
+
return masks
|
mvsepless/models/bandit_v2/tfmodel.py
CHANGED
|
@@ -1,141 +1,141 @@
|
|
| 1 |
-
import warnings
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
import torch.backends.cuda
|
| 5 |
-
from torch import nn
|
| 6 |
-
from torch.nn.modules import rnn
|
| 7 |
-
from torch.utils.checkpoint import checkpoint_sequential
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class TimeFrequencyModellingModule(nn.Module):
|
| 11 |
-
def __init__(self) -> None:
|
| 12 |
-
super().__init__()
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
class ResidualRNN(nn.Module):
|
| 16 |
-
def __init__(
|
| 17 |
-
self,
|
| 18 |
-
emb_dim: int,
|
| 19 |
-
rnn_dim: int,
|
| 20 |
-
bidirectional: bool = True,
|
| 21 |
-
rnn_type: str = "LSTM",
|
| 22 |
-
use_batch_trick: bool = True,
|
| 23 |
-
use_layer_norm: bool = True,
|
| 24 |
-
) -> None:
|
| 25 |
-
super().__init__()
|
| 26 |
-
|
| 27 |
-
assert use_layer_norm
|
| 28 |
-
assert use_batch_trick
|
| 29 |
-
|
| 30 |
-
self.use_layer_norm = use_layer_norm
|
| 31 |
-
self.norm = nn.LayerNorm(emb_dim)
|
| 32 |
-
self.rnn = rnn.__dict__[rnn_type](
|
| 33 |
-
input_size=emb_dim,
|
| 34 |
-
hidden_size=rnn_dim,
|
| 35 |
-
num_layers=1,
|
| 36 |
-
batch_first=True,
|
| 37 |
-
bidirectional=bidirectional,
|
| 38 |
-
)
|
| 39 |
-
|
| 40 |
-
self.fc = nn.Linear(
|
| 41 |
-
in_features=rnn_dim * (2 if bidirectional else 1), out_features=emb_dim
|
| 42 |
-
)
|
| 43 |
-
|
| 44 |
-
self.use_batch_trick = use_batch_trick
|
| 45 |
-
if not self.use_batch_trick:
|
| 46 |
-
warnings.warn("NOT USING BATCH TRICK IS EXTREMELY SLOW!!")
|
| 47 |
-
|
| 48 |
-
def forward(self, z):
|
| 49 |
-
|
| 50 |
-
z0 = torch.clone(z)
|
| 51 |
-
z = self.norm(z)
|
| 52 |
-
|
| 53 |
-
batch, n_uncrossed, n_across, emb_dim = z.shape
|
| 54 |
-
z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
|
| 55 |
-
z = self.rnn(z)[0]
|
| 56 |
-
z = torch.reshape(z, (batch, n_uncrossed, n_across, -1))
|
| 57 |
-
|
| 58 |
-
z = self.fc(z)
|
| 59 |
-
|
| 60 |
-
z = z + z0
|
| 61 |
-
|
| 62 |
-
return z
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
class Transpose(nn.Module):
|
| 66 |
-
def __init__(self, dim0: int, dim1: int) -> None:
|
| 67 |
-
super().__init__()
|
| 68 |
-
self.dim0 = dim0
|
| 69 |
-
self.dim1 = dim1
|
| 70 |
-
|
| 71 |
-
def forward(self, z):
|
| 72 |
-
return z.transpose(self.dim0, self.dim1)
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
class SeqBandModellingModule(TimeFrequencyModellingModule):
|
| 76 |
-
def __init__(
|
| 77 |
-
self,
|
| 78 |
-
n_modules: int = 12,
|
| 79 |
-
emb_dim: int = 128,
|
| 80 |
-
rnn_dim: int = 256,
|
| 81 |
-
bidirectional: bool = True,
|
| 82 |
-
rnn_type: str = "LSTM",
|
| 83 |
-
parallel_mode=False,
|
| 84 |
-
) -> None:
|
| 85 |
-
super().__init__()
|
| 86 |
-
|
| 87 |
-
self.n_modules = n_modules
|
| 88 |
-
|
| 89 |
-
if parallel_mode:
|
| 90 |
-
self.seqband = nn.ModuleList([])
|
| 91 |
-
for _ in range(n_modules):
|
| 92 |
-
self.seqband.append(
|
| 93 |
-
nn.ModuleList(
|
| 94 |
-
[
|
| 95 |
-
ResidualRNN(
|
| 96 |
-
emb_dim=emb_dim,
|
| 97 |
-
rnn_dim=rnn_dim,
|
| 98 |
-
bidirectional=bidirectional,
|
| 99 |
-
rnn_type=rnn_type,
|
| 100 |
-
),
|
| 101 |
-
ResidualRNN(
|
| 102 |
-
emb_dim=emb_dim,
|
| 103 |
-
rnn_dim=rnn_dim,
|
| 104 |
-
bidirectional=bidirectional,
|
| 105 |
-
rnn_type=rnn_type,
|
| 106 |
-
),
|
| 107 |
-
]
|
| 108 |
-
)
|
| 109 |
-
)
|
| 110 |
-
else:
|
| 111 |
-
seqband = []
|
| 112 |
-
for _ in range(2 * n_modules):
|
| 113 |
-
seqband += [
|
| 114 |
-
ResidualRNN(
|
| 115 |
-
emb_dim=emb_dim,
|
| 116 |
-
rnn_dim=rnn_dim,
|
| 117 |
-
bidirectional=bidirectional,
|
| 118 |
-
rnn_type=rnn_type,
|
| 119 |
-
),
|
| 120 |
-
Transpose(1, 2),
|
| 121 |
-
]
|
| 122 |
-
|
| 123 |
-
self.seqband = nn.Sequential(*seqband)
|
| 124 |
-
|
| 125 |
-
self.parallel_mode = parallel_mode
|
| 126 |
-
|
| 127 |
-
def forward(self, z):
|
| 128 |
-
|
| 129 |
-
if self.parallel_mode:
|
| 130 |
-
for sbm_pair in self.seqband:
|
| 131 |
-
sbm_t, sbm_f = sbm_pair[0], sbm_pair[1]
|
| 132 |
-
zt = sbm_t(z)
|
| 133 |
-
zf = sbm_f(z.transpose(1, 2))
|
| 134 |
-
z = zt + zf.transpose(1, 2)
|
| 135 |
-
else:
|
| 136 |
-
z = checkpoint_sequential(
|
| 137 |
-
self.seqband, self.n_modules, z, use_reentrant=False
|
| 138 |
-
)
|
| 139 |
-
|
| 140 |
-
q = z
|
| 141 |
-
return q
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.backends.cuda
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.nn.modules import rnn
|
| 7 |
+
from torch.utils.checkpoint import checkpoint_sequential
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TimeFrequencyModellingModule(nn.Module):
|
| 11 |
+
def __init__(self) -> None:
|
| 12 |
+
super().__init__()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ResidualRNN(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
emb_dim: int,
|
| 19 |
+
rnn_dim: int,
|
| 20 |
+
bidirectional: bool = True,
|
| 21 |
+
rnn_type: str = "LSTM",
|
| 22 |
+
use_batch_trick: bool = True,
|
| 23 |
+
use_layer_norm: bool = True,
|
| 24 |
+
) -> None:
|
| 25 |
+
super().__init__()
|
| 26 |
+
|
| 27 |
+
assert use_layer_norm
|
| 28 |
+
assert use_batch_trick
|
| 29 |
+
|
| 30 |
+
self.use_layer_norm = use_layer_norm
|
| 31 |
+
self.norm = nn.LayerNorm(emb_dim)
|
| 32 |
+
self.rnn = rnn.__dict__[rnn_type](
|
| 33 |
+
input_size=emb_dim,
|
| 34 |
+
hidden_size=rnn_dim,
|
| 35 |
+
num_layers=1,
|
| 36 |
+
batch_first=True,
|
| 37 |
+
bidirectional=bidirectional,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
self.fc = nn.Linear(
|
| 41 |
+
in_features=rnn_dim * (2 if bidirectional else 1), out_features=emb_dim
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
self.use_batch_trick = use_batch_trick
|
| 45 |
+
if not self.use_batch_trick:
|
| 46 |
+
warnings.warn("NOT USING BATCH TRICK IS EXTREMELY SLOW!!")
|
| 47 |
+
|
| 48 |
+
def forward(self, z):
|
| 49 |
+
|
| 50 |
+
z0 = torch.clone(z)
|
| 51 |
+
z = self.norm(z)
|
| 52 |
+
|
| 53 |
+
batch, n_uncrossed, n_across, emb_dim = z.shape
|
| 54 |
+
z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
|
| 55 |
+
z = self.rnn(z)[0]
|
| 56 |
+
z = torch.reshape(z, (batch, n_uncrossed, n_across, -1))
|
| 57 |
+
|
| 58 |
+
z = self.fc(z)
|
| 59 |
+
|
| 60 |
+
z = z + z0
|
| 61 |
+
|
| 62 |
+
return z
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class Transpose(nn.Module):
|
| 66 |
+
def __init__(self, dim0: int, dim1: int) -> None:
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.dim0 = dim0
|
| 69 |
+
self.dim1 = dim1
|
| 70 |
+
|
| 71 |
+
def forward(self, z):
|
| 72 |
+
return z.transpose(self.dim0, self.dim1)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class SeqBandModellingModule(TimeFrequencyModellingModule):
|
| 76 |
+
def __init__(
|
| 77 |
+
self,
|
| 78 |
+
n_modules: int = 12,
|
| 79 |
+
emb_dim: int = 128,
|
| 80 |
+
rnn_dim: int = 256,
|
| 81 |
+
bidirectional: bool = True,
|
| 82 |
+
rnn_type: str = "LSTM",
|
| 83 |
+
parallel_mode=False,
|
| 84 |
+
) -> None:
|
| 85 |
+
super().__init__()
|
| 86 |
+
|
| 87 |
+
self.n_modules = n_modules
|
| 88 |
+
|
| 89 |
+
if parallel_mode:
|
| 90 |
+
self.seqband = nn.ModuleList([])
|
| 91 |
+
for _ in range(n_modules):
|
| 92 |
+
self.seqband.append(
|
| 93 |
+
nn.ModuleList(
|
| 94 |
+
[
|
| 95 |
+
ResidualRNN(
|
| 96 |
+
emb_dim=emb_dim,
|
| 97 |
+
rnn_dim=rnn_dim,
|
| 98 |
+
bidirectional=bidirectional,
|
| 99 |
+
rnn_type=rnn_type,
|
| 100 |
+
),
|
| 101 |
+
ResidualRNN(
|
| 102 |
+
emb_dim=emb_dim,
|
| 103 |
+
rnn_dim=rnn_dim,
|
| 104 |
+
bidirectional=bidirectional,
|
| 105 |
+
rnn_type=rnn_type,
|
| 106 |
+
),
|
| 107 |
+
]
|
| 108 |
+
)
|
| 109 |
+
)
|
| 110 |
+
else:
|
| 111 |
+
seqband = []
|
| 112 |
+
for _ in range(2 * n_modules):
|
| 113 |
+
seqband += [
|
| 114 |
+
ResidualRNN(
|
| 115 |
+
emb_dim=emb_dim,
|
| 116 |
+
rnn_dim=rnn_dim,
|
| 117 |
+
bidirectional=bidirectional,
|
| 118 |
+
rnn_type=rnn_type,
|
| 119 |
+
),
|
| 120 |
+
Transpose(1, 2),
|
| 121 |
+
]
|
| 122 |
+
|
| 123 |
+
self.seqband = nn.Sequential(*seqband)
|
| 124 |
+
|
| 125 |
+
self.parallel_mode = parallel_mode
|
| 126 |
+
|
| 127 |
+
def forward(self, z):
|
| 128 |
+
|
| 129 |
+
if self.parallel_mode:
|
| 130 |
+
for sbm_pair in self.seqband:
|
| 131 |
+
sbm_t, sbm_f = sbm_pair[0], sbm_pair[1]
|
| 132 |
+
zt = sbm_t(z)
|
| 133 |
+
zf = sbm_f(z.transpose(1, 2))
|
| 134 |
+
z = zt + zf.transpose(1, 2)
|
| 135 |
+
else:
|
| 136 |
+
z = checkpoint_sequential(
|
| 137 |
+
self.seqband, self.n_modules, z, use_reentrant=False
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
q = z
|
| 141 |
+
return q
|
mvsepless/models/bandit_v2/utils.py
CHANGED
|
@@ -1,384 +1,384 @@
|
|
| 1 |
-
import os
|
| 2 |
-
from abc import abstractmethod
|
| 3 |
-
from typing import Callable
|
| 4 |
-
|
| 5 |
-
import numpy as np
|
| 6 |
-
import torch
|
| 7 |
-
from librosa import hz_to_midi, midi_to_hz
|
| 8 |
-
from torchaudio import functional as taF
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def band_widths_from_specs(band_specs):
|
| 12 |
-
return [e - i for i, e in band_specs]
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def check_nonzero_bandwidth(band_specs):
|
| 16 |
-
for fstart, fend in band_specs:
|
| 17 |
-
if fend - fstart <= 0:
|
| 18 |
-
raise ValueError("Bands cannot be zero-width")
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
def check_no_overlap(band_specs):
|
| 22 |
-
fend_prev = -1
|
| 23 |
-
for fstart_curr, fend_curr in band_specs:
|
| 24 |
-
if fstart_curr <= fend_prev:
|
| 25 |
-
raise ValueError("Bands cannot overlap")
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def check_no_gap(band_specs):
|
| 29 |
-
fstart, _ = band_specs[0]
|
| 30 |
-
assert fstart == 0
|
| 31 |
-
|
| 32 |
-
fend_prev = -1
|
| 33 |
-
for fstart_curr, fend_curr in band_specs:
|
| 34 |
-
if fstart_curr - fend_prev > 1:
|
| 35 |
-
raise ValueError("Bands cannot leave gap")
|
| 36 |
-
fend_prev = fend_curr
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
class BandsplitSpecification:
|
| 40 |
-
def __init__(self, nfft: int, fs: int) -> None:
|
| 41 |
-
self.fs = fs
|
| 42 |
-
self.nfft = nfft
|
| 43 |
-
self.nyquist = fs / 2
|
| 44 |
-
self.max_index = nfft // 2 + 1
|
| 45 |
-
|
| 46 |
-
self.split500 = self.hertz_to_index(500)
|
| 47 |
-
self.split1k = self.hertz_to_index(1000)
|
| 48 |
-
self.split2k = self.hertz_to_index(2000)
|
| 49 |
-
self.split4k = self.hertz_to_index(4000)
|
| 50 |
-
self.split8k = self.hertz_to_index(8000)
|
| 51 |
-
self.split16k = self.hertz_to_index(16000)
|
| 52 |
-
self.split20k = self.hertz_to_index(20000)
|
| 53 |
-
|
| 54 |
-
self.above20k = [(self.split20k, self.max_index)]
|
| 55 |
-
self.above16k = [(self.split16k, self.split20k)] + self.above20k
|
| 56 |
-
|
| 57 |
-
def index_to_hertz(self, index: int):
|
| 58 |
-
return index * self.fs / self.nfft
|
| 59 |
-
|
| 60 |
-
def hertz_to_index(self, hz: float, round: bool = True):
|
| 61 |
-
index = hz * self.nfft / self.fs
|
| 62 |
-
|
| 63 |
-
if round:
|
| 64 |
-
index = int(np.round(index))
|
| 65 |
-
|
| 66 |
-
return index
|
| 67 |
-
|
| 68 |
-
def get_band_specs_with_bandwidth(self, start_index, end_index, bandwidth_hz):
|
| 69 |
-
band_specs = []
|
| 70 |
-
lower = start_index
|
| 71 |
-
|
| 72 |
-
while lower < end_index:
|
| 73 |
-
upper = int(np.floor(lower + self.hertz_to_index(bandwidth_hz)))
|
| 74 |
-
upper = min(upper, end_index)
|
| 75 |
-
|
| 76 |
-
band_specs.append((lower, upper))
|
| 77 |
-
lower = upper
|
| 78 |
-
|
| 79 |
-
return band_specs
|
| 80 |
-
|
| 81 |
-
@abstractmethod
|
| 82 |
-
def get_band_specs(self):
|
| 83 |
-
raise NotImplementedError
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
class VocalBandsplitSpecification(BandsplitSpecification):
|
| 87 |
-
def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
|
| 88 |
-
super().__init__(nfft=nfft, fs=fs)
|
| 89 |
-
|
| 90 |
-
self.version = version
|
| 91 |
-
|
| 92 |
-
def get_band_specs(self):
|
| 93 |
-
return getattr(self, f"version{self.version}")()
|
| 94 |
-
|
| 95 |
-
@property
|
| 96 |
-
def version1(self):
|
| 97 |
-
return self.get_band_specs_with_bandwidth(
|
| 98 |
-
start_index=0, end_index=self.max_index, bandwidth_hz=1000
|
| 99 |
-
)
|
| 100 |
-
|
| 101 |
-
def version2(self):
|
| 102 |
-
below16k = self.get_band_specs_with_bandwidth(
|
| 103 |
-
start_index=0, end_index=self.split16k, bandwidth_hz=1000
|
| 104 |
-
)
|
| 105 |
-
below20k = self.get_band_specs_with_bandwidth(
|
| 106 |
-
start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
|
| 107 |
-
)
|
| 108 |
-
|
| 109 |
-
return below16k + below20k + self.above20k
|
| 110 |
-
|
| 111 |
-
def version3(self):
|
| 112 |
-
below8k = self.get_band_specs_with_bandwidth(
|
| 113 |
-
start_index=0, end_index=self.split8k, bandwidth_hz=1000
|
| 114 |
-
)
|
| 115 |
-
below16k = self.get_band_specs_with_bandwidth(
|
| 116 |
-
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
|
| 117 |
-
)
|
| 118 |
-
|
| 119 |
-
return below8k + below16k + self.above16k
|
| 120 |
-
|
| 121 |
-
def version4(self):
|
| 122 |
-
below1k = self.get_band_specs_with_bandwidth(
|
| 123 |
-
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 124 |
-
)
|
| 125 |
-
below8k = self.get_band_specs_with_bandwidth(
|
| 126 |
-
start_index=self.split1k, end_index=self.split8k, bandwidth_hz=1000
|
| 127 |
-
)
|
| 128 |
-
below16k = self.get_band_specs_with_bandwidth(
|
| 129 |
-
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
|
| 130 |
-
)
|
| 131 |
-
|
| 132 |
-
return below1k + below8k + below16k + self.above16k
|
| 133 |
-
|
| 134 |
-
def version5(self):
|
| 135 |
-
below1k = self.get_band_specs_with_bandwidth(
|
| 136 |
-
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 137 |
-
)
|
| 138 |
-
below16k = self.get_band_specs_with_bandwidth(
|
| 139 |
-
start_index=self.split1k, end_index=self.split16k, bandwidth_hz=1000
|
| 140 |
-
)
|
| 141 |
-
below20k = self.get_band_specs_with_bandwidth(
|
| 142 |
-
start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
|
| 143 |
-
)
|
| 144 |
-
return below1k + below16k + below20k + self.above20k
|
| 145 |
-
|
| 146 |
-
def version6(self):
|
| 147 |
-
below1k = self.get_band_specs_with_bandwidth(
|
| 148 |
-
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 149 |
-
)
|
| 150 |
-
below4k = self.get_band_specs_with_bandwidth(
|
| 151 |
-
start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
|
| 152 |
-
)
|
| 153 |
-
below8k = self.get_band_specs_with_bandwidth(
|
| 154 |
-
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
|
| 155 |
-
)
|
| 156 |
-
below16k = self.get_band_specs_with_bandwidth(
|
| 157 |
-
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
|
| 158 |
-
)
|
| 159 |
-
return below1k + below4k + below8k + below16k + self.above16k
|
| 160 |
-
|
| 161 |
-
def version7(self):
|
| 162 |
-
below1k = self.get_band_specs_with_bandwidth(
|
| 163 |
-
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 164 |
-
)
|
| 165 |
-
below4k = self.get_band_specs_with_bandwidth(
|
| 166 |
-
start_index=self.split1k, end_index=self.split4k, bandwidth_hz=250
|
| 167 |
-
)
|
| 168 |
-
below8k = self.get_band_specs_with_bandwidth(
|
| 169 |
-
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
|
| 170 |
-
)
|
| 171 |
-
below16k = self.get_band_specs_with_bandwidth(
|
| 172 |
-
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
|
| 173 |
-
)
|
| 174 |
-
below20k = self.get_band_specs_with_bandwidth(
|
| 175 |
-
start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
|
| 176 |
-
)
|
| 177 |
-
return below1k + below4k + below8k + below16k + below20k + self.above20k
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
class OtherBandsplitSpecification(VocalBandsplitSpecification):
|
| 181 |
-
def __init__(self, nfft: int, fs: int) -> None:
|
| 182 |
-
super().__init__(nfft=nfft, fs=fs, version="7")
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
class BassBandsplitSpecification(BandsplitSpecification):
|
| 186 |
-
def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
|
| 187 |
-
super().__init__(nfft=nfft, fs=fs)
|
| 188 |
-
|
| 189 |
-
def get_band_specs(self):
|
| 190 |
-
below500 = self.get_band_specs_with_bandwidth(
|
| 191 |
-
start_index=0, end_index=self.split500, bandwidth_hz=50
|
| 192 |
-
)
|
| 193 |
-
below1k = self.get_band_specs_with_bandwidth(
|
| 194 |
-
start_index=self.split500, end_index=self.split1k, bandwidth_hz=100
|
| 195 |
-
)
|
| 196 |
-
below4k = self.get_band_specs_with_bandwidth(
|
| 197 |
-
start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
|
| 198 |
-
)
|
| 199 |
-
below8k = self.get_band_specs_with_bandwidth(
|
| 200 |
-
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
|
| 201 |
-
)
|
| 202 |
-
below16k = self.get_band_specs_with_bandwidth(
|
| 203 |
-
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
|
| 204 |
-
)
|
| 205 |
-
above16k = [(self.split16k, self.max_index)]
|
| 206 |
-
|
| 207 |
-
return below500 + below1k + below4k + below8k + below16k + above16k
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
class DrumBandsplitSpecification(BandsplitSpecification):
|
| 211 |
-
def __init__(self, nfft: int, fs: int) -> None:
|
| 212 |
-
super().__init__(nfft=nfft, fs=fs)
|
| 213 |
-
|
| 214 |
-
def get_band_specs(self):
|
| 215 |
-
below1k = self.get_band_specs_with_bandwidth(
|
| 216 |
-
start_index=0, end_index=self.split1k, bandwidth_hz=50
|
| 217 |
-
)
|
| 218 |
-
below2k = self.get_band_specs_with_bandwidth(
|
| 219 |
-
start_index=self.split1k, end_index=self.split2k, bandwidth_hz=100
|
| 220 |
-
)
|
| 221 |
-
below4k = self.get_band_specs_with_bandwidth(
|
| 222 |
-
start_index=self.split2k, end_index=self.split4k, bandwidth_hz=250
|
| 223 |
-
)
|
| 224 |
-
below8k = self.get_band_specs_with_bandwidth(
|
| 225 |
-
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
|
| 226 |
-
)
|
| 227 |
-
below16k = self.get_band_specs_with_bandwidth(
|
| 228 |
-
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
|
| 229 |
-
)
|
| 230 |
-
above16k = [(self.split16k, self.max_index)]
|
| 231 |
-
|
| 232 |
-
return below1k + below2k + below4k + below8k + below16k + above16k
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
class PerceptualBandsplitSpecification(BandsplitSpecification):
|
| 236 |
-
def __init__(
|
| 237 |
-
self,
|
| 238 |
-
nfft: int,
|
| 239 |
-
fs: int,
|
| 240 |
-
fbank_fn: Callable[[int, int, float, float, int], torch.Tensor],
|
| 241 |
-
n_bands: int,
|
| 242 |
-
f_min: float = 0.0,
|
| 243 |
-
f_max: float = None,
|
| 244 |
-
) -> None:
|
| 245 |
-
super().__init__(nfft=nfft, fs=fs)
|
| 246 |
-
self.n_bands = n_bands
|
| 247 |
-
if f_max is None:
|
| 248 |
-
f_max = fs / 2
|
| 249 |
-
|
| 250 |
-
self.filterbank = fbank_fn(n_bands, fs, f_min, f_max, self.max_index)
|
| 251 |
-
|
| 252 |
-
weight_per_bin = torch.sum(self.filterbank, dim=0, keepdim=True)
|
| 253 |
-
normalized_mel_fb = self.filterbank / weight_per_bin
|
| 254 |
-
|
| 255 |
-
freq_weights = []
|
| 256 |
-
band_specs = []
|
| 257 |
-
for i in range(self.n_bands):
|
| 258 |
-
active_bins = torch.nonzero(self.filterbank[i, :]).squeeze().tolist()
|
| 259 |
-
if isinstance(active_bins, int):
|
| 260 |
-
active_bins = (active_bins, active_bins)
|
| 261 |
-
if len(active_bins) == 0:
|
| 262 |
-
continue
|
| 263 |
-
start_index = active_bins[0]
|
| 264 |
-
end_index = active_bins[-1] + 1
|
| 265 |
-
band_specs.append((start_index, end_index))
|
| 266 |
-
freq_weights.append(normalized_mel_fb[i, start_index:end_index])
|
| 267 |
-
|
| 268 |
-
self.freq_weights = freq_weights
|
| 269 |
-
self.band_specs = band_specs
|
| 270 |
-
|
| 271 |
-
def get_band_specs(self):
|
| 272 |
-
return self.band_specs
|
| 273 |
-
|
| 274 |
-
def get_freq_weights(self):
|
| 275 |
-
return self.freq_weights
|
| 276 |
-
|
| 277 |
-
def save_to_file(self, dir_path: str) -> None:
|
| 278 |
-
os.makedirs(dir_path, exist_ok=True)
|
| 279 |
-
|
| 280 |
-
import pickle
|
| 281 |
-
|
| 282 |
-
with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f:
|
| 283 |
-
pickle.dump(
|
| 284 |
-
{
|
| 285 |
-
"band_specs": self.band_specs,
|
| 286 |
-
"freq_weights": self.freq_weights,
|
| 287 |
-
"filterbank": self.filterbank,
|
| 288 |
-
},
|
| 289 |
-
f,
|
| 290 |
-
)
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs):
|
| 294 |
-
fb = taF.melscale_fbanks(
|
| 295 |
-
n_mels=n_bands,
|
| 296 |
-
sample_rate=fs,
|
| 297 |
-
f_min=f_min,
|
| 298 |
-
f_max=f_max,
|
| 299 |
-
n_freqs=n_freqs,
|
| 300 |
-
).T
|
| 301 |
-
|
| 302 |
-
fb[0, 0] = 1.0
|
| 303 |
-
|
| 304 |
-
return fb
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
class MelBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 308 |
-
def __init__(
|
| 309 |
-
self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
|
| 310 |
-
) -> None:
|
| 311 |
-
super().__init__(
|
| 312 |
-
fbank_fn=mel_filterbank,
|
| 313 |
-
nfft=nfft,
|
| 314 |
-
fs=fs,
|
| 315 |
-
n_bands=n_bands,
|
| 316 |
-
f_min=f_min,
|
| 317 |
-
f_max=f_max,
|
| 318 |
-
)
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs, scale="constant"):
|
| 322 |
-
nfft = 2 * (n_freqs - 1)
|
| 323 |
-
df = fs / nfft
|
| 324 |
-
f_max = f_max or fs / 2
|
| 325 |
-
f_min = f_min or 0
|
| 326 |
-
f_min = fs / nfft
|
| 327 |
-
|
| 328 |
-
n_octaves = np.log2(f_max / f_min)
|
| 329 |
-
n_octaves_per_band = n_octaves / n_bands
|
| 330 |
-
bandwidth_mult = np.power(2.0, n_octaves_per_band)
|
| 331 |
-
|
| 332 |
-
low_midi = max(0, hz_to_midi(f_min))
|
| 333 |
-
high_midi = hz_to_midi(f_max)
|
| 334 |
-
midi_points = np.linspace(low_midi, high_midi, n_bands)
|
| 335 |
-
hz_pts = midi_to_hz(midi_points)
|
| 336 |
-
|
| 337 |
-
low_pts = hz_pts / bandwidth_mult
|
| 338 |
-
high_pts = hz_pts * bandwidth_mult
|
| 339 |
-
|
| 340 |
-
low_bins = np.floor(low_pts / df).astype(int)
|
| 341 |
-
high_bins = np.ceil(high_pts / df).astype(int)
|
| 342 |
-
|
| 343 |
-
fb = np.zeros((n_bands, n_freqs))
|
| 344 |
-
|
| 345 |
-
for i in range(n_bands):
|
| 346 |
-
fb[i, low_bins[i] : high_bins[i] + 1] = 1.0
|
| 347 |
-
|
| 348 |
-
fb[0, : low_bins[0]] = 1.0
|
| 349 |
-
fb[-1, high_bins[-1] + 1 :] = 1.0
|
| 350 |
-
|
| 351 |
-
return torch.as_tensor(fb)
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
class MusicalBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 355 |
-
def __init__(
|
| 356 |
-
self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
|
| 357 |
-
) -> None:
|
| 358 |
-
super().__init__(
|
| 359 |
-
fbank_fn=musical_filterbank,
|
| 360 |
-
nfft=nfft,
|
| 361 |
-
fs=fs,
|
| 362 |
-
n_bands=n_bands,
|
| 363 |
-
f_min=f_min,
|
| 364 |
-
f_max=f_max,
|
| 365 |
-
)
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
if __name__ == "__main__":
|
| 369 |
-
import pandas as pd
|
| 370 |
-
|
| 371 |
-
band_defs = []
|
| 372 |
-
|
| 373 |
-
for bands in [VocalBandsplitSpecification]:
|
| 374 |
-
band_name = bands.__name__.replace("BandsplitSpecification", "")
|
| 375 |
-
|
| 376 |
-
mbs = bands(nfft=2048, fs=44100).get_band_specs()
|
| 377 |
-
|
| 378 |
-
for i, (f_min, f_max) in enumerate(mbs):
|
| 379 |
-
band_defs.append(
|
| 380 |
-
{"band": band_name, "band_index": i, "f_min": f_min, "f_max": f_max}
|
| 381 |
-
)
|
| 382 |
-
|
| 383 |
-
df = pd.DataFrame(band_defs)
|
| 384 |
-
df.to_csv("vox7bands.csv", index=False)
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from abc import abstractmethod
|
| 3 |
+
from typing import Callable
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from librosa import hz_to_midi, midi_to_hz
|
| 8 |
+
from torchaudio import functional as taF
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def band_widths_from_specs(band_specs):
|
| 12 |
+
return [e - i for i, e in band_specs]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def check_nonzero_bandwidth(band_specs):
|
| 16 |
+
for fstart, fend in band_specs:
|
| 17 |
+
if fend - fstart <= 0:
|
| 18 |
+
raise ValueError("Bands cannot be zero-width")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def check_no_overlap(band_specs):
|
| 22 |
+
fend_prev = -1
|
| 23 |
+
for fstart_curr, fend_curr in band_specs:
|
| 24 |
+
if fstart_curr <= fend_prev:
|
| 25 |
+
raise ValueError("Bands cannot overlap")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def check_no_gap(band_specs):
|
| 29 |
+
fstart, _ = band_specs[0]
|
| 30 |
+
assert fstart == 0
|
| 31 |
+
|
| 32 |
+
fend_prev = -1
|
| 33 |
+
for fstart_curr, fend_curr in band_specs:
|
| 34 |
+
if fstart_curr - fend_prev > 1:
|
| 35 |
+
raise ValueError("Bands cannot leave gap")
|
| 36 |
+
fend_prev = fend_curr
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class BandsplitSpecification:
|
| 40 |
+
def __init__(self, nfft: int, fs: int) -> None:
|
| 41 |
+
self.fs = fs
|
| 42 |
+
self.nfft = nfft
|
| 43 |
+
self.nyquist = fs / 2
|
| 44 |
+
self.max_index = nfft // 2 + 1
|
| 45 |
+
|
| 46 |
+
self.split500 = self.hertz_to_index(500)
|
| 47 |
+
self.split1k = self.hertz_to_index(1000)
|
| 48 |
+
self.split2k = self.hertz_to_index(2000)
|
| 49 |
+
self.split4k = self.hertz_to_index(4000)
|
| 50 |
+
self.split8k = self.hertz_to_index(8000)
|
| 51 |
+
self.split16k = self.hertz_to_index(16000)
|
| 52 |
+
self.split20k = self.hertz_to_index(20000)
|
| 53 |
+
|
| 54 |
+
self.above20k = [(self.split20k, self.max_index)]
|
| 55 |
+
self.above16k = [(self.split16k, self.split20k)] + self.above20k
|
| 56 |
+
|
| 57 |
+
def index_to_hertz(self, index: int):
|
| 58 |
+
return index * self.fs / self.nfft
|
| 59 |
+
|
| 60 |
+
def hertz_to_index(self, hz: float, round: bool = True):
|
| 61 |
+
index = hz * self.nfft / self.fs
|
| 62 |
+
|
| 63 |
+
if round:
|
| 64 |
+
index = int(np.round(index))
|
| 65 |
+
|
| 66 |
+
return index
|
| 67 |
+
|
| 68 |
+
def get_band_specs_with_bandwidth(self, start_index, end_index, bandwidth_hz):
|
| 69 |
+
band_specs = []
|
| 70 |
+
lower = start_index
|
| 71 |
+
|
| 72 |
+
while lower < end_index:
|
| 73 |
+
upper = int(np.floor(lower + self.hertz_to_index(bandwidth_hz)))
|
| 74 |
+
upper = min(upper, end_index)
|
| 75 |
+
|
| 76 |
+
band_specs.append((lower, upper))
|
| 77 |
+
lower = upper
|
| 78 |
+
|
| 79 |
+
return band_specs
|
| 80 |
+
|
| 81 |
+
@abstractmethod
|
| 82 |
+
def get_band_specs(self):
|
| 83 |
+
raise NotImplementedError
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class VocalBandsplitSpecification(BandsplitSpecification):
|
| 87 |
+
def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
|
| 88 |
+
super().__init__(nfft=nfft, fs=fs)
|
| 89 |
+
|
| 90 |
+
self.version = version
|
| 91 |
+
|
| 92 |
+
def get_band_specs(self):
|
| 93 |
+
return getattr(self, f"version{self.version}")()
|
| 94 |
+
|
| 95 |
+
@property
|
| 96 |
+
def version1(self):
|
| 97 |
+
return self.get_band_specs_with_bandwidth(
|
| 98 |
+
start_index=0, end_index=self.max_index, bandwidth_hz=1000
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
def version2(self):
|
| 102 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 103 |
+
start_index=0, end_index=self.split16k, bandwidth_hz=1000
|
| 104 |
+
)
|
| 105 |
+
below20k = self.get_band_specs_with_bandwidth(
|
| 106 |
+
start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
return below16k + below20k + self.above20k
|
| 110 |
+
|
| 111 |
+
def version3(self):
|
| 112 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 113 |
+
start_index=0, end_index=self.split8k, bandwidth_hz=1000
|
| 114 |
+
)
|
| 115 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 116 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
return below8k + below16k + self.above16k
|
| 120 |
+
|
| 121 |
+
def version4(self):
|
| 122 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 123 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 124 |
+
)
|
| 125 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 126 |
+
start_index=self.split1k, end_index=self.split8k, bandwidth_hz=1000
|
| 127 |
+
)
|
| 128 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 129 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
return below1k + below8k + below16k + self.above16k
|
| 133 |
+
|
| 134 |
+
def version5(self):
|
| 135 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 136 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 137 |
+
)
|
| 138 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 139 |
+
start_index=self.split1k, end_index=self.split16k, bandwidth_hz=1000
|
| 140 |
+
)
|
| 141 |
+
below20k = self.get_band_specs_with_bandwidth(
|
| 142 |
+
start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
|
| 143 |
+
)
|
| 144 |
+
return below1k + below16k + below20k + self.above20k
|
| 145 |
+
|
| 146 |
+
def version6(self):
|
| 147 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 148 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 149 |
+
)
|
| 150 |
+
below4k = self.get_band_specs_with_bandwidth(
|
| 151 |
+
start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
|
| 152 |
+
)
|
| 153 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 154 |
+
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
|
| 155 |
+
)
|
| 156 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 157 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
|
| 158 |
+
)
|
| 159 |
+
return below1k + below4k + below8k + below16k + self.above16k
|
| 160 |
+
|
| 161 |
+
def version7(self):
|
| 162 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 163 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 164 |
+
)
|
| 165 |
+
below4k = self.get_band_specs_with_bandwidth(
|
| 166 |
+
start_index=self.split1k, end_index=self.split4k, bandwidth_hz=250
|
| 167 |
+
)
|
| 168 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 169 |
+
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
|
| 170 |
+
)
|
| 171 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 172 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
|
| 173 |
+
)
|
| 174 |
+
below20k = self.get_band_specs_with_bandwidth(
|
| 175 |
+
start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
|
| 176 |
+
)
|
| 177 |
+
return below1k + below4k + below8k + below16k + below20k + self.above20k
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class OtherBandsplitSpecification(VocalBandsplitSpecification):
|
| 181 |
+
def __init__(self, nfft: int, fs: int) -> None:
|
| 182 |
+
super().__init__(nfft=nfft, fs=fs, version="7")
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class BassBandsplitSpecification(BandsplitSpecification):
|
| 186 |
+
def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
|
| 187 |
+
super().__init__(nfft=nfft, fs=fs)
|
| 188 |
+
|
| 189 |
+
def get_band_specs(self):
|
| 190 |
+
below500 = self.get_band_specs_with_bandwidth(
|
| 191 |
+
start_index=0, end_index=self.split500, bandwidth_hz=50
|
| 192 |
+
)
|
| 193 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 194 |
+
start_index=self.split500, end_index=self.split1k, bandwidth_hz=100
|
| 195 |
+
)
|
| 196 |
+
below4k = self.get_band_specs_with_bandwidth(
|
| 197 |
+
start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
|
| 198 |
+
)
|
| 199 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 200 |
+
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
|
| 201 |
+
)
|
| 202 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 203 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
|
| 204 |
+
)
|
| 205 |
+
above16k = [(self.split16k, self.max_index)]
|
| 206 |
+
|
| 207 |
+
return below500 + below1k + below4k + below8k + below16k + above16k
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class DrumBandsplitSpecification(BandsplitSpecification):
|
| 211 |
+
def __init__(self, nfft: int, fs: int) -> None:
|
| 212 |
+
super().__init__(nfft=nfft, fs=fs)
|
| 213 |
+
|
| 214 |
+
def get_band_specs(self):
|
| 215 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 216 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=50
|
| 217 |
+
)
|
| 218 |
+
below2k = self.get_band_specs_with_bandwidth(
|
| 219 |
+
start_index=self.split1k, end_index=self.split2k, bandwidth_hz=100
|
| 220 |
+
)
|
| 221 |
+
below4k = self.get_band_specs_with_bandwidth(
|
| 222 |
+
start_index=self.split2k, end_index=self.split4k, bandwidth_hz=250
|
| 223 |
+
)
|
| 224 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 225 |
+
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
|
| 226 |
+
)
|
| 227 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 228 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
|
| 229 |
+
)
|
| 230 |
+
above16k = [(self.split16k, self.max_index)]
|
| 231 |
+
|
| 232 |
+
return below1k + below2k + below4k + below8k + below16k + above16k
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class PerceptualBandsplitSpecification(BandsplitSpecification):
|
| 236 |
+
def __init__(
|
| 237 |
+
self,
|
| 238 |
+
nfft: int,
|
| 239 |
+
fs: int,
|
| 240 |
+
fbank_fn: Callable[[int, int, float, float, int], torch.Tensor],
|
| 241 |
+
n_bands: int,
|
| 242 |
+
f_min: float = 0.0,
|
| 243 |
+
f_max: float = None,
|
| 244 |
+
) -> None:
|
| 245 |
+
super().__init__(nfft=nfft, fs=fs)
|
| 246 |
+
self.n_bands = n_bands
|
| 247 |
+
if f_max is None:
|
| 248 |
+
f_max = fs / 2
|
| 249 |
+
|
| 250 |
+
self.filterbank = fbank_fn(n_bands, fs, f_min, f_max, self.max_index)
|
| 251 |
+
|
| 252 |
+
weight_per_bin = torch.sum(self.filterbank, dim=0, keepdim=True)
|
| 253 |
+
normalized_mel_fb = self.filterbank / weight_per_bin
|
| 254 |
+
|
| 255 |
+
freq_weights = []
|
| 256 |
+
band_specs = []
|
| 257 |
+
for i in range(self.n_bands):
|
| 258 |
+
active_bins = torch.nonzero(self.filterbank[i, :]).squeeze().tolist()
|
| 259 |
+
if isinstance(active_bins, int):
|
| 260 |
+
active_bins = (active_bins, active_bins)
|
| 261 |
+
if len(active_bins) == 0:
|
| 262 |
+
continue
|
| 263 |
+
start_index = active_bins[0]
|
| 264 |
+
end_index = active_bins[-1] + 1
|
| 265 |
+
band_specs.append((start_index, end_index))
|
| 266 |
+
freq_weights.append(normalized_mel_fb[i, start_index:end_index])
|
| 267 |
+
|
| 268 |
+
self.freq_weights = freq_weights
|
| 269 |
+
self.band_specs = band_specs
|
| 270 |
+
|
| 271 |
+
def get_band_specs(self):
|
| 272 |
+
return self.band_specs
|
| 273 |
+
|
| 274 |
+
def get_freq_weights(self):
|
| 275 |
+
return self.freq_weights
|
| 276 |
+
|
| 277 |
+
def save_to_file(self, dir_path: str) -> None:
|
| 278 |
+
os.makedirs(dir_path, exist_ok=True)
|
| 279 |
+
|
| 280 |
+
import pickle
|
| 281 |
+
|
| 282 |
+
with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f:
|
| 283 |
+
pickle.dump(
|
| 284 |
+
{
|
| 285 |
+
"band_specs": self.band_specs,
|
| 286 |
+
"freq_weights": self.freq_weights,
|
| 287 |
+
"filterbank": self.filterbank,
|
| 288 |
+
},
|
| 289 |
+
f,
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs):
|
| 294 |
+
fb = taF.melscale_fbanks(
|
| 295 |
+
n_mels=n_bands,
|
| 296 |
+
sample_rate=fs,
|
| 297 |
+
f_min=f_min,
|
| 298 |
+
f_max=f_max,
|
| 299 |
+
n_freqs=n_freqs,
|
| 300 |
+
).T
|
| 301 |
+
|
| 302 |
+
fb[0, 0] = 1.0
|
| 303 |
+
|
| 304 |
+
return fb
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
class MelBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 308 |
+
def __init__(
|
| 309 |
+
self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
|
| 310 |
+
) -> None:
|
| 311 |
+
super().__init__(
|
| 312 |
+
fbank_fn=mel_filterbank,
|
| 313 |
+
nfft=nfft,
|
| 314 |
+
fs=fs,
|
| 315 |
+
n_bands=n_bands,
|
| 316 |
+
f_min=f_min,
|
| 317 |
+
f_max=f_max,
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs, scale="constant"):
|
| 322 |
+
nfft = 2 * (n_freqs - 1)
|
| 323 |
+
df = fs / nfft
|
| 324 |
+
f_max = f_max or fs / 2
|
| 325 |
+
f_min = f_min or 0
|
| 326 |
+
f_min = fs / nfft
|
| 327 |
+
|
| 328 |
+
n_octaves = np.log2(f_max / f_min)
|
| 329 |
+
n_octaves_per_band = n_octaves / n_bands
|
| 330 |
+
bandwidth_mult = np.power(2.0, n_octaves_per_band)
|
| 331 |
+
|
| 332 |
+
low_midi = max(0, hz_to_midi(f_min))
|
| 333 |
+
high_midi = hz_to_midi(f_max)
|
| 334 |
+
midi_points = np.linspace(low_midi, high_midi, n_bands)
|
| 335 |
+
hz_pts = midi_to_hz(midi_points)
|
| 336 |
+
|
| 337 |
+
low_pts = hz_pts / bandwidth_mult
|
| 338 |
+
high_pts = hz_pts * bandwidth_mult
|
| 339 |
+
|
| 340 |
+
low_bins = np.floor(low_pts / df).astype(int)
|
| 341 |
+
high_bins = np.ceil(high_pts / df).astype(int)
|
| 342 |
+
|
| 343 |
+
fb = np.zeros((n_bands, n_freqs))
|
| 344 |
+
|
| 345 |
+
for i in range(n_bands):
|
| 346 |
+
fb[i, low_bins[i] : high_bins[i] + 1] = 1.0
|
| 347 |
+
|
| 348 |
+
fb[0, : low_bins[0]] = 1.0
|
| 349 |
+
fb[-1, high_bins[-1] + 1 :] = 1.0
|
| 350 |
+
|
| 351 |
+
return torch.as_tensor(fb)
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
class MusicalBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 355 |
+
def __init__(
|
| 356 |
+
self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
|
| 357 |
+
) -> None:
|
| 358 |
+
super().__init__(
|
| 359 |
+
fbank_fn=musical_filterbank,
|
| 360 |
+
nfft=nfft,
|
| 361 |
+
fs=fs,
|
| 362 |
+
n_bands=n_bands,
|
| 363 |
+
f_min=f_min,
|
| 364 |
+
f_max=f_max,
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
if __name__ == "__main__":
|
| 369 |
+
import pandas as pd
|
| 370 |
+
|
| 371 |
+
band_defs = []
|
| 372 |
+
|
| 373 |
+
for bands in [VocalBandsplitSpecification]:
|
| 374 |
+
band_name = bands.__name__.replace("BandsplitSpecification", "")
|
| 375 |
+
|
| 376 |
+
mbs = bands(nfft=2048, fs=44100).get_band_specs()
|
| 377 |
+
|
| 378 |
+
for i, (f_min, f_max) in enumerate(mbs):
|
| 379 |
+
band_defs.append(
|
| 380 |
+
{"band": band_name, "band_index": i, "f_min": f_min, "f_max": f_max}
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
df = pd.DataFrame(band_defs)
|
| 384 |
+
df.to_csv("vox7bands.csv", index=False)
|
mvsepless/models/bs_roformer/__init__.py
CHANGED
|
@@ -1,14 +1,14 @@
|
|
| 1 |
-
from .bs_roformer import BSRoformer
|
| 2 |
-
from .bs_conformer import BSConformer
|
| 3 |
-
from .bs_roformer_sw import BSRoformer_SW
|
| 4 |
-
try:
|
| 5 |
-
from neuralop.models import FNO1d
|
| 6 |
-
from .bs_roformer_fno import BSRoformer_FNO
|
| 7 |
-
except:
|
| 8 |
-
pass
|
| 9 |
-
from .bs_roformer_hyperace import BSRoformerHyperACE
|
| 10 |
-
from .bs_roformer_hyperace2 import BSRoformerHyperACE_2
|
| 11 |
-
from .bs_roformer_conditional import BSRoformer_Conditional
|
| 12 |
-
from .bs_roformer_unwa_inst_large_2 import BSRoformer_2
|
| 13 |
-
from .mel_band_roformer import MelBandRoformer
|
| 14 |
-
from .mel_band_conformer import MelBandConformer
|
|
|
|
| 1 |
+
from .bs_roformer import BSRoformer
|
| 2 |
+
from .bs_conformer import BSConformer
|
| 3 |
+
from .bs_roformer_sw import BSRoformer_SW
|
| 4 |
+
try:
|
| 5 |
+
from neuralop.models import FNO1d
|
| 6 |
+
from .bs_roformer_fno import BSRoformer_FNO
|
| 7 |
+
except:
|
| 8 |
+
pass
|
| 9 |
+
from .bs_roformer_hyperace import BSRoformerHyperACE
|
| 10 |
+
from .bs_roformer_hyperace2 import BSRoformerHyperACE_2
|
| 11 |
+
from .bs_roformer_conditional import BSRoformer_Conditional
|
| 12 |
+
from .bs_roformer_unwa_inst_large_2 import BSRoformer_2
|
| 13 |
+
from .mel_band_roformer import MelBandRoformer
|
| 14 |
+
from .mel_band_conformer import MelBandConformer
|
mvsepless/models/bs_roformer/attend.py
CHANGED
|
@@ -1,128 +1,128 @@
|
|
| 1 |
-
from functools import wraps
|
| 2 |
-
from packaging import version
|
| 3 |
-
from collections import namedtuple
|
| 4 |
-
|
| 5 |
-
import os
|
| 6 |
-
import torch
|
| 7 |
-
from torch import nn, einsum
|
| 8 |
-
import torch.nn.functional as F
|
| 9 |
-
|
| 10 |
-
from einops import rearrange, reduce
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
FlashAttentionConfig = namedtuple(
|
| 14 |
-
"FlashAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"]
|
| 15 |
-
)
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def exists(val):
|
| 19 |
-
return val is not None
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
def default(v, d):
|
| 23 |
-
return v if exists(v) else d
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def once(fn):
|
| 27 |
-
called = False
|
| 28 |
-
|
| 29 |
-
@wraps(fn)
|
| 30 |
-
def inner(x):
|
| 31 |
-
nonlocal called
|
| 32 |
-
if called:
|
| 33 |
-
return
|
| 34 |
-
called = True
|
| 35 |
-
return fn(x)
|
| 36 |
-
|
| 37 |
-
return inner
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
print_once = once(print)
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
class Attend(nn.Module):
|
| 44 |
-
def __init__(self, dropout=0.0, flash=False, scale=None):
|
| 45 |
-
super().__init__()
|
| 46 |
-
self.scale = scale
|
| 47 |
-
self.dropout = dropout
|
| 48 |
-
self.attn_dropout = nn.Dropout(dropout)
|
| 49 |
-
|
| 50 |
-
self.flash = flash
|
| 51 |
-
self.use_torch_2_sdpa = False
|
| 52 |
-
self._config_checked = False
|
| 53 |
-
|
| 54 |
-
# Проверяем версию PyTorch при первом вызове
|
| 55 |
-
if flash and not self._config_checked:
|
| 56 |
-
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
| 57 |
-
print_once("PyTorch >= 2.0 detected, will use SDPA if available.")
|
| 58 |
-
self.use_torch_2_sdpa = True
|
| 59 |
-
|
| 60 |
-
# Настройки для PyTorch >= 2.0
|
| 61 |
-
self.cpu_config = FlashAttentionConfig(True, True, True)
|
| 62 |
-
self.cuda_config = None
|
| 63 |
-
|
| 64 |
-
if torch.cuda.is_available():
|
| 65 |
-
device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
|
| 66 |
-
device_version = version.parse(
|
| 67 |
-
f"{device_properties.major}.{device_properties.minor}"
|
| 68 |
-
)
|
| 69 |
-
|
| 70 |
-
if device_version >= version.parse("8.0"):
|
| 71 |
-
if os.name == "nt":
|
| 72 |
-
print_once(
|
| 73 |
-
"Windows OS detected, using math or mem efficient attention if input tensor is on cuda"
|
| 74 |
-
)
|
| 75 |
-
self.cuda_config = FlashAttentionConfig(False, True, True)
|
| 76 |
-
else:
|
| 77 |
-
print_once(
|
| 78 |
-
"GPU Compute Capability equal or above 8.0, using flash attention if input tensor is on cuda"
|
| 79 |
-
)
|
| 80 |
-
self.cuda_config = FlashAttentionConfig(True, False, False)
|
| 81 |
-
else:
|
| 82 |
-
print_once(
|
| 83 |
-
"GPU Compute Capability below 8.0, using math or mem efficient attention if input tensor is on cuda"
|
| 84 |
-
)
|
| 85 |
-
self.cuda_config = FlashAttentionConfig(False, True, True)
|
| 86 |
-
else:
|
| 87 |
-
print_once("PyTorch < 2.0 detected, flash attention will use einsum fallback.")
|
| 88 |
-
self.use_torch_2_sdpa = False
|
| 89 |
-
|
| 90 |
-
self._config_checked = True
|
| 91 |
-
|
| 92 |
-
def flash_attn_torch2(self, q, k, v):
|
| 93 |
-
"""SDPA для PyTorch >= 2.0"""
|
| 94 |
-
if exists(self.scale):
|
| 95 |
-
default_scale = q.shape[-1] ** -0.5
|
| 96 |
-
q = q * (self.scale / default_scale)
|
| 97 |
-
|
| 98 |
-
is_cuda = q.is_cuda
|
| 99 |
-
config = self.cuda_config if is_cuda else self.cpu_config
|
| 100 |
-
|
| 101 |
-
with torch.backends.cuda.sdp_kernel(**config._asdict()):
|
| 102 |
-
out = F.scaled_dot_product_attention(
|
| 103 |
-
q, k, v, dropout_p=self.dropout if self.training else 0.0
|
| 104 |
-
)
|
| 105 |
-
|
| 106 |
-
return out
|
| 107 |
-
|
| 108 |
-
def forward(self, q, k, v):
|
| 109 |
-
q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
|
| 110 |
-
|
| 111 |
-
scale = default(self.scale, q.shape[-1] ** -0.5)
|
| 112 |
-
|
| 113 |
-
if self.flash and self.use_torch_2_sdpa:
|
| 114 |
-
try:
|
| 115 |
-
return self.flash_attn_torch2(q, k, v)
|
| 116 |
-
except Exception as e:
|
| 117 |
-
print(f"Flash attention failed: {e}. Falling back to einsum.")
|
| 118 |
-
self.use_torch_2_sdpa = False
|
| 119 |
-
|
| 120 |
-
# Fallback для PyTorch < 2.0 или если flash отключен
|
| 121 |
-
sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
|
| 122 |
-
|
| 123 |
-
attn = sim.softmax(dim=-1)
|
| 124 |
-
attn = self.attn_dropout(attn)
|
| 125 |
-
|
| 126 |
-
out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
|
| 127 |
-
|
| 128 |
return out
|
|
|
|
| 1 |
+
from functools import wraps
|
| 2 |
+
from packaging import version
|
| 3 |
+
from collections import namedtuple
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn, einsum
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from einops import rearrange, reduce
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
FlashAttentionConfig = namedtuple(
|
| 14 |
+
"FlashAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"]
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def exists(val):
|
| 19 |
+
return val is not None
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def default(v, d):
|
| 23 |
+
return v if exists(v) else d
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def once(fn):
|
| 27 |
+
called = False
|
| 28 |
+
|
| 29 |
+
@wraps(fn)
|
| 30 |
+
def inner(x):
|
| 31 |
+
nonlocal called
|
| 32 |
+
if called:
|
| 33 |
+
return
|
| 34 |
+
called = True
|
| 35 |
+
return fn(x)
|
| 36 |
+
|
| 37 |
+
return inner
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
print_once = once(print)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class Attend(nn.Module):
|
| 44 |
+
def __init__(self, dropout=0.0, flash=False, scale=None):
|
| 45 |
+
super().__init__()
|
| 46 |
+
self.scale = scale
|
| 47 |
+
self.dropout = dropout
|
| 48 |
+
self.attn_dropout = nn.Dropout(dropout)
|
| 49 |
+
|
| 50 |
+
self.flash = flash
|
| 51 |
+
self.use_torch_2_sdpa = False
|
| 52 |
+
self._config_checked = False
|
| 53 |
+
|
| 54 |
+
# Проверяем версию PyTorch при первом вызове
|
| 55 |
+
if flash and not self._config_checked:
|
| 56 |
+
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
| 57 |
+
print_once("PyTorch >= 2.0 detected, will use SDPA if available.")
|
| 58 |
+
self.use_torch_2_sdpa = True
|
| 59 |
+
|
| 60 |
+
# Настройки для PyTorch >= 2.0
|
| 61 |
+
self.cpu_config = FlashAttentionConfig(True, True, True)
|
| 62 |
+
self.cuda_config = None
|
| 63 |
+
|
| 64 |
+
if torch.cuda.is_available():
|
| 65 |
+
device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
|
| 66 |
+
device_version = version.parse(
|
| 67 |
+
f"{device_properties.major}.{device_properties.minor}"
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
if device_version >= version.parse("8.0"):
|
| 71 |
+
if os.name == "nt":
|
| 72 |
+
print_once(
|
| 73 |
+
"Windows OS detected, using math or mem efficient attention if input tensor is on cuda"
|
| 74 |
+
)
|
| 75 |
+
self.cuda_config = FlashAttentionConfig(False, True, True)
|
| 76 |
+
else:
|
| 77 |
+
print_once(
|
| 78 |
+
"GPU Compute Capability equal or above 8.0, using flash attention if input tensor is on cuda"
|
| 79 |
+
)
|
| 80 |
+
self.cuda_config = FlashAttentionConfig(True, False, False)
|
| 81 |
+
else:
|
| 82 |
+
print_once(
|
| 83 |
+
"GPU Compute Capability below 8.0, using math or mem efficient attention if input tensor is on cuda"
|
| 84 |
+
)
|
| 85 |
+
self.cuda_config = FlashAttentionConfig(False, True, True)
|
| 86 |
+
else:
|
| 87 |
+
print_once("PyTorch < 2.0 detected, flash attention will use einsum fallback.")
|
| 88 |
+
self.use_torch_2_sdpa = False
|
| 89 |
+
|
| 90 |
+
self._config_checked = True
|
| 91 |
+
|
| 92 |
+
def flash_attn_torch2(self, q, k, v):
|
| 93 |
+
"""SDPA для PyTorch >= 2.0"""
|
| 94 |
+
if exists(self.scale):
|
| 95 |
+
default_scale = q.shape[-1] ** -0.5
|
| 96 |
+
q = q * (self.scale / default_scale)
|
| 97 |
+
|
| 98 |
+
is_cuda = q.is_cuda
|
| 99 |
+
config = self.cuda_config if is_cuda else self.cpu_config
|
| 100 |
+
|
| 101 |
+
with torch.backends.cuda.sdp_kernel(**config._asdict()):
|
| 102 |
+
out = F.scaled_dot_product_attention(
|
| 103 |
+
q, k, v, dropout_p=self.dropout if self.training else 0.0
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
return out
|
| 107 |
+
|
| 108 |
+
def forward(self, q, k, v):
|
| 109 |
+
q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
|
| 110 |
+
|
| 111 |
+
scale = default(self.scale, q.shape[-1] ** -0.5)
|
| 112 |
+
|
| 113 |
+
if self.flash and self.use_torch_2_sdpa:
|
| 114 |
+
try:
|
| 115 |
+
return self.flash_attn_torch2(q, k, v)
|
| 116 |
+
except Exception as e:
|
| 117 |
+
print(f"Flash attention failed: {e}. Falling back to einsum.")
|
| 118 |
+
self.use_torch_2_sdpa = False
|
| 119 |
+
|
| 120 |
+
# Fallback для PyTorch < 2.0 или если flash отключен
|
| 121 |
+
sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
|
| 122 |
+
|
| 123 |
+
attn = sim.softmax(dim=-1)
|
| 124 |
+
attn = self.attn_dropout(attn)
|
| 125 |
+
|
| 126 |
+
out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
|
| 127 |
+
|
| 128 |
return out
|
mvsepless/models/bs_roformer/attend_sage.py
CHANGED
|
@@ -1,147 +1,147 @@
|
|
| 1 |
-
from functools import wraps
|
| 2 |
-
from packaging import version
|
| 3 |
-
from collections import namedtuple
|
| 4 |
-
|
| 5 |
-
import os
|
| 6 |
-
import torch
|
| 7 |
-
from torch import nn, einsum
|
| 8 |
-
import torch.nn.functional as F
|
| 9 |
-
|
| 10 |
-
from einops import rearrange, reduce
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
def _print_once(msg):
|
| 14 |
-
printed = False
|
| 15 |
-
|
| 16 |
-
def inner():
|
| 17 |
-
nonlocal printed
|
| 18 |
-
if not printed:
|
| 19 |
-
print(msg)
|
| 20 |
-
printed = True
|
| 21 |
-
|
| 22 |
-
return inner
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
# Проверяем доступность SageAttention
|
| 26 |
-
try:
|
| 27 |
-
from sageattention import sageattn
|
| 28 |
-
_has_sage_attention = True
|
| 29 |
-
except ImportError:
|
| 30 |
-
_has_sage_attention = False
|
| 31 |
-
_print_sage_not_found = _print_once(
|
| 32 |
-
"SageAttention not found. Will fall back to PyTorch SDPA (if available) or manual einsum."
|
| 33 |
-
)
|
| 34 |
-
_print_sage_not_found()
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def exists(val):
|
| 38 |
-
return val is not None
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def default(v, d):
|
| 42 |
-
return v if exists(v) else d
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
class Attend(nn.Module):
|
| 46 |
-
def __init__(self, dropout=0.0, flash=False, scale=None):
|
| 47 |
-
super().__init__()
|
| 48 |
-
self.scale = scale
|
| 49 |
-
self.dropout = dropout
|
| 50 |
-
|
| 51 |
-
self.use_sage = flash and _has_sage_attention
|
| 52 |
-
self.use_pytorch_sdpa = False
|
| 53 |
-
self._sdpa_checked = False
|
| 54 |
-
self.flash = flash
|
| 55 |
-
|
| 56 |
-
# Инициализируем сообщения
|
| 57 |
-
self._init_messages = False
|
| 58 |
-
|
| 59 |
-
if flash and not self.use_sage:
|
| 60 |
-
if not self._sdpa_checked:
|
| 61 |
-
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
| 62 |
-
self.use_pytorch_sdpa = True
|
| 63 |
-
self._sdpa_checked = True
|
| 64 |
-
|
| 65 |
-
self.attn_dropout = nn.Dropout(dropout)
|
| 66 |
-
|
| 67 |
-
def _print_init_messages(self):
|
| 68 |
-
"""Печатаем сообщения инициализации один раз"""
|
| 69 |
-
if self._init_messages:
|
| 70 |
-
return
|
| 71 |
-
|
| 72 |
-
if self.flash:
|
| 73 |
-
if self.use_sage:
|
| 74 |
-
print_once = _print_once("Using SageAttention backend.")
|
| 75 |
-
print_once()
|
| 76 |
-
elif self.use_pytorch_sdpa:
|
| 77 |
-
print_once = _print_once(
|
| 78 |
-
"Using PyTorch SDPA backend (FlashAttention-2, Memory-Efficient, or Math)."
|
| 79 |
-
)
|
| 80 |
-
print_once()
|
| 81 |
-
else:
|
| 82 |
-
print_once = _print_once(
|
| 83 |
-
"Flash attention requested but Pytorch < 2.0 and SageAttention not found. Falling back to einsum."
|
| 84 |
-
)
|
| 85 |
-
print_once()
|
| 86 |
-
|
| 87 |
-
self._init_messages = True
|
| 88 |
-
|
| 89 |
-
def forward(self, q, k, v):
|
| 90 |
-
q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
|
| 91 |
-
|
| 92 |
-
# Печатаем сообщения инициализации при первом вызове
|
| 93 |
-
self._print_init_messages()
|
| 94 |
-
|
| 95 |
-
# Пробуем SageAttention если доступен
|
| 96 |
-
if self.use_sage and self.flash:
|
| 97 |
-
try:
|
| 98 |
-
# Исправленный вызов: убрали повторный try-except
|
| 99 |
-
out = sageattn(q, k, v, tensor_layout="HND", is_causal=False)
|
| 100 |
-
return out
|
| 101 |
-
except Exception as e:
|
| 102 |
-
print(f"SageAttention failed with error: {e}. Falling back.")
|
| 103 |
-
self.use_sage = False
|
| 104 |
-
if not self._sdpa_checked:
|
| 105 |
-
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
| 106 |
-
self.use_pytorch_sdpa = True
|
| 107 |
-
print_once = _print_once(
|
| 108 |
-
"Falling back to PyTorch SDPA."
|
| 109 |
-
)
|
| 110 |
-
print_once()
|
| 111 |
-
else:
|
| 112 |
-
print_once = _print_once("Falling back to einsum.")
|
| 113 |
-
print_once()
|
| 114 |
-
self._sdpa_checked = True
|
| 115 |
-
|
| 116 |
-
# Пробуем PyTorch SDPA если доступен
|
| 117 |
-
if self.use_pytorch_sdpa and self.flash:
|
| 118 |
-
try:
|
| 119 |
-
# Для PyTorch >= 2.0
|
| 120 |
-
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
| 121 |
-
with torch.backends.cuda.sdp_kernel(
|
| 122 |
-
enable_flash=True, enable_math=True, enable_mem_efficient=True
|
| 123 |
-
):
|
| 124 |
-
out = F.scaled_dot_product_attention(
|
| 125 |
-
q,
|
| 126 |
-
k,
|
| 127 |
-
v,
|
| 128 |
-
attn_mask=None,
|
| 129 |
-
dropout_p=self.dropout if self.training else 0.0,
|
| 130 |
-
is_causal=False,
|
| 131 |
-
)
|
| 132 |
-
return out
|
| 133 |
-
except Exception as e:
|
| 134 |
-
print(f"PyTorch SDPA failed with error: {e}. Falling back to einsum.")
|
| 135 |
-
self.use_pytorch_sdpa = False
|
| 136 |
-
|
| 137 |
-
# Fallback на einsum (работает в PyTorch 1.13+)
|
| 138 |
-
scale = default(self.scale, q.shape[-1] ** -0.5)
|
| 139 |
-
|
| 140 |
-
sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
|
| 141 |
-
|
| 142 |
-
attn = sim.softmax(dim=-1)
|
| 143 |
-
attn = self.attn_dropout(attn)
|
| 144 |
-
|
| 145 |
-
out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
|
| 146 |
-
|
| 147 |
return out
|
|
|
|
| 1 |
+
from functools import wraps
|
| 2 |
+
from packaging import version
|
| 3 |
+
from collections import namedtuple
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn, einsum
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from einops import rearrange, reduce
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _print_once(msg):
|
| 14 |
+
printed = False
|
| 15 |
+
|
| 16 |
+
def inner():
|
| 17 |
+
nonlocal printed
|
| 18 |
+
if not printed:
|
| 19 |
+
print(msg)
|
| 20 |
+
printed = True
|
| 21 |
+
|
| 22 |
+
return inner
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# Проверяем доступность SageAttention
|
| 26 |
+
try:
|
| 27 |
+
from sageattention import sageattn
|
| 28 |
+
_has_sage_attention = True
|
| 29 |
+
except ImportError:
|
| 30 |
+
_has_sage_attention = False
|
| 31 |
+
_print_sage_not_found = _print_once(
|
| 32 |
+
"SageAttention not found. Will fall back to PyTorch SDPA (if available) or manual einsum."
|
| 33 |
+
)
|
| 34 |
+
_print_sage_not_found()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def exists(val):
|
| 38 |
+
return val is not None
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def default(v, d):
|
| 42 |
+
return v if exists(v) else d
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class Attend(nn.Module):
|
| 46 |
+
def __init__(self, dropout=0.0, flash=False, scale=None):
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.scale = scale
|
| 49 |
+
self.dropout = dropout
|
| 50 |
+
|
| 51 |
+
self.use_sage = flash and _has_sage_attention
|
| 52 |
+
self.use_pytorch_sdpa = False
|
| 53 |
+
self._sdpa_checked = False
|
| 54 |
+
self.flash = flash
|
| 55 |
+
|
| 56 |
+
# Инициализируем сообщения
|
| 57 |
+
self._init_messages = False
|
| 58 |
+
|
| 59 |
+
if flash and not self.use_sage:
|
| 60 |
+
if not self._sdpa_checked:
|
| 61 |
+
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
| 62 |
+
self.use_pytorch_sdpa = True
|
| 63 |
+
self._sdpa_checked = True
|
| 64 |
+
|
| 65 |
+
self.attn_dropout = nn.Dropout(dropout)
|
| 66 |
+
|
| 67 |
+
def _print_init_messages(self):
|
| 68 |
+
"""Печатаем сообщения инициализации один раз"""
|
| 69 |
+
if self._init_messages:
|
| 70 |
+
return
|
| 71 |
+
|
| 72 |
+
if self.flash:
|
| 73 |
+
if self.use_sage:
|
| 74 |
+
print_once = _print_once("Using SageAttention backend.")
|
| 75 |
+
print_once()
|
| 76 |
+
elif self.use_pytorch_sdpa:
|
| 77 |
+
print_once = _print_once(
|
| 78 |
+
"Using PyTorch SDPA backend (FlashAttention-2, Memory-Efficient, or Math)."
|
| 79 |
+
)
|
| 80 |
+
print_once()
|
| 81 |
+
else:
|
| 82 |
+
print_once = _print_once(
|
| 83 |
+
"Flash attention requested but Pytorch < 2.0 and SageAttention not found. Falling back to einsum."
|
| 84 |
+
)
|
| 85 |
+
print_once()
|
| 86 |
+
|
| 87 |
+
self._init_messages = True
|
| 88 |
+
|
| 89 |
+
def forward(self, q, k, v):
|
| 90 |
+
q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
|
| 91 |
+
|
| 92 |
+
# Печатаем сообщения инициализации при первом вызове
|
| 93 |
+
self._print_init_messages()
|
| 94 |
+
|
| 95 |
+
# Пробуем SageAttention если доступен
|
| 96 |
+
if self.use_sage and self.flash:
|
| 97 |
+
try:
|
| 98 |
+
# Исправленный вызов: убрали повторный try-except
|
| 99 |
+
out = sageattn(q, k, v, tensor_layout="HND", is_causal=False)
|
| 100 |
+
return out
|
| 101 |
+
except Exception as e:
|
| 102 |
+
print(f"SageAttention failed with error: {e}. Falling back.")
|
| 103 |
+
self.use_sage = False
|
| 104 |
+
if not self._sdpa_checked:
|
| 105 |
+
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
| 106 |
+
self.use_pytorch_sdpa = True
|
| 107 |
+
print_once = _print_once(
|
| 108 |
+
"Falling back to PyTorch SDPA."
|
| 109 |
+
)
|
| 110 |
+
print_once()
|
| 111 |
+
else:
|
| 112 |
+
print_once = _print_once("Falling back to einsum.")
|
| 113 |
+
print_once()
|
| 114 |
+
self._sdpa_checked = True
|
| 115 |
+
|
| 116 |
+
# Пробуем PyTorch SDPA если доступен
|
| 117 |
+
if self.use_pytorch_sdpa and self.flash:
|
| 118 |
+
try:
|
| 119 |
+
# Для PyTorch >= 2.0
|
| 120 |
+
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
| 121 |
+
with torch.backends.cuda.sdp_kernel(
|
| 122 |
+
enable_flash=True, enable_math=True, enable_mem_efficient=True
|
| 123 |
+
):
|
| 124 |
+
out = F.scaled_dot_product_attention(
|
| 125 |
+
q,
|
| 126 |
+
k,
|
| 127 |
+
v,
|
| 128 |
+
attn_mask=None,
|
| 129 |
+
dropout_p=self.dropout if self.training else 0.0,
|
| 130 |
+
is_causal=False,
|
| 131 |
+
)
|
| 132 |
+
return out
|
| 133 |
+
except Exception as e:
|
| 134 |
+
print(f"PyTorch SDPA failed with error: {e}. Falling back to einsum.")
|
| 135 |
+
self.use_pytorch_sdpa = False
|
| 136 |
+
|
| 137 |
+
# Fallback на einsum (работает в PyTorch 1.13+)
|
| 138 |
+
scale = default(self.scale, q.shape[-1] ** -0.5)
|
| 139 |
+
|
| 140 |
+
sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
|
| 141 |
+
|
| 142 |
+
attn = sim.softmax(dim=-1)
|
| 143 |
+
attn = self.attn_dropout(attn)
|
| 144 |
+
|
| 145 |
+
out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
|
| 146 |
+
|
| 147 |
return out
|
mvsepless/models/bs_roformer/attend_sw.py
CHANGED
|
@@ -1,88 +1,88 @@
|
|
| 1 |
-
import logging
|
| 2 |
-
import os
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
import torch.nn.functional as F
|
| 6 |
-
from packaging import version
|
| 7 |
-
from torch import Tensor, einsum, nn
|
| 8 |
-
from torch.nn.attention import SDPBackend, sdpa_kernel
|
| 9 |
-
|
| 10 |
-
logger = logging.getLogger(__name__)
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class Attend(nn.Module):
|
| 14 |
-
def __init__(self, dropout: float = 0.0, flash: bool = False, scale=None):
|
| 15 |
-
super().__init__()
|
| 16 |
-
self.scale = scale
|
| 17 |
-
self.dropout = dropout
|
| 18 |
-
self.attn_dropout = nn.Dropout(dropout)
|
| 19 |
-
|
| 20 |
-
self.flash = flash
|
| 21 |
-
assert not (
|
| 22 |
-
flash and version.parse(torch.__version__) < version.parse("2.0.0")
|
| 23 |
-
), "expected pytorch >= 2.0.0 to use flash attention"
|
| 24 |
-
|
| 25 |
-
self.cpu_backends = [
|
| 26 |
-
SDPBackend.FLASH_ATTENTION,
|
| 27 |
-
SDPBackend.EFFICIENT_ATTENTION,
|
| 28 |
-
SDPBackend.MATH,
|
| 29 |
-
]
|
| 30 |
-
self.cuda_backends: list | None = None
|
| 31 |
-
|
| 32 |
-
if not torch.cuda.is_available() or not flash:
|
| 33 |
-
return
|
| 34 |
-
|
| 35 |
-
device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
|
| 36 |
-
device_version = version.parse(
|
| 37 |
-
f"{device_properties.major}.{device_properties.minor}"
|
| 38 |
-
)
|
| 39 |
-
|
| 40 |
-
if device_version >= version.parse("8.0"):
|
| 41 |
-
if os.name == "nt":
|
| 42 |
-
cuda_backends = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
|
| 43 |
-
logger.info(f"windows detected, {cuda_backends=}")
|
| 44 |
-
else:
|
| 45 |
-
cuda_backends = [SDPBackend.FLASH_ATTENTION]
|
| 46 |
-
logger.info(f"gpu compute capability >= 8.0, {cuda_backends=}")
|
| 47 |
-
else:
|
| 48 |
-
cuda_backends = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
|
| 49 |
-
logger.info(f"gpu compute capability < 8.0, {cuda_backends=}")
|
| 50 |
-
|
| 51 |
-
self.cuda_backends = cuda_backends
|
| 52 |
-
|
| 53 |
-
def flash_attn(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
| 54 |
-
_, _heads, _q_len, _, _k_len, is_cuda, _device = (
|
| 55 |
-
*q.shape,
|
| 56 |
-
k.shape[-2],
|
| 57 |
-
q.is_cuda,
|
| 58 |
-
q.device,
|
| 59 |
-
) # type: ignore
|
| 60 |
-
|
| 61 |
-
if self.scale is not None:
|
| 62 |
-
default_scale = q.shape[-1] ** -0.5
|
| 63 |
-
q = q * (self.scale / default_scale)
|
| 64 |
-
|
| 65 |
-
backends = self.cuda_backends if is_cuda else self.cpu_backends
|
| 66 |
-
with sdpa_kernel(backends=backends): # type: ignore
|
| 67 |
-
out = F.scaled_dot_product_attention(
|
| 68 |
-
q, k, v, dropout_p=self.dropout if self.training else 0.0
|
| 69 |
-
)
|
| 70 |
-
|
| 71 |
-
return out
|
| 72 |
-
|
| 73 |
-
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
| 74 |
-
_q_len, _k_len, _device = q.shape[-2], k.shape[-2], q.device
|
| 75 |
-
|
| 76 |
-
scale = self.scale or q.shape[-1] ** -0.5
|
| 77 |
-
|
| 78 |
-
if self.flash:
|
| 79 |
-
return self.flash_attn(q, k, v)
|
| 80 |
-
|
| 81 |
-
sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale
|
| 82 |
-
|
| 83 |
-
attn = sim.softmax(dim=-1)
|
| 84 |
-
attn = self.attn_dropout(attn)
|
| 85 |
-
|
| 86 |
-
out = einsum("b h i j, b h j d -> b h i d", attn, v)
|
| 87 |
-
|
| 88 |
-
return out
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from packaging import version
|
| 7 |
+
from torch import Tensor, einsum, nn
|
| 8 |
+
from torch.nn.attention import SDPBackend, sdpa_kernel
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Attend(nn.Module):
|
| 14 |
+
def __init__(self, dropout: float = 0.0, flash: bool = False, scale=None):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.scale = scale
|
| 17 |
+
self.dropout = dropout
|
| 18 |
+
self.attn_dropout = nn.Dropout(dropout)
|
| 19 |
+
|
| 20 |
+
self.flash = flash
|
| 21 |
+
assert not (
|
| 22 |
+
flash and version.parse(torch.__version__) < version.parse("2.0.0")
|
| 23 |
+
), "expected pytorch >= 2.0.0 to use flash attention"
|
| 24 |
+
|
| 25 |
+
self.cpu_backends = [
|
| 26 |
+
SDPBackend.FLASH_ATTENTION,
|
| 27 |
+
SDPBackend.EFFICIENT_ATTENTION,
|
| 28 |
+
SDPBackend.MATH,
|
| 29 |
+
]
|
| 30 |
+
self.cuda_backends: list | None = None
|
| 31 |
+
|
| 32 |
+
if not torch.cuda.is_available() or not flash:
|
| 33 |
+
return
|
| 34 |
+
|
| 35 |
+
device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
|
| 36 |
+
device_version = version.parse(
|
| 37 |
+
f"{device_properties.major}.{device_properties.minor}"
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
if device_version >= version.parse("8.0"):
|
| 41 |
+
if os.name == "nt":
|
| 42 |
+
cuda_backends = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
|
| 43 |
+
logger.info(f"windows detected, {cuda_backends=}")
|
| 44 |
+
else:
|
| 45 |
+
cuda_backends = [SDPBackend.FLASH_ATTENTION]
|
| 46 |
+
logger.info(f"gpu compute capability >= 8.0, {cuda_backends=}")
|
| 47 |
+
else:
|
| 48 |
+
cuda_backends = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
|
| 49 |
+
logger.info(f"gpu compute capability < 8.0, {cuda_backends=}")
|
| 50 |
+
|
| 51 |
+
self.cuda_backends = cuda_backends
|
| 52 |
+
|
| 53 |
+
def flash_attn(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
| 54 |
+
_, _heads, _q_len, _, _k_len, is_cuda, _device = (
|
| 55 |
+
*q.shape,
|
| 56 |
+
k.shape[-2],
|
| 57 |
+
q.is_cuda,
|
| 58 |
+
q.device,
|
| 59 |
+
) # type: ignore
|
| 60 |
+
|
| 61 |
+
if self.scale is not None:
|
| 62 |
+
default_scale = q.shape[-1] ** -0.5
|
| 63 |
+
q = q * (self.scale / default_scale)
|
| 64 |
+
|
| 65 |
+
backends = self.cuda_backends if is_cuda else self.cpu_backends
|
| 66 |
+
with sdpa_kernel(backends=backends): # type: ignore
|
| 67 |
+
out = F.scaled_dot_product_attention(
|
| 68 |
+
q, k, v, dropout_p=self.dropout if self.training else 0.0
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
return out
|
| 72 |
+
|
| 73 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
| 74 |
+
_q_len, _k_len, _device = q.shape[-2], k.shape[-2], q.device
|
| 75 |
+
|
| 76 |
+
scale = self.scale or q.shape[-1] ** -0.5
|
| 77 |
+
|
| 78 |
+
if self.flash:
|
| 79 |
+
return self.flash_attn(q, k, v)
|
| 80 |
+
|
| 81 |
+
sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale
|
| 82 |
+
|
| 83 |
+
attn = sim.softmax(dim=-1)
|
| 84 |
+
attn = self.attn_dropout(attn)
|
| 85 |
+
|
| 86 |
+
out = einsum("b h i j, b h j d -> b h i d", attn, v)
|
| 87 |
+
|
| 88 |
+
return out
|
mvsepless/models/bs_roformer/bs_roformer.py
CHANGED
|
@@ -1,696 +1,696 @@
|
|
| 1 |
-
from functools import partial
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
from torch import nn, einsum, tensor, Tensor
|
| 5 |
-
from torch.nn import Module, ModuleList
|
| 6 |
-
import torch.nn.functional as F
|
| 7 |
-
|
| 8 |
-
from .attend import Attend
|
| 9 |
-
|
| 10 |
-
try:
|
| 11 |
-
from .attend_sage import Attend as AttendSage
|
| 12 |
-
except:
|
| 13 |
-
pass
|
| 14 |
-
from torch.utils.checkpoint import checkpoint
|
| 15 |
-
|
| 16 |
-
from beartype.typing import Tuple, Optional, List, Callable
|
| 17 |
-
from beartype import beartype
|
| 18 |
-
|
| 19 |
-
from rotary_embedding_torch import RotaryEmbedding
|
| 20 |
-
|
| 21 |
-
from einops import rearrange, pack, unpack
|
| 22 |
-
from einops.layers.torch import Rearrange
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def exists(val):
|
| 26 |
-
return val is not None
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def default(v, d):
|
| 30 |
-
return v if exists(v) else d
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
def pack_one(t, pattern):
|
| 34 |
-
return pack([t], pattern)
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def unpack_one(t, ps, pattern):
|
| 38 |
-
return unpack(t, ps, pattern)[0]
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def l2norm(t):
|
| 42 |
-
return F.normalize(t, dim=-1, p=2)
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
class RMSNorm(Module):
|
| 46 |
-
def __init__(self, dim):
|
| 47 |
-
super().__init__()
|
| 48 |
-
self.scale = dim**0.5
|
| 49 |
-
self.gamma = nn.Parameter(torch.ones(dim))
|
| 50 |
-
|
| 51 |
-
def forward(self, x):
|
| 52 |
-
return F.normalize(x, dim=-1) * self.scale * self.gamma
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
class FeedForward(Module):
|
| 56 |
-
def __init__(self, dim, mult=4, dropout=0.0):
|
| 57 |
-
super().__init__()
|
| 58 |
-
dim_inner = int(dim * mult)
|
| 59 |
-
self.net = nn.Sequential(
|
| 60 |
-
RMSNorm(dim),
|
| 61 |
-
nn.Linear(dim, dim_inner),
|
| 62 |
-
nn.GELU(),
|
| 63 |
-
nn.Dropout(dropout),
|
| 64 |
-
nn.Linear(dim_inner, dim),
|
| 65 |
-
nn.Dropout(dropout),
|
| 66 |
-
)
|
| 67 |
-
|
| 68 |
-
def forward(self, x):
|
| 69 |
-
return self.net(x)
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
class Attention(Module):
|
| 73 |
-
def __init__(
|
| 74 |
-
self,
|
| 75 |
-
dim,
|
| 76 |
-
heads=8,
|
| 77 |
-
dim_head=64,
|
| 78 |
-
dropout=0.0,
|
| 79 |
-
rotary_embed=None,
|
| 80 |
-
flash=True,
|
| 81 |
-
sage_attention=False,
|
| 82 |
-
):
|
| 83 |
-
super().__init__()
|
| 84 |
-
self.heads = heads
|
| 85 |
-
self.scale = dim_head**-0.5
|
| 86 |
-
dim_inner = heads * dim_head
|
| 87 |
-
|
| 88 |
-
self.rotary_embed = rotary_embed
|
| 89 |
-
|
| 90 |
-
if sage_attention:
|
| 91 |
-
self.attend = AttendSage(flash=flash, dropout=dropout)
|
| 92 |
-
else:
|
| 93 |
-
self.attend = Attend(flash=flash, dropout=dropout)
|
| 94 |
-
|
| 95 |
-
self.norm = RMSNorm(dim)
|
| 96 |
-
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
|
| 97 |
-
|
| 98 |
-
self.to_gates = nn.Linear(dim, heads)
|
| 99 |
-
|
| 100 |
-
self.to_out = nn.Sequential(
|
| 101 |
-
nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout)
|
| 102 |
-
)
|
| 103 |
-
|
| 104 |
-
def forward(self, x):
|
| 105 |
-
x = self.norm(x)
|
| 106 |
-
|
| 107 |
-
q, k, v = rearrange(
|
| 108 |
-
self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads
|
| 109 |
-
)
|
| 110 |
-
|
| 111 |
-
if exists(self.rotary_embed):
|
| 112 |
-
q = self.rotary_embed.rotate_queries_or_keys(q)
|
| 113 |
-
k = self.rotary_embed.rotate_queries_or_keys(k)
|
| 114 |
-
|
| 115 |
-
out = self.attend(q, k, v)
|
| 116 |
-
|
| 117 |
-
gates = self.to_gates(x)
|
| 118 |
-
out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
|
| 119 |
-
|
| 120 |
-
out = rearrange(out, "b h n d -> b n (h d)")
|
| 121 |
-
return self.to_out(out)
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
class LinearAttention(Module):
|
| 125 |
-
|
| 126 |
-
@beartype
|
| 127 |
-
def __init__(
|
| 128 |
-
self,
|
| 129 |
-
*,
|
| 130 |
-
dim,
|
| 131 |
-
dim_head=32,
|
| 132 |
-
heads=8,
|
| 133 |
-
scale=8,
|
| 134 |
-
flash=True,
|
| 135 |
-
dropout=0.0,
|
| 136 |
-
sage_attention=False,
|
| 137 |
-
):
|
| 138 |
-
super().__init__()
|
| 139 |
-
dim_inner = dim_head * heads
|
| 140 |
-
self.norm = RMSNorm(dim)
|
| 141 |
-
|
| 142 |
-
self.to_qkv = nn.Sequential(
|
| 143 |
-
nn.Linear(dim, dim_inner * 3, bias=False),
|
| 144 |
-
Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads),
|
| 145 |
-
)
|
| 146 |
-
|
| 147 |
-
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
|
| 148 |
-
|
| 149 |
-
if sage_attention:
|
| 150 |
-
self.attend = AttendSage(scale=scale, dropout=dropout, flash=flash)
|
| 151 |
-
else:
|
| 152 |
-
self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
|
| 153 |
-
|
| 154 |
-
self.to_out = nn.Sequential(
|
| 155 |
-
Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)
|
| 156 |
-
)
|
| 157 |
-
|
| 158 |
-
def forward(self, x):
|
| 159 |
-
x = self.norm(x)
|
| 160 |
-
|
| 161 |
-
q, k, v = self.to_qkv(x)
|
| 162 |
-
|
| 163 |
-
q, k = map(l2norm, (q, k))
|
| 164 |
-
q = q * self.temperature.exp()
|
| 165 |
-
|
| 166 |
-
out = self.attend(q, k, v)
|
| 167 |
-
|
| 168 |
-
return self.to_out(out)
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
class Transformer(Module):
|
| 172 |
-
def __init__(
|
| 173 |
-
self,
|
| 174 |
-
*,
|
| 175 |
-
dim,
|
| 176 |
-
depth,
|
| 177 |
-
dim_head=64,
|
| 178 |
-
heads=8,
|
| 179 |
-
attn_dropout=0.0,
|
| 180 |
-
ff_dropout=0.0,
|
| 181 |
-
ff_mult=4,
|
| 182 |
-
norm_output=True,
|
| 183 |
-
rotary_embed=None,
|
| 184 |
-
flash_attn=True,
|
| 185 |
-
linear_attn=False,
|
| 186 |
-
sage_attention=False,
|
| 187 |
-
):
|
| 188 |
-
super().__init__()
|
| 189 |
-
self.layers = ModuleList([])
|
| 190 |
-
|
| 191 |
-
for _ in range(depth):
|
| 192 |
-
if linear_attn:
|
| 193 |
-
attn = LinearAttention(
|
| 194 |
-
dim=dim,
|
| 195 |
-
dim_head=dim_head,
|
| 196 |
-
heads=heads,
|
| 197 |
-
dropout=attn_dropout,
|
| 198 |
-
flash=flash_attn,
|
| 199 |
-
sage_attention=sage_attention,
|
| 200 |
-
)
|
| 201 |
-
else:
|
| 202 |
-
attn = Attention(
|
| 203 |
-
dim=dim,
|
| 204 |
-
dim_head=dim_head,
|
| 205 |
-
heads=heads,
|
| 206 |
-
dropout=attn_dropout,
|
| 207 |
-
rotary_embed=rotary_embed,
|
| 208 |
-
flash=flash_attn,
|
| 209 |
-
sage_attention=sage_attention,
|
| 210 |
-
)
|
| 211 |
-
|
| 212 |
-
self.layers.append(
|
| 213 |
-
ModuleList(
|
| 214 |
-
[attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]
|
| 215 |
-
)
|
| 216 |
-
)
|
| 217 |
-
|
| 218 |
-
self.norm = RMSNorm(dim) if norm_output else nn.Identity()
|
| 219 |
-
|
| 220 |
-
def forward(self, x):
|
| 221 |
-
|
| 222 |
-
for attn, ff in self.layers:
|
| 223 |
-
x = attn(x) + x
|
| 224 |
-
x = ff(x) + x
|
| 225 |
-
|
| 226 |
-
return self.norm(x)
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
class BandSplit(Module):
|
| 230 |
-
@beartype
|
| 231 |
-
def __init__(self, dim, dim_inputs: Tuple[int, ...]):
|
| 232 |
-
super().__init__()
|
| 233 |
-
self.dim_inputs = dim_inputs
|
| 234 |
-
self.to_features = ModuleList([])
|
| 235 |
-
|
| 236 |
-
for dim_in in dim_inputs:
|
| 237 |
-
net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
|
| 238 |
-
|
| 239 |
-
self.to_features.append(net)
|
| 240 |
-
|
| 241 |
-
def forward(self, x):
|
| 242 |
-
x = x.split(self.dim_inputs, dim=-1)
|
| 243 |
-
|
| 244 |
-
outs = []
|
| 245 |
-
for split_input, to_feature in zip(x, self.to_features):
|
| 246 |
-
split_output = to_feature(split_input)
|
| 247 |
-
outs.append(split_output)
|
| 248 |
-
|
| 249 |
-
return torch.stack(outs, dim=-2)
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
|
| 253 |
-
dim_hidden = default(dim_hidden, dim_in)
|
| 254 |
-
|
| 255 |
-
net = []
|
| 256 |
-
dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
|
| 257 |
-
|
| 258 |
-
for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
|
| 259 |
-
is_last = ind == (len(dims) - 2)
|
| 260 |
-
|
| 261 |
-
net.append(nn.Linear(layer_dim_in, layer_dim_out))
|
| 262 |
-
|
| 263 |
-
if is_last:
|
| 264 |
-
continue
|
| 265 |
-
|
| 266 |
-
net.append(activation())
|
| 267 |
-
|
| 268 |
-
return nn.Sequential(*net)
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
class MaskEstimator(Module):
|
| 272 |
-
@beartype
|
| 273 |
-
def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
|
| 274 |
-
super().__init__()
|
| 275 |
-
self.dim_inputs = dim_inputs
|
| 276 |
-
self.to_freqs = ModuleList([])
|
| 277 |
-
dim_hidden = dim * mlp_expansion_factor
|
| 278 |
-
|
| 279 |
-
for dim_in in dim_inputs:
|
| 280 |
-
net = []
|
| 281 |
-
|
| 282 |
-
mlp = nn.Sequential(
|
| 283 |
-
MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1)
|
| 284 |
-
)
|
| 285 |
-
|
| 286 |
-
self.to_freqs.append(mlp)
|
| 287 |
-
|
| 288 |
-
def forward(self, x):
|
| 289 |
-
x = x.unbind(dim=-2)
|
| 290 |
-
|
| 291 |
-
outs = []
|
| 292 |
-
|
| 293 |
-
for band_features, mlp in zip(x, self.to_freqs):
|
| 294 |
-
freq_out = mlp(band_features)
|
| 295 |
-
outs.append(freq_out)
|
| 296 |
-
|
| 297 |
-
return torch.cat(outs, dim=-1)
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
DEFAULT_FREQS_PER_BANDS = (
|
| 301 |
-
2,
|
| 302 |
-
2,
|
| 303 |
-
2,
|
| 304 |
-
2,
|
| 305 |
-
2,
|
| 306 |
-
2,
|
| 307 |
-
2,
|
| 308 |
-
2,
|
| 309 |
-
2,
|
| 310 |
-
2,
|
| 311 |
-
2,
|
| 312 |
-
2,
|
| 313 |
-
2,
|
| 314 |
-
2,
|
| 315 |
-
2,
|
| 316 |
-
2,
|
| 317 |
-
2,
|
| 318 |
-
2,
|
| 319 |
-
2,
|
| 320 |
-
2,
|
| 321 |
-
2,
|
| 322 |
-
2,
|
| 323 |
-
2,
|
| 324 |
-
2,
|
| 325 |
-
4,
|
| 326 |
-
4,
|
| 327 |
-
4,
|
| 328 |
-
4,
|
| 329 |
-
4,
|
| 330 |
-
4,
|
| 331 |
-
4,
|
| 332 |
-
4,
|
| 333 |
-
4,
|
| 334 |
-
4,
|
| 335 |
-
4,
|
| 336 |
-
4,
|
| 337 |
-
12,
|
| 338 |
-
12,
|
| 339 |
-
12,
|
| 340 |
-
12,
|
| 341 |
-
12,
|
| 342 |
-
12,
|
| 343 |
-
12,
|
| 344 |
-
12,
|
| 345 |
-
24,
|
| 346 |
-
24,
|
| 347 |
-
24,
|
| 348 |
-
24,
|
| 349 |
-
24,
|
| 350 |
-
24,
|
| 351 |
-
24,
|
| 352 |
-
24,
|
| 353 |
-
48,
|
| 354 |
-
48,
|
| 355 |
-
48,
|
| 356 |
-
48,
|
| 357 |
-
48,
|
| 358 |
-
48,
|
| 359 |
-
48,
|
| 360 |
-
48,
|
| 361 |
-
128,
|
| 362 |
-
129,
|
| 363 |
-
)
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
class BSRoformer(Module):
|
| 367 |
-
|
| 368 |
-
@beartype
|
| 369 |
-
def __init__(
|
| 370 |
-
self,
|
| 371 |
-
dim,
|
| 372 |
-
*,
|
| 373 |
-
depth,
|
| 374 |
-
stereo=False,
|
| 375 |
-
num_stems=1,
|
| 376 |
-
time_transformer_depth=2,
|
| 377 |
-
freq_transformer_depth=2,
|
| 378 |
-
linear_transformer_depth=0,
|
| 379 |
-
freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
|
| 380 |
-
dim_head=64,
|
| 381 |
-
heads=8,
|
| 382 |
-
attn_dropout=0.0,
|
| 383 |
-
ff_dropout=0.0,
|
| 384 |
-
flash_attn=True,
|
| 385 |
-
dim_freqs_in=1025,
|
| 386 |
-
stft_n_fft=2048,
|
| 387 |
-
stft_hop_length=512,
|
| 388 |
-
stft_win_length=2048,
|
| 389 |
-
stft_normalized=False,
|
| 390 |
-
stft_window_fn: Optional[Callable] = None,
|
| 391 |
-
zero_dc=True,
|
| 392 |
-
mask_estimator_depth=2,
|
| 393 |
-
multi_stft_resolution_loss_weight=1.0,
|
| 394 |
-
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (
|
| 395 |
-
4096,
|
| 396 |
-
2048,
|
| 397 |
-
1024,
|
| 398 |
-
512,
|
| 399 |
-
256,
|
| 400 |
-
),
|
| 401 |
-
multi_stft_hop_size=147,
|
| 402 |
-
multi_stft_normalized=False,
|
| 403 |
-
multi_stft_window_fn: Callable = torch.hann_window,
|
| 404 |
-
mlp_expansion_factor=4,
|
| 405 |
-
use_torch_checkpoint=False,
|
| 406 |
-
skip_connection=False,
|
| 407 |
-
sage_attention=False,
|
| 408 |
-
):
|
| 409 |
-
super().__init__()
|
| 410 |
-
|
| 411 |
-
self.stereo = stereo
|
| 412 |
-
self.audio_channels = 2 if stereo else 1
|
| 413 |
-
self.num_stems = num_stems
|
| 414 |
-
self.use_torch_checkpoint = use_torch_checkpoint
|
| 415 |
-
self.skip_connection = skip_connection
|
| 416 |
-
|
| 417 |
-
self.layers = ModuleList([])
|
| 418 |
-
|
| 419 |
-
if sage_attention:
|
| 420 |
-
print("Use Sage Attention")
|
| 421 |
-
|
| 422 |
-
transformer_kwargs = dict(
|
| 423 |
-
dim=dim,
|
| 424 |
-
heads=heads,
|
| 425 |
-
dim_head=dim_head,
|
| 426 |
-
attn_dropout=attn_dropout,
|
| 427 |
-
ff_dropout=ff_dropout,
|
| 428 |
-
flash_attn=flash_attn,
|
| 429 |
-
norm_output=False,
|
| 430 |
-
sage_attention=sage_attention,
|
| 431 |
-
)
|
| 432 |
-
|
| 433 |
-
time_rotary_embed = RotaryEmbedding(dim=dim_head)
|
| 434 |
-
freq_rotary_embed = RotaryEmbedding(dim=dim_head)
|
| 435 |
-
|
| 436 |
-
for _ in range(depth):
|
| 437 |
-
tran_modules = []
|
| 438 |
-
if linear_transformer_depth > 0:
|
| 439 |
-
tran_modules.append(
|
| 440 |
-
Transformer(
|
| 441 |
-
depth=linear_transformer_depth,
|
| 442 |
-
linear_attn=True,
|
| 443 |
-
**transformer_kwargs,
|
| 444 |
-
)
|
| 445 |
-
)
|
| 446 |
-
tran_modules.append(
|
| 447 |
-
Transformer(
|
| 448 |
-
depth=time_transformer_depth,
|
| 449 |
-
rotary_embed=time_rotary_embed,
|
| 450 |
-
**transformer_kwargs,
|
| 451 |
-
)
|
| 452 |
-
)
|
| 453 |
-
tran_modules.append(
|
| 454 |
-
Transformer(
|
| 455 |
-
depth=freq_transformer_depth,
|
| 456 |
-
rotary_embed=freq_rotary_embed,
|
| 457 |
-
**transformer_kwargs,
|
| 458 |
-
)
|
| 459 |
-
)
|
| 460 |
-
self.layers.append(nn.ModuleList(tran_modules))
|
| 461 |
-
|
| 462 |
-
self.final_norm = RMSNorm(dim)
|
| 463 |
-
|
| 464 |
-
self.stft_kwargs = dict(
|
| 465 |
-
n_fft=stft_n_fft,
|
| 466 |
-
hop_length=stft_hop_length,
|
| 467 |
-
win_length=stft_win_length,
|
| 468 |
-
normalized=stft_normalized,
|
| 469 |
-
)
|
| 470 |
-
|
| 471 |
-
self.stft_window_fn = partial(
|
| 472 |
-
default(stft_window_fn, torch.hann_window), stft_win_length
|
| 473 |
-
)
|
| 474 |
-
|
| 475 |
-
freqs = torch.stft(
|
| 476 |
-
torch.randn(1, 4096),
|
| 477 |
-
**self.stft_kwargs,
|
| 478 |
-
window=torch.ones(stft_win_length),
|
| 479 |
-
return_complex=True,
|
| 480 |
-
).shape[1]
|
| 481 |
-
|
| 482 |
-
assert len(freqs_per_bands) > 1
|
| 483 |
-
assert (
|
| 484 |
-
sum(freqs_per_bands) == freqs
|
| 485 |
-
), f"the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}"
|
| 486 |
-
|
| 487 |
-
freqs_per_bands_with_complex = tuple(
|
| 488 |
-
2 * f * self.audio_channels for f in freqs_per_bands
|
| 489 |
-
)
|
| 490 |
-
|
| 491 |
-
self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
|
| 492 |
-
|
| 493 |
-
self.mask_estimators = nn.ModuleList([])
|
| 494 |
-
|
| 495 |
-
for _ in range(num_stems):
|
| 496 |
-
mask_estimator = MaskEstimator(
|
| 497 |
-
dim=dim,
|
| 498 |
-
dim_inputs=freqs_per_bands_with_complex,
|
| 499 |
-
depth=mask_estimator_depth,
|
| 500 |
-
mlp_expansion_factor=mlp_expansion_factor,
|
| 501 |
-
)
|
| 502 |
-
|
| 503 |
-
self.mask_estimators.append(mask_estimator)
|
| 504 |
-
|
| 505 |
-
self.zero_dc = zero_dc
|
| 506 |
-
|
| 507 |
-
self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
|
| 508 |
-
self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
|
| 509 |
-
self.multi_stft_n_fft = stft_n_fft
|
| 510 |
-
self.multi_stft_window_fn = multi_stft_window_fn
|
| 511 |
-
|
| 512 |
-
self.multi_stft_kwargs = dict(
|
| 513 |
-
hop_length=multi_stft_hop_size, normalized=multi_stft_normalized
|
| 514 |
-
)
|
| 515 |
-
|
| 516 |
-
def forward(self, raw_audio, target=None, return_loss_breakdown=False):
|
| 517 |
-
|
| 518 |
-
device = raw_audio.device
|
| 519 |
-
|
| 520 |
-
x_is_mps = True if device.type == "mps" else False
|
| 521 |
-
|
| 522 |
-
if raw_audio.ndim == 2:
|
| 523 |
-
raw_audio = rearrange(raw_audio, "b t -> b 1 t")
|
| 524 |
-
|
| 525 |
-
channels = raw_audio.shape[1]
|
| 526 |
-
assert (not self.stereo and channels == 1) or (
|
| 527 |
-
self.stereo and channels == 2
|
| 528 |
-
), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
|
| 529 |
-
|
| 530 |
-
raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
|
| 531 |
-
|
| 532 |
-
stft_window = self.stft_window_fn(device=device)
|
| 533 |
-
|
| 534 |
-
try:
|
| 535 |
-
stft_repr = torch.stft(
|
| 536 |
-
raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True
|
| 537 |
-
)
|
| 538 |
-
except:
|
| 539 |
-
stft_repr = torch.stft(
|
| 540 |
-
raw_audio.cpu() if x_is_mps else raw_audio,
|
| 541 |
-
**self.stft_kwargs,
|
| 542 |
-
window=stft_window.cpu() if x_is_mps else stft_window,
|
| 543 |
-
return_complex=True,
|
| 544 |
-
).to(device)
|
| 545 |
-
stft_repr = torch.view_as_real(stft_repr)
|
| 546 |
-
|
| 547 |
-
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
|
| 548 |
-
|
| 549 |
-
stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c")
|
| 550 |
-
|
| 551 |
-
x = rearrange(stft_repr, "b f t c -> b t (f c)")
|
| 552 |
-
|
| 553 |
-
if self.use_torch_checkpoint:
|
| 554 |
-
x = checkpoint(self.band_split, x, use_reentrant=False)
|
| 555 |
-
else:
|
| 556 |
-
x = self.band_split(x)
|
| 557 |
-
|
| 558 |
-
store = [None] * len(self.layers)
|
| 559 |
-
for i, transformer_block in enumerate(self.layers):
|
| 560 |
-
|
| 561 |
-
if len(transformer_block) == 3:
|
| 562 |
-
linear_transformer, time_transformer, freq_transformer = (
|
| 563 |
-
transformer_block
|
| 564 |
-
)
|
| 565 |
-
|
| 566 |
-
x, ft_ps = pack([x], "b * d")
|
| 567 |
-
if self.use_torch_checkpoint:
|
| 568 |
-
x = checkpoint(linear_transformer, x, use_reentrant=False)
|
| 569 |
-
else:
|
| 570 |
-
x = linear_transformer(x)
|
| 571 |
-
(x,) = unpack(x, ft_ps, "b * d")
|
| 572 |
-
else:
|
| 573 |
-
time_transformer, freq_transformer = transformer_block
|
| 574 |
-
|
| 575 |
-
if self.skip_connection:
|
| 576 |
-
for j in range(i):
|
| 577 |
-
x = x + store[j]
|
| 578 |
-
|
| 579 |
-
x = rearrange(x, "b t f d -> b f t d")
|
| 580 |
-
x, ps = pack([x], "* t d")
|
| 581 |
-
|
| 582 |
-
if self.use_torch_checkpoint:
|
| 583 |
-
x = checkpoint(time_transformer, x, use_reentrant=False)
|
| 584 |
-
else:
|
| 585 |
-
x = time_transformer(x)
|
| 586 |
-
|
| 587 |
-
(x,) = unpack(x, ps, "* t d")
|
| 588 |
-
x = rearrange(x, "b f t d -> b t f d")
|
| 589 |
-
x, ps = pack([x], "* f d")
|
| 590 |
-
|
| 591 |
-
if self.use_torch_checkpoint:
|
| 592 |
-
x = checkpoint(freq_transformer, x, use_reentrant=False)
|
| 593 |
-
else:
|
| 594 |
-
x = freq_transformer(x)
|
| 595 |
-
|
| 596 |
-
(x,) = unpack(x, ps, "* f d")
|
| 597 |
-
|
| 598 |
-
if self.skip_connection:
|
| 599 |
-
store[i] = x
|
| 600 |
-
|
| 601 |
-
x = self.final_norm(x)
|
| 602 |
-
|
| 603 |
-
num_stems = len(self.mask_estimators)
|
| 604 |
-
|
| 605 |
-
if self.use_torch_checkpoint:
|
| 606 |
-
mask = torch.stack(
|
| 607 |
-
[checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators],
|
| 608 |
-
dim=1,
|
| 609 |
-
)
|
| 610 |
-
else:
|
| 611 |
-
mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
|
| 612 |
-
mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2)
|
| 613 |
-
|
| 614 |
-
stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
|
| 615 |
-
|
| 616 |
-
stft_repr = torch.view_as_complex(stft_repr)
|
| 617 |
-
mask = torch.view_as_complex(mask)
|
| 618 |
-
|
| 619 |
-
stft_repr = stft_repr * mask
|
| 620 |
-
|
| 621 |
-
stft_repr = rearrange(
|
| 622 |
-
stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels
|
| 623 |
-
)
|
| 624 |
-
|
| 625 |
-
if self.zero_dc:
|
| 626 |
-
stft_repr = stft_repr.index_fill(1, tensor(0, device=device), 0.0)
|
| 627 |
-
|
| 628 |
-
try:
|
| 629 |
-
recon_audio = torch.istft(
|
| 630 |
-
stft_repr,
|
| 631 |
-
**self.stft_kwargs,
|
| 632 |
-
window=stft_window,
|
| 633 |
-
return_complex=False,
|
| 634 |
-
length=raw_audio.shape[-1],
|
| 635 |
-
)
|
| 636 |
-
except:
|
| 637 |
-
recon_audio = torch.istft(
|
| 638 |
-
stft_repr.cpu() if x_is_mps else stft_repr,
|
| 639 |
-
**self.stft_kwargs,
|
| 640 |
-
window=stft_window.cpu() if x_is_mps else stft_window,
|
| 641 |
-
return_complex=False,
|
| 642 |
-
length=raw_audio.shape[-1],
|
| 643 |
-
).to(device)
|
| 644 |
-
|
| 645 |
-
recon_audio = rearrange(
|
| 646 |
-
recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=num_stems
|
| 647 |
-
)
|
| 648 |
-
|
| 649 |
-
if num_stems == 1:
|
| 650 |
-
recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
|
| 651 |
-
|
| 652 |
-
if not exists(target):
|
| 653 |
-
return recon_audio
|
| 654 |
-
|
| 655 |
-
if self.num_stems > 1:
|
| 656 |
-
assert target.ndim == 4 and target.shape[1] == self.num_stems
|
| 657 |
-
|
| 658 |
-
if target.ndim == 2:
|
| 659 |
-
target = rearrange(target, "... t -> ... 1 t")
|
| 660 |
-
|
| 661 |
-
target = target[..., : recon_audio.shape[-1]]
|
| 662 |
-
|
| 663 |
-
loss = F.l1_loss(recon_audio, target)
|
| 664 |
-
|
| 665 |
-
multi_stft_resolution_loss = 0.0
|
| 666 |
-
|
| 667 |
-
for window_size in self.multi_stft_resolutions_window_sizes:
|
| 668 |
-
res_stft_kwargs = dict(
|
| 669 |
-
n_fft=max(window_size, self.multi_stft_n_fft),
|
| 670 |
-
win_length=window_size,
|
| 671 |
-
return_complex=True,
|
| 672 |
-
window=self.multi_stft_window_fn(window_size, device=device),
|
| 673 |
-
**self.multi_stft_kwargs,
|
| 674 |
-
)
|
| 675 |
-
|
| 676 |
-
recon_Y = torch.stft(
|
| 677 |
-
rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs
|
| 678 |
-
)
|
| 679 |
-
target_Y = torch.stft(
|
| 680 |
-
rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs
|
| 681 |
-
)
|
| 682 |
-
|
| 683 |
-
multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(
|
| 684 |
-
recon_Y, target_Y
|
| 685 |
-
)
|
| 686 |
-
|
| 687 |
-
weighted_multi_resolution_loss = (
|
| 688 |
-
multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
|
| 689 |
-
)
|
| 690 |
-
|
| 691 |
-
total_loss = loss + weighted_multi_resolution_loss
|
| 692 |
-
|
| 693 |
-
if not return_loss_breakdown:
|
| 694 |
-
return total_loss
|
| 695 |
-
|
| 696 |
-
return total_loss, (loss, multi_stft_resolution_loss)
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn, einsum, tensor, Tensor
|
| 5 |
+
from torch.nn import Module, ModuleList
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from .attend import Attend
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
from .attend_sage import Attend as AttendSage
|
| 12 |
+
except:
|
| 13 |
+
pass
|
| 14 |
+
from torch.utils.checkpoint import checkpoint
|
| 15 |
+
|
| 16 |
+
from beartype.typing import Tuple, Optional, List, Callable
|
| 17 |
+
from beartype import beartype
|
| 18 |
+
|
| 19 |
+
from rotary_embedding_torch import RotaryEmbedding
|
| 20 |
+
|
| 21 |
+
from einops import rearrange, pack, unpack
|
| 22 |
+
from einops.layers.torch import Rearrange
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def exists(val):
|
| 26 |
+
return val is not None
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def default(v, d):
|
| 30 |
+
return v if exists(v) else d
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def pack_one(t, pattern):
|
| 34 |
+
return pack([t], pattern)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def unpack_one(t, ps, pattern):
|
| 38 |
+
return unpack(t, ps, pattern)[0]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def l2norm(t):
|
| 42 |
+
return F.normalize(t, dim=-1, p=2)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class RMSNorm(Module):
|
| 46 |
+
def __init__(self, dim):
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.scale = dim**0.5
|
| 49 |
+
self.gamma = nn.Parameter(torch.ones(dim))
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
return F.normalize(x, dim=-1) * self.scale * self.gamma
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class FeedForward(Module):
|
| 56 |
+
def __init__(self, dim, mult=4, dropout=0.0):
|
| 57 |
+
super().__init__()
|
| 58 |
+
dim_inner = int(dim * mult)
|
| 59 |
+
self.net = nn.Sequential(
|
| 60 |
+
RMSNorm(dim),
|
| 61 |
+
nn.Linear(dim, dim_inner),
|
| 62 |
+
nn.GELU(),
|
| 63 |
+
nn.Dropout(dropout),
|
| 64 |
+
nn.Linear(dim_inner, dim),
|
| 65 |
+
nn.Dropout(dropout),
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
return self.net(x)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class Attention(Module):
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
dim,
|
| 76 |
+
heads=8,
|
| 77 |
+
dim_head=64,
|
| 78 |
+
dropout=0.0,
|
| 79 |
+
rotary_embed=None,
|
| 80 |
+
flash=True,
|
| 81 |
+
sage_attention=False,
|
| 82 |
+
):
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.heads = heads
|
| 85 |
+
self.scale = dim_head**-0.5
|
| 86 |
+
dim_inner = heads * dim_head
|
| 87 |
+
|
| 88 |
+
self.rotary_embed = rotary_embed
|
| 89 |
+
|
| 90 |
+
if sage_attention:
|
| 91 |
+
self.attend = AttendSage(flash=flash, dropout=dropout)
|
| 92 |
+
else:
|
| 93 |
+
self.attend = Attend(flash=flash, dropout=dropout)
|
| 94 |
+
|
| 95 |
+
self.norm = RMSNorm(dim)
|
| 96 |
+
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
|
| 97 |
+
|
| 98 |
+
self.to_gates = nn.Linear(dim, heads)
|
| 99 |
+
|
| 100 |
+
self.to_out = nn.Sequential(
|
| 101 |
+
nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout)
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
def forward(self, x):
|
| 105 |
+
x = self.norm(x)
|
| 106 |
+
|
| 107 |
+
q, k, v = rearrange(
|
| 108 |
+
self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
if exists(self.rotary_embed):
|
| 112 |
+
q = self.rotary_embed.rotate_queries_or_keys(q)
|
| 113 |
+
k = self.rotary_embed.rotate_queries_or_keys(k)
|
| 114 |
+
|
| 115 |
+
out = self.attend(q, k, v)
|
| 116 |
+
|
| 117 |
+
gates = self.to_gates(x)
|
| 118 |
+
out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
|
| 119 |
+
|
| 120 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
| 121 |
+
return self.to_out(out)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class LinearAttention(Module):
|
| 125 |
+
|
| 126 |
+
@beartype
|
| 127 |
+
def __init__(
|
| 128 |
+
self,
|
| 129 |
+
*,
|
| 130 |
+
dim,
|
| 131 |
+
dim_head=32,
|
| 132 |
+
heads=8,
|
| 133 |
+
scale=8,
|
| 134 |
+
flash=True,
|
| 135 |
+
dropout=0.0,
|
| 136 |
+
sage_attention=False,
|
| 137 |
+
):
|
| 138 |
+
super().__init__()
|
| 139 |
+
dim_inner = dim_head * heads
|
| 140 |
+
self.norm = RMSNorm(dim)
|
| 141 |
+
|
| 142 |
+
self.to_qkv = nn.Sequential(
|
| 143 |
+
nn.Linear(dim, dim_inner * 3, bias=False),
|
| 144 |
+
Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads),
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
|
| 148 |
+
|
| 149 |
+
if sage_attention:
|
| 150 |
+
self.attend = AttendSage(scale=scale, dropout=dropout, flash=flash)
|
| 151 |
+
else:
|
| 152 |
+
self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
|
| 153 |
+
|
| 154 |
+
self.to_out = nn.Sequential(
|
| 155 |
+
Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
def forward(self, x):
|
| 159 |
+
x = self.norm(x)
|
| 160 |
+
|
| 161 |
+
q, k, v = self.to_qkv(x)
|
| 162 |
+
|
| 163 |
+
q, k = map(l2norm, (q, k))
|
| 164 |
+
q = q * self.temperature.exp()
|
| 165 |
+
|
| 166 |
+
out = self.attend(q, k, v)
|
| 167 |
+
|
| 168 |
+
return self.to_out(out)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class Transformer(Module):
|
| 172 |
+
def __init__(
|
| 173 |
+
self,
|
| 174 |
+
*,
|
| 175 |
+
dim,
|
| 176 |
+
depth,
|
| 177 |
+
dim_head=64,
|
| 178 |
+
heads=8,
|
| 179 |
+
attn_dropout=0.0,
|
| 180 |
+
ff_dropout=0.0,
|
| 181 |
+
ff_mult=4,
|
| 182 |
+
norm_output=True,
|
| 183 |
+
rotary_embed=None,
|
| 184 |
+
flash_attn=True,
|
| 185 |
+
linear_attn=False,
|
| 186 |
+
sage_attention=False,
|
| 187 |
+
):
|
| 188 |
+
super().__init__()
|
| 189 |
+
self.layers = ModuleList([])
|
| 190 |
+
|
| 191 |
+
for _ in range(depth):
|
| 192 |
+
if linear_attn:
|
| 193 |
+
attn = LinearAttention(
|
| 194 |
+
dim=dim,
|
| 195 |
+
dim_head=dim_head,
|
| 196 |
+
heads=heads,
|
| 197 |
+
dropout=attn_dropout,
|
| 198 |
+
flash=flash_attn,
|
| 199 |
+
sage_attention=sage_attention,
|
| 200 |
+
)
|
| 201 |
+
else:
|
| 202 |
+
attn = Attention(
|
| 203 |
+
dim=dim,
|
| 204 |
+
dim_head=dim_head,
|
| 205 |
+
heads=heads,
|
| 206 |
+
dropout=attn_dropout,
|
| 207 |
+
rotary_embed=rotary_embed,
|
| 208 |
+
flash=flash_attn,
|
| 209 |
+
sage_attention=sage_attention,
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
self.layers.append(
|
| 213 |
+
ModuleList(
|
| 214 |
+
[attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]
|
| 215 |
+
)
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
self.norm = RMSNorm(dim) if norm_output else nn.Identity()
|
| 219 |
+
|
| 220 |
+
def forward(self, x):
|
| 221 |
+
|
| 222 |
+
for attn, ff in self.layers:
|
| 223 |
+
x = attn(x) + x
|
| 224 |
+
x = ff(x) + x
|
| 225 |
+
|
| 226 |
+
return self.norm(x)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class BandSplit(Module):
|
| 230 |
+
@beartype
|
| 231 |
+
def __init__(self, dim, dim_inputs: Tuple[int, ...]):
|
| 232 |
+
super().__init__()
|
| 233 |
+
self.dim_inputs = dim_inputs
|
| 234 |
+
self.to_features = ModuleList([])
|
| 235 |
+
|
| 236 |
+
for dim_in in dim_inputs:
|
| 237 |
+
net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
|
| 238 |
+
|
| 239 |
+
self.to_features.append(net)
|
| 240 |
+
|
| 241 |
+
def forward(self, x):
|
| 242 |
+
x = x.split(self.dim_inputs, dim=-1)
|
| 243 |
+
|
| 244 |
+
outs = []
|
| 245 |
+
for split_input, to_feature in zip(x, self.to_features):
|
| 246 |
+
split_output = to_feature(split_input)
|
| 247 |
+
outs.append(split_output)
|
| 248 |
+
|
| 249 |
+
return torch.stack(outs, dim=-2)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
|
| 253 |
+
dim_hidden = default(dim_hidden, dim_in)
|
| 254 |
+
|
| 255 |
+
net = []
|
| 256 |
+
dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
|
| 257 |
+
|
| 258 |
+
for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
|
| 259 |
+
is_last = ind == (len(dims) - 2)
|
| 260 |
+
|
| 261 |
+
net.append(nn.Linear(layer_dim_in, layer_dim_out))
|
| 262 |
+
|
| 263 |
+
if is_last:
|
| 264 |
+
continue
|
| 265 |
+
|
| 266 |
+
net.append(activation())
|
| 267 |
+
|
| 268 |
+
return nn.Sequential(*net)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class MaskEstimator(Module):
|
| 272 |
+
@beartype
|
| 273 |
+
def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
|
| 274 |
+
super().__init__()
|
| 275 |
+
self.dim_inputs = dim_inputs
|
| 276 |
+
self.to_freqs = ModuleList([])
|
| 277 |
+
dim_hidden = dim * mlp_expansion_factor
|
| 278 |
+
|
| 279 |
+
for dim_in in dim_inputs:
|
| 280 |
+
net = []
|
| 281 |
+
|
| 282 |
+
mlp = nn.Sequential(
|
| 283 |
+
MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1)
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
self.to_freqs.append(mlp)
|
| 287 |
+
|
| 288 |
+
def forward(self, x):
|
| 289 |
+
x = x.unbind(dim=-2)
|
| 290 |
+
|
| 291 |
+
outs = []
|
| 292 |
+
|
| 293 |
+
for band_features, mlp in zip(x, self.to_freqs):
|
| 294 |
+
freq_out = mlp(band_features)
|
| 295 |
+
outs.append(freq_out)
|
| 296 |
+
|
| 297 |
+
return torch.cat(outs, dim=-1)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
DEFAULT_FREQS_PER_BANDS = (
|
| 301 |
+
2,
|
| 302 |
+
2,
|
| 303 |
+
2,
|
| 304 |
+
2,
|
| 305 |
+
2,
|
| 306 |
+
2,
|
| 307 |
+
2,
|
| 308 |
+
2,
|
| 309 |
+
2,
|
| 310 |
+
2,
|
| 311 |
+
2,
|
| 312 |
+
2,
|
| 313 |
+
2,
|
| 314 |
+
2,
|
| 315 |
+
2,
|
| 316 |
+
2,
|
| 317 |
+
2,
|
| 318 |
+
2,
|
| 319 |
+
2,
|
| 320 |
+
2,
|
| 321 |
+
2,
|
| 322 |
+
2,
|
| 323 |
+
2,
|
| 324 |
+
2,
|
| 325 |
+
4,
|
| 326 |
+
4,
|
| 327 |
+
4,
|
| 328 |
+
4,
|
| 329 |
+
4,
|
| 330 |
+
4,
|
| 331 |
+
4,
|
| 332 |
+
4,
|
| 333 |
+
4,
|
| 334 |
+
4,
|
| 335 |
+
4,
|
| 336 |
+
4,
|
| 337 |
+
12,
|
| 338 |
+
12,
|
| 339 |
+
12,
|
| 340 |
+
12,
|
| 341 |
+
12,
|
| 342 |
+
12,
|
| 343 |
+
12,
|
| 344 |
+
12,
|
| 345 |
+
24,
|
| 346 |
+
24,
|
| 347 |
+
24,
|
| 348 |
+
24,
|
| 349 |
+
24,
|
| 350 |
+
24,
|
| 351 |
+
24,
|
| 352 |
+
24,
|
| 353 |
+
48,
|
| 354 |
+
48,
|
| 355 |
+
48,
|
| 356 |
+
48,
|
| 357 |
+
48,
|
| 358 |
+
48,
|
| 359 |
+
48,
|
| 360 |
+
48,
|
| 361 |
+
128,
|
| 362 |
+
129,
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
class BSRoformer(Module):
|
| 367 |
+
|
| 368 |
+
@beartype
|
| 369 |
+
def __init__(
|
| 370 |
+
self,
|
| 371 |
+
dim,
|
| 372 |
+
*,
|
| 373 |
+
depth,
|
| 374 |
+
stereo=False,
|
| 375 |
+
num_stems=1,
|
| 376 |
+
time_transformer_depth=2,
|
| 377 |
+
freq_transformer_depth=2,
|
| 378 |
+
linear_transformer_depth=0,
|
| 379 |
+
freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
|
| 380 |
+
dim_head=64,
|
| 381 |
+
heads=8,
|
| 382 |
+
attn_dropout=0.0,
|
| 383 |
+
ff_dropout=0.0,
|
| 384 |
+
flash_attn=True,
|
| 385 |
+
dim_freqs_in=1025,
|
| 386 |
+
stft_n_fft=2048,
|
| 387 |
+
stft_hop_length=512,
|
| 388 |
+
stft_win_length=2048,
|
| 389 |
+
stft_normalized=False,
|
| 390 |
+
stft_window_fn: Optional[Callable] = None,
|
| 391 |
+
zero_dc=True,
|
| 392 |
+
mask_estimator_depth=2,
|
| 393 |
+
multi_stft_resolution_loss_weight=1.0,
|
| 394 |
+
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (
|
| 395 |
+
4096,
|
| 396 |
+
2048,
|
| 397 |
+
1024,
|
| 398 |
+
512,
|
| 399 |
+
256,
|
| 400 |
+
),
|
| 401 |
+
multi_stft_hop_size=147,
|
| 402 |
+
multi_stft_normalized=False,
|
| 403 |
+
multi_stft_window_fn: Callable = torch.hann_window,
|
| 404 |
+
mlp_expansion_factor=4,
|
| 405 |
+
use_torch_checkpoint=False,
|
| 406 |
+
skip_connection=False,
|
| 407 |
+
sage_attention=False,
|
| 408 |
+
):
|
| 409 |
+
super().__init__()
|
| 410 |
+
|
| 411 |
+
self.stereo = stereo
|
| 412 |
+
self.audio_channels = 2 if stereo else 1
|
| 413 |
+
self.num_stems = num_stems
|
| 414 |
+
self.use_torch_checkpoint = use_torch_checkpoint
|
| 415 |
+
self.skip_connection = skip_connection
|
| 416 |
+
|
| 417 |
+
self.layers = ModuleList([])
|
| 418 |
+
|
| 419 |
+
if sage_attention:
|
| 420 |
+
print("Use Sage Attention")
|
| 421 |
+
|
| 422 |
+
transformer_kwargs = dict(
|
| 423 |
+
dim=dim,
|
| 424 |
+
heads=heads,
|
| 425 |
+
dim_head=dim_head,
|
| 426 |
+
attn_dropout=attn_dropout,
|
| 427 |
+
ff_dropout=ff_dropout,
|
| 428 |
+
flash_attn=flash_attn,
|
| 429 |
+
norm_output=False,
|
| 430 |
+
sage_attention=sage_attention,
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
time_rotary_embed = RotaryEmbedding(dim=dim_head)
|
| 434 |
+
freq_rotary_embed = RotaryEmbedding(dim=dim_head)
|
| 435 |
+
|
| 436 |
+
for _ in range(depth):
|
| 437 |
+
tran_modules = []
|
| 438 |
+
if linear_transformer_depth > 0:
|
| 439 |
+
tran_modules.append(
|
| 440 |
+
Transformer(
|
| 441 |
+
depth=linear_transformer_depth,
|
| 442 |
+
linear_attn=True,
|
| 443 |
+
**transformer_kwargs,
|
| 444 |
+
)
|
| 445 |
+
)
|
| 446 |
+
tran_modules.append(
|
| 447 |
+
Transformer(
|
| 448 |
+
depth=time_transformer_depth,
|
| 449 |
+
rotary_embed=time_rotary_embed,
|
| 450 |
+
**transformer_kwargs,
|
| 451 |
+
)
|
| 452 |
+
)
|
| 453 |
+
tran_modules.append(
|
| 454 |
+
Transformer(
|
| 455 |
+
depth=freq_transformer_depth,
|
| 456 |
+
rotary_embed=freq_rotary_embed,
|
| 457 |
+
**transformer_kwargs,
|
| 458 |
+
)
|
| 459 |
+
)
|
| 460 |
+
self.layers.append(nn.ModuleList(tran_modules))
|
| 461 |
+
|
| 462 |
+
self.final_norm = RMSNorm(dim)
|
| 463 |
+
|
| 464 |
+
self.stft_kwargs = dict(
|
| 465 |
+
n_fft=stft_n_fft,
|
| 466 |
+
hop_length=stft_hop_length,
|
| 467 |
+
win_length=stft_win_length,
|
| 468 |
+
normalized=stft_normalized,
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
self.stft_window_fn = partial(
|
| 472 |
+
default(stft_window_fn, torch.hann_window), stft_win_length
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
freqs = torch.stft(
|
| 476 |
+
torch.randn(1, 4096),
|
| 477 |
+
**self.stft_kwargs,
|
| 478 |
+
window=torch.ones(stft_win_length),
|
| 479 |
+
return_complex=True,
|
| 480 |
+
).shape[1]
|
| 481 |
+
|
| 482 |
+
assert len(freqs_per_bands) > 1
|
| 483 |
+
assert (
|
| 484 |
+
sum(freqs_per_bands) == freqs
|
| 485 |
+
), f"the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}"
|
| 486 |
+
|
| 487 |
+
freqs_per_bands_with_complex = tuple(
|
| 488 |
+
2 * f * self.audio_channels for f in freqs_per_bands
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
|
| 492 |
+
|
| 493 |
+
self.mask_estimators = nn.ModuleList([])
|
| 494 |
+
|
| 495 |
+
for _ in range(num_stems):
|
| 496 |
+
mask_estimator = MaskEstimator(
|
| 497 |
+
dim=dim,
|
| 498 |
+
dim_inputs=freqs_per_bands_with_complex,
|
| 499 |
+
depth=mask_estimator_depth,
|
| 500 |
+
mlp_expansion_factor=mlp_expansion_factor,
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
self.mask_estimators.append(mask_estimator)
|
| 504 |
+
|
| 505 |
+
self.zero_dc = zero_dc
|
| 506 |
+
|
| 507 |
+
self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
|
| 508 |
+
self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
|
| 509 |
+
self.multi_stft_n_fft = stft_n_fft
|
| 510 |
+
self.multi_stft_window_fn = multi_stft_window_fn
|
| 511 |
+
|
| 512 |
+
self.multi_stft_kwargs = dict(
|
| 513 |
+
hop_length=multi_stft_hop_size, normalized=multi_stft_normalized
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
def forward(self, raw_audio, target=None, return_loss_breakdown=False):
|
| 517 |
+
|
| 518 |
+
device = raw_audio.device
|
| 519 |
+
|
| 520 |
+
x_is_mps = True if device.type == "mps" else False
|
| 521 |
+
|
| 522 |
+
if raw_audio.ndim == 2:
|
| 523 |
+
raw_audio = rearrange(raw_audio, "b t -> b 1 t")
|
| 524 |
+
|
| 525 |
+
channels = raw_audio.shape[1]
|
| 526 |
+
assert (not self.stereo and channels == 1) or (
|
| 527 |
+
self.stereo and channels == 2
|
| 528 |
+
), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
|
| 529 |
+
|
| 530 |
+
raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
|
| 531 |
+
|
| 532 |
+
stft_window = self.stft_window_fn(device=device)
|
| 533 |
+
|
| 534 |
+
try:
|
| 535 |
+
stft_repr = torch.stft(
|
| 536 |
+
raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True
|
| 537 |
+
)
|
| 538 |
+
except:
|
| 539 |
+
stft_repr = torch.stft(
|
| 540 |
+
raw_audio.cpu() if x_is_mps else raw_audio,
|
| 541 |
+
**self.stft_kwargs,
|
| 542 |
+
window=stft_window.cpu() if x_is_mps else stft_window,
|
| 543 |
+
return_complex=True,
|
| 544 |
+
).to(device)
|
| 545 |
+
stft_repr = torch.view_as_real(stft_repr)
|
| 546 |
+
|
| 547 |
+
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
|
| 548 |
+
|
| 549 |
+
stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c")
|
| 550 |
+
|
| 551 |
+
x = rearrange(stft_repr, "b f t c -> b t (f c)")
|
| 552 |
+
|
| 553 |
+
if self.use_torch_checkpoint:
|
| 554 |
+
x = checkpoint(self.band_split, x, use_reentrant=False)
|
| 555 |
+
else:
|
| 556 |
+
x = self.band_split(x)
|
| 557 |
+
|
| 558 |
+
store = [None] * len(self.layers)
|
| 559 |
+
for i, transformer_block in enumerate(self.layers):
|
| 560 |
+
|
| 561 |
+
if len(transformer_block) == 3:
|
| 562 |
+
linear_transformer, time_transformer, freq_transformer = (
|
| 563 |
+
transformer_block
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
x, ft_ps = pack([x], "b * d")
|
| 567 |
+
if self.use_torch_checkpoint:
|
| 568 |
+
x = checkpoint(linear_transformer, x, use_reentrant=False)
|
| 569 |
+
else:
|
| 570 |
+
x = linear_transformer(x)
|
| 571 |
+
(x,) = unpack(x, ft_ps, "b * d")
|
| 572 |
+
else:
|
| 573 |
+
time_transformer, freq_transformer = transformer_block
|
| 574 |
+
|
| 575 |
+
if self.skip_connection:
|
| 576 |
+
for j in range(i):
|
| 577 |
+
x = x + store[j]
|
| 578 |
+
|
| 579 |
+
x = rearrange(x, "b t f d -> b f t d")
|
| 580 |
+
x, ps = pack([x], "* t d")
|
| 581 |
+
|
| 582 |
+
if self.use_torch_checkpoint:
|
| 583 |
+
x = checkpoint(time_transformer, x, use_reentrant=False)
|
| 584 |
+
else:
|
| 585 |
+
x = time_transformer(x)
|
| 586 |
+
|
| 587 |
+
(x,) = unpack(x, ps, "* t d")
|
| 588 |
+
x = rearrange(x, "b f t d -> b t f d")
|
| 589 |
+
x, ps = pack([x], "* f d")
|
| 590 |
+
|
| 591 |
+
if self.use_torch_checkpoint:
|
| 592 |
+
x = checkpoint(freq_transformer, x, use_reentrant=False)
|
| 593 |
+
else:
|
| 594 |
+
x = freq_transformer(x)
|
| 595 |
+
|
| 596 |
+
(x,) = unpack(x, ps, "* f d")
|
| 597 |
+
|
| 598 |
+
if self.skip_connection:
|
| 599 |
+
store[i] = x
|
| 600 |
+
|
| 601 |
+
x = self.final_norm(x)
|
| 602 |
+
|
| 603 |
+
num_stems = len(self.mask_estimators)
|
| 604 |
+
|
| 605 |
+
if self.use_torch_checkpoint:
|
| 606 |
+
mask = torch.stack(
|
| 607 |
+
[checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators],
|
| 608 |
+
dim=1,
|
| 609 |
+
)
|
| 610 |
+
else:
|
| 611 |
+
mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
|
| 612 |
+
mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2)
|
| 613 |
+
|
| 614 |
+
stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
|
| 615 |
+
|
| 616 |
+
stft_repr = torch.view_as_complex(stft_repr)
|
| 617 |
+
mask = torch.view_as_complex(mask)
|
| 618 |
+
|
| 619 |
+
stft_repr = stft_repr * mask
|
| 620 |
+
|
| 621 |
+
stft_repr = rearrange(
|
| 622 |
+
stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels
|
| 623 |
+
)
|
| 624 |
+
|
| 625 |
+
if self.zero_dc:
|
| 626 |
+
stft_repr = stft_repr.index_fill(1, tensor(0, device=device), 0.0)
|
| 627 |
+
|
| 628 |
+
try:
|
| 629 |
+
recon_audio = torch.istft(
|
| 630 |
+
stft_repr,
|
| 631 |
+
**self.stft_kwargs,
|
| 632 |
+
window=stft_window,
|
| 633 |
+
return_complex=False,
|
| 634 |
+
length=raw_audio.shape[-1],
|
| 635 |
+
)
|
| 636 |
+
except:
|
| 637 |
+
recon_audio = torch.istft(
|
| 638 |
+
stft_repr.cpu() if x_is_mps else stft_repr,
|
| 639 |
+
**self.stft_kwargs,
|
| 640 |
+
window=stft_window.cpu() if x_is_mps else stft_window,
|
| 641 |
+
return_complex=False,
|
| 642 |
+
length=raw_audio.shape[-1],
|
| 643 |
+
).to(device)
|
| 644 |
+
|
| 645 |
+
recon_audio = rearrange(
|
| 646 |
+
recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=num_stems
|
| 647 |
+
)
|
| 648 |
+
|
| 649 |
+
if num_stems == 1:
|
| 650 |
+
recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
|
| 651 |
+
|
| 652 |
+
if not exists(target):
|
| 653 |
+
return recon_audio
|
| 654 |
+
|
| 655 |
+
if self.num_stems > 1:
|
| 656 |
+
assert target.ndim == 4 and target.shape[1] == self.num_stems
|
| 657 |
+
|
| 658 |
+
if target.ndim == 2:
|
| 659 |
+
target = rearrange(target, "... t -> ... 1 t")
|
| 660 |
+
|
| 661 |
+
target = target[..., : recon_audio.shape[-1]]
|
| 662 |
+
|
| 663 |
+
loss = F.l1_loss(recon_audio, target)
|
| 664 |
+
|
| 665 |
+
multi_stft_resolution_loss = 0.0
|
| 666 |
+
|
| 667 |
+
for window_size in self.multi_stft_resolutions_window_sizes:
|
| 668 |
+
res_stft_kwargs = dict(
|
| 669 |
+
n_fft=max(window_size, self.multi_stft_n_fft),
|
| 670 |
+
win_length=window_size,
|
| 671 |
+
return_complex=True,
|
| 672 |
+
window=self.multi_stft_window_fn(window_size, device=device),
|
| 673 |
+
**self.multi_stft_kwargs,
|
| 674 |
+
)
|
| 675 |
+
|
| 676 |
+
recon_Y = torch.stft(
|
| 677 |
+
rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs
|
| 678 |
+
)
|
| 679 |
+
target_Y = torch.stft(
|
| 680 |
+
rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs
|
| 681 |
+
)
|
| 682 |
+
|
| 683 |
+
multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(
|
| 684 |
+
recon_Y, target_Y
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
weighted_multi_resolution_loss = (
|
| 688 |
+
multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
total_loss = loss + weighted_multi_resolution_loss
|
| 692 |
+
|
| 693 |
+
if not return_loss_breakdown:
|
| 694 |
+
return total_loss
|
| 695 |
+
|
| 696 |
+
return total_loss, (loss, multi_stft_resolution_loss)
|
mvsepless/models/bs_roformer/bs_roformer_fno.py
CHANGED
|
@@ -1,704 +1,704 @@
|
|
| 1 |
-
from functools import partial
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
from torch import nn, einsum, Tensor
|
| 5 |
-
from torch.nn import Module, ModuleList
|
| 6 |
-
import torch.nn.functional as F
|
| 7 |
-
from neuralop.models import FNO1d
|
| 8 |
-
|
| 9 |
-
from .attend import Attend
|
| 10 |
-
|
| 11 |
-
try:
|
| 12 |
-
from .attend_sage import Attend as AttendSage
|
| 13 |
-
except:
|
| 14 |
-
pass
|
| 15 |
-
from torch.utils.checkpoint import checkpoint
|
| 16 |
-
|
| 17 |
-
from beartype.typing import Tuple, Optional, List, Callable
|
| 18 |
-
from beartype import beartype
|
| 19 |
-
|
| 20 |
-
from rotary_embedding_torch import RotaryEmbedding
|
| 21 |
-
|
| 22 |
-
from einops import rearrange, pack, unpack
|
| 23 |
-
from einops.layers.torch import Rearrange
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def exists(val):
|
| 27 |
-
return val is not None
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
def default(v, d):
|
| 31 |
-
return v if exists(v) else d
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def pack_one(t, pattern):
|
| 35 |
-
return pack([t], pattern)
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
def unpack_one(t, ps, pattern):
|
| 39 |
-
return unpack(t, ps, pattern)[0]
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
def l2norm(t):
|
| 43 |
-
return F.normalize(t, dim=-1, p=2)
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
class RMSNorm(Module):
|
| 47 |
-
def __init__(self, dim):
|
| 48 |
-
super().__init__()
|
| 49 |
-
self.scale = dim**0.5
|
| 50 |
-
self.gamma = nn.Parameter(torch.ones(dim))
|
| 51 |
-
|
| 52 |
-
def forward(self, x):
|
| 53 |
-
return F.normalize(x, dim=-1) * self.scale * self.gamma
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
class FeedForward(Module):
|
| 57 |
-
def __init__(self, dim, mult=4, dropout=0.0):
|
| 58 |
-
super().__init__()
|
| 59 |
-
dim_inner = int(dim * mult)
|
| 60 |
-
self.net = nn.Sequential(
|
| 61 |
-
RMSNorm(dim),
|
| 62 |
-
nn.Linear(dim, dim_inner),
|
| 63 |
-
nn.GELU(),
|
| 64 |
-
nn.Dropout(dropout),
|
| 65 |
-
nn.Linear(dim_inner, dim),
|
| 66 |
-
nn.Dropout(dropout),
|
| 67 |
-
)
|
| 68 |
-
|
| 69 |
-
def forward(self, x):
|
| 70 |
-
return self.net(x)
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
class Attention(Module):
|
| 74 |
-
def __init__(
|
| 75 |
-
self,
|
| 76 |
-
dim,
|
| 77 |
-
heads=8,
|
| 78 |
-
dim_head=64,
|
| 79 |
-
dropout=0.0,
|
| 80 |
-
rotary_embed=None,
|
| 81 |
-
flash=True,
|
| 82 |
-
sage_attention=False,
|
| 83 |
-
):
|
| 84 |
-
super().__init__()
|
| 85 |
-
self.heads = heads
|
| 86 |
-
self.scale = dim_head**-0.5
|
| 87 |
-
dim_inner = heads * dim_head
|
| 88 |
-
|
| 89 |
-
self.rotary_embed = rotary_embed
|
| 90 |
-
|
| 91 |
-
if sage_attention:
|
| 92 |
-
self.attend = AttendSage(flash=flash, dropout=dropout)
|
| 93 |
-
else:
|
| 94 |
-
self.attend = Attend(flash=flash, dropout=dropout)
|
| 95 |
-
|
| 96 |
-
self.norm = RMSNorm(dim)
|
| 97 |
-
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
|
| 98 |
-
|
| 99 |
-
self.to_gates = nn.Linear(dim, heads)
|
| 100 |
-
|
| 101 |
-
self.to_out = nn.Sequential(
|
| 102 |
-
nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout)
|
| 103 |
-
)
|
| 104 |
-
|
| 105 |
-
def forward(self, x):
|
| 106 |
-
x = self.norm(x)
|
| 107 |
-
|
| 108 |
-
q, k, v = rearrange(
|
| 109 |
-
self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads
|
| 110 |
-
)
|
| 111 |
-
|
| 112 |
-
if exists(self.rotary_embed):
|
| 113 |
-
q = self.rotary_embed.rotate_queries_or_keys(q)
|
| 114 |
-
k = self.rotary_embed.rotate_queries_or_keys(k)
|
| 115 |
-
|
| 116 |
-
out = self.attend(q, k, v)
|
| 117 |
-
|
| 118 |
-
gates = self.to_gates(x)
|
| 119 |
-
out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
|
| 120 |
-
|
| 121 |
-
out = rearrange(out, "b h n d -> b n (h d)")
|
| 122 |
-
return self.to_out(out)
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
class LinearAttention(Module):
|
| 126 |
-
|
| 127 |
-
@beartype
|
| 128 |
-
def __init__(
|
| 129 |
-
self,
|
| 130 |
-
*,
|
| 131 |
-
dim,
|
| 132 |
-
dim_head=32,
|
| 133 |
-
heads=8,
|
| 134 |
-
scale=8,
|
| 135 |
-
flash=False,
|
| 136 |
-
dropout=0.0,
|
| 137 |
-
sage_attention=False,
|
| 138 |
-
):
|
| 139 |
-
super().__init__()
|
| 140 |
-
dim_inner = dim_head * heads
|
| 141 |
-
self.norm = RMSNorm(dim)
|
| 142 |
-
|
| 143 |
-
self.to_qkv = nn.Sequential(
|
| 144 |
-
nn.Linear(dim, dim_inner * 3, bias=False),
|
| 145 |
-
Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads),
|
| 146 |
-
)
|
| 147 |
-
|
| 148 |
-
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
|
| 149 |
-
|
| 150 |
-
if sage_attention:
|
| 151 |
-
self.attend = AttendSage(scale=scale, dropout=dropout, flash=flash)
|
| 152 |
-
else:
|
| 153 |
-
self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
|
| 154 |
-
|
| 155 |
-
self.to_out = nn.Sequential(
|
| 156 |
-
Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)
|
| 157 |
-
)
|
| 158 |
-
|
| 159 |
-
def forward(self, x):
|
| 160 |
-
x = self.norm(x)
|
| 161 |
-
|
| 162 |
-
q, k, v = self.to_qkv(x)
|
| 163 |
-
|
| 164 |
-
q, k = map(l2norm, (q, k))
|
| 165 |
-
q = q * self.temperature.exp()
|
| 166 |
-
|
| 167 |
-
out = self.attend(q, k, v)
|
| 168 |
-
|
| 169 |
-
return self.to_out(out)
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
class Transformer(Module):
|
| 173 |
-
def __init__(
|
| 174 |
-
self,
|
| 175 |
-
*,
|
| 176 |
-
dim,
|
| 177 |
-
depth,
|
| 178 |
-
dim_head=64,
|
| 179 |
-
heads=8,
|
| 180 |
-
attn_dropout=0.0,
|
| 181 |
-
ff_dropout=0.0,
|
| 182 |
-
ff_mult=4,
|
| 183 |
-
norm_output=True,
|
| 184 |
-
rotary_embed=None,
|
| 185 |
-
flash_attn=True,
|
| 186 |
-
linear_attn=False,
|
| 187 |
-
sage_attention=False,
|
| 188 |
-
):
|
| 189 |
-
super().__init__()
|
| 190 |
-
self.layers = ModuleList([])
|
| 191 |
-
|
| 192 |
-
for _ in range(depth):
|
| 193 |
-
if linear_attn:
|
| 194 |
-
attn = LinearAttention(
|
| 195 |
-
dim=dim,
|
| 196 |
-
dim_head=dim_head,
|
| 197 |
-
heads=heads,
|
| 198 |
-
dropout=attn_dropout,
|
| 199 |
-
flash=flash_attn,
|
| 200 |
-
sage_attention=sage_attention,
|
| 201 |
-
)
|
| 202 |
-
else:
|
| 203 |
-
attn = Attention(
|
| 204 |
-
dim=dim,
|
| 205 |
-
dim_head=dim_head,
|
| 206 |
-
heads=heads,
|
| 207 |
-
dropout=attn_dropout,
|
| 208 |
-
rotary_embed=rotary_embed,
|
| 209 |
-
flash=flash_attn,
|
| 210 |
-
sage_attention=sage_attention,
|
| 211 |
-
)
|
| 212 |
-
|
| 213 |
-
self.layers.append(
|
| 214 |
-
ModuleList(
|
| 215 |
-
[attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]
|
| 216 |
-
)
|
| 217 |
-
)
|
| 218 |
-
|
| 219 |
-
self.norm = RMSNorm(dim) if norm_output else nn.Identity()
|
| 220 |
-
|
| 221 |
-
def forward(self, x):
|
| 222 |
-
|
| 223 |
-
for attn, ff in self.layers:
|
| 224 |
-
x = attn(x) + x
|
| 225 |
-
x = ff(x) + x
|
| 226 |
-
|
| 227 |
-
return self.norm(x)
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
class BandSplit(Module):
|
| 231 |
-
@beartype
|
| 232 |
-
def __init__(self, dim, dim_inputs: Tuple[int, ...]):
|
| 233 |
-
super().__init__()
|
| 234 |
-
self.dim_inputs = dim_inputs
|
| 235 |
-
self.to_features = ModuleList([])
|
| 236 |
-
|
| 237 |
-
for dim_in in dim_inputs:
|
| 238 |
-
net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
|
| 239 |
-
|
| 240 |
-
self.to_features.append(net)
|
| 241 |
-
|
| 242 |
-
def forward(self, x):
|
| 243 |
-
x = x.split(self.dim_inputs, dim=-1)
|
| 244 |
-
|
| 245 |
-
outs = []
|
| 246 |
-
for split_input, to_feature in zip(x, self.to_features):
|
| 247 |
-
split_output = to_feature(split_input)
|
| 248 |
-
outs.append(split_output)
|
| 249 |
-
|
| 250 |
-
return torch.stack(outs, dim=-2)
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
|
| 254 |
-
dim_hidden = default(dim_hidden, dim_in)
|
| 255 |
-
|
| 256 |
-
net = []
|
| 257 |
-
dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
|
| 258 |
-
|
| 259 |
-
for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
|
| 260 |
-
is_last = ind == (len(dims) - 2)
|
| 261 |
-
|
| 262 |
-
net.append(nn.Linear(layer_dim_in, layer_dim_out))
|
| 263 |
-
|
| 264 |
-
if is_last:
|
| 265 |
-
continue
|
| 266 |
-
|
| 267 |
-
net.append(activation())
|
| 268 |
-
|
| 269 |
-
return nn.Sequential(*net)
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
class MaskEstimator(Module):
|
| 273 |
-
@beartype
|
| 274 |
-
def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
|
| 275 |
-
super().__init__()
|
| 276 |
-
self.dim_inputs = dim_inputs
|
| 277 |
-
self.to_freqs = ModuleList([])
|
| 278 |
-
dim_hidden = dim * mlp_expansion_factor
|
| 279 |
-
|
| 280 |
-
for dim_in in dim_inputs:
|
| 281 |
-
net = []
|
| 282 |
-
|
| 283 |
-
mlp = nn.Sequential(
|
| 284 |
-
FNO1d(
|
| 285 |
-
n_modes_height=64,
|
| 286 |
-
hidden_channels=dim,
|
| 287 |
-
in_channels=dim,
|
| 288 |
-
out_channels=dim_in * 2,
|
| 289 |
-
lifting_channels=dim,
|
| 290 |
-
projection_channels=dim,
|
| 291 |
-
n_layers=3,
|
| 292 |
-
separable=True,
|
| 293 |
-
),
|
| 294 |
-
nn.GLU(dim=-2),
|
| 295 |
-
)
|
| 296 |
-
|
| 297 |
-
self.to_freqs.append(mlp)
|
| 298 |
-
|
| 299 |
-
def forward(self, x):
|
| 300 |
-
x = x.unbind(dim=-2)
|
| 301 |
-
|
| 302 |
-
outs = []
|
| 303 |
-
|
| 304 |
-
for band_features, mlp in zip(x, self.to_freqs):
|
| 305 |
-
band_features = rearrange(band_features, "b t c -> b c t")
|
| 306 |
-
with torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32):
|
| 307 |
-
freq_out = mlp(band_features).float()
|
| 308 |
-
freq_out = rearrange(freq_out, "b c t -> b t c")
|
| 309 |
-
outs.append(freq_out)
|
| 310 |
-
|
| 311 |
-
return torch.cat(outs, dim=-1)
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
DEFAULT_FREQS_PER_BANDS = (
|
| 315 |
-
2,
|
| 316 |
-
2,
|
| 317 |
-
2,
|
| 318 |
-
2,
|
| 319 |
-
2,
|
| 320 |
-
2,
|
| 321 |
-
2,
|
| 322 |
-
2,
|
| 323 |
-
2,
|
| 324 |
-
2,
|
| 325 |
-
2,
|
| 326 |
-
2,
|
| 327 |
-
2,
|
| 328 |
-
2,
|
| 329 |
-
2,
|
| 330 |
-
2,
|
| 331 |
-
2,
|
| 332 |
-
2,
|
| 333 |
-
2,
|
| 334 |
-
2,
|
| 335 |
-
2,
|
| 336 |
-
2,
|
| 337 |
-
2,
|
| 338 |
-
2,
|
| 339 |
-
4,
|
| 340 |
-
4,
|
| 341 |
-
4,
|
| 342 |
-
4,
|
| 343 |
-
4,
|
| 344 |
-
4,
|
| 345 |
-
4,
|
| 346 |
-
4,
|
| 347 |
-
4,
|
| 348 |
-
4,
|
| 349 |
-
4,
|
| 350 |
-
4,
|
| 351 |
-
12,
|
| 352 |
-
12,
|
| 353 |
-
12,
|
| 354 |
-
12,
|
| 355 |
-
12,
|
| 356 |
-
12,
|
| 357 |
-
12,
|
| 358 |
-
12,
|
| 359 |
-
24,
|
| 360 |
-
24,
|
| 361 |
-
24,
|
| 362 |
-
24,
|
| 363 |
-
24,
|
| 364 |
-
24,
|
| 365 |
-
24,
|
| 366 |
-
24,
|
| 367 |
-
48,
|
| 368 |
-
48,
|
| 369 |
-
48,
|
| 370 |
-
48,
|
| 371 |
-
48,
|
| 372 |
-
48,
|
| 373 |
-
48,
|
| 374 |
-
48,
|
| 375 |
-
128,
|
| 376 |
-
129,
|
| 377 |
-
)
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
class BSRoformer_FNO(Module):
|
| 381 |
-
|
| 382 |
-
@beartype
|
| 383 |
-
def __init__(
|
| 384 |
-
self,
|
| 385 |
-
dim,
|
| 386 |
-
*,
|
| 387 |
-
depth,
|
| 388 |
-
stereo=False,
|
| 389 |
-
num_stems=1,
|
| 390 |
-
time_transformer_depth=2,
|
| 391 |
-
freq_transformer_depth=2,
|
| 392 |
-
linear_transformer_depth=0,
|
| 393 |
-
freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
|
| 394 |
-
dim_head=64,
|
| 395 |
-
heads=8,
|
| 396 |
-
attn_dropout=0.0,
|
| 397 |
-
ff_dropout=0.0,
|
| 398 |
-
flash_attn=True,
|
| 399 |
-
dim_freqs_in=1025,
|
| 400 |
-
stft_n_fft=2048,
|
| 401 |
-
stft_hop_length=512,
|
| 402 |
-
stft_win_length=2048,
|
| 403 |
-
stft_normalized=False,
|
| 404 |
-
stft_window_fn: Optional[Callable] = None,
|
| 405 |
-
mask_estimator_depth=2,
|
| 406 |
-
multi_stft_resolution_loss_weight=1.0,
|
| 407 |
-
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (
|
| 408 |
-
4096,
|
| 409 |
-
2048,
|
| 410 |
-
1024,
|
| 411 |
-
512,
|
| 412 |
-
256,
|
| 413 |
-
),
|
| 414 |
-
multi_stft_hop_size=147,
|
| 415 |
-
multi_stft_normalized=False,
|
| 416 |
-
multi_stft_window_fn: Callable = torch.hann_window,
|
| 417 |
-
mlp_expansion_factor=4,
|
| 418 |
-
use_torch_checkpoint=False,
|
| 419 |
-
skip_connection=False,
|
| 420 |
-
sage_attention=False,
|
| 421 |
-
):
|
| 422 |
-
super().__init__()
|
| 423 |
-
|
| 424 |
-
self.stereo = stereo
|
| 425 |
-
self.audio_channels = 2 if stereo else 1
|
| 426 |
-
self.num_stems = num_stems
|
| 427 |
-
self.use_torch_checkpoint = use_torch_checkpoint
|
| 428 |
-
self.skip_connection = skip_connection
|
| 429 |
-
|
| 430 |
-
self.layers = ModuleList([])
|
| 431 |
-
|
| 432 |
-
if sage_attention:
|
| 433 |
-
print("Use Sage Attention")
|
| 434 |
-
|
| 435 |
-
transformer_kwargs = dict(
|
| 436 |
-
dim=dim,
|
| 437 |
-
heads=heads,
|
| 438 |
-
dim_head=dim_head,
|
| 439 |
-
attn_dropout=attn_dropout,
|
| 440 |
-
ff_dropout=ff_dropout,
|
| 441 |
-
flash_attn=flash_attn,
|
| 442 |
-
norm_output=False,
|
| 443 |
-
sage_attention=sage_attention,
|
| 444 |
-
)
|
| 445 |
-
|
| 446 |
-
time_rotary_embed = RotaryEmbedding(dim=dim_head)
|
| 447 |
-
freq_rotary_embed = RotaryEmbedding(dim=dim_head)
|
| 448 |
-
|
| 449 |
-
for _ in range(depth):
|
| 450 |
-
tran_modules = []
|
| 451 |
-
if linear_transformer_depth > 0:
|
| 452 |
-
tran_modules.append(
|
| 453 |
-
Transformer(
|
| 454 |
-
depth=linear_transformer_depth,
|
| 455 |
-
linear_attn=True,
|
| 456 |
-
**transformer_kwargs,
|
| 457 |
-
)
|
| 458 |
-
)
|
| 459 |
-
tran_modules.append(
|
| 460 |
-
Transformer(
|
| 461 |
-
depth=time_transformer_depth,
|
| 462 |
-
rotary_embed=time_rotary_embed,
|
| 463 |
-
**transformer_kwargs,
|
| 464 |
-
)
|
| 465 |
-
)
|
| 466 |
-
tran_modules.append(
|
| 467 |
-
Transformer(
|
| 468 |
-
depth=freq_transformer_depth,
|
| 469 |
-
rotary_embed=freq_rotary_embed,
|
| 470 |
-
**transformer_kwargs,
|
| 471 |
-
)
|
| 472 |
-
)
|
| 473 |
-
self.layers.append(nn.ModuleList(tran_modules))
|
| 474 |
-
|
| 475 |
-
self.final_norm = RMSNorm(dim)
|
| 476 |
-
|
| 477 |
-
self.stft_kwargs = dict(
|
| 478 |
-
n_fft=stft_n_fft,
|
| 479 |
-
hop_length=stft_hop_length,
|
| 480 |
-
win_length=stft_win_length,
|
| 481 |
-
normalized=stft_normalized,
|
| 482 |
-
)
|
| 483 |
-
|
| 484 |
-
self.stft_window_fn = partial(
|
| 485 |
-
default(stft_window_fn, torch.hann_window), stft_win_length
|
| 486 |
-
)
|
| 487 |
-
|
| 488 |
-
freqs = torch.stft(
|
| 489 |
-
torch.randn(1, 4096),
|
| 490 |
-
**self.stft_kwargs,
|
| 491 |
-
window=torch.ones(stft_win_length),
|
| 492 |
-
return_complex=True,
|
| 493 |
-
).shape[1]
|
| 494 |
-
|
| 495 |
-
assert len(freqs_per_bands) > 1
|
| 496 |
-
assert (
|
| 497 |
-
sum(freqs_per_bands) == freqs
|
| 498 |
-
), f"the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}"
|
| 499 |
-
|
| 500 |
-
freqs_per_bands_with_complex = tuple(
|
| 501 |
-
2 * f * self.audio_channels for f in freqs_per_bands
|
| 502 |
-
)
|
| 503 |
-
|
| 504 |
-
self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
|
| 505 |
-
|
| 506 |
-
self.mask_estimators = nn.ModuleList([])
|
| 507 |
-
|
| 508 |
-
for _ in range(num_stems):
|
| 509 |
-
mask_estimator = MaskEstimator(
|
| 510 |
-
dim=dim,
|
| 511 |
-
dim_inputs=freqs_per_bands_with_complex,
|
| 512 |
-
depth=mask_estimator_depth,
|
| 513 |
-
mlp_expansion_factor=mlp_expansion_factor,
|
| 514 |
-
)
|
| 515 |
-
|
| 516 |
-
self.mask_estimators.append(mask_estimator)
|
| 517 |
-
|
| 518 |
-
self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
|
| 519 |
-
self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
|
| 520 |
-
self.multi_stft_n_fft = stft_n_fft
|
| 521 |
-
self.multi_stft_window_fn = multi_stft_window_fn
|
| 522 |
-
|
| 523 |
-
self.multi_stft_kwargs = dict(
|
| 524 |
-
hop_length=multi_stft_hop_size, normalized=multi_stft_normalized
|
| 525 |
-
)
|
| 526 |
-
|
| 527 |
-
def forward(self, raw_audio, target=None, return_loss_breakdown=False):
|
| 528 |
-
|
| 529 |
-
device = raw_audio.device
|
| 530 |
-
|
| 531 |
-
x_is_mps = True if device.type == "mps" else False
|
| 532 |
-
|
| 533 |
-
if raw_audio.ndim == 2:
|
| 534 |
-
raw_audio = rearrange(raw_audio, "b t -> b 1 t")
|
| 535 |
-
|
| 536 |
-
channels = raw_audio.shape[1]
|
| 537 |
-
assert (not self.stereo and channels == 1) or (
|
| 538 |
-
self.stereo and channels == 2
|
| 539 |
-
), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
|
| 540 |
-
|
| 541 |
-
raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
|
| 542 |
-
|
| 543 |
-
stft_window = self.stft_window_fn(device=device)
|
| 544 |
-
|
| 545 |
-
try:
|
| 546 |
-
stft_repr = torch.stft(
|
| 547 |
-
raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True
|
| 548 |
-
)
|
| 549 |
-
except:
|
| 550 |
-
stft_repr = torch.stft(
|
| 551 |
-
raw_audio.cpu() if x_is_mps else raw_audio,
|
| 552 |
-
**self.stft_kwargs,
|
| 553 |
-
window=stft_window.cpu() if x_is_mps else stft_window,
|
| 554 |
-
return_complex=True,
|
| 555 |
-
).to(device)
|
| 556 |
-
stft_repr = torch.view_as_real(stft_repr)
|
| 557 |
-
|
| 558 |
-
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
|
| 559 |
-
|
| 560 |
-
stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c")
|
| 561 |
-
|
| 562 |
-
x = rearrange(stft_repr, "b f t c -> b t (f c)")
|
| 563 |
-
|
| 564 |
-
if self.use_torch_checkpoint:
|
| 565 |
-
x = checkpoint(self.band_split, x, use_reentrant=False)
|
| 566 |
-
else:
|
| 567 |
-
x = self.band_split(x)
|
| 568 |
-
|
| 569 |
-
store = [None] * len(self.layers)
|
| 570 |
-
for i, transformer_block in enumerate(self.layers):
|
| 571 |
-
|
| 572 |
-
if len(transformer_block) == 3:
|
| 573 |
-
linear_transformer, time_transformer, freq_transformer = (
|
| 574 |
-
transformer_block
|
| 575 |
-
)
|
| 576 |
-
|
| 577 |
-
x, ft_ps = pack([x], "b * d")
|
| 578 |
-
if self.use_torch_checkpoint:
|
| 579 |
-
x = checkpoint(linear_transformer, x, use_reentrant=False)
|
| 580 |
-
else:
|
| 581 |
-
x = linear_transformer(x)
|
| 582 |
-
(x,) = unpack(x, ft_ps, "b * d")
|
| 583 |
-
else:
|
| 584 |
-
time_transformer, freq_transformer = transformer_block
|
| 585 |
-
|
| 586 |
-
if self.skip_connection:
|
| 587 |
-
for j in range(i):
|
| 588 |
-
x = x + store[j]
|
| 589 |
-
|
| 590 |
-
x = rearrange(x, "b t f d -> b f t d")
|
| 591 |
-
x, ps = pack([x], "* t d")
|
| 592 |
-
|
| 593 |
-
if self.use_torch_checkpoint:
|
| 594 |
-
x = checkpoint(time_transformer, x, use_reentrant=False)
|
| 595 |
-
else:
|
| 596 |
-
x = time_transformer(x)
|
| 597 |
-
|
| 598 |
-
(x,) = unpack(x, ps, "* t d")
|
| 599 |
-
x = rearrange(x, "b f t d -> b t f d")
|
| 600 |
-
x, ps = pack([x], "* f d")
|
| 601 |
-
|
| 602 |
-
if self.use_torch_checkpoint:
|
| 603 |
-
x = checkpoint(freq_transformer, x, use_reentrant=False)
|
| 604 |
-
else:
|
| 605 |
-
x = freq_transformer(x)
|
| 606 |
-
|
| 607 |
-
(x,) = unpack(x, ps, "* f d")
|
| 608 |
-
|
| 609 |
-
if self.skip_connection:
|
| 610 |
-
store[i] = x
|
| 611 |
-
|
| 612 |
-
x = self.final_norm(x)
|
| 613 |
-
|
| 614 |
-
num_stems = len(self.mask_estimators)
|
| 615 |
-
|
| 616 |
-
if self.use_torch_checkpoint:
|
| 617 |
-
mask = torch.stack(
|
| 618 |
-
[checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators],
|
| 619 |
-
dim=1,
|
| 620 |
-
)
|
| 621 |
-
else:
|
| 622 |
-
mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
|
| 623 |
-
mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2)
|
| 624 |
-
|
| 625 |
-
stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
|
| 626 |
-
|
| 627 |
-
stft_repr = torch.view_as_complex(stft_repr)
|
| 628 |
-
mask = torch.view_as_complex(mask)
|
| 629 |
-
|
| 630 |
-
stft_repr = stft_repr * mask
|
| 631 |
-
|
| 632 |
-
stft_repr = rearrange(
|
| 633 |
-
stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels
|
| 634 |
-
)
|
| 635 |
-
|
| 636 |
-
try:
|
| 637 |
-
recon_audio = torch.istft(
|
| 638 |
-
stft_repr,
|
| 639 |
-
**self.stft_kwargs,
|
| 640 |
-
window=stft_window,
|
| 641 |
-
return_complex=False,
|
| 642 |
-
length=raw_audio.shape[-1],
|
| 643 |
-
)
|
| 644 |
-
except:
|
| 645 |
-
recon_audio = torch.istft(
|
| 646 |
-
stft_repr.cpu() if x_is_mps else stft_repr,
|
| 647 |
-
**self.stft_kwargs,
|
| 648 |
-
window=stft_window.cpu() if x_is_mps else stft_window,
|
| 649 |
-
return_complex=False,
|
| 650 |
-
length=raw_audio.shape[-1],
|
| 651 |
-
).to(device)
|
| 652 |
-
|
| 653 |
-
recon_audio = rearrange(
|
| 654 |
-
recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=num_stems
|
| 655 |
-
)
|
| 656 |
-
|
| 657 |
-
if num_stems == 1:
|
| 658 |
-
recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
|
| 659 |
-
|
| 660 |
-
if not exists(target):
|
| 661 |
-
return recon_audio
|
| 662 |
-
|
| 663 |
-
if self.num_stems > 1:
|
| 664 |
-
assert target.ndim == 4 and target.shape[1] == self.num_stems
|
| 665 |
-
|
| 666 |
-
if target.ndim == 2:
|
| 667 |
-
target = rearrange(target, "... t -> ... 1 t")
|
| 668 |
-
|
| 669 |
-
target = target[..., : recon_audio.shape[-1]]
|
| 670 |
-
|
| 671 |
-
loss = F.l1_loss(recon_audio, target)
|
| 672 |
-
|
| 673 |
-
multi_stft_resolution_loss = 0.0
|
| 674 |
-
|
| 675 |
-
for window_size in self.multi_stft_resolutions_window_sizes:
|
| 676 |
-
res_stft_kwargs = dict(
|
| 677 |
-
n_fft=max(window_size, self.multi_stft_n_fft),
|
| 678 |
-
win_length=window_size,
|
| 679 |
-
return_complex=True,
|
| 680 |
-
window=self.multi_stft_window_fn(window_size, device=device),
|
| 681 |
-
**self.multi_stft_kwargs,
|
| 682 |
-
)
|
| 683 |
-
|
| 684 |
-
recon_Y = torch.stft(
|
| 685 |
-
rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs
|
| 686 |
-
)
|
| 687 |
-
target_Y = torch.stft(
|
| 688 |
-
rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs
|
| 689 |
-
)
|
| 690 |
-
|
| 691 |
-
multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(
|
| 692 |
-
recon_Y, target_Y
|
| 693 |
-
)
|
| 694 |
-
|
| 695 |
-
weighted_multi_resolution_loss = (
|
| 696 |
-
multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
|
| 697 |
-
)
|
| 698 |
-
|
| 699 |
-
total_loss = loss + weighted_multi_resolution_loss
|
| 700 |
-
|
| 701 |
-
if not return_loss_breakdown:
|
| 702 |
-
return total_loss
|
| 703 |
-
|
| 704 |
-
return total_loss, (loss, multi_stft_resolution_loss)
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn, einsum, Tensor
|
| 5 |
+
from torch.nn import Module, ModuleList
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from neuralop.models import FNO1d
|
| 8 |
+
|
| 9 |
+
from .attend import Attend
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from .attend_sage import Attend as AttendSage
|
| 13 |
+
except:
|
| 14 |
+
pass
|
| 15 |
+
from torch.utils.checkpoint import checkpoint
|
| 16 |
+
|
| 17 |
+
from beartype.typing import Tuple, Optional, List, Callable
|
| 18 |
+
from beartype import beartype
|
| 19 |
+
|
| 20 |
+
from rotary_embedding_torch import RotaryEmbedding
|
| 21 |
+
|
| 22 |
+
from einops import rearrange, pack, unpack
|
| 23 |
+
from einops.layers.torch import Rearrange
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def exists(val):
|
| 27 |
+
return val is not None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def default(v, d):
|
| 31 |
+
return v if exists(v) else d
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def pack_one(t, pattern):
|
| 35 |
+
return pack([t], pattern)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def unpack_one(t, ps, pattern):
|
| 39 |
+
return unpack(t, ps, pattern)[0]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def l2norm(t):
|
| 43 |
+
return F.normalize(t, dim=-1, p=2)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class RMSNorm(Module):
|
| 47 |
+
def __init__(self, dim):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.scale = dim**0.5
|
| 50 |
+
self.gamma = nn.Parameter(torch.ones(dim))
|
| 51 |
+
|
| 52 |
+
def forward(self, x):
|
| 53 |
+
return F.normalize(x, dim=-1) * self.scale * self.gamma
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class FeedForward(Module):
|
| 57 |
+
def __init__(self, dim, mult=4, dropout=0.0):
|
| 58 |
+
super().__init__()
|
| 59 |
+
dim_inner = int(dim * mult)
|
| 60 |
+
self.net = nn.Sequential(
|
| 61 |
+
RMSNorm(dim),
|
| 62 |
+
nn.Linear(dim, dim_inner),
|
| 63 |
+
nn.GELU(),
|
| 64 |
+
nn.Dropout(dropout),
|
| 65 |
+
nn.Linear(dim_inner, dim),
|
| 66 |
+
nn.Dropout(dropout),
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
def forward(self, x):
|
| 70 |
+
return self.net(x)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class Attention(Module):
|
| 74 |
+
def __init__(
|
| 75 |
+
self,
|
| 76 |
+
dim,
|
| 77 |
+
heads=8,
|
| 78 |
+
dim_head=64,
|
| 79 |
+
dropout=0.0,
|
| 80 |
+
rotary_embed=None,
|
| 81 |
+
flash=True,
|
| 82 |
+
sage_attention=False,
|
| 83 |
+
):
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.heads = heads
|
| 86 |
+
self.scale = dim_head**-0.5
|
| 87 |
+
dim_inner = heads * dim_head
|
| 88 |
+
|
| 89 |
+
self.rotary_embed = rotary_embed
|
| 90 |
+
|
| 91 |
+
if sage_attention:
|
| 92 |
+
self.attend = AttendSage(flash=flash, dropout=dropout)
|
| 93 |
+
else:
|
| 94 |
+
self.attend = Attend(flash=flash, dropout=dropout)
|
| 95 |
+
|
| 96 |
+
self.norm = RMSNorm(dim)
|
| 97 |
+
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
|
| 98 |
+
|
| 99 |
+
self.to_gates = nn.Linear(dim, heads)
|
| 100 |
+
|
| 101 |
+
self.to_out = nn.Sequential(
|
| 102 |
+
nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout)
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
def forward(self, x):
|
| 106 |
+
x = self.norm(x)
|
| 107 |
+
|
| 108 |
+
q, k, v = rearrange(
|
| 109 |
+
self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
if exists(self.rotary_embed):
|
| 113 |
+
q = self.rotary_embed.rotate_queries_or_keys(q)
|
| 114 |
+
k = self.rotary_embed.rotate_queries_or_keys(k)
|
| 115 |
+
|
| 116 |
+
out = self.attend(q, k, v)
|
| 117 |
+
|
| 118 |
+
gates = self.to_gates(x)
|
| 119 |
+
out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
|
| 120 |
+
|
| 121 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
| 122 |
+
return self.to_out(out)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class LinearAttention(Module):
|
| 126 |
+
|
| 127 |
+
@beartype
|
| 128 |
+
def __init__(
|
| 129 |
+
self,
|
| 130 |
+
*,
|
| 131 |
+
dim,
|
| 132 |
+
dim_head=32,
|
| 133 |
+
heads=8,
|
| 134 |
+
scale=8,
|
| 135 |
+
flash=False,
|
| 136 |
+
dropout=0.0,
|
| 137 |
+
sage_attention=False,
|
| 138 |
+
):
|
| 139 |
+
super().__init__()
|
| 140 |
+
dim_inner = dim_head * heads
|
| 141 |
+
self.norm = RMSNorm(dim)
|
| 142 |
+
|
| 143 |
+
self.to_qkv = nn.Sequential(
|
| 144 |
+
nn.Linear(dim, dim_inner * 3, bias=False),
|
| 145 |
+
Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads),
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
|
| 149 |
+
|
| 150 |
+
if sage_attention:
|
| 151 |
+
self.attend = AttendSage(scale=scale, dropout=dropout, flash=flash)
|
| 152 |
+
else:
|
| 153 |
+
self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
|
| 154 |
+
|
| 155 |
+
self.to_out = nn.Sequential(
|
| 156 |
+
Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
def forward(self, x):
|
| 160 |
+
x = self.norm(x)
|
| 161 |
+
|
| 162 |
+
q, k, v = self.to_qkv(x)
|
| 163 |
+
|
| 164 |
+
q, k = map(l2norm, (q, k))
|
| 165 |
+
q = q * self.temperature.exp()
|
| 166 |
+
|
| 167 |
+
out = self.attend(q, k, v)
|
| 168 |
+
|
| 169 |
+
return self.to_out(out)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class Transformer(Module):
|
| 173 |
+
def __init__(
|
| 174 |
+
self,
|
| 175 |
+
*,
|
| 176 |
+
dim,
|
| 177 |
+
depth,
|
| 178 |
+
dim_head=64,
|
| 179 |
+
heads=8,
|
| 180 |
+
attn_dropout=0.0,
|
| 181 |
+
ff_dropout=0.0,
|
| 182 |
+
ff_mult=4,
|
| 183 |
+
norm_output=True,
|
| 184 |
+
rotary_embed=None,
|
| 185 |
+
flash_attn=True,
|
| 186 |
+
linear_attn=False,
|
| 187 |
+
sage_attention=False,
|
| 188 |
+
):
|
| 189 |
+
super().__init__()
|
| 190 |
+
self.layers = ModuleList([])
|
| 191 |
+
|
| 192 |
+
for _ in range(depth):
|
| 193 |
+
if linear_attn:
|
| 194 |
+
attn = LinearAttention(
|
| 195 |
+
dim=dim,
|
| 196 |
+
dim_head=dim_head,
|
| 197 |
+
heads=heads,
|
| 198 |
+
dropout=attn_dropout,
|
| 199 |
+
flash=flash_attn,
|
| 200 |
+
sage_attention=sage_attention,
|
| 201 |
+
)
|
| 202 |
+
else:
|
| 203 |
+
attn = Attention(
|
| 204 |
+
dim=dim,
|
| 205 |
+
dim_head=dim_head,
|
| 206 |
+
heads=heads,
|
| 207 |
+
dropout=attn_dropout,
|
| 208 |
+
rotary_embed=rotary_embed,
|
| 209 |
+
flash=flash_attn,
|
| 210 |
+
sage_attention=sage_attention,
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
self.layers.append(
|
| 214 |
+
ModuleList(
|
| 215 |
+
[attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]
|
| 216 |
+
)
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
self.norm = RMSNorm(dim) if norm_output else nn.Identity()
|
| 220 |
+
|
| 221 |
+
def forward(self, x):
|
| 222 |
+
|
| 223 |
+
for attn, ff in self.layers:
|
| 224 |
+
x = attn(x) + x
|
| 225 |
+
x = ff(x) + x
|
| 226 |
+
|
| 227 |
+
return self.norm(x)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class BandSplit(Module):
|
| 231 |
+
@beartype
|
| 232 |
+
def __init__(self, dim, dim_inputs: Tuple[int, ...]):
|
| 233 |
+
super().__init__()
|
| 234 |
+
self.dim_inputs = dim_inputs
|
| 235 |
+
self.to_features = ModuleList([])
|
| 236 |
+
|
| 237 |
+
for dim_in in dim_inputs:
|
| 238 |
+
net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
|
| 239 |
+
|
| 240 |
+
self.to_features.append(net)
|
| 241 |
+
|
| 242 |
+
def forward(self, x):
|
| 243 |
+
x = x.split(self.dim_inputs, dim=-1)
|
| 244 |
+
|
| 245 |
+
outs = []
|
| 246 |
+
for split_input, to_feature in zip(x, self.to_features):
|
| 247 |
+
split_output = to_feature(split_input)
|
| 248 |
+
outs.append(split_output)
|
| 249 |
+
|
| 250 |
+
return torch.stack(outs, dim=-2)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
|
| 254 |
+
dim_hidden = default(dim_hidden, dim_in)
|
| 255 |
+
|
| 256 |
+
net = []
|
| 257 |
+
dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
|
| 258 |
+
|
| 259 |
+
for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
|
| 260 |
+
is_last = ind == (len(dims) - 2)
|
| 261 |
+
|
| 262 |
+
net.append(nn.Linear(layer_dim_in, layer_dim_out))
|
| 263 |
+
|
| 264 |
+
if is_last:
|
| 265 |
+
continue
|
| 266 |
+
|
| 267 |
+
net.append(activation())
|
| 268 |
+
|
| 269 |
+
return nn.Sequential(*net)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
class MaskEstimator(Module):
|
| 273 |
+
@beartype
|
| 274 |
+
def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
|
| 275 |
+
super().__init__()
|
| 276 |
+
self.dim_inputs = dim_inputs
|
| 277 |
+
self.to_freqs = ModuleList([])
|
| 278 |
+
dim_hidden = dim * mlp_expansion_factor
|
| 279 |
+
|
| 280 |
+
for dim_in in dim_inputs:
|
| 281 |
+
net = []
|
| 282 |
+
|
| 283 |
+
mlp = nn.Sequential(
|
| 284 |
+
FNO1d(
|
| 285 |
+
n_modes_height=64,
|
| 286 |
+
hidden_channels=dim,
|
| 287 |
+
in_channels=dim,
|
| 288 |
+
out_channels=dim_in * 2,
|
| 289 |
+
lifting_channels=dim,
|
| 290 |
+
projection_channels=dim,
|
| 291 |
+
n_layers=3,
|
| 292 |
+
separable=True,
|
| 293 |
+
),
|
| 294 |
+
nn.GLU(dim=-2),
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
self.to_freqs.append(mlp)
|
| 298 |
+
|
| 299 |
+
def forward(self, x):
|
| 300 |
+
x = x.unbind(dim=-2)
|
| 301 |
+
|
| 302 |
+
outs = []
|
| 303 |
+
|
| 304 |
+
for band_features, mlp in zip(x, self.to_freqs):
|
| 305 |
+
band_features = rearrange(band_features, "b t c -> b c t")
|
| 306 |
+
with torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32):
|
| 307 |
+
freq_out = mlp(band_features).float()
|
| 308 |
+
freq_out = rearrange(freq_out, "b c t -> b t c")
|
| 309 |
+
outs.append(freq_out)
|
| 310 |
+
|
| 311 |
+
return torch.cat(outs, dim=-1)
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
DEFAULT_FREQS_PER_BANDS = (
|
| 315 |
+
2,
|
| 316 |
+
2,
|
| 317 |
+
2,
|
| 318 |
+
2,
|
| 319 |
+
2,
|
| 320 |
+
2,
|
| 321 |
+
2,
|
| 322 |
+
2,
|
| 323 |
+
2,
|
| 324 |
+
2,
|
| 325 |
+
2,
|
| 326 |
+
2,
|
| 327 |
+
2,
|
| 328 |
+
2,
|
| 329 |
+
2,
|
| 330 |
+
2,
|
| 331 |
+
2,
|
| 332 |
+
2,
|
| 333 |
+
2,
|
| 334 |
+
2,
|
| 335 |
+
2,
|
| 336 |
+
2,
|
| 337 |
+
2,
|
| 338 |
+
2,
|
| 339 |
+
4,
|
| 340 |
+
4,
|
| 341 |
+
4,
|
| 342 |
+
4,
|
| 343 |
+
4,
|
| 344 |
+
4,
|
| 345 |
+
4,
|
| 346 |
+
4,
|
| 347 |
+
4,
|
| 348 |
+
4,
|
| 349 |
+
4,
|
| 350 |
+
4,
|
| 351 |
+
12,
|
| 352 |
+
12,
|
| 353 |
+
12,
|
| 354 |
+
12,
|
| 355 |
+
12,
|
| 356 |
+
12,
|
| 357 |
+
12,
|
| 358 |
+
12,
|
| 359 |
+
24,
|
| 360 |
+
24,
|
| 361 |
+
24,
|
| 362 |
+
24,
|
| 363 |
+
24,
|
| 364 |
+
24,
|
| 365 |
+
24,
|
| 366 |
+
24,
|
| 367 |
+
48,
|
| 368 |
+
48,
|
| 369 |
+
48,
|
| 370 |
+
48,
|
| 371 |
+
48,
|
| 372 |
+
48,
|
| 373 |
+
48,
|
| 374 |
+
48,
|
| 375 |
+
128,
|
| 376 |
+
129,
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
class BSRoformer_FNO(Module):
|
| 381 |
+
|
| 382 |
+
@beartype
|
| 383 |
+
def __init__(
|
| 384 |
+
self,
|
| 385 |
+
dim,
|
| 386 |
+
*,
|
| 387 |
+
depth,
|
| 388 |
+
stereo=False,
|
| 389 |
+
num_stems=1,
|
| 390 |
+
time_transformer_depth=2,
|
| 391 |
+
freq_transformer_depth=2,
|
| 392 |
+
linear_transformer_depth=0,
|
| 393 |
+
freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
|
| 394 |
+
dim_head=64,
|
| 395 |
+
heads=8,
|
| 396 |
+
attn_dropout=0.0,
|
| 397 |
+
ff_dropout=0.0,
|
| 398 |
+
flash_attn=True,
|
| 399 |
+
dim_freqs_in=1025,
|
| 400 |
+
stft_n_fft=2048,
|
| 401 |
+
stft_hop_length=512,
|
| 402 |
+
stft_win_length=2048,
|
| 403 |
+
stft_normalized=False,
|
| 404 |
+
stft_window_fn: Optional[Callable] = None,
|
| 405 |
+
mask_estimator_depth=2,
|
| 406 |
+
multi_stft_resolution_loss_weight=1.0,
|
| 407 |
+
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (
|
| 408 |
+
4096,
|
| 409 |
+
2048,
|
| 410 |
+
1024,
|
| 411 |
+
512,
|
| 412 |
+
256,
|
| 413 |
+
),
|
| 414 |
+
multi_stft_hop_size=147,
|
| 415 |
+
multi_stft_normalized=False,
|
| 416 |
+
multi_stft_window_fn: Callable = torch.hann_window,
|
| 417 |
+
mlp_expansion_factor=4,
|
| 418 |
+
use_torch_checkpoint=False,
|
| 419 |
+
skip_connection=False,
|
| 420 |
+
sage_attention=False,
|
| 421 |
+
):
|
| 422 |
+
super().__init__()
|
| 423 |
+
|
| 424 |
+
self.stereo = stereo
|
| 425 |
+
self.audio_channels = 2 if stereo else 1
|
| 426 |
+
self.num_stems = num_stems
|
| 427 |
+
self.use_torch_checkpoint = use_torch_checkpoint
|
| 428 |
+
self.skip_connection = skip_connection
|
| 429 |
+
|
| 430 |
+
self.layers = ModuleList([])
|
| 431 |
+
|
| 432 |
+
if sage_attention:
|
| 433 |
+
print("Use Sage Attention")
|
| 434 |
+
|
| 435 |
+
transformer_kwargs = dict(
|
| 436 |
+
dim=dim,
|
| 437 |
+
heads=heads,
|
| 438 |
+
dim_head=dim_head,
|
| 439 |
+
attn_dropout=attn_dropout,
|
| 440 |
+
ff_dropout=ff_dropout,
|
| 441 |
+
flash_attn=flash_attn,
|
| 442 |
+
norm_output=False,
|
| 443 |
+
sage_attention=sage_attention,
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
time_rotary_embed = RotaryEmbedding(dim=dim_head)
|
| 447 |
+
freq_rotary_embed = RotaryEmbedding(dim=dim_head)
|
| 448 |
+
|
| 449 |
+
for _ in range(depth):
|
| 450 |
+
tran_modules = []
|
| 451 |
+
if linear_transformer_depth > 0:
|
| 452 |
+
tran_modules.append(
|
| 453 |
+
Transformer(
|
| 454 |
+
depth=linear_transformer_depth,
|
| 455 |
+
linear_attn=True,
|
| 456 |
+
**transformer_kwargs,
|
| 457 |
+
)
|
| 458 |
+
)
|
| 459 |
+
tran_modules.append(
|
| 460 |
+
Transformer(
|
| 461 |
+
depth=time_transformer_depth,
|
| 462 |
+
rotary_embed=time_rotary_embed,
|
| 463 |
+
**transformer_kwargs,
|
| 464 |
+
)
|
| 465 |
+
)
|
| 466 |
+
tran_modules.append(
|
| 467 |
+
Transformer(
|
| 468 |
+
depth=freq_transformer_depth,
|
| 469 |
+
rotary_embed=freq_rotary_embed,
|
| 470 |
+
**transformer_kwargs,
|
| 471 |
+
)
|
| 472 |
+
)
|
| 473 |
+
self.layers.append(nn.ModuleList(tran_modules))
|
| 474 |
+
|
| 475 |
+
self.final_norm = RMSNorm(dim)
|
| 476 |
+
|
| 477 |
+
self.stft_kwargs = dict(
|
| 478 |
+
n_fft=stft_n_fft,
|
| 479 |
+
hop_length=stft_hop_length,
|
| 480 |
+
win_length=stft_win_length,
|
| 481 |
+
normalized=stft_normalized,
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
self.stft_window_fn = partial(
|
| 485 |
+
default(stft_window_fn, torch.hann_window), stft_win_length
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
freqs = torch.stft(
|
| 489 |
+
torch.randn(1, 4096),
|
| 490 |
+
**self.stft_kwargs,
|
| 491 |
+
window=torch.ones(stft_win_length),
|
| 492 |
+
return_complex=True,
|
| 493 |
+
).shape[1]
|
| 494 |
+
|
| 495 |
+
assert len(freqs_per_bands) > 1
|
| 496 |
+
assert (
|
| 497 |
+
sum(freqs_per_bands) == freqs
|
| 498 |
+
), f"the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}"
|
| 499 |
+
|
| 500 |
+
freqs_per_bands_with_complex = tuple(
|
| 501 |
+
2 * f * self.audio_channels for f in freqs_per_bands
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
|
| 505 |
+
|
| 506 |
+
self.mask_estimators = nn.ModuleList([])
|
| 507 |
+
|
| 508 |
+
for _ in range(num_stems):
|
| 509 |
+
mask_estimator = MaskEstimator(
|
| 510 |
+
dim=dim,
|
| 511 |
+
dim_inputs=freqs_per_bands_with_complex,
|
| 512 |
+
depth=mask_estimator_depth,
|
| 513 |
+
mlp_expansion_factor=mlp_expansion_factor,
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
self.mask_estimators.append(mask_estimator)
|
| 517 |
+
|
| 518 |
+
self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
|
| 519 |
+
self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
|
| 520 |
+
self.multi_stft_n_fft = stft_n_fft
|
| 521 |
+
self.multi_stft_window_fn = multi_stft_window_fn
|
| 522 |
+
|
| 523 |
+
self.multi_stft_kwargs = dict(
|
| 524 |
+
hop_length=multi_stft_hop_size, normalized=multi_stft_normalized
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
def forward(self, raw_audio, target=None, return_loss_breakdown=False):
|
| 528 |
+
|
| 529 |
+
device = raw_audio.device
|
| 530 |
+
|
| 531 |
+
x_is_mps = True if device.type == "mps" else False
|
| 532 |
+
|
| 533 |
+
if raw_audio.ndim == 2:
|
| 534 |
+
raw_audio = rearrange(raw_audio, "b t -> b 1 t")
|
| 535 |
+
|
| 536 |
+
channels = raw_audio.shape[1]
|
| 537 |
+
assert (not self.stereo and channels == 1) or (
|
| 538 |
+
self.stereo and channels == 2
|
| 539 |
+
), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
|
| 540 |
+
|
| 541 |
+
raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
|
| 542 |
+
|
| 543 |
+
stft_window = self.stft_window_fn(device=device)
|
| 544 |
+
|
| 545 |
+
try:
|
| 546 |
+
stft_repr = torch.stft(
|
| 547 |
+
raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True
|
| 548 |
+
)
|
| 549 |
+
except:
|
| 550 |
+
stft_repr = torch.stft(
|
| 551 |
+
raw_audio.cpu() if x_is_mps else raw_audio,
|
| 552 |
+
**self.stft_kwargs,
|
| 553 |
+
window=stft_window.cpu() if x_is_mps else stft_window,
|
| 554 |
+
return_complex=True,
|
| 555 |
+
).to(device)
|
| 556 |
+
stft_repr = torch.view_as_real(stft_repr)
|
| 557 |
+
|
| 558 |
+
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
|
| 559 |
+
|
| 560 |
+
stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c")
|
| 561 |
+
|
| 562 |
+
x = rearrange(stft_repr, "b f t c -> b t (f c)")
|
| 563 |
+
|
| 564 |
+
if self.use_torch_checkpoint:
|
| 565 |
+
x = checkpoint(self.band_split, x, use_reentrant=False)
|
| 566 |
+
else:
|
| 567 |
+
x = self.band_split(x)
|
| 568 |
+
|
| 569 |
+
store = [None] * len(self.layers)
|
| 570 |
+
for i, transformer_block in enumerate(self.layers):
|
| 571 |
+
|
| 572 |
+
if len(transformer_block) == 3:
|
| 573 |
+
linear_transformer, time_transformer, freq_transformer = (
|
| 574 |
+
transformer_block
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
x, ft_ps = pack([x], "b * d")
|
| 578 |
+
if self.use_torch_checkpoint:
|
| 579 |
+
x = checkpoint(linear_transformer, x, use_reentrant=False)
|
| 580 |
+
else:
|
| 581 |
+
x = linear_transformer(x)
|
| 582 |
+
(x,) = unpack(x, ft_ps, "b * d")
|
| 583 |
+
else:
|
| 584 |
+
time_transformer, freq_transformer = transformer_block
|
| 585 |
+
|
| 586 |
+
if self.skip_connection:
|
| 587 |
+
for j in range(i):
|
| 588 |
+
x = x + store[j]
|
| 589 |
+
|
| 590 |
+
x = rearrange(x, "b t f d -> b f t d")
|
| 591 |
+
x, ps = pack([x], "* t d")
|
| 592 |
+
|
| 593 |
+
if self.use_torch_checkpoint:
|
| 594 |
+
x = checkpoint(time_transformer, x, use_reentrant=False)
|
| 595 |
+
else:
|
| 596 |
+
x = time_transformer(x)
|
| 597 |
+
|
| 598 |
+
(x,) = unpack(x, ps, "* t d")
|
| 599 |
+
x = rearrange(x, "b f t d -> b t f d")
|
| 600 |
+
x, ps = pack([x], "* f d")
|
| 601 |
+
|
| 602 |
+
if self.use_torch_checkpoint:
|
| 603 |
+
x = checkpoint(freq_transformer, x, use_reentrant=False)
|
| 604 |
+
else:
|
| 605 |
+
x = freq_transformer(x)
|
| 606 |
+
|
| 607 |
+
(x,) = unpack(x, ps, "* f d")
|
| 608 |
+
|
| 609 |
+
if self.skip_connection:
|
| 610 |
+
store[i] = x
|
| 611 |
+
|
| 612 |
+
x = self.final_norm(x)
|
| 613 |
+
|
| 614 |
+
num_stems = len(self.mask_estimators)
|
| 615 |
+
|
| 616 |
+
if self.use_torch_checkpoint:
|
| 617 |
+
mask = torch.stack(
|
| 618 |
+
[checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators],
|
| 619 |
+
dim=1,
|
| 620 |
+
)
|
| 621 |
+
else:
|
| 622 |
+
mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
|
| 623 |
+
mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2)
|
| 624 |
+
|
| 625 |
+
stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
|
| 626 |
+
|
| 627 |
+
stft_repr = torch.view_as_complex(stft_repr)
|
| 628 |
+
mask = torch.view_as_complex(mask)
|
| 629 |
+
|
| 630 |
+
stft_repr = stft_repr * mask
|
| 631 |
+
|
| 632 |
+
stft_repr = rearrange(
|
| 633 |
+
stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
try:
|
| 637 |
+
recon_audio = torch.istft(
|
| 638 |
+
stft_repr,
|
| 639 |
+
**self.stft_kwargs,
|
| 640 |
+
window=stft_window,
|
| 641 |
+
return_complex=False,
|
| 642 |
+
length=raw_audio.shape[-1],
|
| 643 |
+
)
|
| 644 |
+
except:
|
| 645 |
+
recon_audio = torch.istft(
|
| 646 |
+
stft_repr.cpu() if x_is_mps else stft_repr,
|
| 647 |
+
**self.stft_kwargs,
|
| 648 |
+
window=stft_window.cpu() if x_is_mps else stft_window,
|
| 649 |
+
return_complex=False,
|
| 650 |
+
length=raw_audio.shape[-1],
|
| 651 |
+
).to(device)
|
| 652 |
+
|
| 653 |
+
recon_audio = rearrange(
|
| 654 |
+
recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=num_stems
|
| 655 |
+
)
|
| 656 |
+
|
| 657 |
+
if num_stems == 1:
|
| 658 |
+
recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
|
| 659 |
+
|
| 660 |
+
if not exists(target):
|
| 661 |
+
return recon_audio
|
| 662 |
+
|
| 663 |
+
if self.num_stems > 1:
|
| 664 |
+
assert target.ndim == 4 and target.shape[1] == self.num_stems
|
| 665 |
+
|
| 666 |
+
if target.ndim == 2:
|
| 667 |
+
target = rearrange(target, "... t -> ... 1 t")
|
| 668 |
+
|
| 669 |
+
target = target[..., : recon_audio.shape[-1]]
|
| 670 |
+
|
| 671 |
+
loss = F.l1_loss(recon_audio, target)
|
| 672 |
+
|
| 673 |
+
multi_stft_resolution_loss = 0.0
|
| 674 |
+
|
| 675 |
+
for window_size in self.multi_stft_resolutions_window_sizes:
|
| 676 |
+
res_stft_kwargs = dict(
|
| 677 |
+
n_fft=max(window_size, self.multi_stft_n_fft),
|
| 678 |
+
win_length=window_size,
|
| 679 |
+
return_complex=True,
|
| 680 |
+
window=self.multi_stft_window_fn(window_size, device=device),
|
| 681 |
+
**self.multi_stft_kwargs,
|
| 682 |
+
)
|
| 683 |
+
|
| 684 |
+
recon_Y = torch.stft(
|
| 685 |
+
rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs
|
| 686 |
+
)
|
| 687 |
+
target_Y = torch.stft(
|
| 688 |
+
rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(
|
| 692 |
+
recon_Y, target_Y
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
weighted_multi_resolution_loss = (
|
| 696 |
+
multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
total_loss = loss + weighted_multi_resolution_loss
|
| 700 |
+
|
| 701 |
+
if not return_loss_breakdown:
|
| 702 |
+
return total_loss
|
| 703 |
+
|
| 704 |
+
return total_loss, (loss, multi_stft_resolution_loss)
|
mvsepless/models/bs_roformer/bs_roformer_hyperace.py
CHANGED
|
@@ -1,1122 +1,1122 @@
|
|
| 1 |
-
from functools import partial
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
from torch import nn, einsum, Tensor
|
| 5 |
-
from torch.nn import Module, ModuleList
|
| 6 |
-
import torch.nn.functional as F
|
| 7 |
-
|
| 8 |
-
from .attend import Attend
|
| 9 |
-
|
| 10 |
-
try:
|
| 11 |
-
from .attend_sage import Attend as AttendSage
|
| 12 |
-
except:
|
| 13 |
-
pass
|
| 14 |
-
from torch.utils.checkpoint import checkpoint
|
| 15 |
-
|
| 16 |
-
from beartype.typing import Tuple, Optional, List, Callable
|
| 17 |
-
from beartype import beartype
|
| 18 |
-
|
| 19 |
-
from rotary_embedding_torch import RotaryEmbedding
|
| 20 |
-
|
| 21 |
-
from einops import rearrange, pack, unpack
|
| 22 |
-
from einops.layers.torch import Rearrange
|
| 23 |
-
import torchaudio
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def exists(val):
|
| 27 |
-
return val is not None
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
def default(v, d):
|
| 31 |
-
return v if exists(v) else d
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def pack_one(t, pattern):
|
| 35 |
-
return pack([t], pattern)
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
def unpack_one(t, ps, pattern):
|
| 39 |
-
return unpack(t, ps, pattern)[0]
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
def l2norm(t):
|
| 43 |
-
return F.normalize(t, dim=-1, p=2)
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
class RMSNorm(Module):
|
| 47 |
-
def __init__(self, dim):
|
| 48 |
-
super().__init__()
|
| 49 |
-
self.scale = dim**0.5
|
| 50 |
-
self.gamma = nn.Parameter(torch.ones(dim))
|
| 51 |
-
|
| 52 |
-
def forward(self, x):
|
| 53 |
-
return F.normalize(x, dim=-1) * self.scale * self.gamma
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
class FeedForward(Module):
|
| 57 |
-
def __init__(self, dim, mult=4, dropout=0.0):
|
| 58 |
-
super().__init__()
|
| 59 |
-
dim_inner = int(dim * mult)
|
| 60 |
-
self.net = nn.Sequential(
|
| 61 |
-
RMSNorm(dim),
|
| 62 |
-
nn.Linear(dim, dim_inner),
|
| 63 |
-
nn.GELU(),
|
| 64 |
-
nn.Dropout(dropout),
|
| 65 |
-
nn.Linear(dim_inner, dim),
|
| 66 |
-
nn.Dropout(dropout),
|
| 67 |
-
)
|
| 68 |
-
|
| 69 |
-
def forward(self, x):
|
| 70 |
-
return self.net(x)
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
class Attention(Module):
|
| 74 |
-
def __init__(
|
| 75 |
-
self,
|
| 76 |
-
dim,
|
| 77 |
-
heads=8,
|
| 78 |
-
dim_head=64,
|
| 79 |
-
dropout=0.0,
|
| 80 |
-
rotary_embed=None,
|
| 81 |
-
flash=True,
|
| 82 |
-
sage_attention=False,
|
| 83 |
-
):
|
| 84 |
-
super().__init__()
|
| 85 |
-
self.heads = heads
|
| 86 |
-
self.scale = dim_head**-0.5
|
| 87 |
-
dim_inner = heads * dim_head
|
| 88 |
-
|
| 89 |
-
self.rotary_embed = rotary_embed
|
| 90 |
-
|
| 91 |
-
if sage_attention:
|
| 92 |
-
self.attend = AttendSage(flash=flash, dropout=dropout)
|
| 93 |
-
else:
|
| 94 |
-
self.attend = Attend(flash=flash, dropout=dropout)
|
| 95 |
-
|
| 96 |
-
self.norm = RMSNorm(dim)
|
| 97 |
-
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
|
| 98 |
-
|
| 99 |
-
self.to_gates = nn.Linear(dim, heads)
|
| 100 |
-
|
| 101 |
-
self.to_out = nn.Sequential(
|
| 102 |
-
nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout)
|
| 103 |
-
)
|
| 104 |
-
|
| 105 |
-
def forward(self, x):
|
| 106 |
-
x = self.norm(x)
|
| 107 |
-
|
| 108 |
-
q, k, v = rearrange(
|
| 109 |
-
self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads
|
| 110 |
-
)
|
| 111 |
-
|
| 112 |
-
if exists(self.rotary_embed):
|
| 113 |
-
q = self.rotary_embed.rotate_queries_or_keys(q)
|
| 114 |
-
k = self.rotary_embed.rotate_queries_or_keys(k)
|
| 115 |
-
|
| 116 |
-
out = self.attend(q, k, v)
|
| 117 |
-
|
| 118 |
-
gates = self.to_gates(x)
|
| 119 |
-
out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
|
| 120 |
-
|
| 121 |
-
out = rearrange(out, "b h n d -> b n (h d)")
|
| 122 |
-
return self.to_out(out)
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
class LinearAttention(Module):
|
| 126 |
-
|
| 127 |
-
@beartype
|
| 128 |
-
def __init__(
|
| 129 |
-
self,
|
| 130 |
-
*,
|
| 131 |
-
dim,
|
| 132 |
-
dim_head=32,
|
| 133 |
-
heads=8,
|
| 134 |
-
scale=8,
|
| 135 |
-
flash=True,
|
| 136 |
-
dropout=0.0,
|
| 137 |
-
sage_attention=False,
|
| 138 |
-
):
|
| 139 |
-
super().__init__()
|
| 140 |
-
dim_inner = dim_head * heads
|
| 141 |
-
self.norm = RMSNorm(dim)
|
| 142 |
-
|
| 143 |
-
self.to_qkv = nn.Sequential(
|
| 144 |
-
nn.Linear(dim, dim_inner * 3, bias=False),
|
| 145 |
-
Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads),
|
| 146 |
-
)
|
| 147 |
-
|
| 148 |
-
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
|
| 149 |
-
|
| 150 |
-
if sage_attention:
|
| 151 |
-
self.attend = AttendSage(scale=scale, dropout=dropout, flash=flash)
|
| 152 |
-
else:
|
| 153 |
-
self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
|
| 154 |
-
|
| 155 |
-
self.to_out = nn.Sequential(
|
| 156 |
-
Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)
|
| 157 |
-
)
|
| 158 |
-
|
| 159 |
-
def forward(self, x):
|
| 160 |
-
x = self.norm(x)
|
| 161 |
-
|
| 162 |
-
q, k, v = self.to_qkv(x)
|
| 163 |
-
|
| 164 |
-
q, k = map(l2norm, (q, k))
|
| 165 |
-
q = q * self.temperature.exp()
|
| 166 |
-
|
| 167 |
-
out = self.attend(q, k, v)
|
| 168 |
-
|
| 169 |
-
return self.to_out(out)
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
class Transformer(Module):
|
| 173 |
-
def __init__(
|
| 174 |
-
self,
|
| 175 |
-
*,
|
| 176 |
-
dim,
|
| 177 |
-
depth,
|
| 178 |
-
dim_head=64,
|
| 179 |
-
heads=8,
|
| 180 |
-
attn_dropout=0.0,
|
| 181 |
-
ff_dropout=0.0,
|
| 182 |
-
ff_mult=4,
|
| 183 |
-
norm_output=True,
|
| 184 |
-
rotary_embed=None,
|
| 185 |
-
flash_attn=True,
|
| 186 |
-
linear_attn=False,
|
| 187 |
-
sage_attention=False,
|
| 188 |
-
):
|
| 189 |
-
super().__init__()
|
| 190 |
-
self.layers = ModuleList([])
|
| 191 |
-
|
| 192 |
-
for _ in range(depth):
|
| 193 |
-
if linear_attn:
|
| 194 |
-
attn = LinearAttention(
|
| 195 |
-
dim=dim,
|
| 196 |
-
dim_head=dim_head,
|
| 197 |
-
heads=heads,
|
| 198 |
-
dropout=attn_dropout,
|
| 199 |
-
flash=flash_attn,
|
| 200 |
-
sage_attention=sage_attention,
|
| 201 |
-
)
|
| 202 |
-
else:
|
| 203 |
-
attn = Attention(
|
| 204 |
-
dim=dim,
|
| 205 |
-
dim_head=dim_head,
|
| 206 |
-
heads=heads,
|
| 207 |
-
dropout=attn_dropout,
|
| 208 |
-
rotary_embed=rotary_embed,
|
| 209 |
-
flash=flash_attn,
|
| 210 |
-
sage_attention=sage_attention,
|
| 211 |
-
)
|
| 212 |
-
|
| 213 |
-
self.layers.append(
|
| 214 |
-
ModuleList(
|
| 215 |
-
[attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]
|
| 216 |
-
)
|
| 217 |
-
)
|
| 218 |
-
|
| 219 |
-
self.norm = RMSNorm(dim) if norm_output else nn.Identity()
|
| 220 |
-
|
| 221 |
-
def forward(self, x):
|
| 222 |
-
|
| 223 |
-
for attn, ff in self.layers:
|
| 224 |
-
x = attn(x) + x
|
| 225 |
-
x = ff(x) + x
|
| 226 |
-
|
| 227 |
-
return self.norm(x)
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
class BandSplit(Module):
|
| 231 |
-
@beartype
|
| 232 |
-
def __init__(self, dim, dim_inputs: Tuple[int, ...]):
|
| 233 |
-
super().__init__()
|
| 234 |
-
self.dim_inputs = dim_inputs
|
| 235 |
-
self.to_features = ModuleList([])
|
| 236 |
-
|
| 237 |
-
for dim_in in dim_inputs:
|
| 238 |
-
net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
|
| 239 |
-
|
| 240 |
-
self.to_features.append(net)
|
| 241 |
-
|
| 242 |
-
def forward(self, x):
|
| 243 |
-
|
| 244 |
-
x = x.split(self.dim_inputs, dim=-1)
|
| 245 |
-
|
| 246 |
-
outs = []
|
| 247 |
-
for split_input, to_feature in zip(x, self.to_features):
|
| 248 |
-
split_output = to_feature(split_input)
|
| 249 |
-
outs.append(split_output)
|
| 250 |
-
|
| 251 |
-
x = torch.stack(outs, dim=-2)
|
| 252 |
-
|
| 253 |
-
return x
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
class Conv(nn.Module):
|
| 257 |
-
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):
|
| 258 |
-
super().__init__()
|
| 259 |
-
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
|
| 260 |
-
self.bn = nn.InstanceNorm2d(c2, affine=True, eps=1e-8)
|
| 261 |
-
self.act = nn.SiLU() if act else nn.Identity()
|
| 262 |
-
|
| 263 |
-
def forward(self, x):
|
| 264 |
-
return self.act(self.bn(self.conv(x)))
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
def autopad(k, p=None):
|
| 268 |
-
if p is None:
|
| 269 |
-
p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
|
| 270 |
-
return p
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
class DSConv(nn.Module):
|
| 274 |
-
def __init__(self, c1, c2, k=3, s=1, p=None, act=True):
|
| 275 |
-
super().__init__()
|
| 276 |
-
self.dwconv = nn.Conv2d(c1, c1, k, s, autopad(k, p), groups=c1, bias=False)
|
| 277 |
-
self.pwconv = nn.Conv2d(c1, c2, 1, 1, 0, bias=False)
|
| 278 |
-
self.bn = nn.InstanceNorm2d(c2, affine=True, eps=1e-8)
|
| 279 |
-
self.act = nn.SiLU() if act else nn.Identity()
|
| 280 |
-
|
| 281 |
-
def forward(self, x):
|
| 282 |
-
return self.act(self.bn(self.pwconv(self.dwconv(x))))
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
class DS_Bottleneck(nn.Module):
|
| 286 |
-
def __init__(self, c1, c2, k=3, shortcut=True):
|
| 287 |
-
super().__init__()
|
| 288 |
-
c_ = c1
|
| 289 |
-
self.dsconv1 = DSConv(c1, c_, k=3, s=1)
|
| 290 |
-
self.dsconv2 = DSConv(c_, c2, k=k, s=1)
|
| 291 |
-
self.shortcut = shortcut and c1 == c2
|
| 292 |
-
|
| 293 |
-
def forward(self, x):
|
| 294 |
-
return (
|
| 295 |
-
x + self.dsconv2(self.dsconv1(x))
|
| 296 |
-
if self.shortcut
|
| 297 |
-
else self.dsconv2(self.dsconv1(x))
|
| 298 |
-
)
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
class DS_C3k(nn.Module):
|
| 302 |
-
def __init__(self, c1, c2, n=1, k=3, e=0.5):
|
| 303 |
-
super().__init__()
|
| 304 |
-
c_ = int(c2 * e)
|
| 305 |
-
self.cv1 = Conv(c1, c_, 1, 1)
|
| 306 |
-
self.cv2 = Conv(c1, c_, 1, 1)
|
| 307 |
-
self.cv3 = Conv(2 * c_, c2, 1, 1)
|
| 308 |
-
self.m = nn.Sequential(
|
| 309 |
-
*[DS_Bottleneck(c_, c_, k=k, shortcut=True) for _ in range(n)]
|
| 310 |
-
)
|
| 311 |
-
|
| 312 |
-
def forward(self, x):
|
| 313 |
-
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
class DS_C3k2(nn.Module):
|
| 317 |
-
def __init__(self, c1, c2, n=1, k=3, e=0.5):
|
| 318 |
-
super().__init__()
|
| 319 |
-
c_ = int(c2 * e)
|
| 320 |
-
self.cv1 = Conv(c1, c_, 1, 1)
|
| 321 |
-
self.m = DS_C3k(c_, c_, n=n, k=k, e=1.0)
|
| 322 |
-
self.cv2 = Conv(c_, c2, 1, 1)
|
| 323 |
-
|
| 324 |
-
def forward(self, x):
|
| 325 |
-
x_ = self.cv1(x)
|
| 326 |
-
x_ = self.m(x_)
|
| 327 |
-
return self.cv2(x_)
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
class AdaptiveHyperedgeGeneration(nn.Module):
|
| 331 |
-
def __init__(self, in_channels, num_hyperedges, num_heads=8):
|
| 332 |
-
super().__init__()
|
| 333 |
-
self.num_hyperedges = num_hyperedges
|
| 334 |
-
self.num_heads = num_heads
|
| 335 |
-
self.head_dim = in_channels // num_heads
|
| 336 |
-
|
| 337 |
-
self.global_proto = nn.Parameter(torch.randn(num_hyperedges, in_channels))
|
| 338 |
-
|
| 339 |
-
self.context_mapper = nn.Linear(
|
| 340 |
-
2 * in_channels, num_hyperedges * in_channels, bias=False
|
| 341 |
-
)
|
| 342 |
-
|
| 343 |
-
self.query_proj = nn.Linear(in_channels, in_channels, bias=False)
|
| 344 |
-
|
| 345 |
-
self.scale = self.head_dim**-0.5
|
| 346 |
-
|
| 347 |
-
def forward(self, x):
|
| 348 |
-
B, N, C = x.shape
|
| 349 |
-
|
| 350 |
-
f_avg = F.adaptive_avg_pool1d(x.permute(0, 2, 1), 1).squeeze(-1)
|
| 351 |
-
f_max = F.adaptive_max_pool1d(x.permute(0, 2, 1), 1).squeeze(-1)
|
| 352 |
-
f_ctx = torch.cat((f_avg, f_max), dim=1)
|
| 353 |
-
|
| 354 |
-
delta_P = self.context_mapper(f_ctx).view(B, self.num_hyperedges, C)
|
| 355 |
-
P = self.global_proto.unsqueeze(0) + delta_P
|
| 356 |
-
|
| 357 |
-
z = self.query_proj(x)
|
| 358 |
-
|
| 359 |
-
z = z.view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
| 360 |
-
|
| 361 |
-
P = P.view(B, self.num_hyperedges, self.num_heads, self.head_dim).permute(
|
| 362 |
-
0, 2, 3, 1
|
| 363 |
-
)
|
| 364 |
-
|
| 365 |
-
sim = (z @ P) * self.scale
|
| 366 |
-
|
| 367 |
-
s_bar = sim.mean(dim=1)
|
| 368 |
-
|
| 369 |
-
A = F.softmax(s_bar.permute(0, 2, 1), dim=-1)
|
| 370 |
-
|
| 371 |
-
return A
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
class HypergraphConvolution(nn.Module):
|
| 375 |
-
def __init__(self, in_channels, out_channels):
|
| 376 |
-
super().__init__()
|
| 377 |
-
self.W_e = nn.Linear(in_channels, in_channels, bias=False)
|
| 378 |
-
self.W_v = nn.Linear(in_channels, out_channels, bias=False)
|
| 379 |
-
self.act = nn.SiLU()
|
| 380 |
-
|
| 381 |
-
def forward(self, x, A):
|
| 382 |
-
f_m = torch.bmm(A, x)
|
| 383 |
-
f_m = self.act(self.W_e(f_m))
|
| 384 |
-
|
| 385 |
-
x_out = torch.bmm(A.transpose(1, 2), f_m)
|
| 386 |
-
x_out = self.act(self.W_v(x_out))
|
| 387 |
-
|
| 388 |
-
return x + x_out
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
class AdaptiveHypergraphComputation(nn.Module):
|
| 392 |
-
def __init__(self, in_channels, out_channels, num_hyperedges=8, num_heads=8):
|
| 393 |
-
super().__init__()
|
| 394 |
-
self.adaptive_hyperedge_gen = AdaptiveHyperedgeGeneration(
|
| 395 |
-
in_channels, num_hyperedges, num_heads
|
| 396 |
-
)
|
| 397 |
-
self.hypergraph_conv = HypergraphConvolution(in_channels, out_channels)
|
| 398 |
-
|
| 399 |
-
def forward(self, x):
|
| 400 |
-
B, C, H, W = x.shape
|
| 401 |
-
x_flat = x.flatten(2).permute(0, 2, 1)
|
| 402 |
-
|
| 403 |
-
A = self.adaptive_hyperedge_gen(x_flat)
|
| 404 |
-
|
| 405 |
-
x_out_flat = self.hypergraph_conv(x_flat, A)
|
| 406 |
-
|
| 407 |
-
x_out = x_out_flat.permute(0, 2, 1).view(B, -1, H, W)
|
| 408 |
-
return x_out
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
class C3AH(nn.Module):
|
| 412 |
-
def __init__(self, c1, c2, num_hyperedges=8, num_heads=8, e=0.5):
|
| 413 |
-
super().__init__()
|
| 414 |
-
c_ = int(c1 * e)
|
| 415 |
-
self.cv1 = Conv(c1, c_, 1, 1)
|
| 416 |
-
self.cv2 = Conv(c1, c_, 1, 1)
|
| 417 |
-
self.ahc = AdaptiveHypergraphComputation(c_, c_, num_hyperedges, num_heads)
|
| 418 |
-
self.cv3 = Conv(2 * c_, c2, 1, 1)
|
| 419 |
-
|
| 420 |
-
def forward(self, x):
|
| 421 |
-
x_lateral = self.cv1(x)
|
| 422 |
-
x_ahc = self.ahc(self.cv2(x))
|
| 423 |
-
return self.cv3(torch.cat((x_ahc, x_lateral), dim=1))
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
class HyperACE(nn.Module):
|
| 427 |
-
def __init__(
|
| 428 |
-
self,
|
| 429 |
-
in_channels: List[int],
|
| 430 |
-
out_channels: int,
|
| 431 |
-
num_hyperedges=8,
|
| 432 |
-
num_heads=8,
|
| 433 |
-
k=2,
|
| 434 |
-
l=1,
|
| 435 |
-
c_h=0.5,
|
| 436 |
-
c_l=0.25,
|
| 437 |
-
):
|
| 438 |
-
super().__init__()
|
| 439 |
-
|
| 440 |
-
c2, c3, c4, c5 = in_channels
|
| 441 |
-
c_mid = c4
|
| 442 |
-
|
| 443 |
-
self.fuse_conv = Conv(c2 + c3 + c4 + c5, c_mid, 1, 1)
|
| 444 |
-
|
| 445 |
-
self.c_h = int(c_mid * c_h)
|
| 446 |
-
self.c_l = int(c_mid * c_l)
|
| 447 |
-
self.c_s = c_mid - self.c_h - self.c_l
|
| 448 |
-
assert self.c_s > 0, "Channel split error"
|
| 449 |
-
|
| 450 |
-
self.high_order_branch = nn.ModuleList(
|
| 451 |
-
[
|
| 452 |
-
C3AH(self.c_h, self.c_h, num_hyperedges, num_heads, e=1.0)
|
| 453 |
-
for _ in range(k)
|
| 454 |
-
]
|
| 455 |
-
)
|
| 456 |
-
self.high_order_fuse = Conv(self.c_h * k, self.c_h, 1, 1)
|
| 457 |
-
|
| 458 |
-
self.low_order_branch = nn.Sequential(
|
| 459 |
-
*[DS_C3k(self.c_l, self.c_l, n=1, k=3, e=1.0) for _ in range(l)]
|
| 460 |
-
)
|
| 461 |
-
|
| 462 |
-
self.final_fuse = Conv(self.c_h + self.c_l + self.c_s, out_channels, 1, 1)
|
| 463 |
-
|
| 464 |
-
def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
|
| 465 |
-
B2, B3, B4, B5 = x
|
| 466 |
-
|
| 467 |
-
B, _, H4, W4 = B4.shape
|
| 468 |
-
|
| 469 |
-
B2_resized = F.interpolate(
|
| 470 |
-
B2, size=(H4, W4), mode="bilinear", align_corners=False
|
| 471 |
-
)
|
| 472 |
-
B3_resized = F.interpolate(
|
| 473 |
-
B3, size=(H4, W4), mode="bilinear", align_corners=False
|
| 474 |
-
)
|
| 475 |
-
B5_resized = F.interpolate(
|
| 476 |
-
B5, size=(H4, W4), mode="bilinear", align_corners=False
|
| 477 |
-
)
|
| 478 |
-
|
| 479 |
-
x_b = self.fuse_conv(torch.cat((B2_resized, B3_resized, B4, B5_resized), dim=1))
|
| 480 |
-
|
| 481 |
-
x_h, x_l, x_s = torch.split(x_b, [self.c_h, self.c_l, self.c_s], dim=1)
|
| 482 |
-
|
| 483 |
-
x_h_outs = [m(x_h) for m in self.high_order_branch]
|
| 484 |
-
x_h_fused = self.high_order_fuse(torch.cat(x_h_outs, dim=1))
|
| 485 |
-
|
| 486 |
-
x_l_out = self.low_order_branch(x_l)
|
| 487 |
-
|
| 488 |
-
y = self.final_fuse(torch.cat((x_h_fused, x_l_out, x_s), dim=1))
|
| 489 |
-
|
| 490 |
-
return y
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
class GatedFusion(nn.Module):
|
| 494 |
-
def __init__(self, in_channels):
|
| 495 |
-
super().__init__()
|
| 496 |
-
self.gamma = nn.Parameter(torch.zeros(1, in_channels, 1, 1))
|
| 497 |
-
|
| 498 |
-
def forward(self, f_in, h):
|
| 499 |
-
if f_in.shape[1] != h.shape[1]:
|
| 500 |
-
raise ValueError(f"Channel mismatch: f_in={f_in.shape}, h={h.shape}")
|
| 501 |
-
return f_in + self.gamma * h
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
class Backbone(nn.Module):
|
| 505 |
-
def __init__(self, in_channels=256, base_channels=64, base_depth=3):
|
| 506 |
-
super().__init__()
|
| 507 |
-
c = base_channels
|
| 508 |
-
c2 = base_channels
|
| 509 |
-
c3 = 256
|
| 510 |
-
c4 = 384
|
| 511 |
-
c5 = 512
|
| 512 |
-
c6 = 768
|
| 513 |
-
|
| 514 |
-
self.stem = DSConv(in_channels, c2, k=3, s=(2, 1), p=1)
|
| 515 |
-
|
| 516 |
-
self.p2 = nn.Sequential(
|
| 517 |
-
DSConv(c2, c3, k=3, s=(2, 1), p=1), DS_C3k2(c3, c3, n=base_depth)
|
| 518 |
-
)
|
| 519 |
-
|
| 520 |
-
self.p3 = nn.Sequential(
|
| 521 |
-
DSConv(c3, c4, k=3, s=(2, 1), p=1), DS_C3k2(c4, c4, n=base_depth * 2)
|
| 522 |
-
)
|
| 523 |
-
|
| 524 |
-
self.p4 = nn.Sequential(
|
| 525 |
-
DSConv(c4, c5, k=3, s=(2, 1), p=1), DS_C3k2(c5, c5, n=base_depth * 2)
|
| 526 |
-
)
|
| 527 |
-
|
| 528 |
-
self.p5 = nn.Sequential(
|
| 529 |
-
DSConv(c5, c6, k=3, s=(2, 1), p=1), DS_C3k2(c6, c6, n=base_depth)
|
| 530 |
-
)
|
| 531 |
-
|
| 532 |
-
self.out_channels = [c3, c4, c5, c6]
|
| 533 |
-
|
| 534 |
-
def forward(self, x):
|
| 535 |
-
x = self.stem(x)
|
| 536 |
-
x2 = self.p2(x)
|
| 537 |
-
x3 = self.p3(x2)
|
| 538 |
-
x4 = self.p4(x3)
|
| 539 |
-
x5 = self.p5(x4)
|
| 540 |
-
return [x2, x3, x4, x5]
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
class Decoder(nn.Module):
|
| 544 |
-
def __init__(
|
| 545 |
-
self,
|
| 546 |
-
encoder_channels: List[int],
|
| 547 |
-
hyperace_out_c: int,
|
| 548 |
-
decoder_channels: List[int],
|
| 549 |
-
):
|
| 550 |
-
super().__init__()
|
| 551 |
-
c_p2, c_p3, c_p4, c_p5 = encoder_channels
|
| 552 |
-
c_d2, c_d3, c_d4, c_d5 = decoder_channels
|
| 553 |
-
|
| 554 |
-
self.h_to_d5 = Conv(hyperace_out_c, c_d5, 1, 1)
|
| 555 |
-
self.h_to_d4 = Conv(hyperace_out_c, c_d4, 1, 1)
|
| 556 |
-
self.h_to_d3 = Conv(hyperace_out_c, c_d3, 1, 1)
|
| 557 |
-
self.h_to_d2 = Conv(hyperace_out_c, c_d2, 1, 1)
|
| 558 |
-
|
| 559 |
-
self.fusion_d5 = GatedFusion(c_d5)
|
| 560 |
-
self.fusion_d4 = GatedFusion(c_d4)
|
| 561 |
-
self.fusion_d3 = GatedFusion(c_d3)
|
| 562 |
-
self.fusion_d2 = GatedFusion(c_d2)
|
| 563 |
-
|
| 564 |
-
self.skip_p5 = Conv(c_p5, c_d5, 1, 1)
|
| 565 |
-
self.skip_p4 = Conv(c_p4, c_d4, 1, 1)
|
| 566 |
-
self.skip_p3 = Conv(c_p3, c_d3, 1, 1)
|
| 567 |
-
self.skip_p2 = Conv(c_p2, c_d2, 1, 1)
|
| 568 |
-
|
| 569 |
-
self.up_d5 = DS_C3k2(c_d5, c_d4, n=1)
|
| 570 |
-
self.up_d4 = DS_C3k2(c_d4, c_d3, n=1)
|
| 571 |
-
self.up_d3 = DS_C3k2(c_d3, c_d2, n=1)
|
| 572 |
-
|
| 573 |
-
self.final_d2 = DS_C3k2(c_d2, c_d2, n=1)
|
| 574 |
-
|
| 575 |
-
def forward(self, enc_feats: List[torch.Tensor], h_ace: torch.Tensor):
|
| 576 |
-
p2, p3, p4, p5 = enc_feats
|
| 577 |
-
|
| 578 |
-
d5 = self.skip_p5(p5)
|
| 579 |
-
h_d5 = self.h_to_d5(F.interpolate(h_ace, size=d5.shape[2:], mode="bilinear"))
|
| 580 |
-
d5 = self.fusion_d5(d5, h_d5)
|
| 581 |
-
|
| 582 |
-
d5_up = F.interpolate(d5, size=p4.shape[2:], mode="bilinear")
|
| 583 |
-
d4_skip = self.skip_p4(p4)
|
| 584 |
-
d4 = self.up_d5(d5_up) + d4_skip
|
| 585 |
-
|
| 586 |
-
h_d4 = self.h_to_d4(F.interpolate(h_ace, size=d4.shape[2:], mode="bilinear"))
|
| 587 |
-
d4 = self.fusion_d4(d4, h_d4)
|
| 588 |
-
|
| 589 |
-
d4_up = F.interpolate(d4, size=p3.shape[2:], mode="bilinear")
|
| 590 |
-
d3_skip = self.skip_p3(p3)
|
| 591 |
-
d3 = self.up_d4(d4_up) + d3_skip
|
| 592 |
-
|
| 593 |
-
h_d3 = self.h_to_d3(F.interpolate(h_ace, size=d3.shape[2:], mode="bilinear"))
|
| 594 |
-
d3 = self.fusion_d3(d3, h_d3)
|
| 595 |
-
|
| 596 |
-
d3_up = F.interpolate(d3, size=p2.shape[2:], mode="bilinear")
|
| 597 |
-
d2_skip = self.skip_p2(p2)
|
| 598 |
-
d2 = self.up_d3(d3_up) + d2_skip
|
| 599 |
-
|
| 600 |
-
h_d2 = self.h_to_d2(F.interpolate(h_ace, size=d2.shape[2:], mode="bilinear"))
|
| 601 |
-
d2 = self.fusion_d2(d2, h_d2)
|
| 602 |
-
|
| 603 |
-
d2_final = self.final_d2(d2)
|
| 604 |
-
|
| 605 |
-
return d2_final
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
class FreqPixelShuffle(nn.Module):
|
| 609 |
-
def __init__(self, in_channels, out_channels, scale=2):
|
| 610 |
-
super().__init__()
|
| 611 |
-
self.scale = scale
|
| 612 |
-
self.conv = DSConv(in_channels, out_channels * scale, k=3, s=1, p=1)
|
| 613 |
-
self.act = nn.SiLU()
|
| 614 |
-
|
| 615 |
-
def forward(self, x):
|
| 616 |
-
x = self.conv(x)
|
| 617 |
-
B, C_r, H, W = x.shape
|
| 618 |
-
out_c = C_r // self.scale
|
| 619 |
-
|
| 620 |
-
x = x.view(B, out_c, self.scale, H, W)
|
| 621 |
-
|
| 622 |
-
x = x.permute(0, 1, 3, 4, 2).contiguous()
|
| 623 |
-
x = x.view(B, out_c, H, W * self.scale)
|
| 624 |
-
|
| 625 |
-
return x
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
class ProgressiveUpsampleHead(nn.Module):
|
| 629 |
-
def __init__(self, in_channels, out_channels, target_bins=1025):
|
| 630 |
-
super().__init__()
|
| 631 |
-
self.target_bins = target_bins
|
| 632 |
-
|
| 633 |
-
c = in_channels
|
| 634 |
-
|
| 635 |
-
self.block1 = FreqPixelShuffle(c, c, scale=2)
|
| 636 |
-
self.block2 = FreqPixelShuffle(c, c // 2, scale=2)
|
| 637 |
-
self.block3 = FreqPixelShuffle(c // 2, c // 2, scale=2)
|
| 638 |
-
self.block4 = FreqPixelShuffle(c // 2, c // 4, scale=2)
|
| 639 |
-
|
| 640 |
-
self.final_conv = nn.Conv2d(c // 4, out_channels, kernel_size=1, bias=False)
|
| 641 |
-
|
| 642 |
-
def forward(self, x):
|
| 643 |
-
|
| 644 |
-
x = self.block1(x)
|
| 645 |
-
x = self.block2(x)
|
| 646 |
-
x = self.block3(x)
|
| 647 |
-
x = self.block4(x)
|
| 648 |
-
|
| 649 |
-
if x.shape[-1] != self.target_bins:
|
| 650 |
-
x = F.interpolate(
|
| 651 |
-
x,
|
| 652 |
-
size=(x.shape[2], self.target_bins),
|
| 653 |
-
mode="bilinear",
|
| 654 |
-
align_corners=False,
|
| 655 |
-
)
|
| 656 |
-
|
| 657 |
-
x = self.final_conv(x)
|
| 658 |
-
return x
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
class SegmModel(nn.Module):
|
| 662 |
-
def __init__(
|
| 663 |
-
self,
|
| 664 |
-
in_bands=62,
|
| 665 |
-
in_dim=256,
|
| 666 |
-
out_bins=1025,
|
| 667 |
-
out_channels=4,
|
| 668 |
-
base_channels=64,
|
| 669 |
-
base_depth=2,
|
| 670 |
-
num_hyperedges=16,
|
| 671 |
-
num_heads=8,
|
| 672 |
-
):
|
| 673 |
-
super().__init__()
|
| 674 |
-
|
| 675 |
-
self.backbone = Backbone(
|
| 676 |
-
in_channels=in_dim, base_channels=base_channels, base_depth=base_depth
|
| 677 |
-
)
|
| 678 |
-
enc_channels = self.backbone.out_channels
|
| 679 |
-
c2, c3, c4, c5 = enc_channels
|
| 680 |
-
|
| 681 |
-
hyperace_in_channels = enc_channels
|
| 682 |
-
hyperace_out_channels = c4
|
| 683 |
-
self.hyperace = HyperACE(
|
| 684 |
-
hyperace_in_channels,
|
| 685 |
-
hyperace_out_channels,
|
| 686 |
-
num_hyperedges,
|
| 687 |
-
num_heads,
|
| 688 |
-
k=3,
|
| 689 |
-
l=2,
|
| 690 |
-
)
|
| 691 |
-
|
| 692 |
-
decoder_channels = [c2, c3, c4, c5]
|
| 693 |
-
self.decoder = Decoder(enc_channels, hyperace_out_channels, decoder_channels)
|
| 694 |
-
|
| 695 |
-
self.upsample_head = ProgressiveUpsampleHead(
|
| 696 |
-
in_channels=decoder_channels[0],
|
| 697 |
-
out_channels=out_channels,
|
| 698 |
-
target_bins=out_bins,
|
| 699 |
-
)
|
| 700 |
-
|
| 701 |
-
def forward(self, x):
|
| 702 |
-
H, W = x.shape[2:]
|
| 703 |
-
|
| 704 |
-
enc_feats = self.backbone(x)
|
| 705 |
-
|
| 706 |
-
h_ace_feats = self.hyperace(enc_feats)
|
| 707 |
-
|
| 708 |
-
dec_feat = self.decoder(enc_feats, h_ace_feats)
|
| 709 |
-
|
| 710 |
-
feat_time_restored = F.interpolate(
|
| 711 |
-
dec_feat, size=(H, dec_feat.shape[-1]), mode="bilinear", align_corners=False
|
| 712 |
-
)
|
| 713 |
-
|
| 714 |
-
out = self.upsample_head(feat_time_restored)
|
| 715 |
-
|
| 716 |
-
return out
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
|
| 720 |
-
dim_hidden = default(dim_hidden, dim_in)
|
| 721 |
-
|
| 722 |
-
net = []
|
| 723 |
-
dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
|
| 724 |
-
|
| 725 |
-
for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
|
| 726 |
-
is_last = ind == (len(dims) - 2)
|
| 727 |
-
|
| 728 |
-
net.append(nn.Linear(layer_dim_in, layer_dim_out))
|
| 729 |
-
|
| 730 |
-
if is_last:
|
| 731 |
-
continue
|
| 732 |
-
|
| 733 |
-
net.append(activation())
|
| 734 |
-
|
| 735 |
-
return nn.Sequential(*net)
|
| 736 |
-
|
| 737 |
-
|
| 738 |
-
class MaskEstimator(Module):
|
| 739 |
-
@beartype
|
| 740 |
-
def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
|
| 741 |
-
super().__init__()
|
| 742 |
-
self.dim_inputs = dim_inputs
|
| 743 |
-
self.to_freqs = ModuleList([])
|
| 744 |
-
dim_hidden = dim * mlp_expansion_factor
|
| 745 |
-
|
| 746 |
-
for dim_in in dim_inputs:
|
| 747 |
-
net = []
|
| 748 |
-
|
| 749 |
-
mlp = nn.Sequential(
|
| 750 |
-
MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1)
|
| 751 |
-
)
|
| 752 |
-
|
| 753 |
-
self.to_freqs.append(mlp)
|
| 754 |
-
|
| 755 |
-
self.segm = SegmModel(
|
| 756 |
-
in_bands=len(dim_inputs), in_dim=dim, out_bins=sum(dim_inputs) // 4
|
| 757 |
-
)
|
| 758 |
-
|
| 759 |
-
def forward(self, x):
|
| 760 |
-
y = rearrange(x, "b t f c -> b c t f")
|
| 761 |
-
y = self.segm(y)
|
| 762 |
-
y = rearrange(y, "b c t f -> b t (f c)")
|
| 763 |
-
|
| 764 |
-
x = x.unbind(dim=-2)
|
| 765 |
-
|
| 766 |
-
outs = []
|
| 767 |
-
|
| 768 |
-
for band_features, mlp in zip(x, self.to_freqs):
|
| 769 |
-
freq_out = mlp(band_features)
|
| 770 |
-
outs.append(freq_out)
|
| 771 |
-
|
| 772 |
-
return torch.cat(outs, dim=-1) + y
|
| 773 |
-
|
| 774 |
-
|
| 775 |
-
DEFAULT_FREQS_PER_BANDS = (
|
| 776 |
-
2,
|
| 777 |
-
2,
|
| 778 |
-
2,
|
| 779 |
-
2,
|
| 780 |
-
2,
|
| 781 |
-
2,
|
| 782 |
-
2,
|
| 783 |
-
2,
|
| 784 |
-
2,
|
| 785 |
-
2,
|
| 786 |
-
2,
|
| 787 |
-
2,
|
| 788 |
-
2,
|
| 789 |
-
2,
|
| 790 |
-
2,
|
| 791 |
-
2,
|
| 792 |
-
2,
|
| 793 |
-
2,
|
| 794 |
-
2,
|
| 795 |
-
2,
|
| 796 |
-
2,
|
| 797 |
-
2,
|
| 798 |
-
2,
|
| 799 |
-
2,
|
| 800 |
-
4,
|
| 801 |
-
4,
|
| 802 |
-
4,
|
| 803 |
-
4,
|
| 804 |
-
4,
|
| 805 |
-
4,
|
| 806 |
-
4,
|
| 807 |
-
4,
|
| 808 |
-
4,
|
| 809 |
-
4,
|
| 810 |
-
4,
|
| 811 |
-
4,
|
| 812 |
-
12,
|
| 813 |
-
12,
|
| 814 |
-
12,
|
| 815 |
-
12,
|
| 816 |
-
12,
|
| 817 |
-
12,
|
| 818 |
-
12,
|
| 819 |
-
12,
|
| 820 |
-
24,
|
| 821 |
-
24,
|
| 822 |
-
24,
|
| 823 |
-
24,
|
| 824 |
-
24,
|
| 825 |
-
24,
|
| 826 |
-
24,
|
| 827 |
-
24,
|
| 828 |
-
48,
|
| 829 |
-
48,
|
| 830 |
-
48,
|
| 831 |
-
48,
|
| 832 |
-
48,
|
| 833 |
-
48,
|
| 834 |
-
48,
|
| 835 |
-
48,
|
| 836 |
-
128,
|
| 837 |
-
129,
|
| 838 |
-
)
|
| 839 |
-
|
| 840 |
-
|
| 841 |
-
class BSRoformerHyperACE(Module):
|
| 842 |
-
|
| 843 |
-
@beartype
|
| 844 |
-
def __init__(
|
| 845 |
-
self,
|
| 846 |
-
dim,
|
| 847 |
-
*,
|
| 848 |
-
depth,
|
| 849 |
-
stereo=False,
|
| 850 |
-
num_stems=1,
|
| 851 |
-
time_transformer_depth=2,
|
| 852 |
-
freq_transformer_depth=2,
|
| 853 |
-
linear_transformer_depth=0,
|
| 854 |
-
freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
|
| 855 |
-
dim_head=64,
|
| 856 |
-
heads=8,
|
| 857 |
-
attn_dropout=0.0,
|
| 858 |
-
ff_dropout=0.0,
|
| 859 |
-
flash_attn=True,
|
| 860 |
-
dim_freqs_in=1025,
|
| 861 |
-
stft_n_fft=2048,
|
| 862 |
-
stft_hop_length=512,
|
| 863 |
-
stft_win_length=2048,
|
| 864 |
-
stft_normalized=False,
|
| 865 |
-
stft_window_fn: Optional[Callable] = None,
|
| 866 |
-
mask_estimator_depth=2,
|
| 867 |
-
multi_stft_resolution_loss_weight=1.0,
|
| 868 |
-
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (
|
| 869 |
-
4096,
|
| 870 |
-
2048,
|
| 871 |
-
1024,
|
| 872 |
-
512,
|
| 873 |
-
256,
|
| 874 |
-
),
|
| 875 |
-
multi_stft_hop_size=147,
|
| 876 |
-
multi_stft_normalized=False,
|
| 877 |
-
multi_stft_window_fn: Callable = torch.hann_window,
|
| 878 |
-
mlp_expansion_factor=4,
|
| 879 |
-
use_torch_checkpoint=False,
|
| 880 |
-
skip_connection=False,
|
| 881 |
-
sage_attention=False,
|
| 882 |
-
):
|
| 883 |
-
super().__init__()
|
| 884 |
-
|
| 885 |
-
self.stereo = stereo
|
| 886 |
-
self.audio_channels = 2 if stereo else 1
|
| 887 |
-
self.num_stems = num_stems
|
| 888 |
-
self.use_torch_checkpoint = use_torch_checkpoint
|
| 889 |
-
self.skip_connection = skip_connection
|
| 890 |
-
|
| 891 |
-
self.layers = ModuleList([])
|
| 892 |
-
|
| 893 |
-
if sage_attention:
|
| 894 |
-
print("Use Sage Attention")
|
| 895 |
-
|
| 896 |
-
transformer_kwargs = dict(
|
| 897 |
-
dim=dim,
|
| 898 |
-
heads=heads,
|
| 899 |
-
dim_head=dim_head,
|
| 900 |
-
attn_dropout=attn_dropout,
|
| 901 |
-
ff_dropout=ff_dropout,
|
| 902 |
-
flash_attn=flash_attn,
|
| 903 |
-
norm_output=False,
|
| 904 |
-
sage_attention=sage_attention,
|
| 905 |
-
)
|
| 906 |
-
|
| 907 |
-
time_rotary_embed = RotaryEmbedding(dim=dim_head)
|
| 908 |
-
freq_rotary_embed = RotaryEmbedding(dim=dim_head)
|
| 909 |
-
|
| 910 |
-
for _ in range(depth):
|
| 911 |
-
tran_modules = []
|
| 912 |
-
tran_modules.append(
|
| 913 |
-
Transformer(
|
| 914 |
-
depth=time_transformer_depth,
|
| 915 |
-
rotary_embed=time_rotary_embed,
|
| 916 |
-
**transformer_kwargs,
|
| 917 |
-
)
|
| 918 |
-
)
|
| 919 |
-
tran_modules.append(
|
| 920 |
-
Transformer(
|
| 921 |
-
depth=freq_transformer_depth,
|
| 922 |
-
rotary_embed=freq_rotary_embed,
|
| 923 |
-
**transformer_kwargs,
|
| 924 |
-
)
|
| 925 |
-
)
|
| 926 |
-
self.layers.append(nn.ModuleList(tran_modules))
|
| 927 |
-
|
| 928 |
-
self.final_norm = RMSNorm(dim)
|
| 929 |
-
|
| 930 |
-
self.stft_kwargs = dict(
|
| 931 |
-
n_fft=stft_n_fft,
|
| 932 |
-
hop_length=stft_hop_length,
|
| 933 |
-
win_length=stft_win_length,
|
| 934 |
-
normalized=stft_normalized,
|
| 935 |
-
)
|
| 936 |
-
|
| 937 |
-
self.stft_window_fn = partial(
|
| 938 |
-
default(stft_window_fn, torch.hann_window), stft_win_length
|
| 939 |
-
)
|
| 940 |
-
|
| 941 |
-
freqs = torch.stft(
|
| 942 |
-
torch.randn(1, 4096),
|
| 943 |
-
**self.stft_kwargs,
|
| 944 |
-
window=torch.ones(stft_win_length),
|
| 945 |
-
return_complex=True,
|
| 946 |
-
).shape[1]
|
| 947 |
-
|
| 948 |
-
assert len(freqs_per_bands) > 1
|
| 949 |
-
assert (
|
| 950 |
-
sum(freqs_per_bands) == freqs
|
| 951 |
-
), f"the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}"
|
| 952 |
-
|
| 953 |
-
freqs_per_bands_with_complex = tuple(
|
| 954 |
-
2 * f * self.audio_channels for f in freqs_per_bands
|
| 955 |
-
)
|
| 956 |
-
|
| 957 |
-
self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
|
| 958 |
-
|
| 959 |
-
self.mask_estimators = nn.ModuleList([])
|
| 960 |
-
|
| 961 |
-
for _ in range(num_stems):
|
| 962 |
-
mask_estimator = MaskEstimator(
|
| 963 |
-
dim=dim,
|
| 964 |
-
dim_inputs=freqs_per_bands_with_complex,
|
| 965 |
-
depth=mask_estimator_depth,
|
| 966 |
-
mlp_expansion_factor=mlp_expansion_factor,
|
| 967 |
-
)
|
| 968 |
-
|
| 969 |
-
self.mask_estimators.append(mask_estimator)
|
| 970 |
-
|
| 971 |
-
self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
|
| 972 |
-
self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
|
| 973 |
-
self.multi_stft_n_fft = stft_n_fft
|
| 974 |
-
self.multi_stft_window_fn = multi_stft_window_fn
|
| 975 |
-
|
| 976 |
-
self.multi_stft_kwargs = dict(
|
| 977 |
-
hop_length=multi_stft_hop_size, normalized=multi_stft_normalized
|
| 978 |
-
)
|
| 979 |
-
|
| 980 |
-
def forward(self, raw_audio, target=None, return_loss_breakdown=False):
|
| 981 |
-
|
| 982 |
-
device = raw_audio.device
|
| 983 |
-
|
| 984 |
-
x_is_mps = True if device.type == "mps" else False
|
| 985 |
-
|
| 986 |
-
if raw_audio.ndim == 2:
|
| 987 |
-
raw_audio = rearrange(raw_audio, "b t -> b 1 t")
|
| 988 |
-
|
| 989 |
-
channels = raw_audio.shape[1]
|
| 990 |
-
assert (not self.stereo and channels == 1) or (
|
| 991 |
-
self.stereo and channels == 2
|
| 992 |
-
), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
|
| 993 |
-
|
| 994 |
-
raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
|
| 995 |
-
|
| 996 |
-
stft_window = self.stft_window_fn(device=device)
|
| 997 |
-
|
| 998 |
-
try:
|
| 999 |
-
stft_repr = torch.stft(
|
| 1000 |
-
raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True
|
| 1001 |
-
)
|
| 1002 |
-
except:
|
| 1003 |
-
stft_repr = torch.stft(
|
| 1004 |
-
raw_audio.cpu() if x_is_mps else raw_audio,
|
| 1005 |
-
**self.stft_kwargs,
|
| 1006 |
-
window=stft_window.cpu() if x_is_mps else stft_window,
|
| 1007 |
-
return_complex=True,
|
| 1008 |
-
).to(device)
|
| 1009 |
-
stft_repr = torch.view_as_real(stft_repr)
|
| 1010 |
-
|
| 1011 |
-
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
|
| 1012 |
-
|
| 1013 |
-
stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c")
|
| 1014 |
-
|
| 1015 |
-
x = rearrange(stft_repr, "b f t c -> b t (f c)")
|
| 1016 |
-
|
| 1017 |
-
x = self.band_split(x)
|
| 1018 |
-
|
| 1019 |
-
for i, transformer_block in enumerate(self.layers):
|
| 1020 |
-
|
| 1021 |
-
time_transformer, freq_transformer = transformer_block
|
| 1022 |
-
|
| 1023 |
-
x = rearrange(x, "b t f d -> b f t d")
|
| 1024 |
-
x, ps = pack([x], "* t d")
|
| 1025 |
-
|
| 1026 |
-
x = time_transformer(x)
|
| 1027 |
-
|
| 1028 |
-
(x,) = unpack(x, ps, "* t d")
|
| 1029 |
-
x = rearrange(x, "b f t d -> b t f d")
|
| 1030 |
-
x, ps = pack([x], "* f d")
|
| 1031 |
-
|
| 1032 |
-
x = freq_transformer(x)
|
| 1033 |
-
|
| 1034 |
-
(x,) = unpack(x, ps, "* f d")
|
| 1035 |
-
|
| 1036 |
-
x = self.final_norm(x)
|
| 1037 |
-
|
| 1038 |
-
num_stems = len(self.mask_estimators)
|
| 1039 |
-
|
| 1040 |
-
mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
|
| 1041 |
-
mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2)
|
| 1042 |
-
|
| 1043 |
-
stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
|
| 1044 |
-
|
| 1045 |
-
stft_repr = torch.view_as_complex(stft_repr)
|
| 1046 |
-
mask = torch.view_as_complex(mask)
|
| 1047 |
-
|
| 1048 |
-
stft_repr = stft_repr * mask
|
| 1049 |
-
|
| 1050 |
-
stft_repr = rearrange(
|
| 1051 |
-
stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels
|
| 1052 |
-
)
|
| 1053 |
-
|
| 1054 |
-
try:
|
| 1055 |
-
recon_audio = torch.istft(
|
| 1056 |
-
stft_repr,
|
| 1057 |
-
**self.stft_kwargs,
|
| 1058 |
-
window=stft_window,
|
| 1059 |
-
return_complex=False,
|
| 1060 |
-
length=raw_audio.shape[-1],
|
| 1061 |
-
)
|
| 1062 |
-
except:
|
| 1063 |
-
recon_audio = torch.istft(
|
| 1064 |
-
stft_repr.cpu() if x_is_mps else stft_repr,
|
| 1065 |
-
**self.stft_kwargs,
|
| 1066 |
-
window=stft_window.cpu() if x_is_mps else stft_window,
|
| 1067 |
-
return_complex=False,
|
| 1068 |
-
length=raw_audio.shape[-1],
|
| 1069 |
-
).to(device)
|
| 1070 |
-
|
| 1071 |
-
recon_audio = rearrange(
|
| 1072 |
-
recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=num_stems
|
| 1073 |
-
)
|
| 1074 |
-
|
| 1075 |
-
if num_stems == 1:
|
| 1076 |
-
recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
|
| 1077 |
-
|
| 1078 |
-
if not exists(target):
|
| 1079 |
-
return recon_audio
|
| 1080 |
-
|
| 1081 |
-
if self.num_stems > 1:
|
| 1082 |
-
assert target.ndim == 4 and target.shape[1] == self.num_stems
|
| 1083 |
-
|
| 1084 |
-
if target.ndim == 2:
|
| 1085 |
-
target = rearrange(target, "... t -> ... 1 t")
|
| 1086 |
-
|
| 1087 |
-
target = target[..., : recon_audio.shape[-1]]
|
| 1088 |
-
|
| 1089 |
-
loss = F.l1_loss(recon_audio, target)
|
| 1090 |
-
|
| 1091 |
-
multi_stft_resolution_loss = 0.0
|
| 1092 |
-
|
| 1093 |
-
for window_size in self.multi_stft_resolutions_window_sizes:
|
| 1094 |
-
res_stft_kwargs = dict(
|
| 1095 |
-
n_fft=max(window_size, self.multi_stft_n_fft),
|
| 1096 |
-
win_length=window_size,
|
| 1097 |
-
return_complex=True,
|
| 1098 |
-
window=self.multi_stft_window_fn(window_size, device=device),
|
| 1099 |
-
**self.multi_stft_kwargs,
|
| 1100 |
-
)
|
| 1101 |
-
|
| 1102 |
-
recon_Y = torch.stft(
|
| 1103 |
-
rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs
|
| 1104 |
-
)
|
| 1105 |
-
target_Y = torch.stft(
|
| 1106 |
-
rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs
|
| 1107 |
-
)
|
| 1108 |
-
|
| 1109 |
-
multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(
|
| 1110 |
-
recon_Y, target_Y
|
| 1111 |
-
)
|
| 1112 |
-
|
| 1113 |
-
weighted_multi_resolution_loss = (
|
| 1114 |
-
multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
|
| 1115 |
-
)
|
| 1116 |
-
|
| 1117 |
-
total_loss = loss + weighted_multi_resolution_loss
|
| 1118 |
-
|
| 1119 |
-
if not return_loss_breakdown:
|
| 1120 |
-
return total_loss
|
| 1121 |
-
|
| 1122 |
-
return total_loss, (loss, multi_stft_resolution_loss)
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn, einsum, Tensor
|
| 5 |
+
from torch.nn import Module, ModuleList
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from .attend import Attend
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
from .attend_sage import Attend as AttendSage
|
| 12 |
+
except:
|
| 13 |
+
pass
|
| 14 |
+
from torch.utils.checkpoint import checkpoint
|
| 15 |
+
|
| 16 |
+
from beartype.typing import Tuple, Optional, List, Callable
|
| 17 |
+
from beartype import beartype
|
| 18 |
+
|
| 19 |
+
from rotary_embedding_torch import RotaryEmbedding
|
| 20 |
+
|
| 21 |
+
from einops import rearrange, pack, unpack
|
| 22 |
+
from einops.layers.torch import Rearrange
|
| 23 |
+
import torchaudio
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def exists(val):
|
| 27 |
+
return val is not None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def default(v, d):
|
| 31 |
+
return v if exists(v) else d
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def pack_one(t, pattern):
|
| 35 |
+
return pack([t], pattern)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def unpack_one(t, ps, pattern):
|
| 39 |
+
return unpack(t, ps, pattern)[0]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def l2norm(t):
|
| 43 |
+
return F.normalize(t, dim=-1, p=2)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class RMSNorm(Module):
|
| 47 |
+
def __init__(self, dim):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.scale = dim**0.5
|
| 50 |
+
self.gamma = nn.Parameter(torch.ones(dim))
|
| 51 |
+
|
| 52 |
+
def forward(self, x):
|
| 53 |
+
return F.normalize(x, dim=-1) * self.scale * self.gamma
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class FeedForward(Module):
|
| 57 |
+
def __init__(self, dim, mult=4, dropout=0.0):
|
| 58 |
+
super().__init__()
|
| 59 |
+
dim_inner = int(dim * mult)
|
| 60 |
+
self.net = nn.Sequential(
|
| 61 |
+
RMSNorm(dim),
|
| 62 |
+
nn.Linear(dim, dim_inner),
|
| 63 |
+
nn.GELU(),
|
| 64 |
+
nn.Dropout(dropout),
|
| 65 |
+
nn.Linear(dim_inner, dim),
|
| 66 |
+
nn.Dropout(dropout),
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
def forward(self, x):
|
| 70 |
+
return self.net(x)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class Attention(Module):
|
| 74 |
+
def __init__(
|
| 75 |
+
self,
|
| 76 |
+
dim,
|
| 77 |
+
heads=8,
|
| 78 |
+
dim_head=64,
|
| 79 |
+
dropout=0.0,
|
| 80 |
+
rotary_embed=None,
|
| 81 |
+
flash=True,
|
| 82 |
+
sage_attention=False,
|
| 83 |
+
):
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.heads = heads
|
| 86 |
+
self.scale = dim_head**-0.5
|
| 87 |
+
dim_inner = heads * dim_head
|
| 88 |
+
|
| 89 |
+
self.rotary_embed = rotary_embed
|
| 90 |
+
|
| 91 |
+
if sage_attention:
|
| 92 |
+
self.attend = AttendSage(flash=flash, dropout=dropout)
|
| 93 |
+
else:
|
| 94 |
+
self.attend = Attend(flash=flash, dropout=dropout)
|
| 95 |
+
|
| 96 |
+
self.norm = RMSNorm(dim)
|
| 97 |
+
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
|
| 98 |
+
|
| 99 |
+
self.to_gates = nn.Linear(dim, heads)
|
| 100 |
+
|
| 101 |
+
self.to_out = nn.Sequential(
|
| 102 |
+
nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout)
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
def forward(self, x):
|
| 106 |
+
x = self.norm(x)
|
| 107 |
+
|
| 108 |
+
q, k, v = rearrange(
|
| 109 |
+
self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
if exists(self.rotary_embed):
|
| 113 |
+
q = self.rotary_embed.rotate_queries_or_keys(q)
|
| 114 |
+
k = self.rotary_embed.rotate_queries_or_keys(k)
|
| 115 |
+
|
| 116 |
+
out = self.attend(q, k, v)
|
| 117 |
+
|
| 118 |
+
gates = self.to_gates(x)
|
| 119 |
+
out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
|
| 120 |
+
|
| 121 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
| 122 |
+
return self.to_out(out)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class LinearAttention(Module):
|
| 126 |
+
|
| 127 |
+
@beartype
|
| 128 |
+
def __init__(
|
| 129 |
+
self,
|
| 130 |
+
*,
|
| 131 |
+
dim,
|
| 132 |
+
dim_head=32,
|
| 133 |
+
heads=8,
|
| 134 |
+
scale=8,
|
| 135 |
+
flash=True,
|
| 136 |
+
dropout=0.0,
|
| 137 |
+
sage_attention=False,
|
| 138 |
+
):
|
| 139 |
+
super().__init__()
|
| 140 |
+
dim_inner = dim_head * heads
|
| 141 |
+
self.norm = RMSNorm(dim)
|
| 142 |
+
|
| 143 |
+
self.to_qkv = nn.Sequential(
|
| 144 |
+
nn.Linear(dim, dim_inner * 3, bias=False),
|
| 145 |
+
Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads),
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
|
| 149 |
+
|
| 150 |
+
if sage_attention:
|
| 151 |
+
self.attend = AttendSage(scale=scale, dropout=dropout, flash=flash)
|
| 152 |
+
else:
|
| 153 |
+
self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
|
| 154 |
+
|
| 155 |
+
self.to_out = nn.Sequential(
|
| 156 |
+
Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
def forward(self, x):
|
| 160 |
+
x = self.norm(x)
|
| 161 |
+
|
| 162 |
+
q, k, v = self.to_qkv(x)
|
| 163 |
+
|
| 164 |
+
q, k = map(l2norm, (q, k))
|
| 165 |
+
q = q * self.temperature.exp()
|
| 166 |
+
|
| 167 |
+
out = self.attend(q, k, v)
|
| 168 |
+
|
| 169 |
+
return self.to_out(out)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class Transformer(Module):
|
| 173 |
+
def __init__(
|
| 174 |
+
self,
|
| 175 |
+
*,
|
| 176 |
+
dim,
|
| 177 |
+
depth,
|
| 178 |
+
dim_head=64,
|
| 179 |
+
heads=8,
|
| 180 |
+
attn_dropout=0.0,
|
| 181 |
+
ff_dropout=0.0,
|
| 182 |
+
ff_mult=4,
|
| 183 |
+
norm_output=True,
|
| 184 |
+
rotary_embed=None,
|
| 185 |
+
flash_attn=True,
|
| 186 |
+
linear_attn=False,
|
| 187 |
+
sage_attention=False,
|
| 188 |
+
):
|
| 189 |
+
super().__init__()
|
| 190 |
+
self.layers = ModuleList([])
|
| 191 |
+
|
| 192 |
+
for _ in range(depth):
|
| 193 |
+
if linear_attn:
|
| 194 |
+
attn = LinearAttention(
|
| 195 |
+
dim=dim,
|
| 196 |
+
dim_head=dim_head,
|
| 197 |
+
heads=heads,
|
| 198 |
+
dropout=attn_dropout,
|
| 199 |
+
flash=flash_attn,
|
| 200 |
+
sage_attention=sage_attention,
|
| 201 |
+
)
|
| 202 |
+
else:
|
| 203 |
+
attn = Attention(
|
| 204 |
+
dim=dim,
|
| 205 |
+
dim_head=dim_head,
|
| 206 |
+
heads=heads,
|
| 207 |
+
dropout=attn_dropout,
|
| 208 |
+
rotary_embed=rotary_embed,
|
| 209 |
+
flash=flash_attn,
|
| 210 |
+
sage_attention=sage_attention,
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
self.layers.append(
|
| 214 |
+
ModuleList(
|
| 215 |
+
[attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]
|
| 216 |
+
)
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
self.norm = RMSNorm(dim) if norm_output else nn.Identity()
|
| 220 |
+
|
| 221 |
+
def forward(self, x):
|
| 222 |
+
|
| 223 |
+
for attn, ff in self.layers:
|
| 224 |
+
x = attn(x) + x
|
| 225 |
+
x = ff(x) + x
|
| 226 |
+
|
| 227 |
+
return self.norm(x)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class BandSplit(Module):
|
| 231 |
+
@beartype
|
| 232 |
+
def __init__(self, dim, dim_inputs: Tuple[int, ...]):
|
| 233 |
+
super().__init__()
|
| 234 |
+
self.dim_inputs = dim_inputs
|
| 235 |
+
self.to_features = ModuleList([])
|
| 236 |
+
|
| 237 |
+
for dim_in in dim_inputs:
|
| 238 |
+
net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
|
| 239 |
+
|
| 240 |
+
self.to_features.append(net)
|
| 241 |
+
|
| 242 |
+
def forward(self, x):
|
| 243 |
+
|
| 244 |
+
x = x.split(self.dim_inputs, dim=-1)
|
| 245 |
+
|
| 246 |
+
outs = []
|
| 247 |
+
for split_input, to_feature in zip(x, self.to_features):
|
| 248 |
+
split_output = to_feature(split_input)
|
| 249 |
+
outs.append(split_output)
|
| 250 |
+
|
| 251 |
+
x = torch.stack(outs, dim=-2)
|
| 252 |
+
|
| 253 |
+
return x
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
class Conv(nn.Module):
|
| 257 |
+
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):
|
| 258 |
+
super().__init__()
|
| 259 |
+
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
|
| 260 |
+
self.bn = nn.InstanceNorm2d(c2, affine=True, eps=1e-8)
|
| 261 |
+
self.act = nn.SiLU() if act else nn.Identity()
|
| 262 |
+
|
| 263 |
+
def forward(self, x):
|
| 264 |
+
return self.act(self.bn(self.conv(x)))
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def autopad(k, p=None):
|
| 268 |
+
if p is None:
|
| 269 |
+
p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
|
| 270 |
+
return p
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class DSConv(nn.Module):
|
| 274 |
+
def __init__(self, c1, c2, k=3, s=1, p=None, act=True):
|
| 275 |
+
super().__init__()
|
| 276 |
+
self.dwconv = nn.Conv2d(c1, c1, k, s, autopad(k, p), groups=c1, bias=False)
|
| 277 |
+
self.pwconv = nn.Conv2d(c1, c2, 1, 1, 0, bias=False)
|
| 278 |
+
self.bn = nn.InstanceNorm2d(c2, affine=True, eps=1e-8)
|
| 279 |
+
self.act = nn.SiLU() if act else nn.Identity()
|
| 280 |
+
|
| 281 |
+
def forward(self, x):
|
| 282 |
+
return self.act(self.bn(self.pwconv(self.dwconv(x))))
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
class DS_Bottleneck(nn.Module):
|
| 286 |
+
def __init__(self, c1, c2, k=3, shortcut=True):
|
| 287 |
+
super().__init__()
|
| 288 |
+
c_ = c1
|
| 289 |
+
self.dsconv1 = DSConv(c1, c_, k=3, s=1)
|
| 290 |
+
self.dsconv2 = DSConv(c_, c2, k=k, s=1)
|
| 291 |
+
self.shortcut = shortcut and c1 == c2
|
| 292 |
+
|
| 293 |
+
def forward(self, x):
|
| 294 |
+
return (
|
| 295 |
+
x + self.dsconv2(self.dsconv1(x))
|
| 296 |
+
if self.shortcut
|
| 297 |
+
else self.dsconv2(self.dsconv1(x))
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
class DS_C3k(nn.Module):
|
| 302 |
+
def __init__(self, c1, c2, n=1, k=3, e=0.5):
|
| 303 |
+
super().__init__()
|
| 304 |
+
c_ = int(c2 * e)
|
| 305 |
+
self.cv1 = Conv(c1, c_, 1, 1)
|
| 306 |
+
self.cv2 = Conv(c1, c_, 1, 1)
|
| 307 |
+
self.cv3 = Conv(2 * c_, c2, 1, 1)
|
| 308 |
+
self.m = nn.Sequential(
|
| 309 |
+
*[DS_Bottleneck(c_, c_, k=k, shortcut=True) for _ in range(n)]
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
def forward(self, x):
|
| 313 |
+
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
class DS_C3k2(nn.Module):
|
| 317 |
+
def __init__(self, c1, c2, n=1, k=3, e=0.5):
|
| 318 |
+
super().__init__()
|
| 319 |
+
c_ = int(c2 * e)
|
| 320 |
+
self.cv1 = Conv(c1, c_, 1, 1)
|
| 321 |
+
self.m = DS_C3k(c_, c_, n=n, k=k, e=1.0)
|
| 322 |
+
self.cv2 = Conv(c_, c2, 1, 1)
|
| 323 |
+
|
| 324 |
+
def forward(self, x):
|
| 325 |
+
x_ = self.cv1(x)
|
| 326 |
+
x_ = self.m(x_)
|
| 327 |
+
return self.cv2(x_)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
class AdaptiveHyperedgeGeneration(nn.Module):
|
| 331 |
+
def __init__(self, in_channels, num_hyperedges, num_heads=8):
|
| 332 |
+
super().__init__()
|
| 333 |
+
self.num_hyperedges = num_hyperedges
|
| 334 |
+
self.num_heads = num_heads
|
| 335 |
+
self.head_dim = in_channels // num_heads
|
| 336 |
+
|
| 337 |
+
self.global_proto = nn.Parameter(torch.randn(num_hyperedges, in_channels))
|
| 338 |
+
|
| 339 |
+
self.context_mapper = nn.Linear(
|
| 340 |
+
2 * in_channels, num_hyperedges * in_channels, bias=False
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
self.query_proj = nn.Linear(in_channels, in_channels, bias=False)
|
| 344 |
+
|
| 345 |
+
self.scale = self.head_dim**-0.5
|
| 346 |
+
|
| 347 |
+
def forward(self, x):
|
| 348 |
+
B, N, C = x.shape
|
| 349 |
+
|
| 350 |
+
f_avg = F.adaptive_avg_pool1d(x.permute(0, 2, 1), 1).squeeze(-1)
|
| 351 |
+
f_max = F.adaptive_max_pool1d(x.permute(0, 2, 1), 1).squeeze(-1)
|
| 352 |
+
f_ctx = torch.cat((f_avg, f_max), dim=1)
|
| 353 |
+
|
| 354 |
+
delta_P = self.context_mapper(f_ctx).view(B, self.num_hyperedges, C)
|
| 355 |
+
P = self.global_proto.unsqueeze(0) + delta_P
|
| 356 |
+
|
| 357 |
+
z = self.query_proj(x)
|
| 358 |
+
|
| 359 |
+
z = z.view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
| 360 |
+
|
| 361 |
+
P = P.view(B, self.num_hyperedges, self.num_heads, self.head_dim).permute(
|
| 362 |
+
0, 2, 3, 1
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
sim = (z @ P) * self.scale
|
| 366 |
+
|
| 367 |
+
s_bar = sim.mean(dim=1)
|
| 368 |
+
|
| 369 |
+
A = F.softmax(s_bar.permute(0, 2, 1), dim=-1)
|
| 370 |
+
|
| 371 |
+
return A
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
class HypergraphConvolution(nn.Module):
|
| 375 |
+
def __init__(self, in_channels, out_channels):
|
| 376 |
+
super().__init__()
|
| 377 |
+
self.W_e = nn.Linear(in_channels, in_channels, bias=False)
|
| 378 |
+
self.W_v = nn.Linear(in_channels, out_channels, bias=False)
|
| 379 |
+
self.act = nn.SiLU()
|
| 380 |
+
|
| 381 |
+
def forward(self, x, A):
|
| 382 |
+
f_m = torch.bmm(A, x)
|
| 383 |
+
f_m = self.act(self.W_e(f_m))
|
| 384 |
+
|
| 385 |
+
x_out = torch.bmm(A.transpose(1, 2), f_m)
|
| 386 |
+
x_out = self.act(self.W_v(x_out))
|
| 387 |
+
|
| 388 |
+
return x + x_out
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
class AdaptiveHypergraphComputation(nn.Module):
|
| 392 |
+
def __init__(self, in_channels, out_channels, num_hyperedges=8, num_heads=8):
|
| 393 |
+
super().__init__()
|
| 394 |
+
self.adaptive_hyperedge_gen = AdaptiveHyperedgeGeneration(
|
| 395 |
+
in_channels, num_hyperedges, num_heads
|
| 396 |
+
)
|
| 397 |
+
self.hypergraph_conv = HypergraphConvolution(in_channels, out_channels)
|
| 398 |
+
|
| 399 |
+
def forward(self, x):
|
| 400 |
+
B, C, H, W = x.shape
|
| 401 |
+
x_flat = x.flatten(2).permute(0, 2, 1)
|
| 402 |
+
|
| 403 |
+
A = self.adaptive_hyperedge_gen(x_flat)
|
| 404 |
+
|
| 405 |
+
x_out_flat = self.hypergraph_conv(x_flat, A)
|
| 406 |
+
|
| 407 |
+
x_out = x_out_flat.permute(0, 2, 1).view(B, -1, H, W)
|
| 408 |
+
return x_out
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
class C3AH(nn.Module):
|
| 412 |
+
def __init__(self, c1, c2, num_hyperedges=8, num_heads=8, e=0.5):
|
| 413 |
+
super().__init__()
|
| 414 |
+
c_ = int(c1 * e)
|
| 415 |
+
self.cv1 = Conv(c1, c_, 1, 1)
|
| 416 |
+
self.cv2 = Conv(c1, c_, 1, 1)
|
| 417 |
+
self.ahc = AdaptiveHypergraphComputation(c_, c_, num_hyperedges, num_heads)
|
| 418 |
+
self.cv3 = Conv(2 * c_, c2, 1, 1)
|
| 419 |
+
|
| 420 |
+
def forward(self, x):
|
| 421 |
+
x_lateral = self.cv1(x)
|
| 422 |
+
x_ahc = self.ahc(self.cv2(x))
|
| 423 |
+
return self.cv3(torch.cat((x_ahc, x_lateral), dim=1))
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
class HyperACE(nn.Module):
|
| 427 |
+
def __init__(
|
| 428 |
+
self,
|
| 429 |
+
in_channels: List[int],
|
| 430 |
+
out_channels: int,
|
| 431 |
+
num_hyperedges=8,
|
| 432 |
+
num_heads=8,
|
| 433 |
+
k=2,
|
| 434 |
+
l=1,
|
| 435 |
+
c_h=0.5,
|
| 436 |
+
c_l=0.25,
|
| 437 |
+
):
|
| 438 |
+
super().__init__()
|
| 439 |
+
|
| 440 |
+
c2, c3, c4, c5 = in_channels
|
| 441 |
+
c_mid = c4
|
| 442 |
+
|
| 443 |
+
self.fuse_conv = Conv(c2 + c3 + c4 + c5, c_mid, 1, 1)
|
| 444 |
+
|
| 445 |
+
self.c_h = int(c_mid * c_h)
|
| 446 |
+
self.c_l = int(c_mid * c_l)
|
| 447 |
+
self.c_s = c_mid - self.c_h - self.c_l
|
| 448 |
+
assert self.c_s > 0, "Channel split error"
|
| 449 |
+
|
| 450 |
+
self.high_order_branch = nn.ModuleList(
|
| 451 |
+
[
|
| 452 |
+
C3AH(self.c_h, self.c_h, num_hyperedges, num_heads, e=1.0)
|
| 453 |
+
for _ in range(k)
|
| 454 |
+
]
|
| 455 |
+
)
|
| 456 |
+
self.high_order_fuse = Conv(self.c_h * k, self.c_h, 1, 1)
|
| 457 |
+
|
| 458 |
+
self.low_order_branch = nn.Sequential(
|
| 459 |
+
*[DS_C3k(self.c_l, self.c_l, n=1, k=3, e=1.0) for _ in range(l)]
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
self.final_fuse = Conv(self.c_h + self.c_l + self.c_s, out_channels, 1, 1)
|
| 463 |
+
|
| 464 |
+
def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
|
| 465 |
+
B2, B3, B4, B5 = x
|
| 466 |
+
|
| 467 |
+
B, _, H4, W4 = B4.shape
|
| 468 |
+
|
| 469 |
+
B2_resized = F.interpolate(
|
| 470 |
+
B2, size=(H4, W4), mode="bilinear", align_corners=False
|
| 471 |
+
)
|
| 472 |
+
B3_resized = F.interpolate(
|
| 473 |
+
B3, size=(H4, W4), mode="bilinear", align_corners=False
|
| 474 |
+
)
|
| 475 |
+
B5_resized = F.interpolate(
|
| 476 |
+
B5, size=(H4, W4), mode="bilinear", align_corners=False
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
x_b = self.fuse_conv(torch.cat((B2_resized, B3_resized, B4, B5_resized), dim=1))
|
| 480 |
+
|
| 481 |
+
x_h, x_l, x_s = torch.split(x_b, [self.c_h, self.c_l, self.c_s], dim=1)
|
| 482 |
+
|
| 483 |
+
x_h_outs = [m(x_h) for m in self.high_order_branch]
|
| 484 |
+
x_h_fused = self.high_order_fuse(torch.cat(x_h_outs, dim=1))
|
| 485 |
+
|
| 486 |
+
x_l_out = self.low_order_branch(x_l)
|
| 487 |
+
|
| 488 |
+
y = self.final_fuse(torch.cat((x_h_fused, x_l_out, x_s), dim=1))
|
| 489 |
+
|
| 490 |
+
return y
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
class GatedFusion(nn.Module):
|
| 494 |
+
def __init__(self, in_channels):
|
| 495 |
+
super().__init__()
|
| 496 |
+
self.gamma = nn.Parameter(torch.zeros(1, in_channels, 1, 1))
|
| 497 |
+
|
| 498 |
+
def forward(self, f_in, h):
|
| 499 |
+
if f_in.shape[1] != h.shape[1]:
|
| 500 |
+
raise ValueError(f"Channel mismatch: f_in={f_in.shape}, h={h.shape}")
|
| 501 |
+
return f_in + self.gamma * h
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
class Backbone(nn.Module):
|
| 505 |
+
def __init__(self, in_channels=256, base_channels=64, base_depth=3):
|
| 506 |
+
super().__init__()
|
| 507 |
+
c = base_channels
|
| 508 |
+
c2 = base_channels
|
| 509 |
+
c3 = 256
|
| 510 |
+
c4 = 384
|
| 511 |
+
c5 = 512
|
| 512 |
+
c6 = 768
|
| 513 |
+
|
| 514 |
+
self.stem = DSConv(in_channels, c2, k=3, s=(2, 1), p=1)
|
| 515 |
+
|
| 516 |
+
self.p2 = nn.Sequential(
|
| 517 |
+
DSConv(c2, c3, k=3, s=(2, 1), p=1), DS_C3k2(c3, c3, n=base_depth)
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
self.p3 = nn.Sequential(
|
| 521 |
+
DSConv(c3, c4, k=3, s=(2, 1), p=1), DS_C3k2(c4, c4, n=base_depth * 2)
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
self.p4 = nn.Sequential(
|
| 525 |
+
DSConv(c4, c5, k=3, s=(2, 1), p=1), DS_C3k2(c5, c5, n=base_depth * 2)
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
self.p5 = nn.Sequential(
|
| 529 |
+
DSConv(c5, c6, k=3, s=(2, 1), p=1), DS_C3k2(c6, c6, n=base_depth)
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
self.out_channels = [c3, c4, c5, c6]
|
| 533 |
+
|
| 534 |
+
def forward(self, x):
|
| 535 |
+
x = self.stem(x)
|
| 536 |
+
x2 = self.p2(x)
|
| 537 |
+
x3 = self.p3(x2)
|
| 538 |
+
x4 = self.p4(x3)
|
| 539 |
+
x5 = self.p5(x4)
|
| 540 |
+
return [x2, x3, x4, x5]
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
class Decoder(nn.Module):
|
| 544 |
+
def __init__(
|
| 545 |
+
self,
|
| 546 |
+
encoder_channels: List[int],
|
| 547 |
+
hyperace_out_c: int,
|
| 548 |
+
decoder_channels: List[int],
|
| 549 |
+
):
|
| 550 |
+
super().__init__()
|
| 551 |
+
c_p2, c_p3, c_p4, c_p5 = encoder_channels
|
| 552 |
+
c_d2, c_d3, c_d4, c_d5 = decoder_channels
|
| 553 |
+
|
| 554 |
+
self.h_to_d5 = Conv(hyperace_out_c, c_d5, 1, 1)
|
| 555 |
+
self.h_to_d4 = Conv(hyperace_out_c, c_d4, 1, 1)
|
| 556 |
+
self.h_to_d3 = Conv(hyperace_out_c, c_d3, 1, 1)
|
| 557 |
+
self.h_to_d2 = Conv(hyperace_out_c, c_d2, 1, 1)
|
| 558 |
+
|
| 559 |
+
self.fusion_d5 = GatedFusion(c_d5)
|
| 560 |
+
self.fusion_d4 = GatedFusion(c_d4)
|
| 561 |
+
self.fusion_d3 = GatedFusion(c_d3)
|
| 562 |
+
self.fusion_d2 = GatedFusion(c_d2)
|
| 563 |
+
|
| 564 |
+
self.skip_p5 = Conv(c_p5, c_d5, 1, 1)
|
| 565 |
+
self.skip_p4 = Conv(c_p4, c_d4, 1, 1)
|
| 566 |
+
self.skip_p3 = Conv(c_p3, c_d3, 1, 1)
|
| 567 |
+
self.skip_p2 = Conv(c_p2, c_d2, 1, 1)
|
| 568 |
+
|
| 569 |
+
self.up_d5 = DS_C3k2(c_d5, c_d4, n=1)
|
| 570 |
+
self.up_d4 = DS_C3k2(c_d4, c_d3, n=1)
|
| 571 |
+
self.up_d3 = DS_C3k2(c_d3, c_d2, n=1)
|
| 572 |
+
|
| 573 |
+
self.final_d2 = DS_C3k2(c_d2, c_d2, n=1)
|
| 574 |
+
|
| 575 |
+
def forward(self, enc_feats: List[torch.Tensor], h_ace: torch.Tensor):
|
| 576 |
+
p2, p3, p4, p5 = enc_feats
|
| 577 |
+
|
| 578 |
+
d5 = self.skip_p5(p5)
|
| 579 |
+
h_d5 = self.h_to_d5(F.interpolate(h_ace, size=d5.shape[2:], mode="bilinear"))
|
| 580 |
+
d5 = self.fusion_d5(d5, h_d5)
|
| 581 |
+
|
| 582 |
+
d5_up = F.interpolate(d5, size=p4.shape[2:], mode="bilinear")
|
| 583 |
+
d4_skip = self.skip_p4(p4)
|
| 584 |
+
d4 = self.up_d5(d5_up) + d4_skip
|
| 585 |
+
|
| 586 |
+
h_d4 = self.h_to_d4(F.interpolate(h_ace, size=d4.shape[2:], mode="bilinear"))
|
| 587 |
+
d4 = self.fusion_d4(d4, h_d4)
|
| 588 |
+
|
| 589 |
+
d4_up = F.interpolate(d4, size=p3.shape[2:], mode="bilinear")
|
| 590 |
+
d3_skip = self.skip_p3(p3)
|
| 591 |
+
d3 = self.up_d4(d4_up) + d3_skip
|
| 592 |
+
|
| 593 |
+
h_d3 = self.h_to_d3(F.interpolate(h_ace, size=d3.shape[2:], mode="bilinear"))
|
| 594 |
+
d3 = self.fusion_d3(d3, h_d3)
|
| 595 |
+
|
| 596 |
+
d3_up = F.interpolate(d3, size=p2.shape[2:], mode="bilinear")
|
| 597 |
+
d2_skip = self.skip_p2(p2)
|
| 598 |
+
d2 = self.up_d3(d3_up) + d2_skip
|
| 599 |
+
|
| 600 |
+
h_d2 = self.h_to_d2(F.interpolate(h_ace, size=d2.shape[2:], mode="bilinear"))
|
| 601 |
+
d2 = self.fusion_d2(d2, h_d2)
|
| 602 |
+
|
| 603 |
+
d2_final = self.final_d2(d2)
|
| 604 |
+
|
| 605 |
+
return d2_final
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
class FreqPixelShuffle(nn.Module):
|
| 609 |
+
def __init__(self, in_channels, out_channels, scale=2):
|
| 610 |
+
super().__init__()
|
| 611 |
+
self.scale = scale
|
| 612 |
+
self.conv = DSConv(in_channels, out_channels * scale, k=3, s=1, p=1)
|
| 613 |
+
self.act = nn.SiLU()
|
| 614 |
+
|
| 615 |
+
def forward(self, x):
|
| 616 |
+
x = self.conv(x)
|
| 617 |
+
B, C_r, H, W = x.shape
|
| 618 |
+
out_c = C_r // self.scale
|
| 619 |
+
|
| 620 |
+
x = x.view(B, out_c, self.scale, H, W)
|
| 621 |
+
|
| 622 |
+
x = x.permute(0, 1, 3, 4, 2).contiguous()
|
| 623 |
+
x = x.view(B, out_c, H, W * self.scale)
|
| 624 |
+
|
| 625 |
+
return x
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
class ProgressiveUpsampleHead(nn.Module):
|
| 629 |
+
def __init__(self, in_channels, out_channels, target_bins=1025):
|
| 630 |
+
super().__init__()
|
| 631 |
+
self.target_bins = target_bins
|
| 632 |
+
|
| 633 |
+
c = in_channels
|
| 634 |
+
|
| 635 |
+
self.block1 = FreqPixelShuffle(c, c, scale=2)
|
| 636 |
+
self.block2 = FreqPixelShuffle(c, c // 2, scale=2)
|
| 637 |
+
self.block3 = FreqPixelShuffle(c // 2, c // 2, scale=2)
|
| 638 |
+
self.block4 = FreqPixelShuffle(c // 2, c // 4, scale=2)
|
| 639 |
+
|
| 640 |
+
self.final_conv = nn.Conv2d(c // 4, out_channels, kernel_size=1, bias=False)
|
| 641 |
+
|
| 642 |
+
def forward(self, x):
|
| 643 |
+
|
| 644 |
+
x = self.block1(x)
|
| 645 |
+
x = self.block2(x)
|
| 646 |
+
x = self.block3(x)
|
| 647 |
+
x = self.block4(x)
|
| 648 |
+
|
| 649 |
+
if x.shape[-1] != self.target_bins:
|
| 650 |
+
x = F.interpolate(
|
| 651 |
+
x,
|
| 652 |
+
size=(x.shape[2], self.target_bins),
|
| 653 |
+
mode="bilinear",
|
| 654 |
+
align_corners=False,
|
| 655 |
+
)
|
| 656 |
+
|
| 657 |
+
x = self.final_conv(x)
|
| 658 |
+
return x
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
class SegmModel(nn.Module):
|
| 662 |
+
def __init__(
|
| 663 |
+
self,
|
| 664 |
+
in_bands=62,
|
| 665 |
+
in_dim=256,
|
| 666 |
+
out_bins=1025,
|
| 667 |
+
out_channels=4,
|
| 668 |
+
base_channels=64,
|
| 669 |
+
base_depth=2,
|
| 670 |
+
num_hyperedges=16,
|
| 671 |
+
num_heads=8,
|
| 672 |
+
):
|
| 673 |
+
super().__init__()
|
| 674 |
+
|
| 675 |
+
self.backbone = Backbone(
|
| 676 |
+
in_channels=in_dim, base_channels=base_channels, base_depth=base_depth
|
| 677 |
+
)
|
| 678 |
+
enc_channels = self.backbone.out_channels
|
| 679 |
+
c2, c3, c4, c5 = enc_channels
|
| 680 |
+
|
| 681 |
+
hyperace_in_channels = enc_channels
|
| 682 |
+
hyperace_out_channels = c4
|
| 683 |
+
self.hyperace = HyperACE(
|
| 684 |
+
hyperace_in_channels,
|
| 685 |
+
hyperace_out_channels,
|
| 686 |
+
num_hyperedges,
|
| 687 |
+
num_heads,
|
| 688 |
+
k=3,
|
| 689 |
+
l=2,
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
decoder_channels = [c2, c3, c4, c5]
|
| 693 |
+
self.decoder = Decoder(enc_channels, hyperace_out_channels, decoder_channels)
|
| 694 |
+
|
| 695 |
+
self.upsample_head = ProgressiveUpsampleHead(
|
| 696 |
+
in_channels=decoder_channels[0],
|
| 697 |
+
out_channels=out_channels,
|
| 698 |
+
target_bins=out_bins,
|
| 699 |
+
)
|
| 700 |
+
|
| 701 |
+
def forward(self, x):
|
| 702 |
+
H, W = x.shape[2:]
|
| 703 |
+
|
| 704 |
+
enc_feats = self.backbone(x)
|
| 705 |
+
|
| 706 |
+
h_ace_feats = self.hyperace(enc_feats)
|
| 707 |
+
|
| 708 |
+
dec_feat = self.decoder(enc_feats, h_ace_feats)
|
| 709 |
+
|
| 710 |
+
feat_time_restored = F.interpolate(
|
| 711 |
+
dec_feat, size=(H, dec_feat.shape[-1]), mode="bilinear", align_corners=False
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
out = self.upsample_head(feat_time_restored)
|
| 715 |
+
|
| 716 |
+
return out
|
| 717 |
+
|
| 718 |
+
|
| 719 |
+
def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
|
| 720 |
+
dim_hidden = default(dim_hidden, dim_in)
|
| 721 |
+
|
| 722 |
+
net = []
|
| 723 |
+
dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
|
| 724 |
+
|
| 725 |
+
for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
|
| 726 |
+
is_last = ind == (len(dims) - 2)
|
| 727 |
+
|
| 728 |
+
net.append(nn.Linear(layer_dim_in, layer_dim_out))
|
| 729 |
+
|
| 730 |
+
if is_last:
|
| 731 |
+
continue
|
| 732 |
+
|
| 733 |
+
net.append(activation())
|
| 734 |
+
|
| 735 |
+
return nn.Sequential(*net)
|
| 736 |
+
|
| 737 |
+
|
| 738 |
+
class MaskEstimator(Module):
|
| 739 |
+
@beartype
|
| 740 |
+
def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
|
| 741 |
+
super().__init__()
|
| 742 |
+
self.dim_inputs = dim_inputs
|
| 743 |
+
self.to_freqs = ModuleList([])
|
| 744 |
+
dim_hidden = dim * mlp_expansion_factor
|
| 745 |
+
|
| 746 |
+
for dim_in in dim_inputs:
|
| 747 |
+
net = []
|
| 748 |
+
|
| 749 |
+
mlp = nn.Sequential(
|
| 750 |
+
MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1)
|
| 751 |
+
)
|
| 752 |
+
|
| 753 |
+
self.to_freqs.append(mlp)
|
| 754 |
+
|
| 755 |
+
self.segm = SegmModel(
|
| 756 |
+
in_bands=len(dim_inputs), in_dim=dim, out_bins=sum(dim_inputs) // 4
|
| 757 |
+
)
|
| 758 |
+
|
| 759 |
+
def forward(self, x):
|
| 760 |
+
y = rearrange(x, "b t f c -> b c t f")
|
| 761 |
+
y = self.segm(y)
|
| 762 |
+
y = rearrange(y, "b c t f -> b t (f c)")
|
| 763 |
+
|
| 764 |
+
x = x.unbind(dim=-2)
|
| 765 |
+
|
| 766 |
+
outs = []
|
| 767 |
+
|
| 768 |
+
for band_features, mlp in zip(x, self.to_freqs):
|
| 769 |
+
freq_out = mlp(band_features)
|
| 770 |
+
outs.append(freq_out)
|
| 771 |
+
|
| 772 |
+
return torch.cat(outs, dim=-1) + y
|
| 773 |
+
|
| 774 |
+
|
| 775 |
+
DEFAULT_FREQS_PER_BANDS = (
|
| 776 |
+
2,
|
| 777 |
+
2,
|
| 778 |
+
2,
|
| 779 |
+
2,
|
| 780 |
+
2,
|
| 781 |
+
2,
|
| 782 |
+
2,
|
| 783 |
+
2,
|
| 784 |
+
2,
|
| 785 |
+
2,
|
| 786 |
+
2,
|
| 787 |
+
2,
|
| 788 |
+
2,
|
| 789 |
+
2,
|
| 790 |
+
2,
|
| 791 |
+
2,
|
| 792 |
+
2,
|
| 793 |
+
2,
|
| 794 |
+
2,
|
| 795 |
+
2,
|
| 796 |
+
2,
|
| 797 |
+
2,
|
| 798 |
+
2,
|
| 799 |
+
2,
|
| 800 |
+
4,
|
| 801 |
+
4,
|
| 802 |
+
4,
|
| 803 |
+
4,
|
| 804 |
+
4,
|
| 805 |
+
4,
|
| 806 |
+
4,
|
| 807 |
+
4,
|
| 808 |
+
4,
|
| 809 |
+
4,
|
| 810 |
+
4,
|
| 811 |
+
4,
|
| 812 |
+
12,
|
| 813 |
+
12,
|
| 814 |
+
12,
|
| 815 |
+
12,
|
| 816 |
+
12,
|
| 817 |
+
12,
|
| 818 |
+
12,
|
| 819 |
+
12,
|
| 820 |
+
24,
|
| 821 |
+
24,
|
| 822 |
+
24,
|
| 823 |
+
24,
|
| 824 |
+
24,
|
| 825 |
+
24,
|
| 826 |
+
24,
|
| 827 |
+
24,
|
| 828 |
+
48,
|
| 829 |
+
48,
|
| 830 |
+
48,
|
| 831 |
+
48,
|
| 832 |
+
48,
|
| 833 |
+
48,
|
| 834 |
+
48,
|
| 835 |
+
48,
|
| 836 |
+
128,
|
| 837 |
+
129,
|
| 838 |
+
)
|
| 839 |
+
|
| 840 |
+
|
| 841 |
+
class BSRoformerHyperACE(Module):
|
| 842 |
+
|
| 843 |
+
@beartype
|
| 844 |
+
def __init__(
|
| 845 |
+
self,
|
| 846 |
+
dim,
|
| 847 |
+
*,
|
| 848 |
+
depth,
|
| 849 |
+
stereo=False,
|
| 850 |
+
num_stems=1,
|
| 851 |
+
time_transformer_depth=2,
|
| 852 |
+
freq_transformer_depth=2,
|
| 853 |
+
linear_transformer_depth=0,
|
| 854 |
+
freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
|
| 855 |
+
dim_head=64,
|
| 856 |
+
heads=8,
|
| 857 |
+
attn_dropout=0.0,
|
| 858 |
+
ff_dropout=0.0,
|
| 859 |
+
flash_attn=True,
|
| 860 |
+
dim_freqs_in=1025,
|
| 861 |
+
stft_n_fft=2048,
|
| 862 |
+
stft_hop_length=512,
|
| 863 |
+
stft_win_length=2048,
|
| 864 |
+
stft_normalized=False,
|
| 865 |
+
stft_window_fn: Optional[Callable] = None,
|
| 866 |
+
mask_estimator_depth=2,
|
| 867 |
+
multi_stft_resolution_loss_weight=1.0,
|
| 868 |
+
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (
|
| 869 |
+
4096,
|
| 870 |
+
2048,
|
| 871 |
+
1024,
|
| 872 |
+
512,
|
| 873 |
+
256,
|
| 874 |
+
),
|
| 875 |
+
multi_stft_hop_size=147,
|
| 876 |
+
multi_stft_normalized=False,
|
| 877 |
+
multi_stft_window_fn: Callable = torch.hann_window,
|
| 878 |
+
mlp_expansion_factor=4,
|
| 879 |
+
use_torch_checkpoint=False,
|
| 880 |
+
skip_connection=False,
|
| 881 |
+
sage_attention=False,
|
| 882 |
+
):
|
| 883 |
+
super().__init__()
|
| 884 |
+
|
| 885 |
+
self.stereo = stereo
|
| 886 |
+
self.audio_channels = 2 if stereo else 1
|
| 887 |
+
self.num_stems = num_stems
|
| 888 |
+
self.use_torch_checkpoint = use_torch_checkpoint
|
| 889 |
+
self.skip_connection = skip_connection
|
| 890 |
+
|
| 891 |
+
self.layers = ModuleList([])
|
| 892 |
+
|
| 893 |
+
if sage_attention:
|
| 894 |
+
print("Use Sage Attention")
|
| 895 |
+
|
| 896 |
+
transformer_kwargs = dict(
|
| 897 |
+
dim=dim,
|
| 898 |
+
heads=heads,
|
| 899 |
+
dim_head=dim_head,
|
| 900 |
+
attn_dropout=attn_dropout,
|
| 901 |
+
ff_dropout=ff_dropout,
|
| 902 |
+
flash_attn=flash_attn,
|
| 903 |
+
norm_output=False,
|
| 904 |
+
sage_attention=sage_attention,
|
| 905 |
+
)
|
| 906 |
+
|
| 907 |
+
time_rotary_embed = RotaryEmbedding(dim=dim_head)
|
| 908 |
+
freq_rotary_embed = RotaryEmbedding(dim=dim_head)
|
| 909 |
+
|
| 910 |
+
for _ in range(depth):
|
| 911 |
+
tran_modules = []
|
| 912 |
+
tran_modules.append(
|
| 913 |
+
Transformer(
|
| 914 |
+
depth=time_transformer_depth,
|
| 915 |
+
rotary_embed=time_rotary_embed,
|
| 916 |
+
**transformer_kwargs,
|
| 917 |
+
)
|
| 918 |
+
)
|
| 919 |
+
tran_modules.append(
|
| 920 |
+
Transformer(
|
| 921 |
+
depth=freq_transformer_depth,
|
| 922 |
+
rotary_embed=freq_rotary_embed,
|
| 923 |
+
**transformer_kwargs,
|
| 924 |
+
)
|
| 925 |
+
)
|
| 926 |
+
self.layers.append(nn.ModuleList(tran_modules))
|
| 927 |
+
|
| 928 |
+
self.final_norm = RMSNorm(dim)
|
| 929 |
+
|
| 930 |
+
self.stft_kwargs = dict(
|
| 931 |
+
n_fft=stft_n_fft,
|
| 932 |
+
hop_length=stft_hop_length,
|
| 933 |
+
win_length=stft_win_length,
|
| 934 |
+
normalized=stft_normalized,
|
| 935 |
+
)
|
| 936 |
+
|
| 937 |
+
self.stft_window_fn = partial(
|
| 938 |
+
default(stft_window_fn, torch.hann_window), stft_win_length
|
| 939 |
+
)
|
| 940 |
+
|
| 941 |
+
freqs = torch.stft(
|
| 942 |
+
torch.randn(1, 4096),
|
| 943 |
+
**self.stft_kwargs,
|
| 944 |
+
window=torch.ones(stft_win_length),
|
| 945 |
+
return_complex=True,
|
| 946 |
+
).shape[1]
|
| 947 |
+
|
| 948 |
+
assert len(freqs_per_bands) > 1
|
| 949 |
+
assert (
|
| 950 |
+
sum(freqs_per_bands) == freqs
|
| 951 |
+
), f"the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}"
|
| 952 |
+
|
| 953 |
+
freqs_per_bands_with_complex = tuple(
|
| 954 |
+
2 * f * self.audio_channels for f in freqs_per_bands
|
| 955 |
+
)
|
| 956 |
+
|
| 957 |
+
self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
|
| 958 |
+
|
| 959 |
+
self.mask_estimators = nn.ModuleList([])
|
| 960 |
+
|
| 961 |
+
for _ in range(num_stems):
|
| 962 |
+
mask_estimator = MaskEstimator(
|
| 963 |
+
dim=dim,
|
| 964 |
+
dim_inputs=freqs_per_bands_with_complex,
|
| 965 |
+
depth=mask_estimator_depth,
|
| 966 |
+
mlp_expansion_factor=mlp_expansion_factor,
|
| 967 |
+
)
|
| 968 |
+
|
| 969 |
+
self.mask_estimators.append(mask_estimator)
|
| 970 |
+
|
| 971 |
+
self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
|
| 972 |
+
self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
|
| 973 |
+
self.multi_stft_n_fft = stft_n_fft
|
| 974 |
+
self.multi_stft_window_fn = multi_stft_window_fn
|
| 975 |
+
|
| 976 |
+
self.multi_stft_kwargs = dict(
|
| 977 |
+
hop_length=multi_stft_hop_size, normalized=multi_stft_normalized
|
| 978 |
+
)
|
| 979 |
+
|
| 980 |
+
def forward(self, raw_audio, target=None, return_loss_breakdown=False):
|
| 981 |
+
|
| 982 |
+
device = raw_audio.device
|
| 983 |
+
|
| 984 |
+
x_is_mps = True if device.type == "mps" else False
|
| 985 |
+
|
| 986 |
+
if raw_audio.ndim == 2:
|
| 987 |
+
raw_audio = rearrange(raw_audio, "b t -> b 1 t")
|
| 988 |
+
|
| 989 |
+
channels = raw_audio.shape[1]
|
| 990 |
+
assert (not self.stereo and channels == 1) or (
|
| 991 |
+
self.stereo and channels == 2
|
| 992 |
+
), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
|
| 993 |
+
|
| 994 |
+
raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
|
| 995 |
+
|
| 996 |
+
stft_window = self.stft_window_fn(device=device)
|
| 997 |
+
|
| 998 |
+
try:
|
| 999 |
+
stft_repr = torch.stft(
|
| 1000 |
+
raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True
|
| 1001 |
+
)
|
| 1002 |
+
except:
|
| 1003 |
+
stft_repr = torch.stft(
|
| 1004 |
+
raw_audio.cpu() if x_is_mps else raw_audio,
|
| 1005 |
+
**self.stft_kwargs,
|
| 1006 |
+
window=stft_window.cpu() if x_is_mps else stft_window,
|
| 1007 |
+
return_complex=True,
|
| 1008 |
+
).to(device)
|
| 1009 |
+
stft_repr = torch.view_as_real(stft_repr)
|
| 1010 |
+
|
| 1011 |
+
stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
|
| 1012 |
+
|
| 1013 |
+
stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c")
|
| 1014 |
+
|
| 1015 |
+
x = rearrange(stft_repr, "b f t c -> b t (f c)")
|
| 1016 |
+
|
| 1017 |
+
x = self.band_split(x)
|
| 1018 |
+
|
| 1019 |
+
for i, transformer_block in enumerate(self.layers):
|
| 1020 |
+
|
| 1021 |
+
time_transformer, freq_transformer = transformer_block
|
| 1022 |
+
|
| 1023 |
+
x = rearrange(x, "b t f d -> b f t d")
|
| 1024 |
+
x, ps = pack([x], "* t d")
|
| 1025 |
+
|
| 1026 |
+
x = time_transformer(x)
|
| 1027 |
+
|
| 1028 |
+
(x,) = unpack(x, ps, "* t d")
|
| 1029 |
+
x = rearrange(x, "b f t d -> b t f d")
|
| 1030 |
+
x, ps = pack([x], "* f d")
|
| 1031 |
+
|
| 1032 |
+
x = freq_transformer(x)
|
| 1033 |
+
|
| 1034 |
+
(x,) = unpack(x, ps, "* f d")
|
| 1035 |
+
|
| 1036 |
+
x = self.final_norm(x)
|
| 1037 |
+
|
| 1038 |
+
num_stems = len(self.mask_estimators)
|
| 1039 |
+
|
| 1040 |
+
mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
|
| 1041 |
+
mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2)
|
| 1042 |
+
|
| 1043 |
+
stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
|
| 1044 |
+
|
| 1045 |
+
stft_repr = torch.view_as_complex(stft_repr)
|
| 1046 |
+
mask = torch.view_as_complex(mask)
|
| 1047 |
+
|
| 1048 |
+
stft_repr = stft_repr * mask
|
| 1049 |
+
|
| 1050 |
+
stft_repr = rearrange(
|
| 1051 |
+
stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels
|
| 1052 |
+
)
|
| 1053 |
+
|
| 1054 |
+
try:
|
| 1055 |
+
recon_audio = torch.istft(
|
| 1056 |
+
stft_repr,
|
| 1057 |
+
**self.stft_kwargs,
|
| 1058 |
+
window=stft_window,
|
| 1059 |
+
return_complex=False,
|
| 1060 |
+
length=raw_audio.shape[-1],
|
| 1061 |
+
)
|
| 1062 |
+
except:
|
| 1063 |
+
recon_audio = torch.istft(
|
| 1064 |
+
stft_repr.cpu() if x_is_mps else stft_repr,
|
| 1065 |
+
**self.stft_kwargs,
|
| 1066 |
+
window=stft_window.cpu() if x_is_mps else stft_window,
|
| 1067 |
+
return_complex=False,
|
| 1068 |
+
length=raw_audio.shape[-1],
|
| 1069 |
+
).to(device)
|
| 1070 |
+
|
| 1071 |
+
recon_audio = rearrange(
|
| 1072 |
+
recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=num_stems
|
| 1073 |
+
)
|
| 1074 |
+
|
| 1075 |
+
if num_stems == 1:
|
| 1076 |
+
recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
|
| 1077 |
+
|
| 1078 |
+
if not exists(target):
|
| 1079 |
+
return recon_audio
|
| 1080 |
+
|
| 1081 |
+
if self.num_stems > 1:
|
| 1082 |
+
assert target.ndim == 4 and target.shape[1] == self.num_stems
|
| 1083 |
+
|
| 1084 |
+
if target.ndim == 2:
|
| 1085 |
+
target = rearrange(target, "... t -> ... 1 t")
|
| 1086 |
+
|
| 1087 |
+
target = target[..., : recon_audio.shape[-1]]
|
| 1088 |
+
|
| 1089 |
+
loss = F.l1_loss(recon_audio, target)
|
| 1090 |
+
|
| 1091 |
+
multi_stft_resolution_loss = 0.0
|
| 1092 |
+
|
| 1093 |
+
for window_size in self.multi_stft_resolutions_window_sizes:
|
| 1094 |
+
res_stft_kwargs = dict(
|
| 1095 |
+
n_fft=max(window_size, self.multi_stft_n_fft),
|
| 1096 |
+
win_length=window_size,
|
| 1097 |
+
return_complex=True,
|
| 1098 |
+
window=self.multi_stft_window_fn(window_size, device=device),
|
| 1099 |
+
**self.multi_stft_kwargs,
|
| 1100 |
+
)
|
| 1101 |
+
|
| 1102 |
+
recon_Y = torch.stft(
|
| 1103 |
+
rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs
|
| 1104 |
+
)
|
| 1105 |
+
target_Y = torch.stft(
|
| 1106 |
+
rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs
|
| 1107 |
+
)
|
| 1108 |
+
|
| 1109 |
+
multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(
|
| 1110 |
+
recon_Y, target_Y
|
| 1111 |
+
)
|
| 1112 |
+
|
| 1113 |
+
weighted_multi_resolution_loss = (
|
| 1114 |
+
multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
|
| 1115 |
+
)
|
| 1116 |
+
|
| 1117 |
+
total_loss = loss + weighted_multi_resolution_loss
|
| 1118 |
+
|
| 1119 |
+
if not return_loss_breakdown:
|
| 1120 |
+
return total_loss
|
| 1121 |
+
|
| 1122 |
+
return total_loss, (loss, multi_stft_resolution_loss)
|