Upload 101 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- mvsepless/__init__.py +1579 -0
- mvsepless/__main__.py +67 -0
- mvsepless/audio.py +781 -0
- mvsepless/downloader.py +92 -0
- mvsepless/ensemble.py +224 -0
- mvsepless/infer.py +623 -0
- mvsepless/infer_utils.py +382 -0
- mvsepless/model_manager.py +540 -0
- mvsepless/models.json +0 -0
- mvsepless/models/bandit/core/__init__.py +691 -0
- mvsepless/models/bandit/core/data/__init__.py +2 -0
- mvsepless/models/bandit/core/data/_types.py +17 -0
- mvsepless/models/bandit/core/data/augmentation.py +102 -0
- mvsepless/models/bandit/core/data/augmented.py +34 -0
- mvsepless/models/bandit/core/data/base.py +60 -0
- mvsepless/models/bandit/core/data/dnr/__init__.py +0 -0
- mvsepless/models/bandit/core/data/dnr/datamodule.py +68 -0
- mvsepless/models/bandit/core/data/dnr/dataset.py +366 -0
- mvsepless/models/bandit/core/data/dnr/preprocess.py +51 -0
- mvsepless/models/bandit/core/data/musdb/__init__.py +0 -0
- mvsepless/models/bandit/core/data/musdb/datamodule.py +75 -0
- mvsepless/models/bandit/core/data/musdb/dataset.py +273 -0
- mvsepless/models/bandit/core/data/musdb/preprocess.py +226 -0
- mvsepless/models/bandit/core/data/musdb/validation.yaml +15 -0
- mvsepless/models/bandit/core/loss/__init__.py +8 -0
- mvsepless/models/bandit/core/loss/_complex.py +27 -0
- mvsepless/models/bandit/core/loss/_multistem.py +43 -0
- mvsepless/models/bandit/core/loss/_timefreq.py +95 -0
- mvsepless/models/bandit/core/loss/snr.py +139 -0
- mvsepless/models/bandit/core/metrics/__init__.py +9 -0
- mvsepless/models/bandit/core/metrics/_squim.py +443 -0
- mvsepless/models/bandit/core/metrics/snr.py +127 -0
- mvsepless/models/bandit/core/model/__init__.py +3 -0
- mvsepless/models/bandit/core/model/_spectral.py +54 -0
- mvsepless/models/bandit/core/model/bsrnn/__init__.py +23 -0
- mvsepless/models/bandit/core/model/bsrnn/bandsplit.py +135 -0
- mvsepless/models/bandit/core/model/bsrnn/core.py +651 -0
- mvsepless/models/bandit/core/model/bsrnn/maskestim.py +351 -0
- mvsepless/models/bandit/core/model/bsrnn/tfmodel.py +320 -0
- mvsepless/models/bandit/core/model/bsrnn/utils.py +525 -0
- mvsepless/models/bandit/core/model/bsrnn/wrapper.py +829 -0
- mvsepless/models/bandit/core/utils/__init__.py +0 -0
- mvsepless/models/bandit/core/utils/audio.py +412 -0
- mvsepless/models/bandit/model_from_config.py +26 -0
- mvsepless/models/bandit_v2/bandit.py +363 -0
- mvsepless/models/bandit_v2/bandsplit.py +130 -0
- mvsepless/models/bandit_v2/film.py +23 -0
- mvsepless/models/bandit_v2/maskestim.py +281 -0
- mvsepless/models/bandit_v2/tfmodel.py +145 -0
- mvsepless/models/bandit_v2/utils.py +523 -0
mvsepless/__init__.py
ADDED
|
@@ -0,0 +1,1579 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import shutil
|
| 4 |
+
import logging
|
| 5 |
+
import zipfile
|
| 6 |
+
import importlib.util
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
logging.basicConfig(level=logging.WARNING)
|
| 9 |
+
|
| 10 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 11 |
+
os.chdir(script_dir)
|
| 12 |
+
|
| 13 |
+
if not __package__:
|
| 14 |
+
from model_manager import MvseplessModelManager
|
| 15 |
+
from audio import Audio, Inverter
|
| 16 |
+
from namer import Namer
|
| 17 |
+
from vbach_infer import vbach_inference, model_manager as VbachModel
|
| 18 |
+
from downloader import dw_file, dw_yt_dlp
|
| 19 |
+
from ensemble import ensemble_audio_files
|
| 20 |
+
else:
|
| 21 |
+
from .model_manager import MvseplessModelManager
|
| 22 |
+
from .audio import Audio, Inverter
|
| 23 |
+
from .namer import Namer
|
| 24 |
+
from .vbach_infer import vbach_inference, model_manager as VbachModel
|
| 25 |
+
from .downloader import dw_file, dw_yt_dlp
|
| 26 |
+
from .ensemble import ensemble_audio_files
|
| 27 |
+
from typing import Literal
|
| 28 |
+
import gradio as gr
|
| 29 |
+
import pandas as pd
|
| 30 |
+
import subprocess
|
| 31 |
+
import json
|
| 32 |
+
import threading
|
| 33 |
+
import queue
|
| 34 |
+
import time
|
| 35 |
+
import argparse
|
| 36 |
+
from datetime import datetime
|
| 37 |
+
import tempfile
|
| 38 |
+
import ast
|
| 39 |
+
|
| 40 |
+
class MVSEPLESS:
|
| 41 |
+
audio = Audio()
|
| 42 |
+
inverter = Inverter()
|
| 43 |
+
namer = Namer()
|
| 44 |
+
model_manager = MvseplessModelManager()
|
| 45 |
+
vbach_model_manager = VbachModel
|
| 46 |
+
|
| 47 |
+
class Separator(MVSEPLESS):
|
| 48 |
+
|
| 49 |
+
class OutputReader:
|
| 50 |
+
def __init__(self, debug=False):
|
| 51 |
+
self.debug = debug
|
| 52 |
+
|
| 53 |
+
def parse_json_line(self, line):
|
| 54 |
+
try:
|
| 55 |
+
return json.loads(line)
|
| 56 |
+
except json.JSONDecodeError:
|
| 57 |
+
return None
|
| 58 |
+
|
| 59 |
+
def reaction_line(self, line, progress, add_text):
|
| 60 |
+
_add_text = ""
|
| 61 |
+
if add_text != "" or add_text is not None:
|
| 62 |
+
_add_text = f"| {add_text}"
|
| 63 |
+
|
| 64 |
+
data = self.parse_json_line(line)
|
| 65 |
+
if data is None:
|
| 66 |
+
return None
|
| 67 |
+
elif "reading" in data:
|
| 68 |
+
progress(0.05, desc=f"Чтение файла {_add_text}")
|
| 69 |
+
print("Чтение файла")
|
| 70 |
+
return None
|
| 71 |
+
elif "processing" in data:
|
| 72 |
+
progress_a = data["processing"]
|
| 73 |
+
processed = progress_a.get("processed", 0)
|
| 74 |
+
total = progress_a.get("total", 1)
|
| 75 |
+
# Исправлено: убрано деление на ноль
|
| 76 |
+
if total > 0:
|
| 77 |
+
progress_ratio = min(0.89, 0.05 + (processed / total * 0.85)) # Оставляем место для этапа записи
|
| 78 |
+
percent = int((processed / total) * 100)
|
| 79 |
+
progress(progress_ratio, desc=f"Обработано: {percent}% {_add_text}")
|
| 80 |
+
print(f"\rОбработано: {percent}%", end="")
|
| 81 |
+
return None
|
| 82 |
+
elif "writing" in data:
|
| 83 |
+
progress(0.9, desc="Запись результатов")
|
| 84 |
+
print(f"\rЗапись в файл {data['writing']}", end="")
|
| 85 |
+
return None
|
| 86 |
+
elif "done" in data:
|
| 87 |
+
progress(1.0, desc=f"Завершено {_add_text}")
|
| 88 |
+
print("\rЗавершено", end="\n")
|
| 89 |
+
return data["done"]
|
| 90 |
+
elif "error" in data:
|
| 91 |
+
raise Exception(data["error"])
|
| 92 |
+
|
| 93 |
+
def read_stream_to_queue(self, stream, queue_obj, stream_name):
|
| 94 |
+
"""Чтение потока вывода подпроцесса и запись в очередь"""
|
| 95 |
+
try:
|
| 96 |
+
for line in iter(stream.readline, ''):
|
| 97 |
+
line = line.strip()
|
| 98 |
+
if line:
|
| 99 |
+
if self.debug:
|
| 100 |
+
print(f"[{stream_name}] {line}") # Отладочный вывод
|
| 101 |
+
queue_obj.put(line)
|
| 102 |
+
stream.close()
|
| 103 |
+
except Exception as e:
|
| 104 |
+
print(f"Error reading {stream_name}: {e}")
|
| 105 |
+
|
| 106 |
+
output_reader = OutputReader()
|
| 107 |
+
|
| 108 |
+
def separator_model_loader(self, model_type: str, model_name: str, mdx_denoise: bool, vr_aggr: bool, progress) -> tuple[int, str, str]:
|
| 109 |
+
|
| 110 |
+
if model_type in [
|
| 111 |
+
"mel_band_roformer",
|
| 112 |
+
"bs_roformer",
|
| 113 |
+
"mdx23c",
|
| 114 |
+
"mdxnet",
|
| 115 |
+
"vr",
|
| 116 |
+
"scnet",
|
| 117 |
+
"htdemucs",
|
| 118 |
+
"bandit",
|
| 119 |
+
"bandit_v2",
|
| 120 |
+
]:
|
| 121 |
+
info = self.model_manager.models_info[model_type].get(model_name, None)
|
| 122 |
+
if not info:
|
| 123 |
+
raise ValueError(f"Модель {model_name} не найдена для типа {model_type}")
|
| 124 |
+
|
| 125 |
+
id = self.model_manager.get_id(model_type, model_name)
|
| 126 |
+
conf, ckpt = self.model_manager.download_model(
|
| 127 |
+
self.model_manager.models_cache_dir,
|
| 128 |
+
model_name,
|
| 129 |
+
model_type,
|
| 130 |
+
info["checkpoint_url"],
|
| 131 |
+
info["config_url"],
|
| 132 |
+
)
|
| 133 |
+
if model_type != "htdemucs":
|
| 134 |
+
self.model_manager.conf_editor(conf, mdx_denoise, vr_aggr, model_type)
|
| 135 |
+
|
| 136 |
+
return id, conf, ckpt
|
| 137 |
+
|
| 138 |
+
else:
|
| 139 |
+
raise ValueError("Неподдерживаемый тип модели")
|
| 140 |
+
|
| 141 |
+
def separator_base(
|
| 142 |
+
self,
|
| 143 |
+
input_file: str,
|
| 144 |
+
output_dir: str,
|
| 145 |
+
model_type: Literal[
|
| 146 |
+
"mel_band_roformer",
|
| 147 |
+
"bs_roformer",
|
| 148 |
+
"mdx23c",
|
| 149 |
+
"mdxnet",
|
| 150 |
+
"scnet",
|
| 151 |
+
"htdemucs",
|
| 152 |
+
"bandit",
|
| 153 |
+
"bandit_v2",
|
| 154 |
+
] = "mel_band_roformer",
|
| 155 |
+
model_name: str = "Mel-Band-Roformer_Vocals_kimberley_jensen",
|
| 156 |
+
ext_inst: bool = True,
|
| 157 |
+
output_format: Literal["mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff"] = "mp3",
|
| 158 |
+
output_bitrate: str = "320k",
|
| 159 |
+
template: str = "NAME_(STEM)_MODEL",
|
| 160 |
+
selected_stems: list = None,
|
| 161 |
+
ckpt: str = None,
|
| 162 |
+
conf: str = None,
|
| 163 |
+
id: int = None,
|
| 164 |
+
progress: any = gr.Progress(track_tqdm=True),
|
| 165 |
+
add_text_progress: str = ""
|
| 166 |
+
) -> list[tuple[str, str]]:
|
| 167 |
+
|
| 168 |
+
if model_type in [
|
| 169 |
+
"mel_band_roformer",
|
| 170 |
+
"bs_roformer",
|
| 171 |
+
"mdx23c",
|
| 172 |
+
"mdxnet",
|
| 173 |
+
"vr",
|
| 174 |
+
"scnet",
|
| 175 |
+
"htdemucs",
|
| 176 |
+
"bandit",
|
| 177 |
+
"bandit_v2",
|
| 178 |
+
]:
|
| 179 |
+
|
| 180 |
+
cmd = [
|
| 181 |
+
os.sys.executable,
|
| 182 |
+
"-m",
|
| 183 |
+
"infer",
|
| 184 |
+
"--input", input_file,
|
| 185 |
+
"--store_dir", output_dir,
|
| 186 |
+
"--model_type", model_type,
|
| 187 |
+
"--model_name", model_name,
|
| 188 |
+
"--model_id", str(id),
|
| 189 |
+
"--config_path", conf,
|
| 190 |
+
"--start_check_point", ckpt,
|
| 191 |
+
"--output_format", output_format,
|
| 192 |
+
"--output_bitrate", str(output_bitrate),
|
| 193 |
+
"--template", template
|
| 194 |
+
]
|
| 195 |
+
if ext_inst:
|
| 196 |
+
cmd.append("--extract_instrumental")
|
| 197 |
+
if selected_stems:
|
| 198 |
+
cmd.append("--selected_instruments")
|
| 199 |
+
cmd.extend(selected_stems)
|
| 200 |
+
|
| 201 |
+
try:
|
| 202 |
+
process = subprocess.Popen(
|
| 203 |
+
cmd,
|
| 204 |
+
stdout=subprocess.PIPE,
|
| 205 |
+
stderr=subprocess.PIPE,
|
| 206 |
+
text=True,
|
| 207 |
+
bufsize=1,
|
| 208 |
+
universal_newlines=True,
|
| 209 |
+
encoding='utf-8',
|
| 210 |
+
errors='replace'
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# Создаем очереди для stdout и stderr
|
| 214 |
+
stdout_queue = queue.Queue()
|
| 215 |
+
stderr_queue = queue.Queue()
|
| 216 |
+
|
| 217 |
+
# Запускаем потоки для чтения stdout и stderr
|
| 218 |
+
stdout_thread = threading.Thread(
|
| 219 |
+
target=self.output_reader.read_stream_to_queue,
|
| 220 |
+
args=(process.stdout, stdout_queue, "stdout")
|
| 221 |
+
)
|
| 222 |
+
stderr_thread = threading.Thread(
|
| 223 |
+
target=self.output_reader.read_stream_to_queue,
|
| 224 |
+
args=(process.stderr, stderr_queue, "stderr")
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
stdout_thread.daemon = True
|
| 228 |
+
stderr_thread.daemon = True
|
| 229 |
+
|
| 230 |
+
stdout_thread.start()
|
| 231 |
+
stderr_thread.start()
|
| 232 |
+
|
| 233 |
+
results = {'output': None, 'error': None}
|
| 234 |
+
process_completed = False
|
| 235 |
+
|
| 236 |
+
# Основной цикл обработки сообщений
|
| 237 |
+
while not process_completed:
|
| 238 |
+
# Проверяем завершение процесса
|
| 239 |
+
if process.poll() is not None:
|
| 240 |
+
process_completed = True
|
| 241 |
+
|
| 242 |
+
# Обрабатываем сообщения из stdout
|
| 243 |
+
try:
|
| 244 |
+
stdout_line = stdout_queue.get_nowait()
|
| 245 |
+
result = self.output_reader.reaction_line(stdout_line, progress, add_text_progress)
|
| 246 |
+
if result is not None:
|
| 247 |
+
results['output'] = result
|
| 248 |
+
break
|
| 249 |
+
except queue.Empty:
|
| 250 |
+
pass
|
| 251 |
+
|
| 252 |
+
# Обрабатываем сообщения из stderr
|
| 253 |
+
try:
|
| 254 |
+
stderr_line = stderr_queue.get_nowait()
|
| 255 |
+
result = self.output_reader.reaction_line(stderr_line, progress, add_text_progress)
|
| 256 |
+
if result is not None:
|
| 257 |
+
results['output'] = result
|
| 258 |
+
break
|
| 259 |
+
except queue.Empty:
|
| 260 |
+
pass
|
| 261 |
+
|
| 262 |
+
# Если процесс еще работает, ждем немного перед следующей проверкой
|
| 263 |
+
if not process_completed:
|
| 264 |
+
time.sleep(0.1)
|
| 265 |
+
|
| 266 |
+
# Дополнительная обработка оставшихся сообщений после завершения процесса
|
| 267 |
+
for _ in range(10): # Проверяем несколько раз на случай задержек
|
| 268 |
+
try:
|
| 269 |
+
stdout_line = stdout_queue.get_nowait()
|
| 270 |
+
result = self.output_reader.reaction_line(stdout_line, progress, add_text_progress)
|
| 271 |
+
if result is not None:
|
| 272 |
+
results['output'] = result
|
| 273 |
+
break
|
| 274 |
+
except queue.Empty:
|
| 275 |
+
pass
|
| 276 |
+
|
| 277 |
+
try:
|
| 278 |
+
stderr_line = stderr_queue.get_nowait()
|
| 279 |
+
result = self.output_reader.reaction_line(stderr_line, progress, add_text_progress)
|
| 280 |
+
if result is not None:
|
| 281 |
+
results['output'] = result
|
| 282 |
+
break
|
| 283 |
+
except queue.Empty:
|
| 284 |
+
pass
|
| 285 |
+
|
| 286 |
+
time.sleep(0.1)
|
| 287 |
+
|
| 288 |
+
# Проверяем результаты
|
| 289 |
+
if results.get('error'):
|
| 290 |
+
raise Exception(results['error'])
|
| 291 |
+
|
| 292 |
+
if results.get('output'):
|
| 293 |
+
return results['output']
|
| 294 |
+
|
| 295 |
+
# Если процесс завершился с ошибкой
|
| 296 |
+
if process.returncode != 0:
|
| 297 |
+
# Пытаемся получить последние сообщения об ошибках
|
| 298 |
+
error_messages = []
|
| 299 |
+
try:
|
| 300 |
+
while True:
|
| 301 |
+
error_msg = stderr_queue.get_nowait()
|
| 302 |
+
error_messages.append(error_msg)
|
| 303 |
+
except queue.Empty:
|
| 304 |
+
pass
|
| 305 |
+
|
| 306 |
+
error_text = "\n".join(error_messages[-5:]) # Последние 5 сообщений
|
| 307 |
+
raise Exception(f"Процесс завершился с ошибкой. Код возврата: {process.returncode}. Сообщения об ошибках:\n{error_text}")
|
| 308 |
+
|
| 309 |
+
except Exception as e:
|
| 310 |
+
raise e
|
| 311 |
+
finally:
|
| 312 |
+
# Гарантируем завершение процесса
|
| 313 |
+
try:
|
| 314 |
+
if process.poll() is None:
|
| 315 |
+
process.terminate()
|
| 316 |
+
process.wait(timeout=5)
|
| 317 |
+
except:
|
| 318 |
+
try:
|
| 319 |
+
process.kill()
|
| 320 |
+
except:
|
| 321 |
+
pass
|
| 322 |
+
else:
|
| 323 |
+
raise ValueError("Неподдерживаемый тип модели")
|
| 324 |
+
|
| 325 |
+
def separate(
|
| 326 |
+
self,
|
| 327 |
+
input: str | list = None,
|
| 328 |
+
output_dir: str = None,
|
| 329 |
+
model_type: Literal[
|
| 330 |
+
"mel_band_roformer",
|
| 331 |
+
"bs_roformer",
|
| 332 |
+
"mdx23c",
|
| 333 |
+
"mdxnet",
|
| 334 |
+
"vr",
|
| 335 |
+
"scnet",
|
| 336 |
+
"htdemucs",
|
| 337 |
+
"bandit",
|
| 338 |
+
"bandit_v2",
|
| 339 |
+
] = "mel_band_roformer",
|
| 340 |
+
model_name: str = "Mel-Band-Roformer_Vocals_kimberley_jensen",
|
| 341 |
+
ext_inst: bool = True,
|
| 342 |
+
output_format: Literal["mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff"] = "mp3",
|
| 343 |
+
output_bitrate: str = "320k",
|
| 344 |
+
template: str = "NAME_(STEM)_MODEL",
|
| 345 |
+
selected_stems: list = None,
|
| 346 |
+
add_settings: dict = {"mdx_denoise": False, "vr_aggr": 5, "add_single_sep_text_progress": None},
|
| 347 |
+
progress: any = gr.Progress(track_tqdm=True)
|
| 348 |
+
) -> list[tuple[str, str]] | list[str, list[tuple[str, str]]]:
|
| 349 |
+
|
| 350 |
+
progress(0, desc="Начало обработки")
|
| 351 |
+
|
| 352 |
+
# Валидация параметров
|
| 353 |
+
if output_format not in self.audio.output_formats:
|
| 354 |
+
output_format = "flac"
|
| 355 |
+
|
| 356 |
+
if output_dir is None:
|
| 357 |
+
output_dir = os.getcwd()
|
| 358 |
+
|
| 359 |
+
if output_dir:
|
| 360 |
+
output_dir = os.path.abspath(output_dir)
|
| 361 |
+
|
| 362 |
+
if selected_stems is None:
|
| 363 |
+
selected_stems = []
|
| 364 |
+
|
| 365 |
+
if not input:
|
| 366 |
+
raise ValueError("Входной файл не указан")
|
| 367 |
+
|
| 368 |
+
if "STEM" not in template and template is not None:
|
| 369 |
+
template = template + "_STEM_"
|
| 370 |
+
if not template:
|
| 371 |
+
template = "mvsepless_NAME_(STEM)"
|
| 372 |
+
|
| 373 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 374 |
+
|
| 375 |
+
mdx_denoise = add_settings.get("mdx_denoise", False)
|
| 376 |
+
|
| 377 |
+
vr_aggr = add_settings.get("vr_aggr", 5)
|
| 378 |
+
|
| 379 |
+
add_progress_text_custom = add_settings.get("add_single_sep_text_progress", "")
|
| 380 |
+
|
| 381 |
+
id, conf, ckpt = self.separator_model_loader(model_type, model_name, mdx_denoise, vr_aggr, progress)
|
| 382 |
+
|
| 383 |
+
if isinstance(input, str):
|
| 384 |
+
if not os.path.exists(input):
|
| 385 |
+
raise ValueError(f"Входной файл не найден: {input}")
|
| 386 |
+
|
| 387 |
+
if not self.audio.check(input):
|
| 388 |
+
raise ValueError("Входной файл не содержит аудио")
|
| 389 |
+
|
| 390 |
+
basename = os.path.splitext(os.path.basename(input))[0]
|
| 391 |
+
seped = self.separator_base(input_file=input,
|
| 392 |
+
output_dir=output_dir,
|
| 393 |
+
model_type=model_type,
|
| 394 |
+
model_name=model_name,
|
| 395 |
+
ext_inst=ext_inst,
|
| 396 |
+
output_format=output_format,
|
| 397 |
+
output_bitrate=output_bitrate,
|
| 398 |
+
template=template,
|
| 399 |
+
selected_stems=selected_stems,
|
| 400 |
+
ckpt=ckpt,
|
| 401 |
+
conf=conf,
|
| 402 |
+
id=id,
|
| 403 |
+
progress=progress,
|
| 404 |
+
add_text_progress=add_progress_text_custom)
|
| 405 |
+
return seped
|
| 406 |
+
|
| 407 |
+
elif isinstance(input, list):
|
| 408 |
+
results = []
|
| 409 |
+
for i, f in enumerate(input, 1):
|
| 410 |
+
print(f"Файл {i} из {len(input)}: {f}")
|
| 411 |
+
if os.path.exists(f):
|
| 412 |
+
if self.audio.check(f):
|
| 413 |
+
basename = os.path.splitext(os.path.basename(f))[0]
|
| 414 |
+
seped = self.separator_base(input_file=f,
|
| 415 |
+
output_dir=output_dir,
|
| 416 |
+
model_type=model_type,
|
| 417 |
+
model_name=model_name,
|
| 418 |
+
ext_inst=ext_inst,
|
| 419 |
+
output_format=output_format,
|
| 420 |
+
output_bitrate=output_bitrate,
|
| 421 |
+
template=template,
|
| 422 |
+
selected_stems=selected_stems,
|
| 423 |
+
ckpt=ckpt,
|
| 424 |
+
conf=conf,
|
| 425 |
+
id=id,
|
| 426 |
+
progress=progress,
|
| 427 |
+
add_text_progress=f"({i} из {len(input)}) ")
|
| 428 |
+
results.append([basename, seped])
|
| 429 |
+
return results
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
def UI(self, output_base_dir, output_temp_dir_check):
|
| 433 |
+
default_mts = self.model_manager.get_mt()
|
| 434 |
+
default_mt = self.model_manager.get_mt()[0]
|
| 435 |
+
default_mns = self.model_manager.get_mn(default_mt)
|
| 436 |
+
default_mn = default_mns[0]
|
| 437 |
+
default_stems = self.model_manager.get_stems(default_mt, default_mn)
|
| 438 |
+
default_tgt_inst = self.model_manager.get_tgt_inst(default_mt, default_mn)
|
| 439 |
+
with gr.Blocks():
|
| 440 |
+
with gr.Row():
|
| 441 |
+
with gr.Column():
|
| 442 |
+
with gr.Group(visible=False) as add_inputs:
|
| 443 |
+
input_path = gr.Textbox(label="Путь к входному файлу", interactive=True)
|
| 444 |
+
add_inputs_btn = gr.Button("Добавить файл", variant="primary")
|
| 445 |
+
with gr.Group(visible=False) as add_inputs_from_url:
|
| 446 |
+
input_url = gr.Textbox(label="URL входного файла", interactive=True)
|
| 447 |
+
with gr.Row():
|
| 448 |
+
inputs_url_format = gr.Dropdown(label="Формат входного файла", interactive=True,
|
| 449 |
+
choices=self.audio.output_formats,
|
| 450 |
+
value="mp3", filterable=False)
|
| 451 |
+
inputs_url_bitrate = gr.Slider(label="Битрейт входного файла", minimum=64, maximum=512, step=32, value=320, interactive=True)
|
| 452 |
+
with gr.Row():
|
| 453 |
+
inputs_url_cookie = gr.UploadButton(label="Файл cookie (необязательно)", interactive=True, type="filepath", file_count="single", file_types=[".txt", ".cookies"], variant="secondary")
|
| 454 |
+
add_inputs_url_btn = gr.Button("Добавить файл", variant="primary")
|
| 455 |
+
with gr.Row(visible=True) as add_buttons_row:
|
| 456 |
+
add_path_btn = gr.Button("Добавить файл по пути", variant="secondary")
|
| 457 |
+
add_url_btn = gr.Button("Добавить файл по URL", variant="secondary")
|
| 458 |
+
with gr.Group():
|
| 459 |
+
input_audio = gr.File(label="Входные аудио", interactive=True, type="filepath", file_count="multiple", file_types=[f".{of}" for of in self.audio.input_formats])
|
| 460 |
+
sep_state = gr.Textbox(label="Состояние разделения", interactive=False, value="", visible=False)
|
| 461 |
+
status = gr.Textbox(container=False, lines=3, interactive=False, max_lines=3)
|
| 462 |
+
input_preview_check = gr.Checkbox(label="Показать плееры для входных аудио", interactive=True, value=False)
|
| 463 |
+
@gr.render(inputs=[input_preview_check, input_audio])
|
| 464 |
+
def show_input_players(preview, audios):
|
| 465 |
+
if preview:
|
| 466 |
+
if audios:
|
| 467 |
+
with gr.Group():
|
| 468 |
+
for file in audios:
|
| 469 |
+
gr.Audio(label=os.path.splitext(os.path.basename(file))[0], value=file, interactive=False, show_download_button=False, type="filepath")
|
| 470 |
+
with gr.Column():
|
| 471 |
+
with gr.Group():
|
| 472 |
+
with gr.Row():
|
| 473 |
+
model_type = gr.Dropdown(label="Тип модели", interactive=True, filterable=False,
|
| 474 |
+
choices=default_mts,
|
| 475 |
+
value=default_mt)
|
| 476 |
+
model_name = gr.Dropdown(label="Имя модели", interactive=True, filterable=False,
|
| 477 |
+
choices=default_mns, value=default_mn)
|
| 478 |
+
with gr.Group():
|
| 479 |
+
extract_instrumental = gr.Checkbox(label="Извлечь инструментал", interactive=True, value=True)
|
| 480 |
+
selected_stems = gr.CheckboxGroup(label="Выбранные стемы для разделения", interactive=False,
|
| 481 |
+
choices=default_stems, value=[])
|
| 482 |
+
|
| 483 |
+
with gr.Accordion(label="Дополнительные настройки", open=False):
|
| 484 |
+
vr_aggr_slider = gr.Slider(label="Сила подавления для VR моделей", minimum=-100, maximum=100, value=5, step=1)
|
| 485 |
+
mdx_denoise_check = gr.Checkbox(label="Включить шумоподавление для MDX-NET моделей (это повышает потребление памяти в два раза)", value=False)
|
| 486 |
+
with gr.Row():
|
| 487 |
+
output_format = gr.Dropdown(label="Формат выходного файла", interactive=True,
|
| 488 |
+
choices=self.audio.output_formats,
|
| 489 |
+
value="mp3", filterable=False)
|
| 490 |
+
output_bitrate = gr.Slider(label="Битрейт выходного файла", minimum=64, maximum=512, step=32, value=320, interactive=True)
|
| 491 |
+
|
| 492 |
+
template = gr.Textbox(label="Шаблон именования выходных файлов", interactive=True, value="NAME (STEM) MODEL", info="Используйте ключи: \nNAME - имя входного файла без расширения, \nSTEM - имя стема, \nMODEL - имя модели разделения")
|
| 493 |
+
separate_btn = gr.Button("Разделить", variant="primary")
|
| 494 |
+
|
| 495 |
+
@gr.render(inputs=[sep_state], triggers=[sep_state.change])
|
| 496 |
+
def players(state):
|
| 497 |
+
def create_archive_advanced(file_list, archive_name="archive.zip"):
|
| 498 |
+
"""
|
| 499 |
+
Создает архив с расширенной обработкой ошибок
|
| 500 |
+
"""
|
| 501 |
+
try:
|
| 502 |
+
print("Генерация ZIP-архива с результатами разделения...")
|
| 503 |
+
with zipfile.ZipFile(archive_name, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
| 504 |
+
successful_files = 0
|
| 505 |
+
|
| 506 |
+
for basename, stems in file_list:
|
| 507 |
+
for stem_name, stem_path in stems:
|
| 508 |
+
try:
|
| 509 |
+
if os.path.exists(stem_path) and os.path.isfile(stem_path):
|
| 510 |
+
basename_ = os.path.basename(stem_path)
|
| 511 |
+
zipf.write(stem_path, basename_)
|
| 512 |
+
successful_files += 1
|
| 513 |
+
print(f"✓ Добавлен: {stem_path} -> {basename}")
|
| 514 |
+
else:
|
| 515 |
+
print(f"✗ Файл не найден или не является файлом: {stem_path}")
|
| 516 |
+
|
| 517 |
+
except Exception as e:
|
| 518 |
+
print(f"✗ Ошибка при добавлении {stem_path}: {e}")
|
| 519 |
+
|
| 520 |
+
print(f"\nАрхив создан: {archive_name}")
|
| 521 |
+
print(f"Успешно добавлено файлов: {successful_files}")
|
| 522 |
+
return os.path.abspath(archive_name)
|
| 523 |
+
|
| 524 |
+
except Exception as e:
|
| 525 |
+
print(f"Ошибка при создании архива: {e}")
|
| 526 |
+
if state != "":
|
| 527 |
+
state_loaded = ast.literal_eval(state)
|
| 528 |
+
archive_stems = create_archive_advanced(state_loaded, os.path.join(tempfile.tempdir, f"mvsepless_output_{datetime.now().strftime('%Y%m%d_%H%M%S')}.zip"))
|
| 529 |
+
for basename, stems in state_loaded:
|
| 530 |
+
with gr.Group():
|
| 531 |
+
gr.Markdown(f"<h4><center>{basename}</center></h4>")
|
| 532 |
+
for stem_name, stem_path in stems:
|
| 533 |
+
with gr.Row(equal_height=True):
|
| 534 |
+
output_stem = gr.Audio(value=stem_path, label=stem_name, type="filepath", interactive=False, show_download_button=True, scale=15)
|
| 535 |
+
reuse_btn = gr.Button("Использовать снова", variant="secondary")
|
| 536 |
+
@reuse_btn.click(
|
| 537 |
+
inputs=[output_stem, input_audio],
|
| 538 |
+
outputs=input_audio
|
| 539 |
+
)
|
| 540 |
+
def reuse_fn(stem_audio, input_a):
|
| 541 |
+
if input_a is None:
|
| 542 |
+
input_a = []
|
| 543 |
+
if isinstance(input_a, str):
|
| 544 |
+
input_a = [input_a]
|
| 545 |
+
if os.path.exists(stem_audio):
|
| 546 |
+
if self.audio.check(stem_audio):
|
| 547 |
+
input_a.append(stem_audio)
|
| 548 |
+
return input_a
|
| 549 |
+
|
| 550 |
+
gr.DownloadButton(label="Скачать как ZIP", value=archive_stems, interactive=True)
|
| 551 |
+
|
| 552 |
+
@add_inputs_btn.click(
|
| 553 |
+
inputs=[input_path, input_audio],
|
| 554 |
+
outputs=[add_inputs, input_audio, add_buttons_row])
|
| 555 |
+
def add_inputs_fn(input_p, input_a):
|
| 556 |
+
if input_p and os.path.exists(input_p):
|
| 557 |
+
if input_a is None:
|
| 558 |
+
input_a = []
|
| 559 |
+
if isinstance(input_a, str):
|
| 560 |
+
input_a = [input_a]
|
| 561 |
+
if self.audio.check(input_p):
|
| 562 |
+
input_a.append(input_p)
|
| 563 |
+
return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True)
|
| 564 |
+
return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True)
|
| 565 |
+
|
| 566 |
+
@add_inputs_url_btn.click(
|
| 567 |
+
inputs=[input_url, input_audio, inputs_url_format, inputs_url_bitrate, inputs_url_cookie],
|
| 568 |
+
outputs=[add_inputs_from_url, input_audio, add_buttons_row])
|
| 569 |
+
def add_inputs_from_url_fn(input_u, input_a, fmt, br, cookie):
|
| 570 |
+
if input_u:
|
| 571 |
+
if input_a is None:
|
| 572 |
+
input_a = []
|
| 573 |
+
if isinstance(input_a, str):
|
| 574 |
+
input_a = [input_a]
|
| 575 |
+
downloaded_file = dw_yt_dlp(
|
| 576 |
+
url=input_u,
|
| 577 |
+
output_format=fmt,
|
| 578 |
+
output_bitrate=str(int(br)),
|
| 579 |
+
cookie=cookie
|
| 580 |
+
)
|
| 581 |
+
if downloaded_file and os.path.exists(downloaded_file):
|
| 582 |
+
if self.audio.check(downloaded_file):
|
| 583 |
+
input_a.append(downloaded_file)
|
| 584 |
+
return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True)
|
| 585 |
+
return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True)
|
| 586 |
+
|
| 587 |
+
add_path_btn.click(
|
| 588 |
+
lambda: (gr.update(visible=True), gr.update(visible=False)),
|
| 589 |
+
outputs=[add_inputs, add_buttons_row])
|
| 590 |
+
|
| 591 |
+
add_url_btn.click(
|
| 592 |
+
lambda: (gr.update(visible=True), gr.update(visible=False)),
|
| 593 |
+
outputs=[add_inputs_from_url, add_buttons_row])
|
| 594 |
+
|
| 595 |
+
model_type.change(lambda x: gr.update(choices=self.model_manager.get_mn(x), value=self.model_manager.get_mn(x)[0]),
|
| 596 |
+
inputs=model_type, outputs=model_name)
|
| 597 |
+
model_name.change(lambda mt, mn: (gr.update(choices=self.model_manager.get_stems(mt, mn), value=[], interactive=False if self.model_manager.get_tgt_inst(mt, mn) else True), gr.update(value=True if self.model_manager.get_tgt_inst(mt, mn) else False)), inputs=[model_type, model_name], outputs=[selected_stems, extract_instrumental])
|
| 598 |
+
output_format.change(lambda x: gr.update(visible=False if x in ["wav", "flac", "aiff"] else True), inputs=output_format, outputs=output_bitrate)
|
| 599 |
+
inputs_url_format.change(lambda x: gr.update(visible=False if x in ["wav", "flac", "aiff"] else True), inputs=inputs_url_format, outputs=inputs_url_bitrate)
|
| 600 |
+
@separate_btn.click(
|
| 601 |
+
inputs=[
|
| 602 |
+
input_audio, model_type, model_name,
|
| 603 |
+
extract_instrumental, output_format, output_bitrate,
|
| 604 |
+
template, selected_stems, output_base_dir, output_temp_dir_check, mdx_denoise_check, vr_aggr_slider
|
| 605 |
+
],
|
| 606 |
+
outputs=[sep_state, status],
|
| 607 |
+
show_progress="full"
|
| 608 |
+
)
|
| 609 |
+
def wrap(i, mt, mn, ei, of, ob, t, stems, o_dir, temp_save, mdx_denoise, vr_aggr, progress=gr.Progress(track_tqdm=True)):
|
| 610 |
+
if o_dir.strip() != "" and not temp_save:
|
| 611 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 612 |
+
o = os.path.join(o_dir, f"mvsepless_outputs_{timestamp}")
|
| 613 |
+
os.makedirs(o, exist_ok=True)
|
| 614 |
+
else:
|
| 615 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 616 |
+
o = tempfile.mkdtemp(prefix=f"mvsepless_outputs_{timestamp}_")
|
| 617 |
+
os.makedirs(o, exist_ok=True)
|
| 618 |
+
results = self.separate(i, o, mt, mn, ei, of, ob, t, stems, add_settings={"mdx_denoise": mdx_denoise, "vr_aggr": int(vr_aggr)}, progress=progress)
|
| 619 |
+
return str(results), ""
|
| 620 |
+
|
| 621 |
+
class AutoEnsembless(Separator):
|
| 622 |
+
|
| 623 |
+
class ModelManager(MVSEPLESS):
|
| 624 |
+
def __init__(self):
|
| 625 |
+
self.data = []
|
| 626 |
+
self.ensemble_methods = ("min_fft", "max_fft", "avg_fft", "median_fft")
|
| 627 |
+
self.ensemble_invert_methods_map = {"min_fft": "max_fft", "max_fft": "min_fft", "avg_fft": "avg_fft", "median_fft": "median_fft"}
|
| 628 |
+
self.dir_presets = os.path.join(tempfile.tempdir, "presets")
|
| 629 |
+
os.makedirs(self.dir_presets, exist_ok=True)
|
| 630 |
+
|
| 631 |
+
def save(self, name):
|
| 632 |
+
if not name:
|
| 633 |
+
name = "ensembless_preset"
|
| 634 |
+
filepath = os.path.join(self.dir_presets, f"{self.namer.short(self.namer.sanitize(name), length=50)}.json")
|
| 635 |
+
with open(filepath, "w") as f:
|
| 636 |
+
json.dump(self.data, f, indent=4, ensure_ascii=False)
|
| 637 |
+
return filepath
|
| 638 |
+
|
| 639 |
+
def load(self, filepath):
|
| 640 |
+
with open(filepath, "r") as f:
|
| 641 |
+
ensemble_data_temp = json.load(f)
|
| 642 |
+
self.data = []
|
| 643 |
+
for (mt, mn, s_stem, i_stem, weight) in ensemble_data_temp:
|
| 644 |
+
if {mt, mn} not in [{model[0], model[1]} for model in self.data]:
|
| 645 |
+
self.data.append((mt, mn, s_stem, i_stem, weight))
|
| 646 |
+
|
| 647 |
+
def add(self, mt, mn, s_stem, i_stem, weight):
|
| 648 |
+
if {mt, mn} not in [{model[0], model[1]} for model in self.data]:
|
| 649 |
+
if s_stem and i_stem:
|
| 650 |
+
self.data.append((mt, mn, s_stem, i_stem, weight))
|
| 651 |
+
|
| 652 |
+
def replace(self, mt, mn, s_stem, i_stem, weight, index=1):
|
| 653 |
+
if self.data:
|
| 654 |
+
len_data = len(self.data)
|
| 655 |
+
if index >= 1:
|
| 656 |
+
if index <= len_data:
|
| 657 |
+
self.data[index - 1] = (mt, mn, s_stem, i_stem, weight)
|
| 658 |
+
elif index == 0:
|
| 659 |
+
self.data[0] = (mt, mn, s_stem, i_stem, weight)
|
| 660 |
+
|
| 661 |
+
def remove(self, index=1):
|
| 662 |
+
if self.data:
|
| 663 |
+
len_data = len(self.data)
|
| 664 |
+
if index >= 1:
|
| 665 |
+
if index <= len_data:
|
| 666 |
+
del self.data[index - 1]
|
| 667 |
+
elif index == 0:
|
| 668 |
+
del self.data[0]
|
| 669 |
+
|
| 670 |
+
def clear(self):
|
| 671 |
+
self.data = []
|
| 672 |
+
|
| 673 |
+
def get_df(self):
|
| 674 |
+
if not self.data:
|
| 675 |
+
columns = ["#", "Имя модели", "Основной стем", "Инверсия", "Вес"]
|
| 676 |
+
return pd.DataFrame(columns=columns)
|
| 677 |
+
|
| 678 |
+
data = []
|
| 679 |
+
for i, model in enumerate(self.data):
|
| 680 |
+
data.append(
|
| 681 |
+
[
|
| 682 |
+
f"{i+1}",
|
| 683 |
+
model[1],
|
| 684 |
+
model[2],
|
| 685 |
+
model[3],
|
| 686 |
+
model[4],
|
| 687 |
+
]
|
| 688 |
+
)
|
| 689 |
+
columns = ["#", "Имя модели", "Основной стем", "Инверсия", "Вес"]
|
| 690 |
+
return pd.DataFrame(data, columns=columns)
|
| 691 |
+
|
| 692 |
+
def UI(self, output_base_dir, output_temp_dir_check):
|
| 693 |
+
ensemble_model_manager = self.ModelManager()
|
| 694 |
+
def get_stems(mt, mn):
|
| 695 |
+
stems = []
|
| 696 |
+
for stem in self.model_manager.get_stems(mt, mn):
|
| 697 |
+
stems.append(stem)
|
| 698 |
+
|
| 699 |
+
if not self.model_manager.get_tgt_inst(mt, mn):
|
| 700 |
+
if set(stems) == {"bass", "drums", "other", "vocals"} or set(stems) == {"bass", "drums", "other", "vocals", "piano", "guitar"}:
|
| 701 |
+
stems.append("instrumental +")
|
| 702 |
+
stems.append("instrumental -")
|
| 703 |
+
|
| 704 |
+
return stems
|
| 705 |
+
|
| 706 |
+
def get_invert_stems(mt, mn, s_stem):
|
| 707 |
+
orig_stems = []
|
| 708 |
+
stems = []
|
| 709 |
+
for stem in self.model_manager.get_stems(mt, mn):
|
| 710 |
+
orig_stems.append(stem)
|
| 711 |
+
|
| 712 |
+
for stem in orig_stems:
|
| 713 |
+
if stem != s_stem:
|
| 714 |
+
stems.append(stem)
|
| 715 |
+
|
| 716 |
+
if not self.model_manager.get_tgt_inst(mt, mn):
|
| 717 |
+
if len(orig_stems) > 2:
|
| 718 |
+
if s_stem not in ["instrumental +", "instrumental -"]:
|
| 719 |
+
stems.append("inverted +")
|
| 720 |
+
stems.append("inverted -")
|
| 721 |
+
|
| 722 |
+
return stems
|
| 723 |
+
|
| 724 |
+
default_model = {
|
| 725 |
+
"mt": self.model_manager.get_mt(),
|
| 726 |
+
"mn": self.model_manager.get_mn(self.model_manager.get_mt()[0]),
|
| 727 |
+
"stem": get_stems(
|
| 728 |
+
self.model_manager.get_mt()[0],
|
| 729 |
+
self.model_manager.get_mn(self.model_manager.get_mt()[0])[0],
|
| 730 |
+
),
|
| 731 |
+
"invert_stem": get_invert_stems(
|
| 732 |
+
self.model_manager.get_mt()[0],
|
| 733 |
+
self.model_manager.get_mn(self.model_manager.get_mt()[0])[0],
|
| 734 |
+
"vocals",
|
| 735 |
+
),
|
| 736 |
+
"weight": 1,
|
| 737 |
+
}
|
| 738 |
+
|
| 739 |
+
gr.Markdown("<h3>Пресет</h3>")
|
| 740 |
+
with gr.Group():
|
| 741 |
+
with gr.Row(equal_height=True):
|
| 742 |
+
export_preset_name = gr.Textbox(
|
| 743 |
+
label="Имя пресета",
|
| 744 |
+
interactive=True,
|
| 745 |
+
value="ensembless_preset", scale=9
|
| 746 |
+
)
|
| 747 |
+
export_btn = gr.DownloadButton("Экспорт", variant="secondary", scale=3, interactive=True)
|
| 748 |
+
import_btn = gr.UploadButton(
|
| 749 |
+
"Импорт", file_types=[".json"], file_count="single", scale=3, interactive=True
|
| 750 |
+
)
|
| 751 |
+
gr.Markdown("<h3>Ансамбль</h3>")
|
| 752 |
+
with gr.Row():
|
| 753 |
+
with gr.Column(scale=3): # логика добавлеия моделей
|
| 754 |
+
model_type = gr.Dropdown(label="Тип модели", choices=default_model["mt"], value=default_model["mt"][0], interactive=True, filterable=False)
|
| 755 |
+
model_name = gr.Dropdown(label="Имя модели", choices=default_model["mn"], value=default_model["mn"][0], interactive=True, filterable=False)
|
| 756 |
+
primary_stem = gr.Dropdown(label="Основной стем", choices=default_model["stem"], value=default_model["stem"][0], interactive=True, filterable=False)
|
| 757 |
+
secondary_stem = gr.Dropdown(label="Инверсия", choices=default_model["invert_stem"], value=default_model["invert_stem"][0], interactive=True, filterable=False)
|
| 758 |
+
weight = gr.Slider(label="Вес", minimum=0, maximum=10, step=0.01, value=1, interactive=True)
|
| 759 |
+
@model_type.change(
|
| 760 |
+
inputs=[model_type],
|
| 761 |
+
outputs=[model_name]
|
| 762 |
+
)
|
| 763 |
+
def update_model_names(mt):
|
| 764 |
+
model_names = self.model_manager.get_mn(mt)
|
| 765 |
+
new_mn = model_names[0] if model_names else ""
|
| 766 |
+
|
| 767 |
+
return gr.update(choices=model_names, value=new_mn)
|
| 768 |
+
@model_name.change(
|
| 769 |
+
inputs=[model_type, model_name],
|
| 770 |
+
outputs=[primary_stem, secondary_stem]
|
| 771 |
+
)
|
| 772 |
+
def update_stems_after_model_change(mt, mn):
|
| 773 |
+
stems = get_stems(mt, mn)
|
| 774 |
+
invert_stems = get_invert_stems(mt, mn, stems[0]) if stems else []
|
| 775 |
+
|
| 776 |
+
new_s_stem = stems[0] if stems else ""
|
| 777 |
+
new_i_stem = invert_stems[0] if invert_stems else ""
|
| 778 |
+
|
| 779 |
+
return (
|
| 780 |
+
gr.update(choices=stems, value=new_s_stem),
|
| 781 |
+
gr.update(choices=invert_stems, value=new_i_stem)
|
| 782 |
+
)
|
| 783 |
+
@primary_stem.change(
|
| 784 |
+
inputs=[model_type, model_name, primary_stem],
|
| 785 |
+
outputs=[secondary_stem]
|
| 786 |
+
)
|
| 787 |
+
def update_invert_stems(mt, mn, s_stem):
|
| 788 |
+
stems = get_invert_stems(mt, mn, s_stem)
|
| 789 |
+
new_i_stem = stems[0] if stems else ""
|
| 790 |
+
return gr.update(choices=stems, value=new_i_stem)
|
| 791 |
+
|
| 792 |
+
model_add_button = gr.Button("Добавить", interactive=True)
|
| 793 |
+
with gr.Column(scale=10):
|
| 794 |
+
df = gr.DataFrame(
|
| 795 |
+
value=ensemble_model_manager.get_df(),
|
| 796 |
+
headers=["#", "Имя модели", "Основной стем", "Инверсия", "Вес"],
|
| 797 |
+
datatype=["number", "str", "str", "str", "number"],
|
| 798 |
+
interactive=False
|
| 799 |
+
)
|
| 800 |
+
|
| 801 |
+
with gr.Group():
|
| 802 |
+
with gr.Row(equal_height=True):
|
| 803 |
+
with gr.Column():
|
| 804 |
+
model_index = gr.Number(label="Индекс модели", value=1, interactive=True)
|
| 805 |
+
model_clear_btn = gr.Button("Очистить", variant="stop", interactive=True)
|
| 806 |
+
with gr.Column():
|
| 807 |
+
model_replace_btn = gr.Button("Заменить", variant="primary", interactive=True)
|
| 808 |
+
model_delete_btn = gr.Button("Удалить", variant="stop", interactive=True)
|
| 809 |
+
|
| 810 |
+
@model_add_button.click(
|
| 811 |
+
inputs=[model_type, model_name, primary_stem, secondary_stem, weight],
|
| 812 |
+
outputs=df
|
| 813 |
+
)
|
| 814 |
+
def add_model_to_auto_ensemble(mt, mn, s_stem, i_stem, weight):
|
| 815 |
+
ensemble_model_manager.add(mt, mn, s_stem, i_stem, weight)
|
| 816 |
+
return ensemble_model_manager.get_df()
|
| 817 |
+
|
| 818 |
+
@model_replace_btn.click(
|
| 819 |
+
inputs=[model_type, model_name, primary_stem, secondary_stem, weight, model_index],
|
| 820 |
+
outputs=df
|
| 821 |
+
)
|
| 822 |
+
def replace_model_to_auto_ensemble(mt, mn, s_stem, i_stem, weight, index):
|
| 823 |
+
ensemble_model_manager.replace(mt, mn, s_stem, i_stem, weight, index)
|
| 824 |
+
return ensemble_model_manager.get_df()
|
| 825 |
+
|
| 826 |
+
@model_delete_btn.click(
|
| 827 |
+
inputs=[model_index],
|
| 828 |
+
outputs=df
|
| 829 |
+
)
|
| 830 |
+
def delete_model_to_auto_ensemble(index):
|
| 831 |
+
ensemble_model_manager.remove(index)
|
| 832 |
+
return ensemble_model_manager.get_df()
|
| 833 |
+
|
| 834 |
+
@model_clear_btn.click(
|
| 835 |
+
outputs=df
|
| 836 |
+
)
|
| 837 |
+
def clear_model_to_auto_ensemble():
|
| 838 |
+
ensemble_model_manager.clear()
|
| 839 |
+
return ensemble_model_manager.get_df()
|
| 840 |
+
|
| 841 |
+
gr.on(fn=ensemble_model_manager.get_df, outputs=df)
|
| 842 |
+
|
| 843 |
+
df.change(
|
| 844 |
+
fn=ensemble_model_manager.save,
|
| 845 |
+
inputs=export_preset_name,
|
| 846 |
+
outputs=export_btn
|
| 847 |
+
)
|
| 848 |
+
|
| 849 |
+
export_preset_name.change(
|
| 850 |
+
fn=ensemble_model_manager.save,
|
| 851 |
+
inputs=export_preset_name,
|
| 852 |
+
outputs=export_btn
|
| 853 |
+
)
|
| 854 |
+
|
| 855 |
+
@import_btn.upload(
|
| 856 |
+
inputs=import_btn,
|
| 857 |
+
outputs=df
|
| 858 |
+
)
|
| 859 |
+
def load_ensemble_preset(filepath):
|
| 860 |
+
ensemble_model_manager.load(filepath)
|
| 861 |
+
return ensemble_model_manager.get_df()
|
| 862 |
+
with gr.Row():
|
| 863 |
+
with gr.Column():
|
| 864 |
+
gr.Markdown("<h3>Входное аудио</h3>")
|
| 865 |
+
with gr.Group():
|
| 866 |
+
with gr.Group(visible=False) as add_inputs:
|
| 867 |
+
input_path = gr.Textbox(label="Путь к входному файлу", interactive=True)
|
| 868 |
+
add_inputs_btn = gr.Button("Загрузить файл", variant="primary")
|
| 869 |
+
with gr.Group(visible=False) as add_inputs_from_url:
|
| 870 |
+
input_url = gr.Textbox(label="URL входного файла", interactive=True)
|
| 871 |
+
with gr.Row():
|
| 872 |
+
inputs_url_format = gr.Dropdown(label="Формат входного файла", interactive=True,
|
| 873 |
+
choices=self.audio.output_formats,
|
| 874 |
+
value="mp3", filterable=False)
|
| 875 |
+
inputs_url_bitrate = gr.Slider(label="Битрейт входного файла", minimum=64, maximum=512, step=32, value=320, interactive=True)
|
| 876 |
+
with gr.Row():
|
| 877 |
+
inputs_url_cookie = gr.UploadButton(label="Файл cookie (необязательно)", interactive=True, type="filepath", file_count="single", file_types=[".txt", ".cookies"], variant="secondary")
|
| 878 |
+
add_inputs_url_btn = gr.Button("Загрузить файл", variant="primary")
|
| 879 |
+
with gr.Row(visible=True) as add_buttons_row:
|
| 880 |
+
add_path_btn = gr.Button("Загрузить файл по пути", variant="secondary")
|
| 881 |
+
add_url_btn = gr.Button("Загрузить файл по URL", variant="secondary")
|
| 882 |
+
with gr.Group():
|
| 883 |
+
input_audio = gr.File(label="Входное аудио", interactive=True, type="filepath", file_count="single", file_types=[f".{of}" for of in self.audio.input_formats])
|
| 884 |
+
with gr.Column():
|
| 885 |
+
gr.Markdown("<h3>Настройки</h3>")
|
| 886 |
+
with gr.Group():
|
| 887 |
+
method = gr.Dropdown(label="Алгоритм склеивания", choices=["min_fft", "max_fft", "avg_fft", "median_fft"], value="avg_fft", filterable=False)
|
| 888 |
+
invert_ensemble = gr.Checkbox(label="Инверсия ансамбля", interactive=True, value=False)
|
| 889 |
+
output_format = gr.Dropdown(label="Формат выходного файла", interactive=True,
|
| 890 |
+
choices=self.audio.output_formats,
|
| 891 |
+
value="mp3", filterable=False)
|
| 892 |
+
run_btn = gr.Button("Создать ансамбль", variant="primary", interactive=True)
|
| 893 |
+
|
| 894 |
+
with gr.Row():
|
| 895 |
+
with gr.Column():
|
| 896 |
+
gr.Markdown("<h3>Результаты</h3>")
|
| 897 |
+
output_audio = gr.Audio(label="Результат", type="filepath", interactive=False, show_download_button=True)
|
| 898 |
+
output_audio_wav = gr.Textbox(label="Результат в WAV", interactive=False, visible=False)
|
| 899 |
+
with gr.Group():
|
| 900 |
+
invert_method = gr.Radio(
|
| 901 |
+
choices=["waveform", "spectrogram"],
|
| 902 |
+
label="Метод создания инверсии",
|
| 903 |
+
value="waveform",
|
| 904 |
+
)
|
| 905 |
+
invert_btn = gr.Button("Инвертировать")
|
| 906 |
+
output_inverted_audio = gr.Audio(label="Инверсия", type="filepath", interactive=False, show_download_button=True)
|
| 907 |
+
@invert_btn.click(inputs=[input_audio, output_audio_wav, invert_method, output_format], outputs=[output_inverted_audio])
|
| 908 |
+
def invert_result_ensemble(input_file, output_file, method, out_format):
|
| 909 |
+
if input_file and output_file:
|
| 910 |
+
o_dir = os.path.dirname(output_file)
|
| 911 |
+
basename = os.path.splitext(os.path.basename(input_file))[0]
|
| 912 |
+
output_path = os.path.join(o_dir, f"ensembless_{self.namer.short(basename, length=50)}_{method}_invert.{out_format}")
|
| 913 |
+
inverted = self.inverter.process_audio(audio1_path=input_file, audio2_path=output_file, out_format=out_format, method=method, output_path=output_path)
|
| 914 |
+
return inverted
|
| 915 |
+
else:
|
| 916 |
+
return None
|
| 917 |
+
|
| 918 |
+
with gr.Column():
|
| 919 |
+
gr.Markdown("<h3>Исходники ансамбля (WAV)</h3>")
|
| 920 |
+
output_source_files = gr.Files(type="filepath", interactive=False, show_label=False)
|
| 921 |
+
output_source_preview_check = gr.Checkbox(label="Показать плееры для исходников ансамбля", interactive=True, value=False)
|
| 922 |
+
@gr.render(inputs=[output_source_preview_check, output_source_files])
|
| 923 |
+
def show_output_auto_ensemble_players(preview, audios):
|
| 924 |
+
if preview:
|
| 925 |
+
if audios:
|
| 926 |
+
with gr.Group():
|
| 927 |
+
for file in audios:
|
| 928 |
+
gr.Audio(label=os.path.splitext(os.path.basename(file))[0], value=file, interactive=False, show_download_button=False, type="filepath")
|
| 929 |
+
|
| 930 |
+
@run_btn.click(
|
| 931 |
+
inputs=[input_audio, method, output_format, invert_ensemble, output_base_dir, output_temp_dir_check],
|
| 932 |
+
outputs=[output_audio, output_audio_wav, output_inverted_audio, output_source_files]
|
| 933 |
+
)
|
| 934 |
+
def auto_ensemble_run(input_file, method, out_format, invert_ensemble, o_dir, temp_save, progress=gr.Progress(track_tqdm=True)):
|
| 935 |
+
ensemble_state = ensemble_model_manager.data
|
| 936 |
+
invert_methods_map = ensemble_model_manager.ensemble_invert_methods_map
|
| 937 |
+
if not input_file:
|
| 938 |
+
return None, None, None, None, []
|
| 939 |
+
if not os.path.exists(input_file):
|
| 940 |
+
return None, None, None, None, []
|
| 941 |
+
if not self.audio.check(input_file):
|
| 942 |
+
return None, None, None, None, []
|
| 943 |
+
if o_dir.strip() != "" and not temp_save:
|
| 944 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 945 |
+
o = os.path.join(o_dir, f"ensembless_outputs_{timestamp}")
|
| 946 |
+
os.makedirs(o, exist_ok=True)
|
| 947 |
+
else:
|
| 948 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 949 |
+
o = tempfile.mkdtemp(prefix=f"ensembless_outputs_{timestamp}_")
|
| 950 |
+
os.makedirs(o, exist_ok=True)
|
| 951 |
+
|
| 952 |
+
basename = os.path.splitext(os.path.basename(input_file))[0]
|
| 953 |
+
def invert_weights(weights):
|
| 954 |
+
total_weight = sum(weights)
|
| 955 |
+
return [total_weight - w for w in weights]
|
| 956 |
+
# print(json.dumps(ensemble_state, indent=4, ensure_ascii=False))
|
| 957 |
+
success_separations = [] # list[tuple[str, str, float]] [(s_stem, i_stem, weight)]
|
| 958 |
+
ensemble_sources_list = [] # list[str, str, str, ...]
|
| 959 |
+
if ensemble_state:
|
| 960 |
+
total_ensemble_models = len(ensemble_state)
|
| 961 |
+
for i, model in enumerate(ensemble_state, start=1):
|
| 962 |
+
|
| 963 |
+
ens_mt = model[0]
|
| 964 |
+
ens_mn = model[1]
|
| 965 |
+
ens_s_stem = model[2]
|
| 966 |
+
ens_i_stem = model[3]
|
| 967 |
+
weight = model[4]
|
| 968 |
+
|
| 969 |
+
s_stem = None #path to primary stem
|
| 970 |
+
i_stem = None #path to invert stem
|
| 971 |
+
|
| 972 |
+
try:
|
| 973 |
+
result_seped_auto_ensemble = self.separate(input=input_file, output_dir=os.path.join(o, ens_mn), model_type=ens_mt, model_name=ens_mn, ext_inst=True, template="NAME - MODEL - STEM", output_format="wav", add_settings={"add_single_sep_text_progress": f"{i} из {total_ensemble_models}"}, progress=progress)
|
| 974 |
+
if result_seped_auto_ensemble:
|
| 975 |
+
for stem, path in result_seped_auto_ensemble:
|
| 976 |
+
ensemble_sources_list.append(path)
|
| 977 |
+
if stem == ens_s_stem:
|
| 978 |
+
s_stem = path
|
| 979 |
+
elif stem == ens_i_stem:
|
| 980 |
+
i_stem = path
|
| 981 |
+
|
| 982 |
+
if invert_ensemble:
|
| 983 |
+
if not i_stem:
|
| 984 |
+
result_seped_auto_ensemble_invert = self.separate(input=input_file, output_dir=os.path.join(o, f"{ens_mn}_invert"), model_type=ens_mt, model_name=ens_mn, ext_inst=True, template="NAME - MODEL - STEM", output_format="wav", selected_stems=[ens_s_stem], add_settings={"add_single_sep_text_progress": f"{i} из {total_ensemble_models} (инверт.)"}, progress=progress)
|
| 985 |
+
if result_seped_auto_ensemble_invert:
|
| 986 |
+
for stem, path in result_seped_auto_ensemble:
|
| 987 |
+
if stem == ens_i_stem:
|
| 988 |
+
i_stem = path
|
| 989 |
+
ensemble_sources_list.append(path)
|
| 990 |
+
|
| 991 |
+
except Exception as e:
|
| 992 |
+
print(f"\nПроизошла ошибка при разделении: {e}")
|
| 993 |
+
progress(0, desc="Произошла ошибка при разделении, модель пропускается...")
|
| 994 |
+
continue
|
| 995 |
+
finally:
|
| 996 |
+
if s_stem:
|
| 997 |
+
success_separations.append((ens_mn, s_stem, i_stem, weight))
|
| 998 |
+
|
| 999 |
+
ensemble_sources_stems = []
|
| 1000 |
+
ensemble_sources_invert_stems = []
|
| 1001 |
+
weights = []
|
| 1002 |
+
|
| 1003 |
+
for out_mn, out_s_stem, out_i_stem, out_weight in success_separations:
|
| 1004 |
+
ensemble_sources_stems.append(out_s_stem)
|
| 1005 |
+
ensemble_sources_invert_stems.append(out_i_stem)
|
| 1006 |
+
weights.append(out_weight)
|
| 1007 |
+
|
| 1008 |
+
|
| 1009 |
+
auto_ensemble_invout_file = None
|
| 1010 |
+
auto_ensemble_invout_file_wav = None
|
| 1011 |
+
|
| 1012 |
+
auto_ensemble_output_name = f"ensembless_{self.namer.short(basename, length=50)}_{len(ensemble_sources_stems)}_{method}"
|
| 1013 |
+
auto_ensemble_inverted_output_name = f"ensembless_{self.namer.short(basename, length=50)}_{len(ensemble_sources_stems)}_{invert_methods_map[method]}_invert"
|
| 1014 |
+
auto_ensemble_out_file, auto_ensemble_out_file_wav = ensemble_audio_files(files=ensemble_sources_stems, weights=weights, output=os.path.join(o, auto_ensemble_output_name), ensemble_type=method, out_format=out_format, add_wav=True)
|
| 1015 |
+
|
| 1016 |
+
if invert_ensemble:
|
| 1017 |
+
auto_ensemble_invout_file, auto_ensemble_invout_file_wav = ensemble_audio_files(files=ensemble_sources_invert_stems, weights=invert_weights(weights), output=os.path.join(o, auto_ensemble_inverted_output_name), ensemble_type=invert_methods_map[method], out_format=out_format, add_wav=True)
|
| 1018 |
+
|
| 1019 |
+
return auto_ensemble_out_file, auto_ensemble_out_file_wav, auto_ensemble_invout_file, ensemble_sources_list
|
| 1020 |
+
|
| 1021 |
+
@add_inputs_btn.click(
|
| 1022 |
+
inputs=[input_path, input_audio],
|
| 1023 |
+
outputs=[add_inputs, input_audio, add_buttons_row])
|
| 1024 |
+
def add_inputs_fn(input_p, input_a):
|
| 1025 |
+
if input_p and os.path.exists(input_p):
|
| 1026 |
+
if input_a is None:
|
| 1027 |
+
input_a = None
|
| 1028 |
+
if self.audio.check(input_p):
|
| 1029 |
+
input_a = input_p
|
| 1030 |
+
return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True)
|
| 1031 |
+
return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True)
|
| 1032 |
+
|
| 1033 |
+
@add_inputs_url_btn.click(
|
| 1034 |
+
inputs=[input_url, input_audio, inputs_url_format, inputs_url_bitrate, inputs_url_cookie],
|
| 1035 |
+
outputs=[add_inputs_from_url, input_audio, add_buttons_row])
|
| 1036 |
+
def add_inputs_from_url_fn(input_u, input_a, fmt, br, cookie):
|
| 1037 |
+
if input_u:
|
| 1038 |
+
if input_a is None:
|
| 1039 |
+
input_a = None
|
| 1040 |
+
downloaded_file = dw_yt_dlp(
|
| 1041 |
+
url=input_u,
|
| 1042 |
+
output_format=fmt,
|
| 1043 |
+
output_bitrate=str(int(br)),
|
| 1044 |
+
cookie=cookie
|
| 1045 |
+
)
|
| 1046 |
+
if downloaded_file and os.path.exists(downloaded_file):
|
| 1047 |
+
if self.audio.check(downloaded_file):
|
| 1048 |
+
input_a = downloaded_file
|
| 1049 |
+
return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True)
|
| 1050 |
+
return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True)
|
| 1051 |
+
|
| 1052 |
+
add_path_btn.click(
|
| 1053 |
+
lambda: (gr.update(visible=True), gr.update(visible=False)),
|
| 1054 |
+
outputs=[add_inputs, add_buttons_row])
|
| 1055 |
+
|
| 1056 |
+
add_url_btn.click(
|
| 1057 |
+
lambda: (gr.update(visible=True), gr.update(visible=False)),
|
| 1058 |
+
outputs=[add_inputs_from_url, add_buttons_row])
|
| 1059 |
+
|
| 1060 |
+
inputs_url_format.change(lambda x: gr.update(visible=False if x in ["wav", "flac", "aiff"] else True), inputs=inputs_url_format, outputs=inputs_url_bitrate)
|
| 1061 |
+
|
| 1062 |
+
class ManualEnsembless(MVSEPLESS):
|
| 1063 |
+
def UI(self, output_base_dir, output_temp_dir_check):
|
| 1064 |
+
with gr.Row():
|
| 1065 |
+
with gr.Column():
|
| 1066 |
+
with gr.Group(visible=False) as add_ensemble_inputs:
|
| 1067 |
+
input_ensemble_path = gr.Textbox(label="Путь к входному файлу", interactive=True)
|
| 1068 |
+
add_ensemble_inputs_btn = gr.Button("Добавить файл", variant="primary")
|
| 1069 |
+
add_ensemble_path_btn = gr.Button("Добавить файл по пути", variant="secondary")
|
| 1070 |
+
input_ensemble_files = gr.File(label="Входное аудио", interactive=True, type="filepath", file_count="multiple", file_types=[f".{of}" for of in self.audio.input_formats])
|
| 1071 |
+
input_ensemble_preview_check = gr.Checkbox(label="Показать плееры для входных аудио", interactive=True, value=False)
|
| 1072 |
+
@gr.render(inputs=[input_ensemble_preview_check, input_ensemble_files])
|
| 1073 |
+
def show_input_ensemble_players(preview, audios):
|
| 1074 |
+
if preview:
|
| 1075 |
+
if audios:
|
| 1076 |
+
with gr.Group():
|
| 1077 |
+
for file in audios:
|
| 1078 |
+
gr.Audio(label=os.path.splitext(os.path.basename(file))[0], value=file, interactive=False, show_download_button=False, type="filepath")
|
| 1079 |
+
|
| 1080 |
+
with gr.Column():
|
| 1081 |
+
@gr.render(inputs=[input_ensemble_files])
|
| 1082 |
+
def input_ensemble_files_fn(input_files):
|
| 1083 |
+
check_ensemble_files_status = f"""Анализ входных файлов
|
| 1084 |
+
---"""
|
| 1085 |
+
hz_ = []
|
| 1086 |
+
err_list = []
|
| 1087 |
+
if input_files:
|
| 1088 |
+
for file in input_files:
|
| 1089 |
+
basename = os.path.splitext(os.path.basename(file))[0]
|
| 1090 |
+
if os.path.exists(file):
|
| 1091 |
+
if self.audio.check(file):
|
| 1092 |
+
info = self.audio.get_info(file)
|
| 1093 |
+
hz = info[0].get("sample_rate")
|
| 1094 |
+
check_ensemble_files_status += f"\n{basename} - {hz} hz"
|
| 1095 |
+
hz_.append(hz)
|
| 1096 |
+
else:
|
| 1097 |
+
check_ensemble_files_status += f"\n{basename} - Нет аудио"
|
| 1098 |
+
err_list.append(file)
|
| 1099 |
+
else:
|
| 1100 |
+
check_ensemble_files_status += f"\n{basename} - Файл не найден"
|
| 1101 |
+
err_list.append(file)
|
| 1102 |
+
|
| 1103 |
+
check_ensemble_files_result = f"Действительных файлов: {len(hz_)}"
|
| 1104 |
+
|
| 1105 |
+
all_same = True
|
| 1106 |
+
|
| 1107 |
+
common_rate = None
|
| 1108 |
+
|
| 1109 |
+
for hz_hz in hz_:
|
| 1110 |
+
if common_rate is None:
|
| 1111 |
+
common_rate = hz_hz
|
| 1112 |
+
elif common_rate != hz_hz:
|
| 1113 |
+
all_same = False
|
| 1114 |
+
|
| 1115 |
+
if hz_ and len(hz_) > 1:
|
| 1116 |
+
check_ensemble_files_result += "\nВсе действительные файлы имеют одинаковую частоту дискретизации" if all_same else "\nОшибка! Все действительные файлы имеют РАЗНУЮ частоту дискретизации"
|
| 1117 |
+
else:
|
| 1118 |
+
check_ensemble_files_result += "\nДля создания ансамбля нужно загрузить, как минимум - 2 файла, содержащие аудио"
|
| 1119 |
+
|
| 1120 |
+
|
| 1121 |
+
check_ensemble_files_status += f"\n \n{check_ensemble_files_result}"
|
| 1122 |
+
|
| 1123 |
+
gr.Textbox(container=False, lines=len(check_ensemble_files_status.split("\n")), interactive=False, value=check_ensemble_files_status)
|
| 1124 |
+
|
| 1125 |
+
weights = gr.Textbox(label="Веса", value="1.0,1.0")
|
| 1126 |
+
|
| 1127 |
+
method = gr.Dropdown(label="Алгоритм склеивания", choices=["min_fft", "max_fft", "avg_fft", "median_fft"], value="avg_fft", filterable=False)
|
| 1128 |
+
|
| 1129 |
+
output_format = gr.Dropdown(label="Формат выходного файла", interactive=True,
|
| 1130 |
+
choices=self.audio.output_formats,
|
| 1131 |
+
value="mp3", filterable=False)
|
| 1132 |
+
|
| 1133 |
+
output_manual_ensemble_filename = gr.Textbox(label="Имя выходного файла", value="ensemble", interactive=True)
|
| 1134 |
+
|
| 1135 |
+
make_manual_ensemble_btn = gr.Button(value="Создать ансамбль", variant="primary")
|
| 1136 |
+
|
| 1137 |
+
manual_ensemble_output_audio = gr.Audio(label="Результат", type="filepath", interactive=False, show_download_button=True)
|
| 1138 |
+
|
| 1139 |
+
@make_manual_ensemble_btn.click(
|
| 1140 |
+
inputs=[input_ensemble_files, method, output_format, output_base_dir, output_temp_dir_check, output_manual_ensemble_filename, weights], outputs=manual_ensemble_output_audio
|
| 1141 |
+
)
|
| 1142 |
+
def make_manual_ensemble_fn(input_files_list, method, out_format, o_dir, temp_save, o_filename, weights: str):
|
| 1143 |
+
if o_dir.strip() != "" and not temp_save:
|
| 1144 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 1145 |
+
o = os.path.join(o_dir, f"ensembless_outputs_{timestamp}")
|
| 1146 |
+
os.makedirs(o, exist_ok=True)
|
| 1147 |
+
else:
|
| 1148 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 1149 |
+
o = tempfile.mkdtemp(prefix=f"ensembless_outputs_{timestamp}_")
|
| 1150 |
+
os.makedirs(o, exist_ok=True)
|
| 1151 |
+
|
| 1152 |
+
o_filename = self.namer.sanitize(o_filename)
|
| 1153 |
+
o_filename = self.namer.short(o_filename)
|
| 1154 |
+
|
| 1155 |
+
output_file = ensemble_audio_files(files=input_files_list, output=os.path.join(o, o_filename), weights=[float(x) for x in weights.split(",")], ensemble_type=method, out_format=out_format)
|
| 1156 |
+
return output_file
|
| 1157 |
+
|
| 1158 |
+
add_ensemble_path_btn.click(
|
| 1159 |
+
lambda: (gr.update(visible=True), gr.update(visible=False)),
|
| 1160 |
+
outputs=[add_ensemble_inputs, add_ensemble_path_btn])
|
| 1161 |
+
|
| 1162 |
+
@add_ensemble_inputs_btn.click(
|
| 1163 |
+
inputs=[input_ensemble_path, input_ensemble_files],
|
| 1164 |
+
outputs=[add_ensemble_inputs, input_ensemble_files, add_ensemble_path_btn])
|
| 1165 |
+
def add_ensemble_inputs_fn(input_p, input_a):
|
| 1166 |
+
if input_p and os.path.exists(input_p):
|
| 1167 |
+
if input_a is None:
|
| 1168 |
+
input_a = []
|
| 1169 |
+
if isinstance(input_a, str):
|
| 1170 |
+
input_a = [input_a]
|
| 1171 |
+
if self.audio.check(input_p):
|
| 1172 |
+
input_a.append(input_p)
|
| 1173 |
+
return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True)
|
| 1174 |
+
return gr.update(visible=False), gr.update(value=input_a), gr.update(visible=True)
|
| 1175 |
+
|
| 1176 |
+
class Inverter_UI(MVSEPLESS):
|
| 1177 |
+
def UI(self):
|
| 1178 |
+
with gr.Group():
|
| 1179 |
+
with gr.Row():
|
| 1180 |
+
original_audio = gr.File(label="Оригинал", interactive=True, type="filepath", file_count="single", file_types=[f".{of}" for of in self.audio.input_formats])
|
| 1181 |
+
stem_audio = gr.File(label="Cтем, который будет вычтен из оригинала", interactive=True, type="filepath", file_count="single", file_types=[f".{of}" for of in self.audio.input_formats])
|
| 1182 |
+
with gr.Group():
|
| 1183 |
+
output_format = gr.Dropdown(label="Формат выходного файла", interactive=True,
|
| 1184 |
+
choices=self.audio.output_formats,
|
| 1185 |
+
value="mp3", filterable=False)
|
| 1186 |
+
method = gr.Radio(
|
| 1187 |
+
choices=["waveform", "spectrogram"],
|
| 1188 |
+
label="Метод вычитания",
|
| 1189 |
+
value="waveform",
|
| 1190 |
+
)
|
| 1191 |
+
btn = gr.Button("Вычесть")
|
| 1192 |
+
output_audio = gr.Audio(label="Инверсия", type="filepath", interactive=False, show_download_button=True)
|
| 1193 |
+
@btn.click(inputs=[original_audio, stem_audio, method, output_format], outputs=[output_audio])
|
| 1194 |
+
def invert_result_ensemble(input_file, output_file, method, out_format):
|
| 1195 |
+
if input_file and output_file:
|
| 1196 |
+
o_dir = tempfile.mkdtemp(suffix="_inverter")
|
| 1197 |
+
basename = os.path.splitext(os.path.basename(input_file))[0]
|
| 1198 |
+
output_path = os.path.join(o_dir, f"inverter_{self.namer.short(basename, length=50)}_{method}.{out_format}")
|
| 1199 |
+
inverted = self.inverter.process_audio(audio1_path=input_file, audio2_path=output_file, out_format=out_format, method=method, output_path=output_path)
|
| 1200 |
+
return inverted
|
| 1201 |
+
else:
|
| 1202 |
+
return None
|
| 1203 |
+
|
| 1204 |
+
class Vbach(MVSEPLESS):
|
| 1205 |
+
pitch_methods = ("rmvpe+", "fcpe", "mangio-crepe")
|
| 1206 |
+
hop_length_values = (8, 512)
|
| 1207 |
+
index_rates_values = (0, 1)
|
| 1208 |
+
filter_radius_values = (0, 7)
|
| 1209 |
+
protect_values = (0, 0.5)
|
| 1210 |
+
rms_values = (0, 1)
|
| 1211 |
+
f0_min_values = (50, 3000)
|
| 1212 |
+
f0_max_values = (300, 6000)
|
| 1213 |
+
|
| 1214 |
+
def UI(self):
|
| 1215 |
+
with gr.Tab("Инференс"):
|
| 1216 |
+
with gr.Row():
|
| 1217 |
+
with gr.Column():
|
| 1218 |
+
with gr.Group():
|
| 1219 |
+
input_audio = gr.File(label="Входные аудио", interactive=True, type="filepath", file_count="multiple", file_types=[f".{of}" for of in self.audio.input_formats])
|
| 1220 |
+
converted_state = gr.Textbox(label="Состояние разделения", interactive=False, value="", visible=False)
|
| 1221 |
+
status = gr.Textbox(container=False, lines=3, interactive=False, max_lines=3)
|
| 1222 |
+
input_preview_check = gr.Checkbox(label="Показать плееры для входных аудио", interactive=True, value=False)
|
| 1223 |
+
@gr.render(inputs=[input_preview_check, input_audio])
|
| 1224 |
+
def show_input_players(preview, audios):
|
| 1225 |
+
if preview:
|
| 1226 |
+
if audios:
|
| 1227 |
+
with gr.Group():
|
| 1228 |
+
for file in audios:
|
| 1229 |
+
gr.Audio(label=os.path.splitext(os.path.basename(file))[0], value=file, interactive=False, show_download_button=False, type="filepath")
|
| 1230 |
+
|
| 1231 |
+
with gr.Column():
|
| 1232 |
+
with gr.Group():
|
| 1233 |
+
model_name = gr.Dropdown(label="Имя модели", interactive=True)
|
| 1234 |
+
model_list_refresh_btn = gr.Button("Обновить", variant="secondary", interactive=True)
|
| 1235 |
+
@model_list_refresh_btn.click(
|
| 1236 |
+
outputs=[model_name]
|
| 1237 |
+
)
|
| 1238 |
+
def refresh_list_voice_models():
|
| 1239 |
+
models = []
|
| 1240 |
+
models = self.vbach_model_manager.parse_voice_models()
|
| 1241 |
+
first_model = None
|
| 1242 |
+
if len(models) > 0:
|
| 1243 |
+
first_model = models[0]
|
| 1244 |
+
return gr.update(choices=models, value=first_model)
|
| 1245 |
+
with gr.Group():
|
| 1246 |
+
pitch_method = gr.Radio(label="Метод извлечения высоты тона", choices=self.pitch_methods, value=self.pitch_methods[0], interactive=True)
|
| 1247 |
+
pitch = gr.Slider(label="Высота тона", minimum=-48, maximum=48, step=0.5, value=0, interactive=True)
|
| 1248 |
+
hop_length = gr.Slider(label="Длина шага", info="Длина шага влияет на точность передачи высоты тона\nЧем меньше длина шага - тем точнее будет передана высота тона", minimum=self.hop_length_values[0], maximum=self.hop_length_values[1], step=8, value=128, interactive=True, visible=False)
|
| 1249 |
+
@pitch_method.change(
|
| 1250 |
+
inputs=[pitch_method],
|
| 1251 |
+
outputs=[hop_length]
|
| 1252 |
+
)
|
| 1253 |
+
def show_mangio_crepe_hop_length(pitch_method):
|
| 1254 |
+
return gr.update(visible=True if pitch_method in ["mangio-crepe"] else False)
|
| 1255 |
+
stereo_mode = gr.Radio(
|
| 1256 |
+
choices=["mono", "left/right", "sim/dif"],
|
| 1257 |
+
label="Стерео режим",
|
| 1258 |
+
info="mono - монофоническая обработка аудио, \nleft/right - обработка левого и правого каналов отдельно, \nsim/dif - обработка фантомного центра и стерео-базы, разделенную на левый и правый каналы",
|
| 1259 |
+
value="mono",
|
| 1260 |
+
interactive=True
|
| 1261 |
+
)
|
| 1262 |
+
with gr.Accordion(label="Дополнительные настройки",open=False):
|
| 1263 |
+
with gr.Group():
|
| 1264 |
+
with gr.Row():
|
| 1265 |
+
index_rate = gr.Slider(label="Влияние индекса", info="Чем ниже значение, тем больше голос похож на исходный; чем выше, тем ближе к модели", minimum=self.index_rates_values[0], maximum=self.index_rates_values[1], step=0.05, value=0, interactive=True)
|
| 1266 |
+
filter_radius = gr.Slider(label="Радиус фильтра", info="Сглаживает результаты извлечения тона\nМожет снизить дыхание и шумы на выходе", minimum=self.filter_radius_values[0], maximum=self.filter_radius_values[1], step=1, value=3, interactive=True)
|
| 1267 |
+
with gr.Row():
|
| 1268 |
+
rms = gr.Slider(label="Соотношение огибающих громкости", info="Значение 0 - огибающая громкости как у входного аудио, 1 - как у выходного сигнала", minimum=self.rms_values[0], maximum=self.rms_values[1], step=0.05, value=0.25, interactive=True)
|
| 1269 |
+
protect = gr.Slider(label="Защита согласных", info="Предовращает роботизацию дыхания и согласных (Может влиять на четкость речи)\nЗначение 0.5 - выключает защиту, 0 - максимальная защита", minimum=self.protect_values[0], maximum=self.protect_values[1], step=0.05, value=0.35, interactive=True)
|
| 1270 |
+
with gr.Group():
|
| 1271 |
+
with gr.Row():
|
| 1272 |
+
f0_min = gr.Slider(label="Нижний предел диапазона определения высоты тона", minimum=self.f0_min_values[0], maximum=self.f0_min_values[1], step=10, value=50, interactive=True)
|
| 1273 |
+
f0_max = gr.Slider(label="Верхний предел диапазона определения высоты тона", minimum=self.f0_max_values[0], maximum=self.f0_max_values[1], step=10, value=1100, interactive=True)
|
| 1274 |
+
|
| 1275 |
+
with gr.Group():
|
| 1276 |
+
output_name = gr.Textbox(label="Имя выходного файла", interactive=True, value="NAME - MODEL - F0METHOD - PITCH")
|
| 1277 |
+
format_output_name_check = gr.Checkbox(label="Форматировать имя", info="Используйте ключи: \nNAME - имя входного файла без расширения, \nPITCH - высота тона, \nF0METHOD - метод извлечения высота тона, \nMODEL - имя голосовой модели", value=True, interactive=True)
|
| 1278 |
+
output_format = gr.Dropdown(label="Формат выходного файла", interactive=True, choices=self.audio.output_formats, value=self.audio.output_formats[0], filterable=False)
|
| 1279 |
+
convert_btn = gr.Button("Преобразовать", variant="primary", interactive=True)
|
| 1280 |
+
|
| 1281 |
+
|
| 1282 |
+
@convert_btn.click(
|
| 1283 |
+
inputs=[
|
| 1284 |
+
input_audio,
|
| 1285 |
+
model_name,
|
| 1286 |
+
pitch_method,
|
| 1287 |
+
pitch,
|
| 1288 |
+
hop_length,
|
| 1289 |
+
index_rate,
|
| 1290 |
+
filter_radius,
|
| 1291 |
+
rms,
|
| 1292 |
+
protect,
|
| 1293 |
+
f0_min,
|
| 1294 |
+
f0_max,
|
| 1295 |
+
output_name,
|
| 1296 |
+
format_output_name_check,
|
| 1297 |
+
output_format,
|
| 1298 |
+
stereo_mode
|
| 1299 |
+
], outputs=[converted_state, status]
|
| 1300 |
+
)
|
| 1301 |
+
def vbach_convert_batch(ifl, mn, pm, p, hl, ir, fr, rms, pr, f0min, f0max, on, fn, of, sm):
|
| 1302 |
+
output_converted_files = []
|
| 1303 |
+
progress = gr.Progress()
|
| 1304 |
+
if ifl:
|
| 1305 |
+
for i, file in enumerate(ifl, start=1):
|
| 1306 |
+
try:
|
| 1307 |
+
print(f"Файл {i} из {len(ifl)}: {file}")
|
| 1308 |
+
progress(progress=(i / len(ifl)), desc=f"Файл {i} из {len(ifl)}")
|
| 1309 |
+
out_conv = vbach_inference(input_file=file, model_name=mn, output_dir=tempfile.mkdtemp(), output_name=on, format_name=True if len(ifl) > 1 else fn, output_format=of, pitch=p, method_pitch=pm, output_bitrate=320, add_params={ "index_rate": ir,"filter_radius": fr,"protect": pr,"rms": rms,"mangio_crepe_hop_length": hl,"f0_min": f0min,"f0_max": f0max,"stereo_mode": sm })
|
| 1310 |
+
output_converted_files.append(out_conv)
|
| 1311 |
+
except Exception as e:
|
| 1312 |
+
print(e)
|
| 1313 |
+
return str(output_converted_files), None
|
| 1314 |
+
|
| 1315 |
+
@gr.render(inputs=[converted_state])
|
| 1316 |
+
def show_players_converted(state):
|
| 1317 |
+
if state != "":
|
| 1318 |
+
output_converted_files = ast.literal_eval(state)
|
| 1319 |
+
if output_converted_files:
|
| 1320 |
+
with gr.Group():
|
| 1321 |
+
for conv_file in output_converted_files:
|
| 1322 |
+
basename = os.path.splitext(os.path.basename(conv_file))[0]
|
| 1323 |
+
gr.Audio(
|
| 1324 |
+
label=basename,
|
| 1325 |
+
value=conv_file,
|
| 1326 |
+
type="filepath",
|
| 1327 |
+
interactive=False,
|
| 1328 |
+
show_download_button=True,
|
| 1329 |
+
)
|
| 1330 |
+
|
| 1331 |
+
with gr.TabItem("Менеджер"):
|
| 1332 |
+
with gr.TabItem("Загрузить по ссылке"):
|
| 1333 |
+
with gr.TabItem("Через zip файл"):
|
| 1334 |
+
with gr.Row():
|
| 1335 |
+
with gr.Column(variant="panel"):
|
| 1336 |
+
url_zip = gr.Text(label="Ссылка на zip файл")
|
| 1337 |
+
with gr.Group():
|
| 1338 |
+
url_zip_model_name = gr.Text(
|
| 1339 |
+
label="Имя модели",
|
| 1340 |
+
)
|
| 1341 |
+
url_zip_download_btn = gr.Button("Загрузить", variant="primary")
|
| 1342 |
+
|
| 1343 |
+
url_zip_output = gr.Text(label="Статус", interactive=False, lines=5)
|
| 1344 |
+
url_zip_download_btn.click(
|
| 1345 |
+
(lambda x, y: self.vbach_model_manager.install_model_zip(x, self.namer.short(self.namer.sanitize(y), length=40), "url")),
|
| 1346 |
+
inputs=[url_zip, url_zip_model_name],
|
| 1347 |
+
outputs=url_zip_output,
|
| 1348 |
+
)
|
| 1349 |
+
|
| 1350 |
+
with gr.TabItem("Через отдельные файлы"):
|
| 1351 |
+
with gr.Row():
|
| 1352 |
+
with gr.Column(variant="panel"):
|
| 1353 |
+
url_pth = gr.Text(label="Ссылка на *.pth файл")
|
| 1354 |
+
url_index = gr.Text(label="Ссылка на *.index файл (необязательно)")
|
| 1355 |
+
with gr.Group():
|
| 1356 |
+
url_file_model_name = gr.Text(
|
| 1357 |
+
label="Имя модели",
|
| 1358 |
+
)
|
| 1359 |
+
url_file_download_btn = gr.Button("Загрузить", variant="primary")
|
| 1360 |
+
|
| 1361 |
+
url_file_output = gr.Text(label="Статус", interactive=False, lines=5)
|
| 1362 |
+
url_file_download_btn.click(
|
| 1363 |
+
(lambda x, y, z: self.vbach_model_manager.install_model_files(x, y, self.namer.short(self.namer.sanitize(z), length=40), "url")),
|
| 1364 |
+
inputs=[url_index, url_pth, url_file_model_name],
|
| 1365 |
+
outputs=url_file_output,
|
| 1366 |
+
)
|
| 1367 |
+
|
| 1368 |
+
with gr.Tab("Загрузить с устройства"):
|
| 1369 |
+
with gr.Tab("Через zip файл"):
|
| 1370 |
+
with gr.Row():
|
| 1371 |
+
with gr.Column():
|
| 1372 |
+
local_zip = gr.File(
|
| 1373 |
+
label="zip файл", file_types=[".zip"], file_count="single"
|
| 1374 |
+
)
|
| 1375 |
+
with gr.Column(variant="panel"):
|
| 1376 |
+
with gr.Group():
|
| 1377 |
+
local_zip_model_name = gr.Text(
|
| 1378 |
+
label="Имя модели",
|
| 1379 |
+
)
|
| 1380 |
+
local_zip_upload_btn = gr.Button("Загрузить", variant="primary")
|
| 1381 |
+
|
| 1382 |
+
local_zip_output = gr.Text(label="Статус", interactive=False, lines=5)
|
| 1383 |
+
local_zip_upload_btn.click(
|
| 1384 |
+
(lambda x, y: self.vbach_model_manager.install_model_zip(x, self.namer.short(self.namer.sanitize(y), length=40), "local")),
|
| 1385 |
+
inputs=[local_zip, local_zip_model_name],
|
| 1386 |
+
outputs=local_zip_output,
|
| 1387 |
+
)
|
| 1388 |
+
|
| 1389 |
+
with gr.TabItem("Через отдельные файлы"):
|
| 1390 |
+
with gr.Group():
|
| 1391 |
+
with gr.Row():
|
| 1392 |
+
local_pth = gr.File(
|
| 1393 |
+
label="*.pth файл", file_types=[".pth"], file_count="single"
|
| 1394 |
+
)
|
| 1395 |
+
local_index = gr.File(
|
| 1396 |
+
label="*.index файл (необязательно)", file_types=[".index"], file_count="single"
|
| 1397 |
+
)
|
| 1398 |
+
with gr.Column(variant="panel"):
|
| 1399 |
+
with gr.Group():
|
| 1400 |
+
local_file_model_name = gr.Text(
|
| 1401 |
+
label="Имя модели",
|
| 1402 |
+
)
|
| 1403 |
+
local_file_upload_btn = gr.Button("Загрузить", variant="primary")
|
| 1404 |
+
|
| 1405 |
+
local_file_output = gr.Text(
|
| 1406 |
+
label="Статус", interactive=False
|
| 1407 |
+
)
|
| 1408 |
+
local_file_upload_btn.click(
|
| 1409 |
+
(lambda x, y, z: self.vbach_model_manager.install_model_files(x, y, self.namer.short(self.namer.sanitize(z), length=40), "local")),
|
| 1410 |
+
inputs=[local_index, local_pth, local_file_model_name],
|
| 1411 |
+
outputs=local_file_output,
|
| 1412 |
+
)
|
| 1413 |
+
|
| 1414 |
+
with gr.TabItem("Удалить модель"):
|
| 1415 |
+
with gr.Column(variant="panel"):
|
| 1416 |
+
with gr.Group():
|
| 1417 |
+
delete_model_name = gr.Dropdown(
|
| 1418 |
+
label="Имя модели",
|
| 1419 |
+
choices=self.vbach_model_manager.parse_voice_models(),
|
| 1420 |
+
interactive=True,
|
| 1421 |
+
filterable=False
|
| 1422 |
+
)
|
| 1423 |
+
delete_refresh_btn = gr.Button("Обновить")
|
| 1424 |
+
@delete_refresh_btn.click(inputs=None, outputs=delete_model_name)
|
| 1425 |
+
def refresh_list_voice_models():
|
| 1426 |
+
models = []
|
| 1427 |
+
models = self.vbach_model_manager.parse_voice_models()
|
| 1428 |
+
first_model = None
|
| 1429 |
+
if len(models) > 0:
|
| 1430 |
+
first_model = models[0]
|
| 1431 |
+
return gr.update(choices=models, value=first_model)
|
| 1432 |
+
|
| 1433 |
+
delete_output = gr.Text(
|
| 1434 |
+
label="Статус", interactive=False, lines=5
|
| 1435 |
+
)
|
| 1436 |
+
delete_btn = gr.Button("Удалить")
|
| 1437 |
+
delete_btn.click(
|
| 1438 |
+
fn=self.vbach_model_manager.del_voice_model,
|
| 1439 |
+
inputs=delete_model_name,
|
| 1440 |
+
outputs=delete_output
|
| 1441 |
+
)
|
| 1442 |
+
|
| 1443 |
+
@gr.on(fn="decorator", inputs=None, outputs=[delete_model_name, model_name])
|
| 1444 |
+
def refresh_list_voice_models():
|
| 1445 |
+
models = []
|
| 1446 |
+
models = self.vbach_model_manager.parse_voice_models()
|
| 1447 |
+
first_model = None
|
| 1448 |
+
if len(models) > 0:
|
| 1449 |
+
first_model = models[0]
|
| 1450 |
+
return gr.update(choices=models, value=first_model), gr.update(choices=models, value=first_model)
|
| 1451 |
+
|
| 1452 |
+
|
| 1453 |
+
class PluginManager(Separator):
|
| 1454 |
+
plugins_dir = os.path.join(script_dir, "plugins")
|
| 1455 |
+
os.makedirs(plugins_dir, exist_ok=True)
|
| 1456 |
+
|
| 1457 |
+
def restart_after_install_plugin(self):
|
| 1458 |
+
subprocess.Popen([os.sys.executable] + sys.argv)
|
| 1459 |
+
os._exit(0)
|
| 1460 |
+
|
| 1461 |
+
def parse_plugins(self):
|
| 1462 |
+
for plugin_file in os.listdir(self.plugins_dir):
|
| 1463 |
+
# Пропускаем не-Python файлы и __init__.py
|
| 1464 |
+
if not plugin_file.endswith('.py') or plugin_file == '__init__.py':
|
| 1465 |
+
continue
|
| 1466 |
+
|
| 1467 |
+
# Получаем имя модуля без расширения
|
| 1468 |
+
plugin_module_name = os.path.splitext(plugin_file)[0]
|
| 1469 |
+
|
| 1470 |
+
try:
|
| 1471 |
+
# Определяем путь импорта в зависимости от структуры проекта
|
| 1472 |
+
if __package__:
|
| 1473 |
+
# Если мы в пакете, используем абсолютный импорт
|
| 1474 |
+
plugin_module = importlib.import_module(f".plugins.{plugin_module_name}", package=__package__)
|
| 1475 |
+
else:
|
| 1476 |
+
# Если не в пакете, пробуем разные варианты
|
| 1477 |
+
try:
|
| 1478 |
+
# Попробуем абсолютный импорт
|
| 1479 |
+
plugin_module = importlib.import_module(f"plugins.{plugin_module_name}")
|
| 1480 |
+
except ImportError:
|
| 1481 |
+
# Если не сработало, загружаем из файла
|
| 1482 |
+
plugin_path = os.path.join(self.plugins_dir, plugin_file)
|
| 1483 |
+
spec = importlib.util.spec_from_file_location(plugin_module_name, plugin_path)
|
| 1484 |
+
plugin_module = importlib.util.module_from_spec(spec)
|
| 1485 |
+
spec.loader.exec_module(plugin_module)
|
| 1486 |
+
|
| 1487 |
+
# Получаем класс Plugin из модуля
|
| 1488 |
+
plugin_class = getattr(plugin_module, 'Plugin')
|
| 1489 |
+
|
| 1490 |
+
# Создаем экземпляр плагина
|
| 1491 |
+
plugin_instance = plugin_class()
|
| 1492 |
+
|
| 1493 |
+
# Создаем UI плагина
|
| 1494 |
+
with gr.Tab(plugin_instance.name):
|
| 1495 |
+
plugin_instance.UI()
|
| 1496 |
+
|
| 1497 |
+
except Exception as e:
|
| 1498 |
+
print(f"Ошибка загрузки плагина {plugin_module_name}: {e}")
|
| 1499 |
+
continue
|
| 1500 |
+
|
| 1501 |
+
def UI(self):
|
| 1502 |
+
with gr.Tab("Установка"):
|
| 1503 |
+
upload_plugins_files = gr.File(label="Загрузить плагины", file_types=[".py"], file_count="multiple", interactive=True)
|
| 1504 |
+
install_plugins_btn = gr.Button("Установить", interactive=True)
|
| 1505 |
+
|
| 1506 |
+
@install_plugins_btn.click(
|
| 1507 |
+
inputs=[upload_plugins_files]
|
| 1508 |
+
)
|
| 1509 |
+
def upload_plugin_list(files):
|
| 1510 |
+
if not files:
|
| 1511 |
+
return
|
| 1512 |
+
for file in files:
|
| 1513 |
+
try:
|
| 1514 |
+
# Копируем только .py файлы
|
| 1515 |
+
if file.name.endswith('.py'):
|
| 1516 |
+
shutil.copy(
|
| 1517 |
+
file, os.path.join(self.plugins_dir, os.path.basename(file).replace(" ", "_"))
|
| 1518 |
+
)
|
| 1519 |
+
except Exception as e:
|
| 1520 |
+
print(f"Ошибка копирования файла {file}: {e}")
|
| 1521 |
+
time.sleep(2)
|
| 1522 |
+
self.restart_after_install_plugin()
|
| 1523 |
+
|
| 1524 |
+
self.parse_plugins()
|
| 1525 |
+
|
| 1526 |
+
def mvsepless_app(theme):
|
| 1527 |
+
css = None
|
| 1528 |
+
with gr.Blocks(theme=theme, css=css, title="Разделение музыки и вокала") as app:
|
| 1529 |
+
|
| 1530 |
+
output_base_dir_state = gr.State(value=os.path.join(os.getcwd(), "outputs"))
|
| 1531 |
+
output_temp_dir_state = gr.State(value=False)
|
| 1532 |
+
|
| 1533 |
+
with gr.Tab("Инференс"):
|
| 1534 |
+
Separator().UI(output_base_dir_state, output_temp_dir_state)
|
| 1535 |
+
|
| 1536 |
+
with gr.Tab("Ансамбль"):
|
| 1537 |
+
with gr.Tab("Авто-ансамбль"):
|
| 1538 |
+
AutoEnsembless().UI(output_base_dir_state, output_temp_dir_state)
|
| 1539 |
+
|
| 1540 |
+
with gr.Tab("Ручной ансамбль"):
|
| 1541 |
+
ManualEnsembless().UI(output_base_dir_state, output_temp_dir_state)
|
| 1542 |
+
|
| 1543 |
+
with gr.Tab("Вычитание"):
|
| 1544 |
+
Inverter_UI().UI()
|
| 1545 |
+
|
| 1546 |
+
with gr.Tab("Преобразование"):
|
| 1547 |
+
Vbach().UI()
|
| 1548 |
+
|
| 1549 |
+
with gr.Tab("Плагины"):
|
| 1550 |
+
PluginManager().UI()
|
| 1551 |
+
|
| 1552 |
+
with gr.Tab("Настройки"):
|
| 1553 |
+
with gr.Column():
|
| 1554 |
+
output_base_dir_ui = gr.Textbox(
|
| 1555 |
+
label="Базовый каталог для выходных файлов",
|
| 1556 |
+
interactive=True,
|
| 1557 |
+
value=os.path.join(os.getcwd(), "outputs"),
|
| 1558 |
+
lines=5
|
| 1559 |
+
)
|
| 1560 |
+
output_temp_dir_check_ui = gr.Checkbox(
|
| 1561 |
+
label="Использовать временный каталог для выходных файлов",
|
| 1562 |
+
interactive=True,
|
| 1563 |
+
value=False
|
| 1564 |
+
)
|
| 1565 |
+
|
| 1566 |
+
# Связываем UI элементы с состоянием
|
| 1567 |
+
output_base_dir_ui.change(
|
| 1568 |
+
lambda x: x,
|
| 1569 |
+
inputs=[output_base_dir_ui],
|
| 1570 |
+
outputs=[output_base_dir_state]
|
| 1571 |
+
)
|
| 1572 |
+
output_temp_dir_check_ui.change(
|
| 1573 |
+
lambda x: x,
|
| 1574 |
+
inputs=[output_temp_dir_check_ui],
|
| 1575 |
+
outputs=[output_temp_dir_state]
|
| 1576 |
+
)
|
| 1577 |
+
|
| 1578 |
+
return app
|
| 1579 |
+
|
mvsepless/__main__.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
if not __package__:
|
| 6 |
+
from __init__ import mvsepless_app, Separator
|
| 7 |
+
else:
|
| 8 |
+
from .__init__ import mvsepless_app, Separator
|
| 9 |
+
|
| 10 |
+
if __name__ == "__main__":
|
| 11 |
+
parser = argparse.ArgumentParser(description="MVSepless")
|
| 12 |
+
subparsers = parser.add_subparsers(dest="command")
|
| 13 |
+
app_parser = subparsers.add_parser("app", help="Приложение MVSepless")
|
| 14 |
+
app_parser.add_argument("--port", type=int, default=None, help="Порт для запуска сервера Gradio.")
|
| 15 |
+
app_parser.add_argument("--share", action="store_true", help="Создать публичную ссылку для приложения Gradio.")
|
| 16 |
+
cli_parser = subparsers.add_parser("cli", help="CLI MVSepless Lite")
|
| 17 |
+
cli_parser.add_argument("--input", type=str, required=True, help="Входной аудиофайл или каталог.")
|
| 18 |
+
cli_parser.add_argument("--output_dir", type=str, default=None, help="Каталог для выходных файлов.")
|
| 19 |
+
cli_parser.add_argument("--model_type", type=str, default="mel_band_roformer", help="Тип модели разделения.")
|
| 20 |
+
cli_parser.add_argument("--model_name", type=str, default="Mel-Band-Roformer_Vocals_kimberley_jensen", help="Имя модели разделения.")
|
| 21 |
+
cli_parser.add_argument("--ext_inst", action="store_true", help="Извлечь инструментал.")
|
| 22 |
+
cli_parser.add_argument("--output_format", type=str, default="mp3", choices=Separator.audio.output_formats, help="Формат выходного файла.")
|
| 23 |
+
cli_parser.add_argument("--output_bitrate", type=str, default="320k", help="Битрейт выходного файла.")
|
| 24 |
+
cli_parser.add_argument("--template", type=str, default="NAME (STEM) MODEL", help="Шаблон именования выходных файлов.")
|
| 25 |
+
cli_parser.add_argument("--selected_stems", type=str, nargs='*', default=None, help="Выбранные стемы для разделения.")
|
| 26 |
+
args = parser.parse_args()
|
| 27 |
+
|
| 28 |
+
if args.command == "app":
|
| 29 |
+
theme = gr.themes.Citrus(
|
| 30 |
+
primary_hue="teal",
|
| 31 |
+
secondary_hue="blue",
|
| 32 |
+
neutral_hue="blue",
|
| 33 |
+
spacing_size="sm",
|
| 34 |
+
font=[
|
| 35 |
+
gr.themes.GoogleFont("Montserrat"),
|
| 36 |
+
"ui-sans-serif",
|
| 37 |
+
"system-ui",
|
| 38 |
+
"sans-serif",
|
| 39 |
+
],
|
| 40 |
+
)
|
| 41 |
+
mvsepless_lite_app = mvsepless_app(theme=theme)
|
| 42 |
+
mvsepless_lite_app.launch(server_name="0.0.0.0", server_port=args.port, share=args.share, allowed_paths=["/"], debug=True)
|
| 43 |
+
elif args.command == "cli":
|
| 44 |
+
input_data = args.input
|
| 45 |
+
if os.path.isdir(input_data):
|
| 46 |
+
list_valid_files = []
|
| 47 |
+
for file in os.listdir(args.input):
|
| 48 |
+
if os.path.isfile(os.path.join(args.input, file)):
|
| 49 |
+
if Separator.audio.check(os.path.join(args.input, file)):
|
| 50 |
+
list_valid_files.append(os.path.join(args.input, file))
|
| 51 |
+
|
| 52 |
+
input_files = list_valid_files
|
| 53 |
+
else:
|
| 54 |
+
input_files = input_data
|
| 55 |
+
|
| 56 |
+
results = Separator().separate(
|
| 57 |
+
input=input_files,
|
| 58 |
+
output_dir=args.output_dir,
|
| 59 |
+
model_type=args.model_type,
|
| 60 |
+
model_name=args.model_name,
|
| 61 |
+
ext_inst=args.ext_inst,
|
| 62 |
+
output_format=args.output_format,
|
| 63 |
+
output_bitrate=args.output_bitrate,
|
| 64 |
+
template=args.template,
|
| 65 |
+
selected_stems=args.selected_stems
|
| 66 |
+
)
|
| 67 |
+
print("Разделение завершено.")
|
mvsepless/audio.py
ADDED
|
@@ -0,0 +1,781 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import sys
|
| 4 |
+
import json
|
| 5 |
+
import subprocess
|
| 6 |
+
import numpy as np
|
| 7 |
+
from typing import Literal
|
| 8 |
+
from collections.abc import Callable
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from numpy.typing import DTypeLike
|
| 11 |
+
import tempfile
|
| 12 |
+
import librosa
|
| 13 |
+
if not __package__:
|
| 14 |
+
from namer import Namer
|
| 15 |
+
else:
|
| 16 |
+
from .namer import Namer
|
| 17 |
+
class NotInputFileSpecified(Exception): pass
|
| 18 |
+
class NotOutputFileSpecified(Exception): pass
|
| 19 |
+
class NotSupportedDataType(Exception): pass
|
| 20 |
+
class ErrorDecode(Exception): pass
|
| 21 |
+
class ErrorEncode(Exception): pass
|
| 22 |
+
class NotSupportedFormat(Exception): pass
|
| 23 |
+
class SampleRateError(Exception): pass
|
| 24 |
+
class FileIsNotAudio(Exception): pass
|
| 25 |
+
|
| 26 |
+
class Audio(Namer):
|
| 27 |
+
def __init__(self):
|
| 28 |
+
"""
|
| 29 |
+
Чтение и запись аудио файла через ffmpeg
|
| 30 |
+
|
| 31 |
+
Поддерживаемые типы данных: - int16, int32, float32, float64
|
| 32 |
+
"""
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.ffmpeg_path = os.environ.get("MVSEPLESS_FFMPEG", "ffmpeg")
|
| 35 |
+
self.ffprobe_path = os.environ.get("MVSEPLESS_FFPROBE", "ffprobe")
|
| 36 |
+
self.output_formats = ("mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff")
|
| 37 |
+
self.input_formats = ("mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff", "mp4", "mkv", "webm", "avi", "mov", "ts")
|
| 38 |
+
self.supported_dtypes = ("int16", "int32", "float32", "float64")
|
| 39 |
+
self.dtypes_dict = {
|
| 40 |
+
"int16": "s16le",
|
| 41 |
+
"int32": "s32le",
|
| 42 |
+
"float32": "f32le",
|
| 43 |
+
"float64": "f64le",
|
| 44 |
+
np.int16: "s16le",
|
| 45 |
+
np.int32: "s32le",
|
| 46 |
+
np.float32: "f32le",
|
| 47 |
+
np.float64: "f64le",
|
| 48 |
+
}
|
| 49 |
+
self.bitrate_limit = {
|
| 50 |
+
"mp3": {"min": 8, "max": 320},
|
| 51 |
+
"aac": {"min": 8, "max": 512},
|
| 52 |
+
"m4a": {"min": 8, "max": 512},
|
| 53 |
+
"ac3": {"min": 32, "max": 640},
|
| 54 |
+
"ogg": {"min": 64, "max": 500},
|
| 55 |
+
"opus": {"min": 6, "max": 512},
|
| 56 |
+
}
|
| 57 |
+
self.sample_rates = {
|
| 58 |
+
"mp3": {
|
| 59 |
+
"supported": (44100, 48000, 32000, 22050, 24000, 16000, 11025, 12000, 8000)
|
| 60 |
+
},
|
| 61 |
+
"opus": {"supported": (48000, 24000, 16000, 12000, 8000)},
|
| 62 |
+
"m4a": {
|
| 63 |
+
"supported": (
|
| 64 |
+
96000,
|
| 65 |
+
88200,
|
| 66 |
+
64000,
|
| 67 |
+
48000,
|
| 68 |
+
44100,
|
| 69 |
+
32000,
|
| 70 |
+
24000,
|
| 71 |
+
22050,
|
| 72 |
+
16000,
|
| 73 |
+
12000,
|
| 74 |
+
11025,
|
| 75 |
+
8000,
|
| 76 |
+
7350,
|
| 77 |
+
)
|
| 78 |
+
},
|
| 79 |
+
"aac": {
|
| 80 |
+
"supported": (
|
| 81 |
+
96000,
|
| 82 |
+
88200,
|
| 83 |
+
64000,
|
| 84 |
+
48000,
|
| 85 |
+
44100,
|
| 86 |
+
32000,
|
| 87 |
+
24000,
|
| 88 |
+
22050,
|
| 89 |
+
16000,
|
| 90 |
+
12000,
|
| 91 |
+
11025,
|
| 92 |
+
8000,
|
| 93 |
+
7350,
|
| 94 |
+
)
|
| 95 |
+
},
|
| 96 |
+
"ac3": {
|
| 97 |
+
"supported": (
|
| 98 |
+
48000,
|
| 99 |
+
44100,
|
| 100 |
+
32000,
|
| 101 |
+
)
|
| 102 |
+
},
|
| 103 |
+
"ogg": {"min": 6, "max": 192000},
|
| 104 |
+
"wav": {"min": 0, "max": float("inf")},
|
| 105 |
+
"aiff": {"min": 0, "max": float("inf")},
|
| 106 |
+
"flac": {"min": 0, "max": 192000},
|
| 107 |
+
}
|
| 108 |
+
self.check_ffmpeg()
|
| 109 |
+
self.check_ffprobe()
|
| 110 |
+
|
| 111 |
+
def check_ffmpeg(self):
|
| 112 |
+
"""
|
| 113 |
+
Проверяет, установлен ли ffmpeg?
|
| 114 |
+
"""
|
| 115 |
+
try:
|
| 116 |
+
ffmpeg_version_output = subprocess.check_output(
|
| 117 |
+
[self.ffmpeg_path, "-version"], text=True
|
| 118 |
+
)
|
| 119 |
+
except FileNotFoundError:
|
| 120 |
+
if "PYTEST_CURRENT_TEST" not in os.environ:
|
| 121 |
+
raise FileNotFoundError("FFMPEG не установлен. Укажите путь к установленному FFMPEG через переменную окружения MVSEPLESS_FFMPEG")
|
| 122 |
+
|
| 123 |
+
def check_ffprobe(self):
|
| 124 |
+
"""
|
| 125 |
+
Проверяет, установлен ли ffprobe?
|
| 126 |
+
"""
|
| 127 |
+
try:
|
| 128 |
+
ffmpeg_version_output = subprocess.check_output(
|
| 129 |
+
[self.ffprobe_path, "-version"], text=True
|
| 130 |
+
)
|
| 131 |
+
except FileNotFoundError:
|
| 132 |
+
if "PYTEST_CURRENT_TEST" not in os.environ:
|
| 133 |
+
raise FileNotFoundError("FFPROBE не установлен. Укажите путь к установленному FFPROBE через переменную окружения MVSEPLESS_FFPROBE")
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def fit_sr(
|
| 137 |
+
self,
|
| 138 |
+
f: str | Literal["mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff"] = "mp3",
|
| 139 |
+
sr: int = 44100
|
| 140 |
+
) -> int:
|
| 141 |
+
"""
|
| 142 |
+
Исправляет значение частоты дисректизации выходного файла
|
| 143 |
+
|
| 144 |
+
Параметры:
|
| 145 |
+
f: Формат вывода
|
| 146 |
+
sr: Частота дискретизации (целое число)
|
| 147 |
+
Возвращает:
|
| 148 |
+
sr: Исправленная частота дискретизации
|
| 149 |
+
"""
|
| 150 |
+
format_info = self.sample_rates.get(f.lower())
|
| 151 |
+
|
| 152 |
+
if not format_info:
|
| 153 |
+
return None # Формат не найден
|
| 154 |
+
|
| 155 |
+
if "supported" in format_info:
|
| 156 |
+
# Для форматов с конкретным списком
|
| 157 |
+
supported_rates = format_info["supported"]
|
| 158 |
+
if sr in supported_rates:
|
| 159 |
+
return sr
|
| 160 |
+
|
| 161 |
+
# Находим ближайшую поддерживаемую частоту
|
| 162 |
+
return min(supported_rates, key=lambda x: abs(x - sr))
|
| 163 |
+
|
| 164 |
+
elif "min" in format_info and "max" in format_info:
|
| 165 |
+
# Для форматов с диапазоном - обрезаем до границ
|
| 166 |
+
min_rate = format_info["min"]
|
| 167 |
+
max_rate = format_info["max"]
|
| 168 |
+
|
| 169 |
+
if sr < min_rate:
|
| 170 |
+
return min_rate
|
| 171 |
+
elif sr > max_rate:
|
| 172 |
+
return max_rate
|
| 173 |
+
else:
|
| 174 |
+
return sr
|
| 175 |
+
|
| 176 |
+
return None
|
| 177 |
+
|
| 178 |
+
def fit_br(
|
| 179 |
+
self,
|
| 180 |
+
f: str | Literal["mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff"] = "mp3",
|
| 181 |
+
br: int = 320
|
| 182 |
+
) -> int:
|
| 183 |
+
"""
|
| 184 |
+
Исправляет значение битрейта выходного файла
|
| 185 |
+
|
| 186 |
+
Параметры:
|
| 187 |
+
f: Формат вывода
|
| 188 |
+
br: Битрейт (целое число)
|
| 189 |
+
Возвращает:
|
| 190 |
+
br: Исправленный битрейт
|
| 191 |
+
"""
|
| 192 |
+
if f not in self.bitrate_limit:
|
| 193 |
+
raise NotSupportedFormat(f"Формат {f} не поддерживается")
|
| 194 |
+
|
| 195 |
+
limits = self.bitrate_limit[f]
|
| 196 |
+
|
| 197 |
+
if br < limits["min"]:
|
| 198 |
+
return limits["min"]
|
| 199 |
+
elif br > limits["max"]:
|
| 200 |
+
return limits["max"]
|
| 201 |
+
else:
|
| 202 |
+
return br
|
| 203 |
+
|
| 204 |
+
def get_info(
|
| 205 |
+
self,
|
| 206 |
+
i: str | os.PathLike | Callable | None = None,
|
| 207 |
+
) -> dict[int, dict[int, float]]:
|
| 208 |
+
"""
|
| 209 |
+
Получает информацию о аудио потоках из файла напрямую через FFMPEG
|
| 210 |
+
|
| 211 |
+
Параметры:
|
| 212 |
+
i: Путь к выходному файлу
|
| 213 |
+
Возвращает:
|
| 214 |
+
audio_info: Словарь с информацией о аудиопотоках вида:
|
| 215 |
+
|
| 216 |
+
{Номер потока:
|
| 217 |
+
{
|
| 218 |
+
"sample_rate": Частота дисректизации (является целым числом),
|
| 219 |
+
"duration": Длительность аудиопотока (является числом с плавающей точкой)
|
| 220 |
+
}
|
| 221 |
+
}
|
| 222 |
+
"""
|
| 223 |
+
audio_info = {}
|
| 224 |
+
if i:
|
| 225 |
+
if isinstance(i, Path):
|
| 226 |
+
i = str(i)
|
| 227 |
+
if os.path.exists(i):
|
| 228 |
+
cmd = [self.ffprobe_path, "-i", i, "-v", "quiet", "-hide_banner",
|
| 229 |
+
"-show_entries", "stream=index,sample_rate,duration", "-select_streams", "a", "-of", "json"]
|
| 230 |
+
|
| 231 |
+
process = subprocess.Popen(
|
| 232 |
+
cmd,
|
| 233 |
+
stdin=subprocess.PIPE,
|
| 234 |
+
stdout=subprocess.PIPE,
|
| 235 |
+
stderr=subprocess.PIPE,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
stdout, stderr = process.communicate()
|
| 239 |
+
|
| 240 |
+
if process.returncode != 0:
|
| 241 |
+
print(f"STDERR: {stderr.decode('utf-8')}")
|
| 242 |
+
print(f"STDOUT: {stdout.decode('utf-8')}")
|
| 243 |
+
|
| 244 |
+
json_output = json.loads(stdout)
|
| 245 |
+
streams = json_output["streams"]
|
| 246 |
+
if not streams:
|
| 247 |
+
pass
|
| 248 |
+
|
| 249 |
+
else:
|
| 250 |
+
for a, stream in enumerate(streams):
|
| 251 |
+
audio_info[a] = {
|
| 252 |
+
"sample_rate": int(stream.get("sample_rate", 0)),
|
| 253 |
+
"duration": float(stream.get("duration", 0))
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
return audio_info
|
| 257 |
+
|
| 258 |
+
else:
|
| 259 |
+
raise FileExistsError("Указанного файла не существует")
|
| 260 |
+
|
| 261 |
+
else:
|
| 262 |
+
raise NotInputFileSpecified("Не указан путь к файлу")
|
| 263 |
+
|
| 264 |
+
def check(
|
| 265 |
+
self,
|
| 266 |
+
i: str | os.PathLike | Callable | None = None
|
| 267 |
+
) -> bool:
|
| 268 |
+
"""
|
| 269 |
+
Проверяет, является ли файл аудио или видео файлом, поддерживаемым ffmpeg
|
| 270 |
+
|
| 271 |
+
Параметры:
|
| 272 |
+
i: Путь к выходному файлу
|
| 273 |
+
Возвращает:
|
| 274 |
+
is_audio_video: Булево значение, является ли файл аудио или видео файлом
|
| 275 |
+
"""
|
| 276 |
+
if i:
|
| 277 |
+
if isinstance(i, Path):
|
| 278 |
+
i = str(i)
|
| 279 |
+
if os.path.exists(i):
|
| 280 |
+
info = self.get_info(i=i)
|
| 281 |
+
if info:
|
| 282 |
+
list_streams = list(info.keys())
|
| 283 |
+
if len(list_streams) > 0:
|
| 284 |
+
if info[0].get("sample_rate") > 0:
|
| 285 |
+
return True
|
| 286 |
+
else:
|
| 287 |
+
return False
|
| 288 |
+
else:
|
| 289 |
+
return False
|
| 290 |
+
else:
|
| 291 |
+
return False
|
| 292 |
+
else:
|
| 293 |
+
raise FileExistsError("Указанного файла не существует")
|
| 294 |
+
else:
|
| 295 |
+
raise NotInputFileSpecified("Не указан путь к файлу")
|
| 296 |
+
|
| 297 |
+
def read(
|
| 298 |
+
self,
|
| 299 |
+
i: str | os.PathLike | Callable | None = None,
|
| 300 |
+
sr: int | None = None,
|
| 301 |
+
mono: bool = False,
|
| 302 |
+
dtype: DTypeLike = np.float32,
|
| 303 |
+
s: int = 0
|
| 304 |
+
) -> tuple[np.ndarray, int, float]:
|
| 305 |
+
"""
|
| 306 |
+
Читает аудио-файл, преобразовывая его в массив с аудио данными напрямую через FFMPEG
|
| 307 |
+
Является заменой soundfile.read() и librosa.load()
|
| 308 |
+
|
| 309 |
+
Параметры:
|
| 310 |
+
i: Путь к выходному файлу
|
| 311 |
+
sr: Целевая частота дискретизации (Если не указана, то используется частота дискретизации входного файла)
|
| 312 |
+
mono: Конвертация в моно (по умолчанию отключена)
|
| 313 |
+
dtype: Тип данных (поддерживаются типы: int16, int32, float32, float64; по умолчанию - float32)
|
| 314 |
+
s: Номер аудиопотока (по умолчанию 0)
|
| 315 |
+
Возвращает:
|
| 316 |
+
audio_array: Массив с аудио данными
|
| 317 |
+
sr: Частота дискретизации массива
|
| 318 |
+
duration: Длительность аудио (количество сэмплов / частота дискретизации)
|
| 319 |
+
"""
|
| 320 |
+
output_format = self.dtypes_dict.get(dtype, None)
|
| 321 |
+
if not output_format:
|
| 322 |
+
raise NotSupportedDataType(f"Этот тип данных не поддерживается {dtype}")
|
| 323 |
+
if i:
|
| 324 |
+
if isinstance(i, Path):
|
| 325 |
+
i = str(i)
|
| 326 |
+
if os.path.exists(i):
|
| 327 |
+
audio_info = self.get_info(i=i)
|
| 328 |
+
list_streams = list(audio_info.keys())
|
| 329 |
+
if audio_info.get(s, False):
|
| 330 |
+
stream = s
|
| 331 |
+
else:
|
| 332 |
+
if len(list_streams) > 0:
|
| 333 |
+
stream = 0
|
| 334 |
+
else:
|
| 335 |
+
raise FileIsNotAudio("В входном файле нет аудио потоков")
|
| 336 |
+
|
| 337 |
+
sample_rate_input = audio_info[stream]["sample_rate"]
|
| 338 |
+
if sample_rate_input == 0:
|
| 339 |
+
raise FileIsNotAudio("В входном файле нет аудио потоков")
|
| 340 |
+
|
| 341 |
+
cmd = [
|
| 342 |
+
self.ffmpeg_path,
|
| 343 |
+
"-i", i,
|
| 344 |
+
"-map", f"0:a:{stream}", "-vn",
|
| 345 |
+
"-f", output_format,
|
| 346 |
+
"-ac", "1" if mono else "2",
|
| 347 |
+
]
|
| 348 |
+
|
| 349 |
+
if sr:
|
| 350 |
+
cmd.extend(["-ar", str(sr)])
|
| 351 |
+
else:
|
| 352 |
+
sr = sample_rate_input
|
| 353 |
+
|
| 354 |
+
cmd.append("pipe:1")
|
| 355 |
+
|
| 356 |
+
process = subprocess.Popen(
|
| 357 |
+
cmd,
|
| 358 |
+
stdout=subprocess.PIPE,
|
| 359 |
+
stderr=subprocess.PIPE,
|
| 360 |
+
bufsize=10**8
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
try:
|
| 364 |
+
|
| 365 |
+
raw_audio, stderr = process.communicate(timeout=300)
|
| 366 |
+
|
| 367 |
+
if process.returncode != 0:
|
| 368 |
+
raise ErrorDecode(f"FFmpeg error: {stderr.decode()}")
|
| 369 |
+
|
| 370 |
+
except subprocess.TimeoutExpired:
|
| 371 |
+
process.kill()
|
| 372 |
+
raise ErrorDecode("FFmpeg timeout при чтении файла")
|
| 373 |
+
|
| 374 |
+
audio_array = np.frombuffer(raw_audio, dtype=dtype)
|
| 375 |
+
|
| 376 |
+
channels = 1 if mono else 2
|
| 377 |
+
audio_array = audio_array.reshape((-1, channels)).T
|
| 378 |
+
if audio_array.ndim > 1 and channels == 1:
|
| 379 |
+
audio_array = np.mean(audio_array, axis=tuple(range(audio_array.ndim - 1)))
|
| 380 |
+
|
| 381 |
+
len_samples = float(audio_array.shape[-1])
|
| 382 |
+
|
| 383 |
+
duration = len_samples / sr
|
| 384 |
+
|
| 385 |
+
print(f"Частота дискретизации: {sr}")
|
| 386 |
+
|
| 387 |
+
return audio_array, sr, duration
|
| 388 |
+
else:
|
| 389 |
+
raise FileExistsError("Указанного файла не существует")
|
| 390 |
+
|
| 391 |
+
else:
|
| 392 |
+
raise NotInputFileSpecified("Не указан путь к файлу")
|
| 393 |
+
|
| 394 |
+
def write(
|
| 395 |
+
self,
|
| 396 |
+
o: str | os.PathLike | Callable | None = None,
|
| 397 |
+
array: np.ndarray = np.array([], dtype=np.float32),
|
| 398 |
+
sr: int = 44100,
|
| 399 |
+
of: str | Literal["mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "ac3", "aiff"] | None = None,
|
| 400 |
+
br: str | int | None = None
|
| 401 |
+
) -> str:
|
| 402 |
+
"""
|
| 403 |
+
Записывает numpy-массив с аудио данными в файл напрямую через ffmpeg.
|
| 404 |
+
Является заменой soundfile.write()
|
| 405 |
+
|
| 406 |
+
Параметры:
|
| 407 |
+
o: Путь к выходному файлу
|
| 408 |
+
array: Массив с аудио данными (поддерживаются типы: int16, int32, float32, float64)
|
| 409 |
+
sr: Частота дискретизации массива
|
| 410 |
+
of: Формат вывода (по умолчанию mp3)
|
| 411 |
+
br: Битрейт для кодеков, сжимающих аудио с потерями
|
| 412 |
+
Возвращает:
|
| 413 |
+
o: Путь к выходному файлу
|
| 414 |
+
"""
|
| 415 |
+
if isinstance(array, np.ndarray):
|
| 416 |
+
|
| 417 |
+
if len(array.shape) == 1:
|
| 418 |
+
array = array.reshape(-1, 1)
|
| 419 |
+
elif len(array.shape) == 2:
|
| 420 |
+
if array.shape[0] == 2:
|
| 421 |
+
array = array.T
|
| 422 |
+
else:
|
| 423 |
+
raise ValueError("numpy-массив должен быть либо одномерным, либо двухмерным")
|
| 424 |
+
|
| 425 |
+
if array.dtype == np.int16:
|
| 426 |
+
input_format = "s16le"
|
| 427 |
+
elif array.dtype == np.int32:
|
| 428 |
+
input_format = "s32le"
|
| 429 |
+
elif array.dtype == np.float32:
|
| 430 |
+
input_format = "f32le"
|
| 431 |
+
elif array.dtype == np.float64:
|
| 432 |
+
input_format = "f64le"
|
| 433 |
+
else:
|
| 434 |
+
raise NotSupportedDataType(f"Этот тип данных не поддерживается {array.dtype}")
|
| 435 |
+
|
| 436 |
+
if array.shape[1] == 1:
|
| 437 |
+
audio_bytes = array.tobytes()
|
| 438 |
+
|
| 439 |
+
channels = 1
|
| 440 |
+
|
| 441 |
+
elif array.shape[1] == 2:
|
| 442 |
+
audio_bytes = array.tobytes()
|
| 443 |
+
|
| 444 |
+
channels = 2
|
| 445 |
+
else:
|
| 446 |
+
raise ValueError("numpy-массив должен содержать 1 или 2 канала")
|
| 447 |
+
|
| 448 |
+
else:
|
| 449 |
+
raise ValueError("Вход должен быть numpy-массивом")
|
| 450 |
+
|
| 451 |
+
if o:
|
| 452 |
+
if isinstance(o, Path):
|
| 453 |
+
o = str(o)
|
| 454 |
+
output_dir = os.path.dirname(o)
|
| 455 |
+
output_base = os.path.basename(o)
|
| 456 |
+
output_name, output_ext = os.path.splitext(output_base)
|
| 457 |
+
if output_dir != "":
|
| 458 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 459 |
+
if output_ext == "":
|
| 460 |
+
if of:
|
| 461 |
+
o += f".{of}"
|
| 462 |
+
else:
|
| 463 |
+
o += f".mp3"
|
| 464 |
+
elif output_ext == ".":
|
| 465 |
+
if of:
|
| 466 |
+
o += f"{of}"
|
| 467 |
+
else:
|
| 468 |
+
o += f"mp3"
|
| 469 |
+
else:
|
| 470 |
+
raise NotOutputFileSpecified("Не указан путь к выходному файлу")
|
| 471 |
+
|
| 472 |
+
if of:
|
| 473 |
+
if of in self.output_formats:
|
| 474 |
+
output_name, output_ext = os.path.splitext(o)
|
| 475 |
+
if output_ext == f".{of}":
|
| 476 |
+
pass
|
| 477 |
+
else:
|
| 478 |
+
o = f"{os.path.join(output_dir, output_name)}.{of}"
|
| 479 |
+
else:
|
| 480 |
+
raise NotSupportedFormat(f"Неподдерживаемый формат: {of}")
|
| 481 |
+
else:
|
| 482 |
+
of = os.path.splitext(o)[1].strip(".")
|
| 483 |
+
if of in self.output_formats:
|
| 484 |
+
pass
|
| 485 |
+
else:
|
| 486 |
+
raise NotSupportedFormat(f"Неподдерживаемый формат: {of}")
|
| 487 |
+
|
| 488 |
+
if sr:
|
| 489 |
+
if isinstance(sr, int):
|
| 490 |
+
sample_rate_fixed = self.fit_sr(f=of, sr=sr)
|
| 491 |
+
elif isinstance(sr, float):
|
| 492 |
+
sr = int(sr)
|
| 493 |
+
sample_rate_fixed = self.fit_sr(f=of, sr=sr)
|
| 494 |
+
else:
|
| 495 |
+
raise SampleRateError(f"Частота дискретизации должна быть числом\n\nЗначение: {sr}\nТип: {type(sr)}")
|
| 496 |
+
else:
|
| 497 |
+
raise SampleRateError("Не указана частота дискретизации")
|
| 498 |
+
|
| 499 |
+
bitrate_fixed = "320k"
|
| 500 |
+
|
| 501 |
+
if of not in ["wav", "flac", "aiff"]:
|
| 502 |
+
if br:
|
| 503 |
+
if isinstance(br, int):
|
| 504 |
+
bitrate_fixed = self.fit_br(f=of, br=br)
|
| 505 |
+
elif isinstance(br, float):
|
| 506 |
+
bitrate_fixed = self.fit_br(f=of, br=int(br))
|
| 507 |
+
elif isinstance(br, str):
|
| 508 |
+
bitrate_fixed = self.fit_br(f=of, br=int(br.strip("k").strip("K")))
|
| 509 |
+
else:
|
| 510 |
+
bitrate_fixed = self.fit_br(f=of, br=320)
|
| 511 |
+
else:
|
| 512 |
+
bitrate_fixed = self.fit_br(of, 320)
|
| 513 |
+
|
| 514 |
+
format_settings = {
|
| 515 |
+
"wav": [
|
| 516 |
+
"-c:a",
|
| 517 |
+
"pcm_f32le",
|
| 518 |
+
"-sample_fmt",
|
| 519 |
+
"flt",
|
| 520 |
+
],
|
| 521 |
+
"aiff": [
|
| 522 |
+
"-c:a",
|
| 523 |
+
"pcm_f32le",
|
| 524 |
+
"-sample_fmt",
|
| 525 |
+
"flt",
|
| 526 |
+
],
|
| 527 |
+
"flac": [
|
| 528 |
+
"-c:a",
|
| 529 |
+
"flac",
|
| 530 |
+
"-compression_level",
|
| 531 |
+
"12",
|
| 532 |
+
"-sample_fmt",
|
| 533 |
+
"s32",
|
| 534 |
+
],
|
| 535 |
+
"mp3": [
|
| 536 |
+
"-c:a",
|
| 537 |
+
"libmp3lame",
|
| 538 |
+
"-b:a",
|
| 539 |
+
f"{bitrate_fixed}k",
|
| 540 |
+
],
|
| 541 |
+
"ogg": [
|
| 542 |
+
"-c:a",
|
| 543 |
+
"libvorbis",
|
| 544 |
+
"-b:a",
|
| 545 |
+
f"{bitrate_fixed}k",
|
| 546 |
+
],
|
| 547 |
+
"opus": [
|
| 548 |
+
"-c:a",
|
| 549 |
+
"libopus",
|
| 550 |
+
"-b:a",
|
| 551 |
+
f"{bitrate_fixed}k",
|
| 552 |
+
],
|
| 553 |
+
"m4a": [
|
| 554 |
+
"-c:a",
|
| 555 |
+
"aac",
|
| 556 |
+
"-b:a",
|
| 557 |
+
f"{bitrate_fixed}k",
|
| 558 |
+
],
|
| 559 |
+
"aac": [
|
| 560 |
+
"-c:a",
|
| 561 |
+
"aac",
|
| 562 |
+
"-b:a",
|
| 563 |
+
f"{bitrate_fixed}k",
|
| 564 |
+
],
|
| 565 |
+
"ac3": [
|
| 566 |
+
"-c:a",
|
| 567 |
+
"ac3",
|
| 568 |
+
"-b:a",
|
| 569 |
+
f"{bitrate_fixed}k",
|
| 570 |
+
],
|
| 571 |
+
}
|
| 572 |
+
|
| 573 |
+
cmd = [
|
| 574 |
+
self.ffmpeg_path,
|
| 575 |
+
"-y",
|
| 576 |
+
"-f",
|
| 577 |
+
input_format,
|
| 578 |
+
"-ar",
|
| 579 |
+
str(sr),
|
| 580 |
+
"-ac",
|
| 581 |
+
str(channels),
|
| 582 |
+
"-i",
|
| 583 |
+
"pipe:0",
|
| 584 |
+
"-ac",
|
| 585 |
+
str(channels),
|
| 586 |
+
]
|
| 587 |
+
|
| 588 |
+
cmd.extend(["-ar", str(sample_rate_fixed)])
|
| 589 |
+
cmd.extend(format_settings[of])
|
| 590 |
+
o_dir, o_base = os.path.split(o)
|
| 591 |
+
o_base_n, o_base_ext = os.path.splitext(o_base)
|
| 592 |
+
o_base_n = self.sanitize(o_base_n)
|
| 593 |
+
o_base_n = self.short(o_base_n)
|
| 594 |
+
o = os.path.join(o_dir, f"{o_base_n}{o_base_ext}")
|
| 595 |
+
o = self.iter(o)
|
| 596 |
+
cmd.append(o)
|
| 597 |
+
|
| 598 |
+
process = subprocess.Popen(
|
| 599 |
+
cmd,
|
| 600 |
+
stdin=subprocess.PIPE,
|
| 601 |
+
stdout=subprocess.PIPE,
|
| 602 |
+
stderr=subprocess.PIPE,
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
try:
|
| 606 |
+
stdout, stderr = process.communicate(input=audio_bytes, timeout=300)
|
| 607 |
+
except subprocess.TimeoutExpired:
|
| 608 |
+
process.kill()
|
| 609 |
+
raise ErrorEncode("FFmpeg timeout: операция заняла слишком много времени")
|
| 610 |
+
|
| 611 |
+
if process.returncode != 0:
|
| 612 |
+
raise ErrorEncode(f"FFmpeg завершился с ошибкой (код: {process.returncode})")
|
| 613 |
+
|
| 614 |
+
return os.path.abspath(o)
|
| 615 |
+
|
| 616 |
+
class Inverter(Audio):
|
| 617 |
+
def __init__(self):
|
| 618 |
+
super().__init__()
|
| 619 |
+
self.test = "test"
|
| 620 |
+
self.w_types = [
|
| 621 |
+
"boxcar", # Прямоугольное окно
|
| 622 |
+
"triang", # Треугольное окно
|
| 623 |
+
"blackman", # Окно Блэкмана
|
| 624 |
+
"hamming", # Окно Хэмминга
|
| 625 |
+
"hann", # Окно Ханна
|
| 626 |
+
"bartlett", # Окно Бартлетта
|
| 627 |
+
"flattop", # Окно с плоской вершиной
|
| 628 |
+
"parzen", # Окно Парзена
|
| 629 |
+
"bohman", # Окно Бохмана
|
| 630 |
+
"blackmanharris", # Окно Блэкмана-Харриса
|
| 631 |
+
"nuttall", # Окно Нуттала
|
| 632 |
+
"barthann", # Окно Бартлетта-Ханна
|
| 633 |
+
"cosine", # Косинусное окно
|
| 634 |
+
"exponential", # Экспоненциальное окно
|
| 635 |
+
"tukey", # Окно Туки
|
| 636 |
+
"taylor", # Окно Тейлора
|
| 637 |
+
"lanczos", # Окно Ланцоша
|
| 638 |
+
]
|
| 639 |
+
|
| 640 |
+
def load_audio(self, filepath):
|
| 641 |
+
"""Загрузка аудиофайла с помощью librosa"""
|
| 642 |
+
try:
|
| 643 |
+
y, sr, _ = self.read(i=filepath, sr=None, mono=False)
|
| 644 |
+
return y, sr
|
| 645 |
+
except Exception as e:
|
| 646 |
+
print(f"Ошибка загрузки аудио: {e}")
|
| 647 |
+
return None, None
|
| 648 |
+
|
| 649 |
+
def process_channel(self, y1_ch, y2_ch, sr, method, w_size=2048, overlap=2, w_type="hann"):
|
| 650 |
+
"""Обработка одного аудиоканала"""
|
| 651 |
+
HOP_LENGTH = w_size // overlap
|
| 652 |
+
if method == "waveform":
|
| 653 |
+
return y1_ch - y2_ch
|
| 654 |
+
|
| 655 |
+
elif method == "spectrogram":
|
| 656 |
+
# Вычисляем спектрограммы
|
| 657 |
+
S1 = librosa.stft(
|
| 658 |
+
y1_ch, n_fft=w_size, hop_length=HOP_LENGTH, win_length=w_size
|
| 659 |
+
)
|
| 660 |
+
S2 = librosa.stft(
|
| 661 |
+
y2_ch, n_fft=w_size, hop_length=HOP_LENGTH, win_length=w_size
|
| 662 |
+
)
|
| 663 |
+
|
| 664 |
+
# Амплитудные спектрограммы
|
| 665 |
+
mag1 = np.abs(S1)
|
| 666 |
+
mag2 = np.abs(S2)
|
| 667 |
+
|
| 668 |
+
# Спектральное вычитание
|
| 669 |
+
mag_result = np.maximum(mag1 - mag2, 0)
|
| 670 |
+
|
| 671 |
+
# Сохраняем фазовую информацию исходного сигнала
|
| 672 |
+
phase = np.angle(S1)
|
| 673 |
+
|
| 674 |
+
# Комбинируем амплитуду результата с фазой
|
| 675 |
+
S_result = mag_result * np.exp(1j * phase)
|
| 676 |
+
|
| 677 |
+
# Обратное преобразование
|
| 678 |
+
return librosa.istft(
|
| 679 |
+
S_result,
|
| 680 |
+
n_fft=w_size,
|
| 681 |
+
hop_length=HOP_LENGTH,
|
| 682 |
+
win_length=w_size,
|
| 683 |
+
length=len(y1_ch),
|
| 684 |
+
)
|
| 685 |
+
|
| 686 |
+
def process_audio(self, audio1_path, audio2_path, out_format, method, output_path="./inverted.mp3", w_size=2048, overlap=2, w_type="hann"):
|
| 687 |
+
# Загрузка аудиофайлов
|
| 688 |
+
y1, sr1 = self.load_audio(audio1_path)
|
| 689 |
+
y2, sr2 = self.load_audio(audio2_path)
|
| 690 |
+
|
| 691 |
+
if sr1 is None or sr2 is None:
|
| 692 |
+
raise Exception("Произошла ошибка при чтении файлов")
|
| 693 |
+
|
| 694 |
+
# Определяем количество каналов
|
| 695 |
+
channels1 = 1 if y1.ndim == 1 else y1.shape[0]
|
| 696 |
+
channels2 = 1 if y2.ndim == 1 else y2.shape[0]
|
| 697 |
+
|
| 698 |
+
# Преобразование в форму (samples, channels)
|
| 699 |
+
if channels1 > 1:
|
| 700 |
+
y1 = y1.T # (channels, samples) -> (samples, channels)
|
| 701 |
+
else:
|
| 702 |
+
y1 = y1.reshape(-1, 1)
|
| 703 |
+
|
| 704 |
+
if channels2 > 1:
|
| 705 |
+
y2 = y2.T # (channels, samples) -> (samples, channels)
|
| 706 |
+
else:
|
| 707 |
+
y2 = y2.reshape(-1, 1)
|
| 708 |
+
|
| 709 |
+
if sr1 != sr2:
|
| 710 |
+
if channels2 > 1:
|
| 711 |
+
# Ресемплинг для каждого канала отдельно
|
| 712 |
+
y2_resampled_list = []
|
| 713 |
+
for c in range(channels2):
|
| 714 |
+
channel_resampled = librosa.resample(
|
| 715 |
+
y2[:, c], orig_sr=sr2, target_sr=sr1
|
| 716 |
+
)
|
| 717 |
+
y2_resampled_list.append(channel_resampled)
|
| 718 |
+
|
| 719 |
+
# Находим минимальную длину среди всех каналов
|
| 720 |
+
min_channel_length = min(len(ch) for ch in y2_resampled_list)
|
| 721 |
+
|
| 722 |
+
# Обрезаем все каналы до одинаковой длины и собираем в массив
|
| 723 |
+
y2_resampled = np.zeros((min_channel_length, channels2), dtype=np.float32)
|
| 724 |
+
for c, channel in enumerate(y2_resampled_list):
|
| 725 |
+
y2_resampled[:, c] = channel[:min_channel_length]
|
| 726 |
+
|
| 727 |
+
y2 = y2_resampled
|
| 728 |
+
else:
|
| 729 |
+
y2 = librosa.resample(y2[:, 0], orig_sr=sr2, target_sr=sr1)
|
| 730 |
+
y2 = y2.reshape(-1, 1)
|
| 731 |
+
sr2 = sr1
|
| 732 |
+
|
| 733 |
+
# Приводим к одинаковой длине
|
| 734 |
+
min_len = min(len(y1), len(y2))
|
| 735 |
+
y1 = y1[:min_len]
|
| 736 |
+
y2 = y2[:min_len]
|
| 737 |
+
|
| 738 |
+
# Обрабатываем каждый канал отдельно
|
| 739 |
+
result_channels = []
|
| 740 |
+
|
| 741 |
+
# Если основной сигнал моно, а удаляемый стерео - преобразуем удаляемый в моно
|
| 742 |
+
if channels1 == 1 and channels2 > 1:
|
| 743 |
+
y2 = y2.mean(axis=1, keepdims=True)
|
| 744 |
+
channels2 = 1
|
| 745 |
+
|
| 746 |
+
for c in range(channels1):
|
| 747 |
+
# Выбираем канал для основного сигнала
|
| 748 |
+
y1_ch = y1[:, c]
|
| 749 |
+
|
| 750 |
+
# Выбираем канал для удаляемого сигнала
|
| 751 |
+
if channels2 == 1:
|
| 752 |
+
y2_ch = y2[:, 0]
|
| 753 |
+
else:
|
| 754 |
+
# Если каналов удаляемого сигнала больше, используем соответствующий канал
|
| 755 |
+
y2_ch = y2[:, min(c, channels2 - 1)]
|
| 756 |
+
|
| 757 |
+
# Обрабатываем канал
|
| 758 |
+
result_ch = self.process_channel(y1_ch, y2_ch, sr1, method, w_size=w_size, overlap=overlap, w_type=w_type)
|
| 759 |
+
result_channels.append(result_ch)
|
| 760 |
+
|
| 761 |
+
# Собираем каналы в один массив
|
| 762 |
+
if len(result_channels) > 1:
|
| 763 |
+
result = np.column_stack(result_channels)
|
| 764 |
+
else:
|
| 765 |
+
result = np.array(result_channels[0])
|
| 766 |
+
|
| 767 |
+
# Нормализация (предотвращение клиппинга)
|
| 768 |
+
if result.ndim > 1:
|
| 769 |
+
# Для многоканального аудио нормализуем каждый канал отдельно
|
| 770 |
+
for c in range(result.shape[1]):
|
| 771 |
+
channel = result[:, c]
|
| 772 |
+
max_val = np.max(np.abs(channel))
|
| 773 |
+
if max_val > 0:
|
| 774 |
+
result[:, c] = channel * 0.9 / max_val
|
| 775 |
+
else:
|
| 776 |
+
max_val = np.max(np.abs(result))
|
| 777 |
+
if max_val > 0:
|
| 778 |
+
result = result * 0.9 / max_val
|
| 779 |
+
|
| 780 |
+
inverted = self.write(o=output_path, array=result.T, sr=sr1, of=out_format, br="320k")
|
| 781 |
+
return inverted
|
mvsepless/downloader.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import yt_dlp
|
| 3 |
+
import time
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import urllib.request
|
| 6 |
+
|
| 7 |
+
DOWNLOAD_DIR = os.environ.get(
|
| 8 |
+
"MVSEPLESS_DOWNLOAD_DIR", os.path.join(os.getcwd(), "downloaded")
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
def dw_file(url_model: str, local_path: str, retries: int = 30):
|
| 12 |
+
dir_name = os.path.dirname(local_path)
|
| 13 |
+
if dir_name != "":
|
| 14 |
+
os.makedirs(dir_name, exist_ok=True)
|
| 15 |
+
|
| 16 |
+
class TqdmUpTo(tqdm):
|
| 17 |
+
def update_to(self, b=1, bsize=1, tsize=None):
|
| 18 |
+
if tsize is not None:
|
| 19 |
+
self.total = tsize
|
| 20 |
+
self.update(b * bsize - self.n)
|
| 21 |
+
|
| 22 |
+
for attempt in range(retries):
|
| 23 |
+
try:
|
| 24 |
+
with TqdmUpTo(
|
| 25 |
+
unit="B",
|
| 26 |
+
unit_scale=True,
|
| 27 |
+
unit_divisor=1024,
|
| 28 |
+
miniters=1,
|
| 29 |
+
desc=os.path.basename(local_path),
|
| 30 |
+
) as t:
|
| 31 |
+
urllib.request.urlretrieve(
|
| 32 |
+
url_model, local_path, reporthook=t.update_to
|
| 33 |
+
)
|
| 34 |
+
# Если дошли сюда - загрузка успешна, прерываем цикл
|
| 35 |
+
break
|
| 36 |
+
except Exception as e:
|
| 37 |
+
print(f"Попытка {attempt + 1}/{retries} не удалась. Ошибка: {e}")
|
| 38 |
+
if attempt < retries - 1: # Если это не последняя попытка
|
| 39 |
+
print("Повторная попытка...")
|
| 40 |
+
time.sleep(2) # Небольшая задержка перед повторной попыткой
|
| 41 |
+
else:
|
| 42 |
+
print("Все попытки загрузки завершились неудачно")
|
| 43 |
+
raise # Пробрасываем исключение дальше
|
| 44 |
+
|
| 45 |
+
def dw_yt_dlp(
|
| 46 |
+
url,
|
| 47 |
+
output_dir=None,
|
| 48 |
+
cookie=None,
|
| 49 |
+
output_format="mp3",
|
| 50 |
+
output_bitrate="320",
|
| 51 |
+
title=None,
|
| 52 |
+
):
|
| 53 |
+
# Подготовка шаблона имени файла
|
| 54 |
+
outtmpl = "%(title)s.%(ext)s" if title is None else f"{title}.%(ext)s"
|
| 55 |
+
|
| 56 |
+
ydl_opts = {
|
| 57 |
+
"format": "bestaudio/best",
|
| 58 |
+
"outtmpl": os.path.join(DOWNLOAD_DIR if not output_dir else output_dir, outtmpl),
|
| 59 |
+
"postprocessors": [
|
| 60 |
+
{
|
| 61 |
+
"key": "FFmpegExtractAudio",
|
| 62 |
+
"preferredcodec": output_format,
|
| 63 |
+
"preferredquality": output_bitrate,
|
| 64 |
+
}
|
| 65 |
+
],
|
| 66 |
+
"noplaylist": True, # Скачивать только одно видео, не плейлист
|
| 67 |
+
"quiet": True, # Отключить вывод в консоль
|
| 68 |
+
"no_warnings": True, # Скрыть предупреждения
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
# Добавляем cookies если указаны
|
| 72 |
+
if cookie and os.path.exists(cookie):
|
| 73 |
+
ydl_opts["cookiefile"] = cookie
|
| 74 |
+
|
| 75 |
+
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
| 76 |
+
try:
|
| 77 |
+
info = ydl.extract_info(url, download=True)
|
| 78 |
+
if "_type" in info and info["_type"] == "playlist":
|
| 79 |
+
# Для плейлистов берем первое видео
|
| 80 |
+
entry = info["entries"][0]
|
| 81 |
+
filename = ydl.prepare_filename(entry)
|
| 82 |
+
else:
|
| 83 |
+
# Для одиночного видео
|
| 84 |
+
filename = ydl.prepare_filename(info)
|
| 85 |
+
|
| 86 |
+
# Заменяем оригинальное расширение на выбранный формат
|
| 87 |
+
base, _ = os.path.splitext(filename)
|
| 88 |
+
audio_file = base + f".{output_format}"
|
| 89 |
+
|
| 90 |
+
return os.path.join(DOWNLOAD_DIR, audio_file)
|
| 91 |
+
except Exception as e:
|
| 92 |
+
return None
|
mvsepless/ensemble.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding: utf-8
|
| 2 |
+
__author__ = "Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/"
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
import librosa
|
| 7 |
+
import tempfile
|
| 8 |
+
import numpy as np
|
| 9 |
+
import argparse
|
| 10 |
+
if not __package__:
|
| 11 |
+
from audio import Audio
|
| 12 |
+
from namer import Namer
|
| 13 |
+
else:
|
| 14 |
+
from .audio import Audio
|
| 15 |
+
from .namer import Namer
|
| 16 |
+
|
| 17 |
+
audio = Audio()
|
| 18 |
+
|
| 19 |
+
def stft(wave, nfft, hl):
|
| 20 |
+
wave_left = np.asfortranarray(wave[0])
|
| 21 |
+
wave_right = np.asfortranarray(wave[1])
|
| 22 |
+
spec_left = librosa.stft(wave_left, n_fft=nfft, hop_length=hl)
|
| 23 |
+
spec_right = librosa.stft(wave_right, n_fft=nfft, hop_length=hl)
|
| 24 |
+
spec = np.asfortranarray([spec_left, spec_right])
|
| 25 |
+
return spec
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def istft(spec, hl, length):
|
| 29 |
+
spec_left = np.asfortranarray(spec[0])
|
| 30 |
+
spec_right = np.asfortranarray(spec[1])
|
| 31 |
+
wave_left = librosa.istft(spec_left, hop_length=hl, length=length)
|
| 32 |
+
wave_right = librosa.istft(spec_right, hop_length=hl, length=length)
|
| 33 |
+
wave = np.asfortranarray([wave_left, wave_right])
|
| 34 |
+
return wave
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def absmax(a, *, axis):
|
| 38 |
+
dims = list(a.shape)
|
| 39 |
+
dims.pop(axis)
|
| 40 |
+
indices = np.ogrid[tuple(slice(0, d) for d in dims)]
|
| 41 |
+
argmax = np.abs(a).argmax(axis=axis)
|
| 42 |
+
# Convert indices to list before insertion
|
| 43 |
+
indices = list(indices)
|
| 44 |
+
indices.insert(axis % len(a.shape), argmax)
|
| 45 |
+
return a[tuple(indices)]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def absmin(a, *, axis):
|
| 49 |
+
dims = list(a.shape)
|
| 50 |
+
dims.pop(axis)
|
| 51 |
+
indices = np.ogrid[tuple(slice(0, d) for d in dims)]
|
| 52 |
+
argmax = np.abs(a).argmin(axis=axis)
|
| 53 |
+
indices.insert((len(a.shape) + axis) % len(a.shape), argmax)
|
| 54 |
+
return a[tuple(indices)]
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def lambda_max(arr, axis=None, key=None, keepdims=False):
|
| 58 |
+
idxs = np.argmax(key(arr), axis)
|
| 59 |
+
if axis is not None:
|
| 60 |
+
idxs = np.expand_dims(idxs, axis)
|
| 61 |
+
result = np.take_along_axis(arr, idxs, axis)
|
| 62 |
+
if not keepdims:
|
| 63 |
+
result = np.squeeze(result, axis=axis)
|
| 64 |
+
return result
|
| 65 |
+
else:
|
| 66 |
+
return arr.flatten()[idxs]
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def lambda_min(arr, axis=None, key=None, keepdims=False):
|
| 70 |
+
idxs = np.argmin(key(arr), axis)
|
| 71 |
+
if axis is not None:
|
| 72 |
+
idxs = np.expand_dims(idxs, axis)
|
| 73 |
+
result = np.take_along_axis(arr, idxs, axis)
|
| 74 |
+
if not keepdims:
|
| 75 |
+
result = np.squeeze(result, axis=axis)
|
| 76 |
+
return result
|
| 77 |
+
else:
|
| 78 |
+
return arr.flatten()[idxs]
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def average_waveforms(pred_track, weights, algorithm):
|
| 82 |
+
"""
|
| 83 |
+
:param pred_track: shape = (num, channels, length)
|
| 84 |
+
:param weights: shape = (num, )
|
| 85 |
+
:param algorithm: One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft
|
| 86 |
+
:return: averaged waveform in shape (channels, length)
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
pred_track = np.array(pred_track)
|
| 90 |
+
final_length = pred_track.shape[-1]
|
| 91 |
+
|
| 92 |
+
mod_track = []
|
| 93 |
+
for i in range(pred_track.shape[0]):
|
| 94 |
+
if algorithm == "avg_wave":
|
| 95 |
+
mod_track.append(pred_track[i] * weights[i])
|
| 96 |
+
elif algorithm in ["median_wave", "min_wave", "max_wave"]:
|
| 97 |
+
mod_track.append(pred_track[i])
|
| 98 |
+
elif algorithm in ["avg_fft", "min_fft", "max_fft", "median_fft"]:
|
| 99 |
+
spec = stft(pred_track[i], nfft=2048, hl=1024)
|
| 100 |
+
if algorithm in ["avg_fft"]:
|
| 101 |
+
mod_track.append(spec * weights[i])
|
| 102 |
+
else:
|
| 103 |
+
mod_track.append(spec)
|
| 104 |
+
pred_track = np.array(mod_track)
|
| 105 |
+
|
| 106 |
+
if algorithm in ["avg_wave"]:
|
| 107 |
+
pred_track = pred_track.sum(axis=0)
|
| 108 |
+
pred_track /= np.array(weights).sum().T
|
| 109 |
+
elif algorithm in ["median_wave"]:
|
| 110 |
+
pred_track = np.median(pred_track, axis=0)
|
| 111 |
+
elif algorithm in ["min_wave"]:
|
| 112 |
+
pred_track = np.array(pred_track)
|
| 113 |
+
pred_track = lambda_min(pred_track, axis=0, key=np.abs)
|
| 114 |
+
elif algorithm in ["max_wave"]:
|
| 115 |
+
pred_track = np.array(pred_track)
|
| 116 |
+
pred_track = lambda_max(pred_track, axis=0, key=np.abs)
|
| 117 |
+
elif algorithm in ["avg_fft"]:
|
| 118 |
+
pred_track = pred_track.sum(axis=0)
|
| 119 |
+
pred_track /= np.array(weights).sum()
|
| 120 |
+
pred_track = istft(pred_track, 1024, final_length)
|
| 121 |
+
elif algorithm in ["min_fft"]:
|
| 122 |
+
pred_track = np.array(pred_track)
|
| 123 |
+
pred_track = lambda_min(pred_track, axis=0, key=np.abs)
|
| 124 |
+
pred_track = istft(pred_track, 1024, final_length)
|
| 125 |
+
elif algorithm in ["max_fft"]:
|
| 126 |
+
pred_track = np.array(pred_track)
|
| 127 |
+
pred_track = absmax(pred_track, axis=0)
|
| 128 |
+
pred_track = istft(pred_track, 1024, final_length)
|
| 129 |
+
elif algorithm in ["median_fft"]:
|
| 130 |
+
pred_track = np.array(pred_track)
|
| 131 |
+
pred_track = np.median(pred_track, axis=0)
|
| 132 |
+
pred_track = istft(pred_track, 1024, final_length)
|
| 133 |
+
return pred_track
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def ensemble_audio_files(
|
| 137 |
+
files, output="res.wav", ensemble_type="avg_wave", weights=None, out_format="wav", add_wav=False
|
| 138 |
+
) -> str | tuple[str, str]:
|
| 139 |
+
"""
|
| 140 |
+
Основная функция для объединения аудиофайлов
|
| 141 |
+
|
| 142 |
+
:param files: список путей к аудиофайлам
|
| 143 |
+
:param output: путь для сохранения результа��а
|
| 144 |
+
:param ensemble_type: метод объединения (avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft)
|
| 145 |
+
:param weights: список весов для каждого файла (None для равных весов)
|
| 146 |
+
:return: None
|
| 147 |
+
"""
|
| 148 |
+
print("Алгоритм склеивания: {}".format(ensemble_type))
|
| 149 |
+
print("Количество входных файлов: {}".format(len(files)))
|
| 150 |
+
if weights is not None:
|
| 151 |
+
weights = np.array(weights)
|
| 152 |
+
else:
|
| 153 |
+
weights = np.ones(len(files))
|
| 154 |
+
print("Весы: {}".format(weights))
|
| 155 |
+
print("Имя выходного файла: {}".format(output))
|
| 156 |
+
|
| 157 |
+
data = []
|
| 158 |
+
sr = None
|
| 159 |
+
max_length = 0
|
| 160 |
+
max_channels = 0
|
| 161 |
+
|
| 162 |
+
# Первый проход: определяем максимальную длину и количество каналов
|
| 163 |
+
for f in files:
|
| 164 |
+
if not os.path.isfile(f):
|
| 165 |
+
print("Не удается найти файл: {}. Check paths.".format(f))
|
| 166 |
+
exit()
|
| 167 |
+
print("Читается файл: {}".format(f))
|
| 168 |
+
wav, current_sr, _ = audio.read(i=f, sr=None, mono=False)
|
| 169 |
+
if sr is None:
|
| 170 |
+
sr = current_sr
|
| 171 |
+
elif sr != current_sr:
|
| 172 |
+
print("Частота дискретизации на всех файлах должна быть одинаковой")
|
| 173 |
+
exit()
|
| 174 |
+
|
| 175 |
+
# Определяем количество каналов
|
| 176 |
+
if wav.ndim == 1:
|
| 177 |
+
channels = 1
|
| 178 |
+
length = len(wav)
|
| 179 |
+
else:
|
| 180 |
+
channels = wav.shape[0]
|
| 181 |
+
length = wav.shape[1]
|
| 182 |
+
|
| 183 |
+
max_length = max(max_length, length)
|
| 184 |
+
max_channels = max(max_channels, channels)
|
| 185 |
+
print("Форма сигнала: {} частота дискретизации: {}".format(wav.shape, sr))
|
| 186 |
+
|
| 187 |
+
# Второй проход: обработка и выравнивание файлов
|
| 188 |
+
for f in files:
|
| 189 |
+
wav, current_sr, _ = audio.read(i=f, sr=None, mono=False)
|
| 190 |
+
|
| 191 |
+
# Обработка каналов
|
| 192 |
+
if wav.ndim == 1:
|
| 193 |
+
# Моно -> стерео
|
| 194 |
+
wav = np.vstack([wav, wav])
|
| 195 |
+
elif wav.shape[0] == 1:
|
| 196 |
+
# Один канал -> стерео
|
| 197 |
+
wav = np.vstack([wav[0], wav[0]])
|
| 198 |
+
elif wav.shape[0] > 2:
|
| 199 |
+
# Более 2 каналов -> берем первые два
|
| 200 |
+
wav = wav[:2, :]
|
| 201 |
+
|
| 202 |
+
# Выравнивание длины
|
| 203 |
+
if wav.shape[1] < max_length:
|
| 204 |
+
pad_width = ((0, 0), (0, max_length - wav.shape[1]))
|
| 205 |
+
wav = np.pad(wav, pad_width, mode="constant")
|
| 206 |
+
elif wav.shape[1] > max_length:
|
| 207 |
+
wav = wav[:, :max_length]
|
| 208 |
+
|
| 209 |
+
data.append(wav)
|
| 210 |
+
|
| 211 |
+
data = np.array(data)
|
| 212 |
+
res = average_waveforms(data, weights, ensemble_type)
|
| 213 |
+
print("Форма результата: {}".format(res.shape))
|
| 214 |
+
|
| 215 |
+
output_wav = f"{output}_orig.wav"
|
| 216 |
+
output = f"{output}.{out_format}"
|
| 217 |
+
|
| 218 |
+
output = audio.write(o=output, array=res.T, sr=sr, of=out_format, br="320k")
|
| 219 |
+
if add_wav:
|
| 220 |
+
output_wav = audio.write(o=output_wav, array=res.T, sr=sr, of="wav")
|
| 221 |
+
return output, output_wav
|
| 222 |
+
|
| 223 |
+
else:
|
| 224 |
+
return output
|
mvsepless/infer.py
ADDED
|
@@ -0,0 +1,623 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import json
|
| 4 |
+
import argparse
|
| 5 |
+
import time
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
import gc
|
| 8 |
+
import glob
|
| 9 |
+
import yaml
|
| 10 |
+
import torch
|
| 11 |
+
import numpy as np
|
| 12 |
+
import soundfile as sf
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
|
| 15 |
+
from typing import Literal
|
| 16 |
+
|
| 17 |
+
from audio import Audio
|
| 18 |
+
from namer import Namer
|
| 19 |
+
|
| 20 |
+
namer = Namer()
|
| 21 |
+
audio = Audio()
|
| 22 |
+
|
| 23 |
+
from infer_utils import (
|
| 24 |
+
prefer_target_instrument,
|
| 25 |
+
demix,
|
| 26 |
+
get_model_from_config
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def normalize_peak(audio, peak):
|
| 31 |
+
current_peak = np.max(np.abs(audio))
|
| 32 |
+
if current_peak == 0:
|
| 33 |
+
return audio # избегаем деления на ноль
|
| 34 |
+
scale_factor = peak / current_peak
|
| 35 |
+
return audio * scale_factor
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
gc.enable()
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def cleanup_model(model):
|
| 42 |
+
try:
|
| 43 |
+
if isinstance(model, torch.nn.DataParallel):
|
| 44 |
+
model = model.module
|
| 45 |
+
|
| 46 |
+
model.to("cpu")
|
| 47 |
+
|
| 48 |
+
for name, param in list(model.named_parameters()):
|
| 49 |
+
del param
|
| 50 |
+
for name, buf in list(model.named_buffers()):
|
| 51 |
+
del buf
|
| 52 |
+
|
| 53 |
+
del model
|
| 54 |
+
|
| 55 |
+
if torch.cuda.is_available():
|
| 56 |
+
torch.cuda.empty_cache()
|
| 57 |
+
torch.cuda.ipc_collect()
|
| 58 |
+
|
| 59 |
+
gc.collect()
|
| 60 |
+
sys.stdout.write(json.dumps({"cleanup": "Модель выгружена из памяти"}, ensure_ascii=False) + '\n')
|
| 61 |
+
sys.stdout.flush()
|
| 62 |
+
except Exception as e:
|
| 63 |
+
sys.stdout.write(json.dumps({"error": f"Ошибка при выгрузке модели: {str(e)}"}, ensure_ascii=False) + '\n')
|
| 64 |
+
sys.stdout.flush()
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def once_inference(
|
| 68 |
+
path: str = None,
|
| 69 |
+
model: any = None,
|
| 70 |
+
config: any = None,
|
| 71 |
+
device: any = None,
|
| 72 |
+
model_type: str = None,
|
| 73 |
+
extract_instrumental: bool = False,
|
| 74 |
+
detailed_pbar: bool = False,
|
| 75 |
+
output_format: Literal[
|
| 76 |
+
"mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "aiff"
|
| 77 |
+
] = "mp3",
|
| 78 |
+
output_bitrate: str = "320k",
|
| 79 |
+
use_tta: bool = False,
|
| 80 |
+
verbose: bool = False,
|
| 81 |
+
model_name: str = None,
|
| 82 |
+
sample_rate: int = 44100,
|
| 83 |
+
instruments: list = [],
|
| 84 |
+
store_dir: str = None,
|
| 85 |
+
template: str = None,
|
| 86 |
+
selected_instruments: list = [],
|
| 87 |
+
model_id: int = 0,
|
| 88 |
+
):
|
| 89 |
+
results = []
|
| 90 |
+
sys.stdout.write(json.dumps({"reading": path}, ensure_ascii=False) + '\n')
|
| 91 |
+
sys.stdout.flush()
|
| 92 |
+
sys.stdout.write(json.dumps({"selected_stems": selected_instruments}, ensure_ascii=False) + '\n')
|
| 93 |
+
sys.stdout.flush()
|
| 94 |
+
sys.stdout.write(json.dumps({"stems": list(instruments)}, ensure_ascii=False) + '\n')
|
| 95 |
+
sys.stdout.flush()
|
| 96 |
+
|
| 97 |
+
if config.training.target_instrument is not None:
|
| 98 |
+
sys.stdout.write(json.dumps({"target_instrument": config.training.target_instrument}, ensure_ascii=False) + '\n')
|
| 99 |
+
sys.stdout.flush()
|
| 100 |
+
|
| 101 |
+
try:
|
| 102 |
+
mix, sr, _ = audio.read(i=path, sr=sample_rate, mono=False)
|
| 103 |
+
except Exception as e:
|
| 104 |
+
error_msg = f"Не удалось прочитать аудио: {path}\nОшибка: {e}"
|
| 105 |
+
sys.stdout.write(json.dumps({"error": error_msg}, ensure_ascii=False) + '\n')
|
| 106 |
+
sys.stdout.flush()
|
| 107 |
+
return results
|
| 108 |
+
|
| 109 |
+
mix_orig = mix.copy()
|
| 110 |
+
|
| 111 |
+
mean = std = None
|
| 112 |
+
if config.inference.get("normalize", False):
|
| 113 |
+
mono = mix.mean(0)
|
| 114 |
+
mean = mono.mean()
|
| 115 |
+
std = mono.std()
|
| 116 |
+
mix = (mix - mean) / std
|
| 117 |
+
|
| 118 |
+
if use_tta:
|
| 119 |
+
track_proc_list = [mix.copy(), mix[::-1].copy(), -1.0 * mix.copy()]
|
| 120 |
+
else:
|
| 121 |
+
track_proc_list = [mix.copy()]
|
| 122 |
+
full_result = []
|
| 123 |
+
for m in track_proc_list:
|
| 124 |
+
try:
|
| 125 |
+
waveforms = demix(
|
| 126 |
+
config, model, m, device, pbar=detailed_pbar, model_type=model_type
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
full_result.append(waveforms)
|
| 130 |
+
except Exception as e:
|
| 131 |
+
sys.stdout.write(json.dumps({"error": f"Ошибка при демиксе: {e}"}, ensure_ascii=False) + '\n')
|
| 132 |
+
sys.stdout.flush()
|
| 133 |
+
del m
|
| 134 |
+
gc.collect()
|
| 135 |
+
|
| 136 |
+
if not full_result:
|
| 137 |
+
sys.stdout.write(json.dumps({"error": "Пустой результат демикса."}, ensure_ascii=False) + '\n')
|
| 138 |
+
sys.stdout.flush()
|
| 139 |
+
return results
|
| 140 |
+
|
| 141 |
+
waveforms = full_result[0]
|
| 142 |
+
for i in range(1, len(full_result)):
|
| 143 |
+
d = full_result[i]
|
| 144 |
+
for el in d:
|
| 145 |
+
if i == 2:
|
| 146 |
+
waveforms[el] += -1.0 * d[el]
|
| 147 |
+
elif i == 1:
|
| 148 |
+
waveforms[el] += d[el][::-1].copy()
|
| 149 |
+
else:
|
| 150 |
+
waveforms[el] += d[el]
|
| 151 |
+
for el in waveforms:
|
| 152 |
+
waveforms[el] /= len(full_result)
|
| 153 |
+
|
| 154 |
+
if (
|
| 155 |
+
extract_instrumental and config.training.target_instrument is not None
|
| 156 |
+
): # Если включен "Extract Instrumental / Извлечь инструментал" и найден целевой инструмент
|
| 157 |
+
second_stem = [
|
| 158 |
+
s
|
| 159 |
+
for s in config.training.instruments
|
| 160 |
+
if s != config.training.target_instrument
|
| 161 |
+
]
|
| 162 |
+
if second_stem:
|
| 163 |
+
second_stem_key = second_stem[0]
|
| 164 |
+
if second_stem_key not in instruments:
|
| 165 |
+
instruments.append(second_stem_key)
|
| 166 |
+
waveforms[second_stem_key] = mix_orig - waveforms[instruments[0]]
|
| 167 |
+
|
| 168 |
+
elif (
|
| 169 |
+
extract_instrumental
|
| 170 |
+
and selected_instruments
|
| 171 |
+
and config.training.target_instrument is None
|
| 172 |
+
): # Если включен "Extract Instrumental / Извлечь инструментал" и выбраны инструменты, то создаются стемы "inverted -" и "inverted +" (если не найден целевого инструмент)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
all_instruments = config.training.instruments
|
| 176 |
+
if len(all_instruments) > 2:
|
| 177 |
+
|
| 178 |
+
waveforms["inverted -"] = mix_orig.copy()
|
| 179 |
+
for instr in instruments:
|
| 180 |
+
if instr in waveforms:
|
| 181 |
+
waveforms["inverted -"] -= waveforms[
|
| 182 |
+
instr
|
| 183 |
+
] # стем "inverted -": вычитание выбранного стема из оригинального сигнала (не всегда хорошо)
|
| 184 |
+
|
| 185 |
+
if "inverted -" not in instruments:
|
| 186 |
+
instruments.append("inverted -")
|
| 187 |
+
|
| 188 |
+
unselected_stems = [s for s in all_instruments if s not in selected_instruments]
|
| 189 |
+
if unselected_stems:
|
| 190 |
+
waveforms["inverted +"] = np.zeros_like(mix_orig)
|
| 191 |
+
for stem in unselected_stems:
|
| 192 |
+
if stem in waveforms:
|
| 193 |
+
waveforms["inverted +"] += waveforms[
|
| 194 |
+
stem
|
| 195 |
+
] # стем "inverted +": сложение не выбранных инструментов в один стем
|
| 196 |
+
if "inverted +" not in instruments:
|
| 197 |
+
instruments.append("inverted +")
|
| 198 |
+
|
| 199 |
+
peak = np.max(np.abs(waveforms["inverted -"]))
|
| 200 |
+
waveforms["inverted +"] = normalize_peak(waveforms["inverted +"], peak)
|
| 201 |
+
|
| 202 |
+
elif (
|
| 203 |
+
extract_instrumental
|
| 204 |
+
and not selected_instruments
|
| 205 |
+
and config.training.target_instrument is None
|
| 206 |
+
and (
|
| 207 |
+
all(
|
| 208 |
+
instr in config.training.instruments
|
| 209 |
+
for instr in ["bass", "drums", "other", "vocals"]
|
| 210 |
+
)
|
| 211 |
+
or all(
|
| 212 |
+
instr in config.training.instruments
|
| 213 |
+
for instr in ["bass", "drums", "other", "vocals", "piano", "guitar"]
|
| 214 |
+
)
|
| 215 |
+
)
|
| 216 |
+
):
|
| 217 |
+
|
| 218 |
+
waveforms["instrumental -"] = mix_orig.copy()
|
| 219 |
+
waveforms["instrumental -"] -= waveforms[
|
| 220 |
+
"vocals"
|
| 221 |
+
] # стем "inverted -": вычитание выбранного стема из оригинального сигнала (не всегда хорошо)
|
| 222 |
+
|
| 223 |
+
if "instrumental -" not in instruments:
|
| 224 |
+
instruments.append("instrumental -")
|
| 225 |
+
|
| 226 |
+
all_instruments = config.training.instruments
|
| 227 |
+
non_vocal_stems = [s for s in all_instruments if s not in ["vocals"]]
|
| 228 |
+
if non_vocal_stems:
|
| 229 |
+
waveforms["instrumental +"] = np.zeros_like(mix_orig)
|
| 230 |
+
for stem in non_vocal_stems:
|
| 231 |
+
if stem in waveforms:
|
| 232 |
+
waveforms["instrumental +"] += waveforms[
|
| 233 |
+
stem
|
| 234 |
+
] # стем "inverted +": сложение не выбранных инструментов в один стем
|
| 235 |
+
if "instrumental +" not in instruments:
|
| 236 |
+
instruments.append("instrumental +")
|
| 237 |
+
|
| 238 |
+
peak = np.max(np.abs(waveforms["instrumental -"]))
|
| 239 |
+
waveforms["instrumental +"] = normalize_peak(waveforms["instrumental +"], peak)
|
| 240 |
+
|
| 241 |
+
template = namer.sanitize(template)
|
| 242 |
+
template = namer.dedup_template(template, keys=["NAME", "MODEL", "STEM", "ID"])
|
| 243 |
+
template = namer.short(template, length=40)
|
| 244 |
+
|
| 245 |
+
for instr in instruments:
|
| 246 |
+
try:
|
| 247 |
+
estimates = waveforms[instr].T
|
| 248 |
+
if mean is not None and std is not None:
|
| 249 |
+
estimates = estimates * std + mean
|
| 250 |
+
|
| 251 |
+
file_name = os.path.splitext(os.path.basename(path))[0]
|
| 252 |
+
file_name_shorted = namer.short_input_name_template(template, STEM=instr, MODEL=model_name, ID=model_id, NAME=file_name)
|
| 253 |
+
custom_name = namer.template(
|
| 254 |
+
template, STEM=instr, MODEL=model_name, ID=model_id, NAME=file_name_shorted
|
| 255 |
+
)
|
| 256 |
+
output_path = os.path.join(store_dir, f"{custom_name}.{output_format}")
|
| 257 |
+
|
| 258 |
+
sys.stdout.write(json.dumps({"writing": output_path}, ensure_ascii=False) + '\n')
|
| 259 |
+
sys.stdout.flush()
|
| 260 |
+
|
| 261 |
+
output_path = audio.write(
|
| 262 |
+
o=output_path, array=estimates, sr=sr, of=output_format, br=output_bitrate
|
| 263 |
+
) # запись стема в аудио файл с помощью универсальной функции
|
| 264 |
+
|
| 265 |
+
results.append(
|
| 266 |
+
(instr, output_path)
|
| 267 |
+
) # запись информации о разделении: (название стема, путь к файлу)
|
| 268 |
+
del estimates
|
| 269 |
+
except Exception as e:
|
| 270 |
+
sys.stdout.write(json.dumps({"error": f"Ошибка при обработке {instr}: {e}"}, ensure_ascii=False) + '\n')
|
| 271 |
+
sys.stdout.flush()
|
| 272 |
+
gc.collect()
|
| 273 |
+
|
| 274 |
+
del mix, mix_orig, waveforms, full_result
|
| 275 |
+
gc.collect()
|
| 276 |
+
|
| 277 |
+
return results
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def run_inference(
|
| 281 |
+
model: any = None,
|
| 282 |
+
config: any = None,
|
| 283 |
+
input_path: str = None,
|
| 284 |
+
store_dir: str = None,
|
| 285 |
+
device: any = None,
|
| 286 |
+
model_type: str = None,
|
| 287 |
+
extract_instrumental: bool = False,
|
| 288 |
+
disable_detailed_pbar: bool = False,
|
| 289 |
+
output_format: Literal[
|
| 290 |
+
"mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "aiff"
|
| 291 |
+
] = "mp3",
|
| 292 |
+
output_bitrate: str = "320k",
|
| 293 |
+
use_tta: bool = False,
|
| 294 |
+
verbose: bool = False,
|
| 295 |
+
model_name: str = None,
|
| 296 |
+
template: str = "NAME_STEM",
|
| 297 |
+
selected_instruments: list = [],
|
| 298 |
+
model_id: int = 0,
|
| 299 |
+
):
|
| 300 |
+
start_time = time.time()
|
| 301 |
+
if model_type != "vr":
|
| 302 |
+
model.eval()
|
| 303 |
+
sample_rate = 44100
|
| 304 |
+
if "sample_rate" in config.audio:
|
| 305 |
+
sample_rate = config.audio["sample_rate"]
|
| 306 |
+
|
| 307 |
+
instruments = prefer_target_instrument(config)
|
| 308 |
+
|
| 309 |
+
if config.training.target_instrument is not None:
|
| 310 |
+
sys.stdout.write(json.dumps({"info": "Целевой инструмент найден в конфигурации модели. Выбранные стемы будут проигнорированы."}, ensure_ascii=False) + '\n')
|
| 311 |
+
sys.stdout.flush()
|
| 312 |
+
else:
|
| 313 |
+
if selected_instruments is not None and selected_instruments != []:
|
| 314 |
+
instruments = [
|
| 315 |
+
instr for instr in instruments if instr in selected_instruments
|
| 316 |
+
]
|
| 317 |
+
if verbose:
|
| 318 |
+
sys.stdout.write(json.dumps({"selected_stems": instruments}, ensure_ascii=False) + '\n')
|
| 319 |
+
sys.stdout.flush()
|
| 320 |
+
|
| 321 |
+
os.makedirs(store_dir, exist_ok=True)
|
| 322 |
+
|
| 323 |
+
detailed_pbar = not disable_detailed_pbar
|
| 324 |
+
|
| 325 |
+
results = once_inference(
|
| 326 |
+
path=input_path,
|
| 327 |
+
model=model,
|
| 328 |
+
config=config,
|
| 329 |
+
device=device,
|
| 330 |
+
model_type=model_type,
|
| 331 |
+
extract_instrumental=extract_instrumental,
|
| 332 |
+
detailed_pbar=detailed_pbar,
|
| 333 |
+
output_format=output_format,
|
| 334 |
+
output_bitrate=output_bitrate,
|
| 335 |
+
use_tta=use_tta,
|
| 336 |
+
verbose=verbose,
|
| 337 |
+
model_name=model_name,
|
| 338 |
+
sample_rate=sample_rate,
|
| 339 |
+
instruments=instruments,
|
| 340 |
+
store_dir=store_dir,
|
| 341 |
+
template=template,
|
| 342 |
+
selected_instruments=selected_instruments,
|
| 343 |
+
model_id=model_id,
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
time.sleep(1)
|
| 347 |
+
time_taken = time.time() - start_time
|
| 348 |
+
sys.stdout.write(json.dumps({"time": f"{time_taken:.2f} сек."}, ensure_ascii=False) + '\n')
|
| 349 |
+
sys.stdout.flush()
|
| 350 |
+
sys.stdout.write(json.dumps({"done": results}, ensure_ascii=False) + '\n')
|
| 351 |
+
sys.stdout.flush()
|
| 352 |
+
return results
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def load_model(model_type, config_path, start_check_point, device_ids, force_cpu=False):
|
| 356 |
+
device = "cpu"
|
| 357 |
+
if force_cpu:
|
| 358 |
+
device = "cpu"
|
| 359 |
+
elif torch.cuda.is_available():
|
| 360 |
+
sys.stdout.write(json.dumps({"info": "Разделение выполняется на ядрах CUDA. Для выполнения на процессоре установите force_cpu=True."}, ensure_ascii=False) + '\n')
|
| 361 |
+
sys.stdout.flush()
|
| 362 |
+
device = "cuda"
|
| 363 |
+
|
| 364 |
+
if device_ids is None:
|
| 365 |
+
device = "cuda:0"
|
| 366 |
+
elif isinstance(device_ids, (list, tuple)):
|
| 367 |
+
device = f"cuda:{device_ids[0]}" if device_ids else "cuda:0"
|
| 368 |
+
elif isinstance(device_ids, bool):
|
| 369 |
+
device = "cuda:0"
|
| 370 |
+
else:
|
| 371 |
+
device = f"cuda:{int(device_ids)}"
|
| 372 |
+
elif torch.backends.mps.is_available():
|
| 373 |
+
device = "mps"
|
| 374 |
+
|
| 375 |
+
sys.stdout.write(json.dumps({"device": device}, ensure_ascii=False) + '\n')
|
| 376 |
+
sys.stdout.flush()
|
| 377 |
+
|
| 378 |
+
model_load_start_time = time.time()
|
| 379 |
+
torch.backends.cudnn.benchmark = True
|
| 380 |
+
|
| 381 |
+
model, config = get_model_from_config(model_type, config_path)
|
| 382 |
+
|
| 383 |
+
if model_type == "vr":
|
| 384 |
+
model.load_checkpoint(start_check_point, device)
|
| 385 |
+
model.settings(enable_post_process=False,
|
| 386 |
+
post_process_threshold=config.inference.post_process_threshold,
|
| 387 |
+
batch_size=config.inference.batch_size,
|
| 388 |
+
window_size=config.inference.window_size,
|
| 389 |
+
high_end_process=config.inference.high_end_process,
|
| 390 |
+
primary_stem=config.training.instruments[0],
|
| 391 |
+
secondary_stem=config.training.instruments[1]
|
| 392 |
+
)
|
| 393 |
+
return model, config, device
|
| 394 |
+
|
| 395 |
+
elif model_type == "mdxnet":
|
| 396 |
+
if start_check_point != "":
|
| 397 |
+
sys.stdout.write(json.dumps({"checkpoint": start_check_point}) + '\n')
|
| 398 |
+
sys.stdout.flush()
|
| 399 |
+
model.init_onnx_session(start_check_point, device)
|
| 400 |
+
|
| 401 |
+
return model, config, device
|
| 402 |
+
|
| 403 |
+
else:
|
| 404 |
+
if start_check_point != "":
|
| 405 |
+
sys.stdout.write(json.dumps({"checkpoint": start_check_point}) + '\n')
|
| 406 |
+
sys.stdout.flush()
|
| 407 |
+
|
| 408 |
+
if model_type in ["htdemucs", "apollo"]:
|
| 409 |
+
state_dict = torch.load(
|
| 410 |
+
start_check_point, map_location=device, weights_only=False
|
| 411 |
+
)
|
| 412 |
+
if "state" in state_dict:
|
| 413 |
+
state_dict = state_dict["state"]
|
| 414 |
+
if "state_dict" in state_dict:
|
| 415 |
+
state_dict = state_dict["state_dict"]
|
| 416 |
+
else:
|
| 417 |
+
try:
|
| 418 |
+
state_dict = torch.load(
|
| 419 |
+
start_check_point, map_location=device, weights_only=True
|
| 420 |
+
)
|
| 421 |
+
except torch.serialization.pickle.UnpicklingError:
|
| 422 |
+
with torch.serialization.safe_globals([torch._C._nn.gelu]):
|
| 423 |
+
state_dict = torch.load(
|
| 424 |
+
start_check_point, map_location=device, weights_only=True
|
| 425 |
+
)
|
| 426 |
+
try:
|
| 427 |
+
model.load_state_dict(state_dict)
|
| 428 |
+
except RuntimeError:
|
| 429 |
+
model.load_state_dict(state_dict, strict=False)
|
| 430 |
+
|
| 431 |
+
sys.stdout.write(json.dumps({"stems": list(config.training.instruments)}, ensure_ascii=False) + '\n')
|
| 432 |
+
sys.stdout.flush()
|
| 433 |
+
|
| 434 |
+
if (
|
| 435 |
+
isinstance(device_ids, (list, tuple))
|
| 436 |
+
and len(device_ids) > 1
|
| 437 |
+
and not force_cpu
|
| 438 |
+
and torch.cuda.is_available()
|
| 439 |
+
):
|
| 440 |
+
model = nn.DataParallel(model, device_ids=[int(d) for d in device_ids])
|
| 441 |
+
|
| 442 |
+
model = model.to(device)
|
| 443 |
+
|
| 444 |
+
load_time = time.time() - model_load_start_time
|
| 445 |
+
|
| 446 |
+
sys.stdout.write(json.dumps({"model_load_time": f"{load_time:.2f} сек."}, ensure_ascii=False) + '\n')
|
| 447 |
+
sys.stdout.flush()
|
| 448 |
+
|
| 449 |
+
return model, config, device
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def mvsep_offline(
|
| 453 |
+
input_path,
|
| 454 |
+
store_dir,
|
| 455 |
+
model_type,
|
| 456 |
+
config_path,
|
| 457 |
+
start_check_point,
|
| 458 |
+
extract_instrumental,
|
| 459 |
+
output_format,
|
| 460 |
+
output_bitrate,
|
| 461 |
+
model_name,
|
| 462 |
+
template,
|
| 463 |
+
device_ids=None,
|
| 464 |
+
disable_detailed_pbar=False,
|
| 465 |
+
use_tta=False,
|
| 466 |
+
force_cpu=False,
|
| 467 |
+
verbose=False,
|
| 468 |
+
selected_instruments=None,
|
| 469 |
+
model_id=0,
|
| 470 |
+
):
|
| 471 |
+
model, config, device = load_model(
|
| 472 |
+
model_type, config_path, start_check_point, device_ids, force_cpu
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
results = run_inference(
|
| 476 |
+
model=model,
|
| 477 |
+
config=config,
|
| 478 |
+
input_path=input_path,
|
| 479 |
+
store_dir=store_dir,
|
| 480 |
+
device=device,
|
| 481 |
+
model_type=model_type,
|
| 482 |
+
extract_instrumental=extract_instrumental,
|
| 483 |
+
disable_detailed_pbar=disable_detailed_pbar,
|
| 484 |
+
output_format=output_format,
|
| 485 |
+
output_bitrate=output_bitrate,
|
| 486 |
+
use_tta=use_tta,
|
| 487 |
+
verbose=verbose,
|
| 488 |
+
model_name=model_name,
|
| 489 |
+
template=template,
|
| 490 |
+
selected_instruments=selected_instruments,
|
| 491 |
+
model_id=model_id,
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
if model_type != "vr":
|
| 495 |
+
cleanup_model(model)
|
| 496 |
+
del config
|
| 497 |
+
gc.collect()
|
| 498 |
+
return results
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
def parse_args():
|
| 502 |
+
parser = argparse.ArgumentParser(
|
| 503 |
+
description="Модифицированный Music-Source-Separation-Training для разделения аудио на источники"
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
# Обязательные аргументы
|
| 507 |
+
parser.add_argument("--input", type=str, help="Путь к входному файлу или папке")
|
| 508 |
+
parser.add_argument(
|
| 509 |
+
"--store_dir", type=str, required=True, help="Путь для сохранения результатов"
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
# Основные параметры модели
|
| 513 |
+
parser.add_argument(
|
| 514 |
+
"--model_type",
|
| 515 |
+
type=str,
|
| 516 |
+
default="htdemucs",
|
| 517 |
+
choices=[
|
| 518 |
+
"mel_band_roformer",
|
| 519 |
+
"bs_roformer",
|
| 520 |
+
"mdx23c",
|
| 521 |
+
"scnet",
|
| 522 |
+
"htdemucs",
|
| 523 |
+
"bandit",
|
| 524 |
+
"bandit_v2",
|
| 525 |
+
"mdxnet",
|
| 526 |
+
"vr"
|
| 527 |
+
],
|
| 528 |
+
help="Тип модели (по умолчанию: htdemucs)",
|
| 529 |
+
)
|
| 530 |
+
parser.add_argument(
|
| 531 |
+
"--config_path",
|
| 532 |
+
type=str,
|
| 533 |
+
required=True,
|
| 534 |
+
help="Путь к конфигурационному файлу модели",
|
| 535 |
+
)
|
| 536 |
+
parser.add_argument(
|
| 537 |
+
"--start_check_point", type=str, required=True, help="Путь к чекпоинту модели"
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
# Параметры вывода
|
| 541 |
+
parser.add_argument(
|
| 542 |
+
"--output_format",
|
| 543 |
+
type=str,
|
| 544 |
+
default="wav",
|
| 545 |
+
choices=audio.output_formats,
|
| 546 |
+
help="Формат выходных файлов",
|
| 547 |
+
)
|
| 548 |
+
parser.add_argument(
|
| 549 |
+
"--output_bitrate", type=str, required=True, help="Битрейт выходного файла"
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
parser.add_argument(
|
| 553 |
+
"--selected_instruments",
|
| 554 |
+
nargs="+",
|
| 555 |
+
help="Список стемов для сохранения (например: vocals drums)",
|
| 556 |
+
)
|
| 557 |
+
parser.add_argument(
|
| 558 |
+
"--extract_instrumental",
|
| 559 |
+
action="store_true",
|
| 560 |
+
help="Извлечь инструментальную версию",
|
| 561 |
+
)
|
| 562 |
+
parser.add_argument(
|
| 563 |
+
"--template",
|
| 564 |
+
type=str,
|
| 565 |
+
default="NAME_STEM",
|
| 566 |
+
help="Шаблон для имен выходных файлов",
|
| 567 |
+
)
|
| 568 |
+
parser.add_argument(
|
| 569 |
+
"--model_name",
|
| 570 |
+
type=str,
|
| 571 |
+
default="model",
|
| 572 |
+
help="Имя модели для шаблона имен файлов",
|
| 573 |
+
)
|
| 574 |
+
parser.add_argument("-m_id", "--model_id", type=int, required=True, help="Model ID")
|
| 575 |
+
parser.add_argument(
|
| 576 |
+
"--device_ids", nargs="+", help="ID GPU устройств для использования"
|
| 577 |
+
)
|
| 578 |
+
parser.add_argument(
|
| 579 |
+
"--force_cpu", action="store_true", help="Принудительно использовать CPU"
|
| 580 |
+
)
|
| 581 |
+
parser.add_argument(
|
| 582 |
+
"--use_tta", action="store_true", help="Использовать тестовую аугментацию"
|
| 583 |
+
)
|
| 584 |
+
parser.add_argument(
|
| 585 |
+
"--disable_detailed_pbar",
|
| 586 |
+
action="store_true",
|
| 587 |
+
help="Отключить детальный прогресс-бар",
|
| 588 |
+
)
|
| 589 |
+
parser.add_argument("--verbose", action="store_true", help="Подробный вывод")
|
| 590 |
+
|
| 591 |
+
return parser.parse_args()
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
def main():
|
| 595 |
+
args = parse_args()
|
| 596 |
+
|
| 597 |
+
device_ids = None
|
| 598 |
+
if args.device_ids:
|
| 599 |
+
device_ids = [int(x) for x in args.device_ids]
|
| 600 |
+
|
| 601 |
+
results = mvsep_offline(
|
| 602 |
+
input_path=args.input,
|
| 603 |
+
store_dir=args.store_dir,
|
| 604 |
+
model_type=args.model_type,
|
| 605 |
+
config_path=args.config_path,
|
| 606 |
+
start_check_point=args.start_check_point,
|
| 607 |
+
extract_instrumental=args.extract_instrumental,
|
| 608 |
+
output_format=args.output_format,
|
| 609 |
+
output_bitrate=args.output_bitrate,
|
| 610 |
+
model_name=args.model_name,
|
| 611 |
+
template=args.template,
|
| 612 |
+
device_ids=device_ids,
|
| 613 |
+
disable_detailed_pbar=args.disable_detailed_pbar,
|
| 614 |
+
use_tta=args.use_tta,
|
| 615 |
+
force_cpu=args.force_cpu,
|
| 616 |
+
verbose=args.verbose,
|
| 617 |
+
selected_instruments=args.selected_instruments,
|
| 618 |
+
model_id=args.model_id,
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
if __name__ == "__main__":
|
| 623 |
+
main()
|
mvsepless/infer_utils.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding: utf-8
|
| 2 |
+
__author__ = "Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/"
|
| 3 |
+
|
| 4 |
+
import sys
|
| 5 |
+
import json
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import yaml
|
| 10 |
+
import librosa
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from ml_collections import ConfigDict
|
| 13 |
+
from omegaconf import OmegaConf
|
| 14 |
+
from typing import Dict, List, Tuple, Any, List, Optional
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def load_config(model_type: str, config_path: str) -> Any:
|
| 18 |
+
"""
|
| 19 |
+
Load the configuration from the specified path based on the model type.
|
| 20 |
+
"""
|
| 21 |
+
try:
|
| 22 |
+
with open(config_path, "r") as f:
|
| 23 |
+
if model_type == "htdemucs":
|
| 24 |
+
config = OmegaConf.load(config_path)
|
| 25 |
+
else:
|
| 26 |
+
config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
|
| 27 |
+
return config
|
| 28 |
+
except FileNotFoundError:
|
| 29 |
+
raise FileNotFoundError(f"Configuration file not found at {config_path}")
|
| 30 |
+
except Exception as e:
|
| 31 |
+
raise ValueError(f"Error loading configuration: {e}")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_model_from_config(model_type: str, config_path: str) -> Tuple:
|
| 35 |
+
"""
|
| 36 |
+
Load the model specified by the model type and configuration file.
|
| 37 |
+
"""
|
| 38 |
+
config = load_config(model_type, config_path)
|
| 39 |
+
|
| 40 |
+
if model_type == "mdx23c":
|
| 41 |
+
from models.mdx23c_tfc_tdf_v3 import TFC_TDF_net
|
| 42 |
+
|
| 43 |
+
model = TFC_TDF_net(config)
|
| 44 |
+
|
| 45 |
+
elif model_type == "mdxnet":
|
| 46 |
+
from models.mdx_net import MDXNet
|
| 47 |
+
|
| 48 |
+
model = MDXNet(**dict(config.model))
|
| 49 |
+
|
| 50 |
+
# В функции get_model_from_config добавьте:
|
| 51 |
+
|
| 52 |
+
elif model_type == "vr":
|
| 53 |
+
from models.vr_arch import VRNet
|
| 54 |
+
# Передаем instruments из config.training в модель
|
| 55 |
+
model = VRNet(**dict(config.model))
|
| 56 |
+
|
| 57 |
+
elif model_type == "htdemucs":
|
| 58 |
+
from models.demucs4ht import get_model
|
| 59 |
+
|
| 60 |
+
model = get_model(config)
|
| 61 |
+
|
| 62 |
+
elif model_type == "mel_band_roformer":
|
| 63 |
+
if hasattr(config, "windowed"): # Это не нарушает совместимость со обычными моделями на Mel-Band Roformer
|
| 64 |
+
from models.windowed_roformer.model import MelBandRoformerWSA
|
| 65 |
+
|
| 66 |
+
model = MelBandRoformerWSA(**dict(config.model))
|
| 67 |
+
|
| 68 |
+
else:
|
| 69 |
+
from models.bs_roformer import MelBandRoformer
|
| 70 |
+
|
| 71 |
+
model = MelBandRoformer(**dict(config.model))
|
| 72 |
+
|
| 73 |
+
elif model_type == "bs_roformer":
|
| 74 |
+
if hasattr(config.model, "use_shared_bias"):
|
| 75 |
+
from models.bs_roformer import BSRoformer_SW
|
| 76 |
+
|
| 77 |
+
model = BSRoformer_SW(**dict(config.model))
|
| 78 |
+
elif hasattr(config.model, "fno"):
|
| 79 |
+
from models.bs_roformer import BSRoformer_FNO
|
| 80 |
+
|
| 81 |
+
model = BSRoformer_FNO(**dict(config.model))
|
| 82 |
+
else:
|
| 83 |
+
from models.bs_roformer import BSRoformer
|
| 84 |
+
|
| 85 |
+
model = BSRoformer(**dict(config.model))
|
| 86 |
+
|
| 87 |
+
elif model_type == "bandit":
|
| 88 |
+
from models.bandit.core.model import MultiMaskMultiSourceBandSplitRNNSimple
|
| 89 |
+
|
| 90 |
+
model = MultiMaskMultiSourceBandSplitRNNSimple(**config.model)
|
| 91 |
+
|
| 92 |
+
elif model_type == "bandit_v2":
|
| 93 |
+
from models.bandit_v2.bandit import Bandit
|
| 94 |
+
|
| 95 |
+
model = Bandit(**config.kwargs)
|
| 96 |
+
elif model_type == "scnet_unofficial":
|
| 97 |
+
from models.scnet_unofficial import SCNet
|
| 98 |
+
|
| 99 |
+
model = SCNet(**config.model)
|
| 100 |
+
elif model_type == "scnet":
|
| 101 |
+
from models.scnet import SCNet
|
| 102 |
+
|
| 103 |
+
model = SCNet(**config.model)
|
| 104 |
+
else:
|
| 105 |
+
raise ValueError(f"Unknown model type: {model_type}")
|
| 106 |
+
|
| 107 |
+
return model, config
|
| 108 |
+
|
| 109 |
+
def _getWindowingArray(window_size: int, fade_size: int) -> torch.Tensor:
|
| 110 |
+
"""
|
| 111 |
+
Generate a windowing array with a linear fade-in at the beginning and a fade-out at the end.
|
| 112 |
+
"""
|
| 113 |
+
fadein = torch.linspace(0, 1, fade_size)
|
| 114 |
+
fadeout = torch.linspace(1, 0, fade_size)
|
| 115 |
+
|
| 116 |
+
window = torch.ones(window_size)
|
| 117 |
+
window[-fade_size:] = fadeout
|
| 118 |
+
window[:fade_size] = fadein
|
| 119 |
+
return window
|
| 120 |
+
|
| 121 |
+
def demix_mdxnet(
|
| 122 |
+
config: Any,
|
| 123 |
+
model: Any,
|
| 124 |
+
mix: np.ndarray,
|
| 125 |
+
device: torch.device,
|
| 126 |
+
pbar: bool = False,
|
| 127 |
+
) -> Dict[str, np.ndarray]:
|
| 128 |
+
"""
|
| 129 |
+
MDX-Net specific demixing function с поддержкой overlap
|
| 130 |
+
"""
|
| 131 |
+
mix_tensor = torch.tensor(mix, dtype=torch.float32)
|
| 132 |
+
inv_mix_tensor = torch.tensor(-mix, dtype=torch.float32)
|
| 133 |
+
|
| 134 |
+
num_overlap = config.inference.num_overlap
|
| 135 |
+
denoise = config.inference.denoise
|
| 136 |
+
stem_name = model.primary_stem
|
| 137 |
+
if denoise:
|
| 138 |
+
processed_wav = model.process_wave(mix_tensor, device, num_overlap, pbar=pbar)
|
| 139 |
+
inv_processed_wav = model.process_wave(inv_mix_tensor, device, num_overlap, pbar=pbar)
|
| 140 |
+
result = processed_wav.cpu().numpy()
|
| 141 |
+
inv_result = inv_processed_wav.cpu().numpy()
|
| 142 |
+
result_separation = (result + -inv_result) * 0.5
|
| 143 |
+
else:
|
| 144 |
+
processed_wav = model.process_wave(mix_tensor, device, num_overlap, pbar=pbar)
|
| 145 |
+
result_separation = processed_wav.cpu().numpy()
|
| 146 |
+
|
| 147 |
+
result_separation = np.nan_to_num(result_separation, nan=0.0, posinf=0.0, neginf=0.0)
|
| 148 |
+
|
| 149 |
+
return {stem_name: result_separation} # Перемещаем на CPU для возврата
|
| 150 |
+
|
| 151 |
+
def demix_vr(
|
| 152 |
+
config: Any,
|
| 153 |
+
model: Any,
|
| 154 |
+
mix: np.ndarray,
|
| 155 |
+
device: torch.device,
|
| 156 |
+
pbar: bool = False,
|
| 157 |
+
) -> Dict[str, np.ndarray]:
|
| 158 |
+
"""
|
| 159 |
+
VR-specific demixing function that processes the entire audio at once
|
| 160 |
+
since VR architecture doesn't support chunk-based processing
|
| 161 |
+
"""
|
| 162 |
+
# Convert to tensor and add batch dimension
|
| 163 |
+
return model.demix(mix, config.audio.sample_rate, device, config.inference.aggression)
|
| 164 |
+
|
| 165 |
+
def demix_demucs(config, model, mix, device, pbar=False):
|
| 166 |
+
mix = torch.tensor(mix, dtype=torch.float32)
|
| 167 |
+
chunk_size = config.training.samplerate * config.training.segment
|
| 168 |
+
num_instruments = len(config.training.instruments)
|
| 169 |
+
num_overlap = config.inference.num_overlap
|
| 170 |
+
step = chunk_size // num_overlap
|
| 171 |
+
fade_size = chunk_size // 10 # Добавляем fade_size для оконной функции
|
| 172 |
+
windowing_array = _getWindowingArray(chunk_size, fade_size) # Создаём окно
|
| 173 |
+
|
| 174 |
+
batch_size = config.inference.batch_size
|
| 175 |
+
use_amp = getattr(config.training, "use_amp", True)
|
| 176 |
+
|
| 177 |
+
with torch.cuda.amp.autocast(enabled=use_amp):
|
| 178 |
+
with torch.inference_mode():
|
| 179 |
+
req_shape = (num_instruments,) + mix.shape
|
| 180 |
+
result = torch.zeros(req_shape, dtype=torch.float32)
|
| 181 |
+
counter = torch.zeros(req_shape, dtype=torch.float32)
|
| 182 |
+
|
| 183 |
+
i = 0
|
| 184 |
+
batch_data = []
|
| 185 |
+
batch_locations = []
|
| 186 |
+
|
| 187 |
+
while i < mix.shape[1]:
|
| 188 |
+
part = mix[:, i : i + chunk_size].to(device)
|
| 189 |
+
chunk_len = part.shape[-1]
|
| 190 |
+
pad_mode = "reflect" if chunk_len > chunk_size // 2 else "constant"
|
| 191 |
+
part = nn.functional.pad(
|
| 192 |
+
part, (0, chunk_size - chunk_len), mode=pad_mode, value=0
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
batch_data.append(part)
|
| 196 |
+
batch_locations.append((i, chunk_len))
|
| 197 |
+
i += step
|
| 198 |
+
|
| 199 |
+
if len(batch_data) >= batch_size or i >= mix.shape[1]:
|
| 200 |
+
arr = torch.stack(batch_data, dim=0)
|
| 201 |
+
x = model(arr)
|
| 202 |
+
|
| 203 |
+
window = windowing_array.clone()
|
| 204 |
+
if i - step == 0: # Первый чанк, без fade-in
|
| 205 |
+
window[:fade_size] = 1
|
| 206 |
+
elif i >= mix.shape[1]: # Последний чанк, без fade-out
|
| 207 |
+
window[-fade_size:] = 1
|
| 208 |
+
|
| 209 |
+
for j, (start, seg_len) in enumerate(batch_locations):
|
| 210 |
+
result[..., start : start + seg_len] += (
|
| 211 |
+
x[j, ..., :seg_len].cpu() * window[..., :seg_len]
|
| 212 |
+
)
|
| 213 |
+
counter[..., start : start + seg_len] += window[..., :seg_len]
|
| 214 |
+
|
| 215 |
+
# Output progress
|
| 216 |
+
processed = min(i, mix.shape[1])
|
| 217 |
+
total = mix.shape[1]
|
| 218 |
+
sys.stdout.write(json.dumps({"processing": {"processed": processed, "total": total}}) + '\n')
|
| 219 |
+
sys.stdout.flush()
|
| 220 |
+
|
| 221 |
+
batch_data.clear()
|
| 222 |
+
batch_locations.clear()
|
| 223 |
+
|
| 224 |
+
estimated_sources = result / counter
|
| 225 |
+
estimated_sources = estimated_sources.cpu().numpy()
|
| 226 |
+
np.nan_to_num(estimated_sources, copy=False, nan=0.0)
|
| 227 |
+
|
| 228 |
+
if num_instruments <= 1:
|
| 229 |
+
return estimated_sources
|
| 230 |
+
else:
|
| 231 |
+
instruments = config.training.instruments
|
| 232 |
+
return {k: v for k, v in zip(instruments, estimated_sources)}
|
| 233 |
+
|
| 234 |
+
def demix_generic(
|
| 235 |
+
config: ConfigDict,
|
| 236 |
+
model: torch.nn.Module,
|
| 237 |
+
mix: torch.Tensor,
|
| 238 |
+
device: torch.device,
|
| 239 |
+
pbar: bool = False,
|
| 240 |
+
) -> Dict[str, np.ndarray]:
|
| 241 |
+
"""
|
| 242 |
+
Generic demixing function for models that support chunk-based processing
|
| 243 |
+
"""
|
| 244 |
+
mix = torch.tensor(mix, dtype=torch.float32)
|
| 245 |
+
chunk_size = config.audio.chunk_size
|
| 246 |
+
instruments = prefer_target_instrument(config)
|
| 247 |
+
num_instruments = len(instruments)
|
| 248 |
+
num_overlap = config.inference.num_overlap
|
| 249 |
+
|
| 250 |
+
fade_size = chunk_size // 10
|
| 251 |
+
step = chunk_size // num_overlap
|
| 252 |
+
border = chunk_size - step
|
| 253 |
+
length_init = mix.shape[-1]
|
| 254 |
+
windowing_array = _getWindowingArray(chunk_size, fade_size)
|
| 255 |
+
|
| 256 |
+
# Add padding to handle edge artifacts
|
| 257 |
+
if length_init > 2 * border and border > 0:
|
| 258 |
+
mix = nn.functional.pad(mix, (border, border), mode="reflect")
|
| 259 |
+
|
| 260 |
+
batch_size = config.inference.batch_size
|
| 261 |
+
use_amp = getattr(config.training, "use_amp", True)
|
| 262 |
+
|
| 263 |
+
with torch.cuda.amp.autocast(enabled=use_amp):
|
| 264 |
+
with torch.inference_mode():
|
| 265 |
+
# Initialize result and counter tensors
|
| 266 |
+
req_shape = (num_instruments,) + mix.shape
|
| 267 |
+
result = torch.zeros(req_shape, dtype=torch.float32)
|
| 268 |
+
counter = torch.zeros(req_shape, dtype=torch.float32)
|
| 269 |
+
|
| 270 |
+
i = 0
|
| 271 |
+
batch_data = []
|
| 272 |
+
batch_locations = []
|
| 273 |
+
|
| 274 |
+
while i < mix.shape[1]:
|
| 275 |
+
# Extract chunk and apply padding if necessary
|
| 276 |
+
part = mix[:, i : i + chunk_size].to(device)
|
| 277 |
+
chunk_len = part.shape[-1]
|
| 278 |
+
|
| 279 |
+
pad_mode = "reflect" if chunk_len > chunk_size // 2 else "constant"
|
| 280 |
+
part = nn.functional.pad(
|
| 281 |
+
part, (0, chunk_size - chunk_len), mode=pad_mode, value=0
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
batch_data.append(part)
|
| 285 |
+
batch_locations.append((i, chunk_len))
|
| 286 |
+
i += step
|
| 287 |
+
|
| 288 |
+
# Process batch if it's full or the end is reached
|
| 289 |
+
if len(batch_data) >= batch_size or i >= mix.shape[1]:
|
| 290 |
+
arr = torch.stack(batch_data, dim=0)
|
| 291 |
+
x = model(arr)
|
| 292 |
+
|
| 293 |
+
window = windowing_array.clone()
|
| 294 |
+
if i - step == 0: # First audio chunk, no fadein
|
| 295 |
+
window[:fade_size] = 1
|
| 296 |
+
elif i >= mix.shape[1]: # Last audio chunk, no fadeout
|
| 297 |
+
window[-fade_size:] = 1
|
| 298 |
+
|
| 299 |
+
for j, (start, seg_len) in enumerate(batch_locations):
|
| 300 |
+
result[..., start : start + seg_len] += (
|
| 301 |
+
x[j, ..., :seg_len].cpu() * window[..., :seg_len]
|
| 302 |
+
)
|
| 303 |
+
counter[..., start : start + seg_len] += window[..., :seg_len]
|
| 304 |
+
|
| 305 |
+
# Output progress
|
| 306 |
+
processed = min(i, mix.shape[1])
|
| 307 |
+
total = mix.shape[1]
|
| 308 |
+
sys.stdout.write(json.dumps({"processing": {"processed": processed, "total": total}}, ensure_ascii=False) + '\n')
|
| 309 |
+
sys.stdout.flush()
|
| 310 |
+
|
| 311 |
+
batch_data.clear()
|
| 312 |
+
batch_locations.clear()
|
| 313 |
+
|
| 314 |
+
# Compute final estimated sources
|
| 315 |
+
estimated_sources = result / counter
|
| 316 |
+
estimated_sources = estimated_sources.cpu().numpy()
|
| 317 |
+
np.nan_to_num(estimated_sources, copy=False, nan=0.0)
|
| 318 |
+
|
| 319 |
+
# Remove padding
|
| 320 |
+
if length_init > 2 * border and border > 0:
|
| 321 |
+
estimated_sources = estimated_sources[..., border:-border]
|
| 322 |
+
|
| 323 |
+
# Return the result as a dictionary
|
| 324 |
+
return {k: v for k, v in zip(instruments, estimated_sources)}
|
| 325 |
+
|
| 326 |
+
def demix(
|
| 327 |
+
config: ConfigDict,
|
| 328 |
+
model: torch.nn.Module,
|
| 329 |
+
mix: np.ndarray,
|
| 330 |
+
device: torch.device,
|
| 331 |
+
model_type: str,
|
| 332 |
+
pbar: bool = False,
|
| 333 |
+
) -> Dict[str, np.ndarray]:
|
| 334 |
+
"""
|
| 335 |
+
Unified function for audio source separation with support for multiple processing modes.
|
| 336 |
+
"""
|
| 337 |
+
# Handle different model types
|
| 338 |
+
if model_type == "vr":
|
| 339 |
+
return demix_vr(config, model, mix, device, pbar)
|
| 340 |
+
elif model_type == "mdxnet":
|
| 341 |
+
return demix_mdxnet(config, model, mix, device, pbar)
|
| 342 |
+
elif model_type == "htdemucs":
|
| 343 |
+
# HTDemucs uses its own processing
|
| 344 |
+
return demix_demucs(config, model, mix, device, pbar)
|
| 345 |
+
else:
|
| 346 |
+
# Generic processing for other models
|
| 347 |
+
return demix_generic(config, model, mix, device, pbar)
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
def prefer_target_instrument(config: ConfigDict) -> List[str]:
|
| 351 |
+
"""
|
| 352 |
+
Return the list of target instruments based on the configuration.
|
| 353 |
+
If a specific target instrument is specified in the configuration,
|
| 354 |
+
it returns a list with that instrument. Otherwise, it returns the list of instruments.
|
| 355 |
+
"""
|
| 356 |
+
if config.training.get("target_instrument"):
|
| 357 |
+
return [config.training.target_instrument]
|
| 358 |
+
else:
|
| 359 |
+
return config.training.instruments
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def prefer_target_instrument_test(
|
| 363 |
+
config: ConfigDict, selected_instruments: Optional[List[str]] = None
|
| 364 |
+
) -> List[str]:
|
| 365 |
+
"""
|
| 366 |
+
Return the list of target instruments based on the configuration and selected instruments.
|
| 367 |
+
If selected_instruments is specified, returns the intersection with available instruments.
|
| 368 |
+
Otherwise, if a target instrument is specified, returns it, else returns all instruments.
|
| 369 |
+
"""
|
| 370 |
+
available_instruments = config.training.instruments
|
| 371 |
+
|
| 372 |
+
if selected_instruments is not None:
|
| 373 |
+
# Return only selected instruments that are available
|
| 374 |
+
return [
|
| 375 |
+
instr for instr in selected_instruments if instr in available_instruments
|
| 376 |
+
]
|
| 377 |
+
elif config.training.get("target_instrument"):
|
| 378 |
+
# Default behavior if no selection - return target instrument
|
| 379 |
+
return [config.training.target_instrument]
|
| 380 |
+
else:
|
| 381 |
+
# If no target and no selection, return all instruments
|
| 382 |
+
return available_instruments
|
mvsepless/model_manager.py
ADDED
|
@@ -0,0 +1,540 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import json
|
| 4 |
+
import yaml
|
| 5 |
+
from tabulate import tabulate
|
| 6 |
+
import shutil
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
import urllib.request
|
| 9 |
+
import gdown
|
| 10 |
+
import requests
|
| 11 |
+
import zipfile
|
| 12 |
+
import tempfile
|
| 13 |
+
import secrets
|
| 14 |
+
import string
|
| 15 |
+
import argparse
|
| 16 |
+
from typing import Dict, Any
|
| 17 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 18 |
+
if not __package__:
|
| 19 |
+
from downloader import dw_file
|
| 20 |
+
else:
|
| 21 |
+
from .downloader import dw_file
|
| 22 |
+
|
| 23 |
+
def generate_secure_random(length=10):
|
| 24 |
+
"""Генерирует криптографически безопасную случайную строку"""
|
| 25 |
+
characters = string.ascii_letters + string.digits
|
| 26 |
+
return ''.join(secrets.choice(characters) for _ in range(length))
|
| 27 |
+
|
| 28 |
+
class MvseplessModelManager:
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
models_info_path=os.path.join(script_dir, "models.json"),
|
| 32 |
+
cache_dir=os.path.join(script_dir, "mvsepless_models_cache"),
|
| 33 |
+
):
|
| 34 |
+
self.models_cache_dir = cache_dir
|
| 35 |
+
self.models_info_path = models_info_path
|
| 36 |
+
with open(self.models_info_path, "r", encoding="utf-8") as f:
|
| 37 |
+
models_info = json.load(f)
|
| 38 |
+
self.models_info = models_info
|
| 39 |
+
|
| 40 |
+
def get_mt(self):
|
| 41 |
+
return list(self.models_info.keys())
|
| 42 |
+
|
| 43 |
+
def get_mn(self, model_type):
|
| 44 |
+
try:
|
| 45 |
+
mt = self.models_info.get(model_type, None)
|
| 46 |
+
if mt:
|
| 47 |
+
return list(self.models_info[model_type].keys())
|
| 48 |
+
return []
|
| 49 |
+
except (KeyError, TypeError):
|
| 50 |
+
return []
|
| 51 |
+
|
| 52 |
+
def get_stems(self, model_type, model_name):
|
| 53 |
+
try:
|
| 54 |
+
mt = self.models_info.get(model_type, None)
|
| 55 |
+
if mt:
|
| 56 |
+
mn = self.models_info[model_type].get(model_name, None)
|
| 57 |
+
if mn and "stems" in self.models_info[model_type][model_name]:
|
| 58 |
+
return self.models_info[model_type][model_name]["stems"]
|
| 59 |
+
return []
|
| 60 |
+
except (KeyError, TypeError):
|
| 61 |
+
return []
|
| 62 |
+
|
| 63 |
+
def get_id(self, model_type, model_name):
|
| 64 |
+
try:
|
| 65 |
+
mt = self.models_info.get(model_type, None)
|
| 66 |
+
if mt:
|
| 67 |
+
mn = self.models_info[model_type].get(model_name, None)
|
| 68 |
+
if mn and "id" in self.models_info[model_type][model_name]:
|
| 69 |
+
return self.models_info[model_type][model_name]["id"]
|
| 70 |
+
return 0
|
| 71 |
+
except (KeyError, TypeError):
|
| 72 |
+
return 0
|
| 73 |
+
|
| 74 |
+
def get_tgt_inst(self, model_type, model_name):
|
| 75 |
+
try:
|
| 76 |
+
mt = self.models_info.get(model_type, None)
|
| 77 |
+
if mt:
|
| 78 |
+
mn = self.models_info[model_type].get(model_name, None)
|
| 79 |
+
if mn and "target_instrument" in self.models_info[model_type][model_name]:
|
| 80 |
+
return self.models_info[model_type][model_name]["target_instrument"]
|
| 81 |
+
return None
|
| 82 |
+
except (KeyError, TypeError):
|
| 83 |
+
return None
|
| 84 |
+
|
| 85 |
+
def display_models_info(self, filter: str = None):
|
| 86 |
+
# Собираем данные для таблицы
|
| 87 |
+
table_data = []
|
| 88 |
+
headers = [
|
| 89 |
+
"Тип модели",
|
| 90 |
+
"ID",
|
| 91 |
+
"Имя модели",
|
| 92 |
+
"Стемы",
|
| 93 |
+
"Целевой инструмент",
|
| 94 |
+
]
|
| 95 |
+
|
| 96 |
+
for model_type, models in self.models_info.items():
|
| 97 |
+
for model_name, model_info in models.items():
|
| 98 |
+
try:
|
| 99 |
+
stems_list = model_info.get("stems", [])
|
| 100 |
+
id = model_info.get("id", "н/д")
|
| 101 |
+
# Применяем фильтр (регистронезависимо)
|
| 102 |
+
if filter:
|
| 103 |
+
filter_lower = filter.lower()
|
| 104 |
+
if not any(filter_lower == s.lower() for s in stems_list):
|
| 105 |
+
continue
|
| 106 |
+
|
| 107 |
+
# Подготавливаем данные для строки таблицы
|
| 108 |
+
row = [
|
| 109 |
+
model_type,
|
| 110 |
+
id,
|
| 111 |
+
model_name,
|
| 112 |
+
", ".join(stems_list) or "н/д",
|
| 113 |
+
model_info.get("target_instrument", "н/д"),
|
| 114 |
+
]
|
| 115 |
+
table_data.append(row)
|
| 116 |
+
except (KeyError, TypeError, AttributeError) as e:
|
| 117 |
+
print(f"Ошибка при обработке модели {model_type}/{model_name}: {e}")
|
| 118 |
+
continue
|
| 119 |
+
|
| 120 |
+
# Выводим результат
|
| 121 |
+
if table_data:
|
| 122 |
+
print(tabulate(table_data, headers=headers, tablefmt="grid"))
|
| 123 |
+
else:
|
| 124 |
+
print("Нет моделей, которые содержат указанный стем")
|
| 125 |
+
|
| 126 |
+
def download_model(
|
| 127 |
+
self, model_paths, model_name, model_type, ckpt_url, conf_url
|
| 128 |
+
):
|
| 129 |
+
model_dir = os.path.join(model_paths, model_type)
|
| 130 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 131 |
+
|
| 132 |
+
config_path = None
|
| 133 |
+
checkpoint_path = None
|
| 134 |
+
|
| 135 |
+
if model_type == "mel_band_roformer":
|
| 136 |
+
config_path = os.path.join(model_dir, f"{model_name}_config.yaml")
|
| 137 |
+
checkpoint_path = os.path.join(model_dir, f"{model_name}.ckpt")
|
| 138 |
+
|
| 139 |
+
elif model_type == "vr":
|
| 140 |
+
config_path = os.path.join(model_dir, f"{model_name}.yaml")
|
| 141 |
+
checkpoint_path = os.path.join(model_dir, f"{model_name}.pth")
|
| 142 |
+
|
| 143 |
+
elif model_type == "mdxnet":
|
| 144 |
+
config_path = os.path.join(model_dir, f"{model_name}.yaml")
|
| 145 |
+
checkpoint_path = os.path.join(model_dir, f"{model_name}.onnx")
|
| 146 |
+
|
| 147 |
+
elif model_type == "bs_roformer":
|
| 148 |
+
config_path = os.path.join(model_dir, f"{model_name}_config.yaml")
|
| 149 |
+
checkpoint_path = os.path.join(model_dir, f"{model_name}.ckpt")
|
| 150 |
+
|
| 151 |
+
elif model_type == "mdx23c":
|
| 152 |
+
config_path = os.path.join(model_dir, f"{model_name}_config.yaml")
|
| 153 |
+
checkpoint_path = os.path.join(model_dir, f"{model_name}.ckpt")
|
| 154 |
+
|
| 155 |
+
elif model_type == "scnet":
|
| 156 |
+
config_path = os.path.join(model_dir, f"{model_name}_config.yaml")
|
| 157 |
+
checkpoint_path = os.path.join(model_dir, f"{model_name}.ckpt")
|
| 158 |
+
|
| 159 |
+
elif model_type == "bandit":
|
| 160 |
+
config_path = os.path.join(model_dir, f"{model_name}_config.yaml")
|
| 161 |
+
checkpoint_path = os.path.join(model_dir, f"{model_name}.chpt")
|
| 162 |
+
|
| 163 |
+
elif model_type == "bandit_v2":
|
| 164 |
+
config_path = os.path.join(model_dir, f"{model_name}_config.yaml")
|
| 165 |
+
checkpoint_path = os.path.join(model_dir, f"{model_name}.ckpt")
|
| 166 |
+
|
| 167 |
+
elif model_type == "htdemucs":
|
| 168 |
+
config_path = os.path.join(model_dir, f"{model_name}_config.yaml")
|
| 169 |
+
checkpoint_path = os.path.join(model_dir, f"{model_name}.th")
|
| 170 |
+
|
| 171 |
+
else:
|
| 172 |
+
raise ValueError(
|
| 173 |
+
f"{self.I18N_helper.t('error_unsupported_model_type')}: {model_type}"
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# Проверяем, что пути заданы (на всякий случай)
|
| 177 |
+
if config_path is None or checkpoint_path is None:
|
| 178 |
+
raise RuntimeError()
|
| 179 |
+
|
| 180 |
+
# Если файлы уже есть — пропускаем загрузку
|
| 181 |
+
if os.path.exists(checkpoint_path) and os.path.exists(config_path):
|
| 182 |
+
if os.path.getsize(checkpoint_path) == 0 or os.path.getsize(checkpoint_path) == 0:
|
| 183 |
+
for local_path, url_model in [
|
| 184 |
+
(checkpoint_path, ckpt_url),
|
| 185 |
+
(config_path, conf_url),
|
| 186 |
+
]:
|
| 187 |
+
if not os.path.exists(local_path):
|
| 188 |
+
|
| 189 |
+
dw_file(url_model, local_path)
|
| 190 |
+
else:
|
| 191 |
+
pass
|
| 192 |
+
else:
|
| 193 |
+
for local_path, url_model in [
|
| 194 |
+
(checkpoint_path, ckpt_url),
|
| 195 |
+
(config_path, conf_url),
|
| 196 |
+
]:
|
| 197 |
+
if not os.path.exists(local_path):
|
| 198 |
+
|
| 199 |
+
dw_file(url_model, local_path)
|
| 200 |
+
|
| 201 |
+
return config_path, checkpoint_path
|
| 202 |
+
|
| 203 |
+
def conf_editor(self, config_path, mdx_denoise, vr_aggr, model_type):
|
| 204 |
+
|
| 205 |
+
class IndentDumper(yaml.Dumper):
|
| 206 |
+
def increase_indent(self, flow=False, indentless=False):
|
| 207 |
+
return super(IndentDumper, self).increase_indent(flow, False)
|
| 208 |
+
|
| 209 |
+
def tuple_constructor(loader, node):
|
| 210 |
+
# Load the sequence of values from the YAML node
|
| 211 |
+
values = loader.construct_sequence(node)
|
| 212 |
+
# Return a tuple constructed from the sequence
|
| 213 |
+
return tuple(values)
|
| 214 |
+
|
| 215 |
+
# Register the constructor with PyYAML
|
| 216 |
+
yaml.SafeLoader.add_constructor(
|
| 217 |
+
"tag:yaml.org,2002:python/tuple", tuple_constructor
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
def conf_edit(config_path, mdx_denoise, vr_aggr, model_type):
|
| 221 |
+
with open(config_path, "r") as f:
|
| 222 |
+
data = yaml.load(f, Loader=yaml.SafeLoader)
|
| 223 |
+
|
| 224 |
+
# handle cases where 'use_amp' is missing from config:
|
| 225 |
+
if "use_amp" not in data.keys():
|
| 226 |
+
data["training"]["use_amp"] = True
|
| 227 |
+
|
| 228 |
+
if model_type != "vr":
|
| 229 |
+
if data["inference"]["num_overlap"] != 2:
|
| 230 |
+
data["inference"]["num_overlap"] = 2
|
| 231 |
+
|
| 232 |
+
if data["inference"]["batch_size"] != 1:
|
| 233 |
+
data["inference"]["batch_size"] = 1
|
| 234 |
+
|
| 235 |
+
if model_type == "mdxnet":
|
| 236 |
+
data["inference"]["denoise"] = mdx_denoise
|
| 237 |
+
|
| 238 |
+
elif model_type == "vr":
|
| 239 |
+
data["inference"]["aggression"] = vr_aggr
|
| 240 |
+
|
| 241 |
+
with open(config_path, "w") as f:
|
| 242 |
+
yaml.dump(
|
| 243 |
+
data,
|
| 244 |
+
f,
|
| 245 |
+
default_flow_style=False,
|
| 246 |
+
sort_keys=False,
|
| 247 |
+
Dumper=IndentDumper,
|
| 248 |
+
allow_unicode=True,
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
conf_edit(config_path, mdx_denoise, vr_aggr, model_type)
|
| 252 |
+
|
| 253 |
+
class VbachModelManager:
|
| 254 |
+
def __init__(self):
|
| 255 |
+
self.rmvpe_path = os.path.join(script_dir, "predictors", "rmvpe.pt")
|
| 256 |
+
self.fcpe_path = os.path.join(script_dir, "predictors", "fcpe.pt")
|
| 257 |
+
self.hubert_path = os.path.join(script_dir, "embedders", "hubert_base.pt")
|
| 258 |
+
self.requirements = [["https://huggingface.co/Politrees/RVC_resources/resolve/main/predictors/rmvpe.pt", self.rmvpe_path], ["https://huggingface.co/Politrees/RVC_resources/resolve/main/predictors/fcpe.pt", self.fcpe_path], ["https://huggingface.co/Politrees/RVC_resources/resolve/main/embedders/hubert_base.pt", self.hubert_path]]
|
| 259 |
+
self.voicemodels_dir = os.path.join(script_dir, "vbach_models_cache")
|
| 260 |
+
os.makedirs(self.voicemodels_dir, exist_ok=True)
|
| 261 |
+
self.voicemodels_info = os.path.join(self.voicemodels_dir, "vbach_models.json")
|
| 262 |
+
self.voicemodels: Dict[str, Dict[str, str]] = {}
|
| 263 |
+
self.download_requirements()
|
| 264 |
+
self.check_and_load()
|
| 265 |
+
pass
|
| 266 |
+
|
| 267 |
+
def write_voicemodels_info(self):
|
| 268 |
+
with open(self.voicemodels_info, "w") as f:
|
| 269 |
+
json.dump(self.voicemodels, f, indent=4)
|
| 270 |
+
|
| 271 |
+
def load_voicemodels_info(self):
|
| 272 |
+
with open(self.voicemodels_info, "r") as f:
|
| 273 |
+
return json.load(f)
|
| 274 |
+
|
| 275 |
+
def add_voice_model(
|
| 276 |
+
self,
|
| 277 |
+
name,
|
| 278 |
+
pth_path,
|
| 279 |
+
index_path,
|
| 280 |
+
):
|
| 281 |
+
self.voicemodels[name] = {"pth": pth_path, "index": index_path}
|
| 282 |
+
self.write_voicemodels_info()
|
| 283 |
+
|
| 284 |
+
def del_voice_model(
|
| 285 |
+
self, name
|
| 286 |
+
):
|
| 287 |
+
if name in self.parse_voice_models():
|
| 288 |
+
pth = self.voicemodels[name].get("pth", None)
|
| 289 |
+
index = self.voicemodels[name].get("index", None)
|
| 290 |
+
if index:
|
| 291 |
+
os.remove(index)
|
| 292 |
+
if pth:
|
| 293 |
+
os.remove(pth)
|
| 294 |
+
del self.voicemodels[name]
|
| 295 |
+
self.write_voicemodels_info()
|
| 296 |
+
return f"Модель {name} удалена"
|
| 297 |
+
else:
|
| 298 |
+
return f"Модель не была удалена, как так её не существует"
|
| 299 |
+
|
| 300 |
+
def parse_voice_models(self):
|
| 301 |
+
list_models = list(self.voicemodels.keys())
|
| 302 |
+
return list_models
|
| 303 |
+
|
| 304 |
+
def parse_pth_and_index(self, name):
|
| 305 |
+
pth = self.voicemodels[name].get("pth", None)
|
| 306 |
+
index = self.voicemodels[name].get("index", None)
|
| 307 |
+
return pth, index
|
| 308 |
+
|
| 309 |
+
def check_and_load(self):
|
| 310 |
+
if os.path.exists(self.voicemodels_info):
|
| 311 |
+
self.voicemodels = self.load_voicemodels_info()
|
| 312 |
+
else:
|
| 313 |
+
self.write_voicemodels_info()
|
| 314 |
+
|
| 315 |
+
def clear_voicemodels_info(self):
|
| 316 |
+
self.voicemodels: Dict[str, Dict[str, str]] = {}
|
| 317 |
+
self.write_voicemodels_info()
|
| 318 |
+
|
| 319 |
+
def download_file(self, url_model, local_path):
|
| 320 |
+
dir_name = os.path.dirname(local_path)
|
| 321 |
+
if dir_name != "":
|
| 322 |
+
os.makedirs(dir_name, exist_ok=True)
|
| 323 |
+
class TqdmUpTo(tqdm):
|
| 324 |
+
def update_to(self, b=1, bsize=1, tsize=None):
|
| 325 |
+
if tsize is not None:
|
| 326 |
+
self.total = tsize
|
| 327 |
+
self.update(b * bsize - self.n)
|
| 328 |
+
|
| 329 |
+
with TqdmUpTo(
|
| 330 |
+
unit="B",
|
| 331 |
+
unit_scale=True,
|
| 332 |
+
unit_divisor=1024,
|
| 333 |
+
miniters=1,
|
| 334 |
+
desc=os.path.basename(local_path),
|
| 335 |
+
) as t:
|
| 336 |
+
urllib.request.urlretrieve(
|
| 337 |
+
url_model, local_path, reporthook=t.update_to
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
def download_requirements(self):
|
| 341 |
+
for url, file in self.requirements:
|
| 342 |
+
if not os.path.exists(file):
|
| 343 |
+
self.download_file(url_model=url, local_path=file)
|
| 344 |
+
|
| 345 |
+
def download_voice_model_file(self, url, zip_name):
|
| 346 |
+
try:
|
| 347 |
+
if "drive.google.com" in url:
|
| 348 |
+
self.download_from_google_drive(url, zip_name)
|
| 349 |
+
elif "pixeldrain.com" in url:
|
| 350 |
+
self.download_from_pixeldrain(url, zip_name)
|
| 351 |
+
elif "disk.yandex.ru" in url or "yadi.sk" in url:
|
| 352 |
+
self.download_from_yandex(url, zip_name)
|
| 353 |
+
else:
|
| 354 |
+
self.download_file(url, zip_name)
|
| 355 |
+
except Exception as e:
|
| 356 |
+
print(e)
|
| 357 |
+
|
| 358 |
+
def download_from_google_drive(self, url, zip_name):
|
| 359 |
+
file_id = (
|
| 360 |
+
url.split("file/d/")[1].split("/")[0]
|
| 361 |
+
if "file/d/" in url
|
| 362 |
+
else url.split("id=")[1].split("&")[0]
|
| 363 |
+
)
|
| 364 |
+
gdown.download(id=file_id, output=str(zip_name), quiet=False)
|
| 365 |
+
|
| 366 |
+
def download_from_pixeldrain(self, url, zip_name):
|
| 367 |
+
file_id = url.split("pixeldrain.com/u/")[1]
|
| 368 |
+
response = requests.get(f"https://pixeldrain.com/api/file/{file_id}")
|
| 369 |
+
with open(zip_name, "wb") as f:
|
| 370 |
+
f.write(response.content)
|
| 371 |
+
|
| 372 |
+
def download_from_yandex(self, url, zip_name):
|
| 373 |
+
yandex_public_key = f"download?public_key={url}"
|
| 374 |
+
yandex_api_url = f"https://cloud-api.yandex.net/v1/disk/public/resources/{yandex_public_key}"
|
| 375 |
+
response = requests.get(yandex_api_url)
|
| 376 |
+
if response.status_code == 200:
|
| 377 |
+
download_link = response.json().get("href")
|
| 378 |
+
urllib.request.urlretrieve(download_link, zip_name)
|
| 379 |
+
else:
|
| 380 |
+
print(response.status_code)
|
| 381 |
+
|
| 382 |
+
def extract_zip(self, zip_name, model_name):
|
| 383 |
+
model_dir = os.path.join(self.voicemodels_dir, f"{model_name}_{generate_secure_random(17)}")
|
| 384 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 385 |
+
try:
|
| 386 |
+
with zipfile.ZipFile(zip_name, "r") as zip_ref:
|
| 387 |
+
zip_ref.extractall(model_dir)
|
| 388 |
+
os.remove(zip_name)
|
| 389 |
+
|
| 390 |
+
added_voice_models = []
|
| 391 |
+
|
| 392 |
+
index_filepath, model_filepaths = None, []
|
| 393 |
+
for root, _, files in os.walk(model_dir):
|
| 394 |
+
for name in files:
|
| 395 |
+
file_path = os.path.join(root, name)
|
| 396 |
+
if name.endswith(".index") and os.stat(file_path).st_size > 1024 * 100:
|
| 397 |
+
index_filepath = file_path
|
| 398 |
+
if name.endswith(".pth") and os.stat(file_path).st_size > 1024 * 1024 * 20:
|
| 399 |
+
model_filepaths.append(file_path)
|
| 400 |
+
|
| 401 |
+
if len(model_filepaths) == 1:
|
| 402 |
+
self.add_voice_model(model_name, model_filepaths[0], index_filepath)
|
| 403 |
+
added_voice_models.append(model_name)
|
| 404 |
+
else:
|
| 405 |
+
for i, pth in enumerate(model_filepaths):
|
| 406 |
+
self.add_voice_model(f"{model_name}_{i + 1}", pth, index_filepath)
|
| 407 |
+
added_voice_models.append(f"{model_name}_{i + 1}")
|
| 408 |
+
list_models_str = '\n'.join(added_voice_models)
|
| 409 |
+
return f"Добавленные модели:\n{list_models_str}"
|
| 410 |
+
except Exception as e:
|
| 411 |
+
return f"Произошла ошибка при загрузке модели: {e}"
|
| 412 |
+
|
| 413 |
+
def install_model_zip(self, zip, model_name, mode="url"):
|
| 414 |
+
if model_name in self.parse_voice_models():
|
| 415 |
+
print("Эта модель уже есть в списке установленных моделей. Она будут перезаписана")
|
| 416 |
+
if mode == "url":
|
| 417 |
+
with tempfile.TemporaryDirectory(prefix="vbach_temp_model", ignore_cleanup_errors=True) as tmp:
|
| 418 |
+
zip_path = os.path.join(tmp, "model.zip")
|
| 419 |
+
self.download_voice_model_file(zip, zip_path)
|
| 420 |
+
status = self.extract_zip(zip_path, model_name)
|
| 421 |
+
if mode == "local":
|
| 422 |
+
status = self.extract_zip(zip, model_name)
|
| 423 |
+
return status
|
| 424 |
+
|
| 425 |
+
def install_model_files(self, index, pth, model_name, mode="url"):
|
| 426 |
+
if model_name in self.parse_voice_models():
|
| 427 |
+
print("Эта модель уже есть в списке установленных моделей. Она будут перезаписана")
|
| 428 |
+
model_dir = os.path.join(self.voicemodels_dir, f"{model_name}_{generate_secure_random(17)}")
|
| 429 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 430 |
+
local_index_path = None
|
| 431 |
+
local_pth_path = None
|
| 432 |
+
try:
|
| 433 |
+
if mode == "url":
|
| 434 |
+
if index:
|
| 435 |
+
local_index_path = os.path.join(model_dir, "model.index")
|
| 436 |
+
self.download_voice_model_file(index, local_index_path)
|
| 437 |
+
if pth:
|
| 438 |
+
local_pth_path = os.path.join(model_dir, "model.pth")
|
| 439 |
+
self.download_voice_model_file(pth, local_pth_path)
|
| 440 |
+
|
| 441 |
+
if mode == "local":
|
| 442 |
+
if index:
|
| 443 |
+
if os.path.exists(index):
|
| 444 |
+
local_index_path = os.path.join(model_dir, os.path.basename(index))
|
| 445 |
+
shutil.copy(index, local_index_path)
|
| 446 |
+
if pth:
|
| 447 |
+
if os.path.exists(pth):
|
| 448 |
+
local_pth_path = os.path.join(model_dir, os.path.basename(pth))
|
| 449 |
+
shutil.copy(pth, local_pth_path)
|
| 450 |
+
|
| 451 |
+
self.add_voice_model(model_name, local_pth_path, local_index_path)
|
| 452 |
+
return f"Модель {model_name} добавлена"
|
| 453 |
+
except Exception as e:
|
| 454 |
+
return f"Произошла ошибка при загрузке модели: {e}"
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
if __name__ == "__main__":
|
| 458 |
+
parser = argparse.ArgumentParser(description="Менеджер моделей")
|
| 459 |
+
subparsers = parser.add_subparsers(title="subcommands", dest="command", required=True)
|
| 460 |
+
|
| 461 |
+
# Mvsepless subcommand
|
| 462 |
+
mvsepless_parser = subparsers.add_parser("mvsepless", help="Скачивание моделей в MVSepLess")
|
| 463 |
+
mvsepless_parser.add_argument("--model_type", required=True, help="Тип модели")
|
| 464 |
+
mvsepless_parser.add_argument("--model_name", required=True, help="Имя модели")
|
| 465 |
+
|
| 466 |
+
# Vbach subcommand
|
| 467 |
+
vbach_parser = subparsers.add_parser("vbach", help="Установка голосовых моделей в Vbach")
|
| 468 |
+
vbach_subparsers = vbach_parser.add_subparsers(title="vbach_commands", dest="vbach_command", required=True)
|
| 469 |
+
|
| 470 |
+
# Vbach install_local
|
| 471 |
+
install_local_parser = vbach_subparsers.add_parser("install_local", help="Установка голосовой модели по локальным файлам")
|
| 472 |
+
install_local_parser.add_argument("--model_name", required=True, help="Имя голосовой модели")
|
| 473 |
+
install_local_parser.add_argument("--pth", required=True, help="Путь к *.pth файлу")
|
| 474 |
+
install_local_parser.add_argument("--index", required=False, help="Путь к *.index файлу")
|
| 475 |
+
|
| 476 |
+
# Vbach install_url_zip
|
| 477 |
+
install_url_zip_parser = vbach_subparsers.add_parser("install_url_zip", help="Установка голосовой модели по URL (архив с файлами)")
|
| 478 |
+
install_url_zip_parser.add_argument("--model_name", required=True, help="Имя голосовой модели")
|
| 479 |
+
install_url_zip_parser.add_argument("--url", required=True, help="URL *.zip файла")
|
| 480 |
+
|
| 481 |
+
# Vbach install_url_files
|
| 482 |
+
install_url_files_parser = vbach_subparsers.add_parser("install_url_files", help="Установка голосовой модели по URL (отдельные файлы)")
|
| 483 |
+
install_url_files_parser.add_argument("--model_name", required=True, help="Имя голосовой модели")
|
| 484 |
+
install_url_files_parser.add_argument("--pth_url", required=True, help="URL *.pth файла")
|
| 485 |
+
install_url_files_parser.add_argument("--index_url", required=False, help="URL *.index файла")
|
| 486 |
+
|
| 487 |
+
# Vbach list
|
| 488 |
+
list_parser = vbach_subparsers.add_parser("list", help="List installed voice models")
|
| 489 |
+
|
| 490 |
+
args = parser.parse_args()
|
| 491 |
+
|
| 492 |
+
if args.command == "mvsepless":
|
| 493 |
+
|
| 494 |
+
_model_manager = MvseplessModelManager()
|
| 495 |
+
info = _model_manager.models_info[args.model_type].get(args.model_name, None)
|
| 496 |
+
if not info:
|
| 497 |
+
raise ValueError(f"Модель {args.model_name} не найдена для типа {args.model_type}")
|
| 498 |
+
conf, ckpt = _model_manager.download_model(
|
| 499 |
+
_model_manager.models_cache_dir,
|
| 500 |
+
args.model_name,
|
| 501 |
+
args.model_type,
|
| 502 |
+
info["checkpoint_url"],
|
| 503 |
+
info["config_url"],
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
elif args.command == "vbach":
|
| 507 |
+
model_manager = VbachModelManager()
|
| 508 |
+
|
| 509 |
+
if args.vbach_command == "install_local":
|
| 510 |
+
status = model_manager.install_model_files(
|
| 511 |
+
args.index, args.pth, args.model_name, mode="local"
|
| 512 |
+
)
|
| 513 |
+
print(status)
|
| 514 |
+
|
| 515 |
+
elif args.vbach_command == "install_url_zip":
|
| 516 |
+
status = model_manager.install_model_zip(
|
| 517 |
+
args.url, args.model_name, mode="url"
|
| 518 |
+
)
|
| 519 |
+
print(status)
|
| 520 |
+
|
| 521 |
+
elif args.vbach_command == "install_url_files":
|
| 522 |
+
status = model_manager.install_model_files(
|
| 523 |
+
args.index_url, args.pth_url, args.model_name, mode="url"
|
| 524 |
+
)
|
| 525 |
+
print(status)
|
| 526 |
+
|
| 527 |
+
elif args.vbach_command == "list":
|
| 528 |
+
models = model_manager.parse_voice_models()
|
| 529 |
+
if models:
|
| 530 |
+
print("Установленные модели:")
|
| 531 |
+
for model in models:
|
| 532 |
+
print(f" - {model}")
|
| 533 |
+
else:
|
| 534 |
+
print("Нет установленных моделей")
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
|
mvsepless/models.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
mvsepless/models/bandit/core/__init__.py
ADDED
|
@@ -0,0 +1,691 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
from itertools import chain, combinations
|
| 4 |
+
from typing import Any, Dict, Iterator, Mapping, Optional, Tuple, Type, TypedDict
|
| 5 |
+
|
| 6 |
+
import pytorch_lightning as pl
|
| 7 |
+
import torch
|
| 8 |
+
import torchaudio as ta
|
| 9 |
+
import torchmetrics as tm
|
| 10 |
+
from asteroid import losses as asteroid_losses
|
| 11 |
+
|
| 12 |
+
# from deepspeed.ops.adam import DeepSpeedCPUAdam
|
| 13 |
+
# from geoopt import optim as gooptim
|
| 14 |
+
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
| 15 |
+
from torch import nn, optim
|
| 16 |
+
from torch.optim import lr_scheduler
|
| 17 |
+
from torch.optim.lr_scheduler import LRScheduler
|
| 18 |
+
|
| 19 |
+
from . import loss, metrics as metrics_, model
|
| 20 |
+
from .data._types import BatchedDataDict
|
| 21 |
+
from .data.augmentation import BaseAugmentor, StemAugmentor
|
| 22 |
+
from .utils import audio as audio_
|
| 23 |
+
from .utils.audio import BaseFader
|
| 24 |
+
|
| 25 |
+
# from pandas.io.json._normalize import nested_to_record
|
| 26 |
+
|
| 27 |
+
ConfigDict = TypedDict("ConfigDict", {"name": str, "kwargs": Dict[str, Any]})
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class SchedulerConfigDict(ConfigDict):
|
| 31 |
+
monitor: str
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
OptimizerSchedulerConfigDict = TypedDict(
|
| 35 |
+
"OptimizerSchedulerConfigDict",
|
| 36 |
+
{"optimizer": ConfigDict, "scheduler": SchedulerConfigDict},
|
| 37 |
+
total=False,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class LRSchedulerReturnDict(TypedDict, total=False):
|
| 42 |
+
scheduler: LRScheduler
|
| 43 |
+
monitor: str
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class ConfigureOptimizerReturnDict(TypedDict, total=False):
|
| 47 |
+
optimizer: torch.optim.Optimizer
|
| 48 |
+
lr_scheduler: LRSchedulerReturnDict
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
OutputType = Dict[str, Any]
|
| 52 |
+
MetricsType = Dict[str, torch.Tensor]
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def get_optimizer_class(name: str) -> Type[optim.Optimizer]:
|
| 56 |
+
|
| 57 |
+
if name == "DeepSpeedCPUAdam":
|
| 58 |
+
return DeepSpeedCPUAdam
|
| 59 |
+
|
| 60 |
+
for module in [optim, gooptim]:
|
| 61 |
+
if name in module.__dict__:
|
| 62 |
+
return module.__dict__[name]
|
| 63 |
+
|
| 64 |
+
raise NameError
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def parse_optimizer_config(
|
| 68 |
+
config: OptimizerSchedulerConfigDict, parameters: Iterator[nn.Parameter]
|
| 69 |
+
) -> ConfigureOptimizerReturnDict:
|
| 70 |
+
optim_class = get_optimizer_class(config["optimizer"]["name"])
|
| 71 |
+
optimizer = optim_class(parameters, **config["optimizer"]["kwargs"])
|
| 72 |
+
|
| 73 |
+
optim_dict: ConfigureOptimizerReturnDict = {
|
| 74 |
+
"optimizer": optimizer,
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
if "scheduler" in config:
|
| 78 |
+
|
| 79 |
+
lr_scheduler_class_ = config["scheduler"]["name"]
|
| 80 |
+
lr_scheduler_class = lr_scheduler.__dict__[lr_scheduler_class_]
|
| 81 |
+
lr_scheduler_dict: LRSchedulerReturnDict = {
|
| 82 |
+
"scheduler": lr_scheduler_class(optimizer, **config["scheduler"]["kwargs"])
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
if lr_scheduler_class_ == "ReduceLROnPlateau":
|
| 86 |
+
lr_scheduler_dict["monitor"] = config["scheduler"]["monitor"]
|
| 87 |
+
|
| 88 |
+
optim_dict["lr_scheduler"] = lr_scheduler_dict
|
| 89 |
+
|
| 90 |
+
return optim_dict
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def parse_model_config(config: ConfigDict) -> Any:
|
| 94 |
+
name = config["name"]
|
| 95 |
+
|
| 96 |
+
for module in [model]:
|
| 97 |
+
if name in module.__dict__:
|
| 98 |
+
return module.__dict__[name](**config["kwargs"])
|
| 99 |
+
|
| 100 |
+
raise NameError
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
_LEGACY_LOSS_NAMES = ["HybridL1Loss"]
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def _parse_legacy_loss_config(config: ConfigDict) -> nn.Module:
|
| 107 |
+
name = config["name"]
|
| 108 |
+
|
| 109 |
+
if name == "HybridL1Loss":
|
| 110 |
+
return loss.TimeFreqL1Loss(**config["kwargs"])
|
| 111 |
+
|
| 112 |
+
raise NameError
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def parse_loss_config(config: ConfigDict) -> nn.Module:
|
| 116 |
+
name = config["name"]
|
| 117 |
+
|
| 118 |
+
if name in _LEGACY_LOSS_NAMES:
|
| 119 |
+
return _parse_legacy_loss_config(config)
|
| 120 |
+
|
| 121 |
+
for module in [loss, nn.modules.loss, asteroid_losses]:
|
| 122 |
+
if name in module.__dict__:
|
| 123 |
+
# print(config["kwargs"])
|
| 124 |
+
return module.__dict__[name](**config["kwargs"])
|
| 125 |
+
|
| 126 |
+
raise NameError
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def get_metric(config: ConfigDict) -> tm.Metric:
|
| 130 |
+
name = config["name"]
|
| 131 |
+
|
| 132 |
+
for module in [tm, metrics_]:
|
| 133 |
+
if name in module.__dict__:
|
| 134 |
+
return module.__dict__[name](**config["kwargs"])
|
| 135 |
+
raise NameError
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def parse_metric_config(config: Dict[str, ConfigDict]) -> tm.MetricCollection:
|
| 139 |
+
metrics = {}
|
| 140 |
+
|
| 141 |
+
for metric in config:
|
| 142 |
+
metrics[metric] = get_metric(config[metric])
|
| 143 |
+
|
| 144 |
+
return tm.MetricCollection(metrics)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def parse_fader_config(config: ConfigDict) -> BaseFader:
|
| 148 |
+
name = config["name"]
|
| 149 |
+
|
| 150 |
+
for module in [audio_]:
|
| 151 |
+
if name in module.__dict__:
|
| 152 |
+
return module.__dict__[name](**config["kwargs"])
|
| 153 |
+
|
| 154 |
+
raise NameError
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class LightningSystem(pl.LightningModule):
|
| 158 |
+
_VOX_STEMS = ["speech", "vocals"]
|
| 159 |
+
_BG_STEMS = ["background", "effects", "mne"]
|
| 160 |
+
|
| 161 |
+
def __init__(
|
| 162 |
+
self, config: Dict, loss_adjustment: float = 1.0, attach_fader: bool = False
|
| 163 |
+
) -> None:
|
| 164 |
+
super().__init__()
|
| 165 |
+
self.optimizer_config = config["optimizer"]
|
| 166 |
+
self.model = parse_model_config(config["model"])
|
| 167 |
+
self.loss = parse_loss_config(config["loss"])
|
| 168 |
+
self.metrics = nn.ModuleDict(
|
| 169 |
+
{
|
| 170 |
+
stem: parse_metric_config(config["metrics"]["dev"])
|
| 171 |
+
for stem in self.model.stems
|
| 172 |
+
}
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
self.metrics.disallow_fsdp = True
|
| 176 |
+
|
| 177 |
+
self.test_metrics = nn.ModuleDict(
|
| 178 |
+
{
|
| 179 |
+
stem: parse_metric_config(config["metrics"]["test"])
|
| 180 |
+
for stem in self.model.stems
|
| 181 |
+
}
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
self.test_metrics.disallow_fsdp = True
|
| 185 |
+
|
| 186 |
+
self.fs = config["model"]["kwargs"]["fs"]
|
| 187 |
+
|
| 188 |
+
self.fader_config = config["inference"]["fader"]
|
| 189 |
+
if attach_fader:
|
| 190 |
+
self.fader = parse_fader_config(config["inference"]["fader"])
|
| 191 |
+
else:
|
| 192 |
+
self.fader = None
|
| 193 |
+
|
| 194 |
+
self.augmentation: Optional[BaseAugmentor]
|
| 195 |
+
if config.get("augmentation", None) is not None:
|
| 196 |
+
self.augmentation = StemAugmentor(**config["augmentation"])
|
| 197 |
+
else:
|
| 198 |
+
self.augmentation = None
|
| 199 |
+
|
| 200 |
+
self.predict_output_path: Optional[str] = None
|
| 201 |
+
self.loss_adjustment = loss_adjustment
|
| 202 |
+
|
| 203 |
+
self.val_prefix = None
|
| 204 |
+
self.test_prefix = None
|
| 205 |
+
|
| 206 |
+
def configure_optimizers(self) -> Any:
|
| 207 |
+
return parse_optimizer_config(
|
| 208 |
+
self.optimizer_config, self.trainer.model.parameters()
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
def compute_loss(
|
| 212 |
+
self, batch: BatchedDataDict, output: OutputType
|
| 213 |
+
) -> Dict[str, torch.Tensor]:
|
| 214 |
+
return {"loss": self.loss(output, batch)}
|
| 215 |
+
|
| 216 |
+
def update_metrics(
|
| 217 |
+
self, batch: BatchedDataDict, output: OutputType, mode: str
|
| 218 |
+
) -> None:
|
| 219 |
+
|
| 220 |
+
if mode == "test":
|
| 221 |
+
metrics = self.test_metrics
|
| 222 |
+
else:
|
| 223 |
+
metrics = self.metrics
|
| 224 |
+
|
| 225 |
+
for stem, metric in metrics.items():
|
| 226 |
+
|
| 227 |
+
if stem == "mne:+":
|
| 228 |
+
stem = "mne"
|
| 229 |
+
|
| 230 |
+
# print(f"matching for {stem}")
|
| 231 |
+
if mode == "train":
|
| 232 |
+
metric.update(
|
| 233 |
+
output["audio"][stem], # .cpu(),
|
| 234 |
+
batch["audio"][stem], # .cpu()
|
| 235 |
+
)
|
| 236 |
+
else:
|
| 237 |
+
if stem not in batch["audio"]:
|
| 238 |
+
matched = False
|
| 239 |
+
if stem in self._VOX_STEMS:
|
| 240 |
+
for bstem in self._VOX_STEMS:
|
| 241 |
+
if bstem in batch["audio"]:
|
| 242 |
+
batch["audio"][stem] = batch["audio"][bstem]
|
| 243 |
+
matched = True
|
| 244 |
+
break
|
| 245 |
+
elif stem in self._BG_STEMS:
|
| 246 |
+
for bstem in self._BG_STEMS:
|
| 247 |
+
if bstem in batch["audio"]:
|
| 248 |
+
batch["audio"][stem] = batch["audio"][bstem]
|
| 249 |
+
matched = True
|
| 250 |
+
break
|
| 251 |
+
else:
|
| 252 |
+
matched = True
|
| 253 |
+
|
| 254 |
+
# print(batch["audio"].keys())
|
| 255 |
+
|
| 256 |
+
if matched:
|
| 257 |
+
# print(f"matched {stem}!")
|
| 258 |
+
if stem == "mne" and "mne" not in output["audio"]:
|
| 259 |
+
output["audio"]["mne"] = (
|
| 260 |
+
output["audio"]["music"] + output["audio"]["effects"]
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
metric.update(
|
| 264 |
+
output["audio"][stem], # .cpu(),
|
| 265 |
+
batch["audio"][stem], # .cpu(),
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
# print(metric.compute())
|
| 269 |
+
|
| 270 |
+
def compute_metrics(self, mode: str = "dev") -> Dict[str, torch.Tensor]:
|
| 271 |
+
|
| 272 |
+
if mode == "test":
|
| 273 |
+
metrics = self.test_metrics
|
| 274 |
+
else:
|
| 275 |
+
metrics = self.metrics
|
| 276 |
+
|
| 277 |
+
metric_dict = {}
|
| 278 |
+
|
| 279 |
+
for stem, metric in metrics.items():
|
| 280 |
+
md = metric.compute()
|
| 281 |
+
metric_dict.update({f"{stem}/{k}": v for k, v in md.items()})
|
| 282 |
+
|
| 283 |
+
self.log_dict(metric_dict, prog_bar=True, logger=False)
|
| 284 |
+
|
| 285 |
+
return metric_dict
|
| 286 |
+
|
| 287 |
+
def reset_metrics(self, test_mode: bool = False) -> None:
|
| 288 |
+
|
| 289 |
+
if test_mode:
|
| 290 |
+
metrics = self.test_metrics
|
| 291 |
+
else:
|
| 292 |
+
metrics = self.metrics
|
| 293 |
+
|
| 294 |
+
for _, metric in metrics.items():
|
| 295 |
+
metric.reset()
|
| 296 |
+
|
| 297 |
+
def forward(self, batch: BatchedDataDict) -> Any:
|
| 298 |
+
batch, output = self.model(batch)
|
| 299 |
+
|
| 300 |
+
return batch, output
|
| 301 |
+
|
| 302 |
+
def common_step(self, batch: BatchedDataDict, mode: str) -> Any:
|
| 303 |
+
batch, output = self.forward(batch)
|
| 304 |
+
# print(batch)
|
| 305 |
+
# print(output)
|
| 306 |
+
loss_dict = self.compute_loss(batch, output)
|
| 307 |
+
|
| 308 |
+
with torch.no_grad():
|
| 309 |
+
self.update_metrics(batch, output, mode=mode)
|
| 310 |
+
|
| 311 |
+
if mode == "train":
|
| 312 |
+
self.log("loss", loss_dict["loss"], prog_bar=True)
|
| 313 |
+
|
| 314 |
+
return output, loss_dict
|
| 315 |
+
|
| 316 |
+
def training_step(self, batch: BatchedDataDict) -> Dict[str, Any]:
|
| 317 |
+
|
| 318 |
+
if self.augmentation is not None:
|
| 319 |
+
with torch.no_grad():
|
| 320 |
+
batch = self.augmentation(batch)
|
| 321 |
+
|
| 322 |
+
_, loss_dict = self.common_step(batch, mode="train")
|
| 323 |
+
|
| 324 |
+
with torch.inference_mode():
|
| 325 |
+
self.log_dict_with_prefix(
|
| 326 |
+
loss_dict, "train", batch_size=batch["audio"]["mixture"].shape[0]
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
loss_dict["loss"] *= self.loss_adjustment
|
| 330 |
+
|
| 331 |
+
return loss_dict
|
| 332 |
+
|
| 333 |
+
def on_train_batch_end(
|
| 334 |
+
self, outputs: STEP_OUTPUT, batch: BatchedDataDict, batch_idx: int
|
| 335 |
+
) -> None:
|
| 336 |
+
|
| 337 |
+
metric_dict = self.compute_metrics()
|
| 338 |
+
self.log_dict_with_prefix(metric_dict, "train")
|
| 339 |
+
self.reset_metrics()
|
| 340 |
+
|
| 341 |
+
def validation_step(
|
| 342 |
+
self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0
|
| 343 |
+
) -> Dict[str, Any]:
|
| 344 |
+
|
| 345 |
+
with torch.inference_mode():
|
| 346 |
+
curr_val_prefix = f"val{dataloader_idx}" if dataloader_idx > 0 else "val"
|
| 347 |
+
|
| 348 |
+
if curr_val_prefix != self.val_prefix:
|
| 349 |
+
# print(f"Switching to validation dataloader {dataloader_idx}")
|
| 350 |
+
if self.val_prefix is not None:
|
| 351 |
+
self._on_validation_epoch_end()
|
| 352 |
+
self.val_prefix = curr_val_prefix
|
| 353 |
+
_, loss_dict = self.common_step(batch, mode="val")
|
| 354 |
+
|
| 355 |
+
self.log_dict_with_prefix(
|
| 356 |
+
loss_dict,
|
| 357 |
+
self.val_prefix,
|
| 358 |
+
batch_size=batch["audio"]["mixture"].shape[0],
|
| 359 |
+
prog_bar=True,
|
| 360 |
+
add_dataloader_idx=False,
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
return loss_dict
|
| 364 |
+
|
| 365 |
+
def on_validation_epoch_end(self) -> None:
|
| 366 |
+
self._on_validation_epoch_end()
|
| 367 |
+
|
| 368 |
+
def _on_validation_epoch_end(self) -> None:
|
| 369 |
+
metric_dict = self.compute_metrics()
|
| 370 |
+
self.log_dict_with_prefix(
|
| 371 |
+
metric_dict, self.val_prefix, prog_bar=True, add_dataloader_idx=False
|
| 372 |
+
)
|
| 373 |
+
# self.logger.save()
|
| 374 |
+
# print(self.val_prefix, "Validation metrics:", metric_dict)
|
| 375 |
+
self.reset_metrics()
|
| 376 |
+
|
| 377 |
+
def old_predtest_step(
|
| 378 |
+
self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0
|
| 379 |
+
) -> Tuple[BatchedDataDict, OutputType]:
|
| 380 |
+
|
| 381 |
+
audio_batch = batch["audio"]["mixture"]
|
| 382 |
+
track_batch = batch.get("track", ["" for _ in range(len(audio_batch))])
|
| 383 |
+
|
| 384 |
+
output_list_of_dicts = [
|
| 385 |
+
self.fader(audio[None, ...], lambda a: self.test_forward(a, track))
|
| 386 |
+
for audio, track in zip(audio_batch, track_batch)
|
| 387 |
+
]
|
| 388 |
+
|
| 389 |
+
output_dict_of_lists = defaultdict(list)
|
| 390 |
+
|
| 391 |
+
for output_dict in output_list_of_dicts:
|
| 392 |
+
for stem, audio in output_dict.items():
|
| 393 |
+
output_dict_of_lists[stem].append(audio)
|
| 394 |
+
|
| 395 |
+
output = {
|
| 396 |
+
"audio": {
|
| 397 |
+
stem: torch.concat(output_list, dim=0)
|
| 398 |
+
for stem, output_list in output_dict_of_lists.items()
|
| 399 |
+
}
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
return batch, output
|
| 403 |
+
|
| 404 |
+
def predtest_step(
|
| 405 |
+
self, batch: BatchedDataDict, batch_idx: int = -1, dataloader_idx: int = 0
|
| 406 |
+
) -> Tuple[BatchedDataDict, OutputType]:
|
| 407 |
+
|
| 408 |
+
if getattr(self.model, "bypass_fader", False):
|
| 409 |
+
batch, output = self.model(batch)
|
| 410 |
+
else:
|
| 411 |
+
audio_batch = batch["audio"]["mixture"]
|
| 412 |
+
output = self.fader(
|
| 413 |
+
audio_batch, lambda a: self.test_forward(a, "", batch=batch)
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
return batch, output
|
| 417 |
+
|
| 418 |
+
def test_forward(
|
| 419 |
+
self, audio: torch.Tensor, track: str = "", batch: BatchedDataDict = None
|
| 420 |
+
) -> torch.Tensor:
|
| 421 |
+
|
| 422 |
+
if self.fader is None:
|
| 423 |
+
self.attach_fader()
|
| 424 |
+
|
| 425 |
+
cond = batch.get("condition", None)
|
| 426 |
+
|
| 427 |
+
if cond is not None and cond.shape[0] == 1:
|
| 428 |
+
cond = cond.repeat(audio.shape[0], 1)
|
| 429 |
+
|
| 430 |
+
_, output = self.forward(
|
| 431 |
+
{
|
| 432 |
+
"audio": {"mixture": audio},
|
| 433 |
+
"track": track,
|
| 434 |
+
"condition": cond,
|
| 435 |
+
}
|
| 436 |
+
) # TODO: support track properly
|
| 437 |
+
|
| 438 |
+
return output["audio"]
|
| 439 |
+
|
| 440 |
+
def on_test_epoch_start(self) -> None:
|
| 441 |
+
self.attach_fader(force_reattach=True)
|
| 442 |
+
|
| 443 |
+
def test_step(
|
| 444 |
+
self, batch: BatchedDataDict, batch_idx: int, dataloader_idx: int = 0
|
| 445 |
+
) -> Any:
|
| 446 |
+
curr_test_prefix = f"test{dataloader_idx}"
|
| 447 |
+
|
| 448 |
+
# print(batch["audio"].keys())
|
| 449 |
+
|
| 450 |
+
if curr_test_prefix != self.test_prefix:
|
| 451 |
+
# print(f"Switching to test dataloader {dataloader_idx}")
|
| 452 |
+
if self.test_prefix is not None:
|
| 453 |
+
self._on_test_epoch_end()
|
| 454 |
+
self.test_prefix = curr_test_prefix
|
| 455 |
+
|
| 456 |
+
with torch.inference_mode():
|
| 457 |
+
_, output = self.predtest_step(batch, batch_idx, dataloader_idx)
|
| 458 |
+
# print(output)
|
| 459 |
+
self.update_metrics(batch, output, mode="test")
|
| 460 |
+
|
| 461 |
+
return output
|
| 462 |
+
|
| 463 |
+
def on_test_epoch_end(self) -> None:
|
| 464 |
+
self._on_test_epoch_end()
|
| 465 |
+
|
| 466 |
+
def _on_test_epoch_end(self) -> None:
|
| 467 |
+
metric_dict = self.compute_metrics(mode="test")
|
| 468 |
+
self.log_dict_with_prefix(
|
| 469 |
+
metric_dict, self.test_prefix, prog_bar=True, add_dataloader_idx=False
|
| 470 |
+
)
|
| 471 |
+
# self.logger.save()
|
| 472 |
+
# print(self.test_prefix, "Test metrics:", metric_dict)
|
| 473 |
+
self.reset_metrics()
|
| 474 |
+
|
| 475 |
+
def predict_step(
|
| 476 |
+
self,
|
| 477 |
+
batch: BatchedDataDict,
|
| 478 |
+
batch_idx: int = 0,
|
| 479 |
+
dataloader_idx: int = 0,
|
| 480 |
+
include_track_name: Optional[bool] = None,
|
| 481 |
+
get_no_vox_combinations: bool = True,
|
| 482 |
+
get_residual: bool = False,
|
| 483 |
+
treat_batch_as_channels: bool = False,
|
| 484 |
+
fs: Optional[int] = None,
|
| 485 |
+
) -> Any:
|
| 486 |
+
assert self.predict_output_path is not None
|
| 487 |
+
|
| 488 |
+
batch_size = batch["audio"]["mixture"].shape[0]
|
| 489 |
+
|
| 490 |
+
if include_track_name is None:
|
| 491 |
+
include_track_name = batch_size > 1
|
| 492 |
+
|
| 493 |
+
with torch.inference_mode():
|
| 494 |
+
batch, output = self.predtest_step(batch, batch_idx, dataloader_idx)
|
| 495 |
+
print("Pred test finished...")
|
| 496 |
+
torch.cuda.empty_cache()
|
| 497 |
+
metric_dict = {}
|
| 498 |
+
|
| 499 |
+
if get_residual:
|
| 500 |
+
mixture = batch["audio"]["mixture"]
|
| 501 |
+
extracted = sum([output["audio"][stem] for stem in output["audio"]])
|
| 502 |
+
residual = mixture - extracted
|
| 503 |
+
print(extracted.shape, mixture.shape, residual.shape)
|
| 504 |
+
|
| 505 |
+
output["audio"]["residual"] = residual
|
| 506 |
+
|
| 507 |
+
if get_no_vox_combinations:
|
| 508 |
+
no_vox_stems = [
|
| 509 |
+
stem for stem in output["audio"] if stem not in self._VOX_STEMS
|
| 510 |
+
]
|
| 511 |
+
no_vox_combinations = chain.from_iterable(
|
| 512 |
+
combinations(no_vox_stems, r) for r in range(2, len(no_vox_stems) + 1)
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
for combination in no_vox_combinations:
|
| 516 |
+
combination_ = list(combination)
|
| 517 |
+
output["audio"]["+".join(combination_)] = sum(
|
| 518 |
+
[output["audio"][stem] for stem in combination_]
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
if treat_batch_as_channels:
|
| 522 |
+
for stem in output["audio"]:
|
| 523 |
+
output["audio"][stem] = output["audio"][stem].reshape(
|
| 524 |
+
1, -1, output["audio"][stem].shape[-1]
|
| 525 |
+
)
|
| 526 |
+
batch_size = 1
|
| 527 |
+
|
| 528 |
+
for b in range(batch_size):
|
| 529 |
+
print("!!", b)
|
| 530 |
+
for stem in output["audio"]:
|
| 531 |
+
print(f"Saving audio for {stem} to {self.predict_output_path}")
|
| 532 |
+
track_name = batch["track"][b].split("/")[-1]
|
| 533 |
+
|
| 534 |
+
if batch.get("audio", {}).get(stem, None) is not None:
|
| 535 |
+
self.test_metrics[stem].reset()
|
| 536 |
+
metrics = self.test_metrics[stem](
|
| 537 |
+
batch["audio"][stem][[b], ...], output["audio"][stem][[b], ...]
|
| 538 |
+
)
|
| 539 |
+
snr = metrics["snr"]
|
| 540 |
+
sisnr = metrics["sisnr"]
|
| 541 |
+
sdr = metrics["sdr"]
|
| 542 |
+
metric_dict[stem] = metrics
|
| 543 |
+
print(
|
| 544 |
+
track_name,
|
| 545 |
+
f"snr={snr:2.2f} dB",
|
| 546 |
+
f"sisnr={sisnr:2.2f}",
|
| 547 |
+
f"sdr={sdr:2.2f} dB",
|
| 548 |
+
)
|
| 549 |
+
filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav"
|
| 550 |
+
else:
|
| 551 |
+
filename = f"{stem}.wav"
|
| 552 |
+
|
| 553 |
+
if include_track_name:
|
| 554 |
+
output_dir = os.path.join(self.predict_output_path, track_name)
|
| 555 |
+
else:
|
| 556 |
+
output_dir = self.predict_output_path
|
| 557 |
+
|
| 558 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 559 |
+
|
| 560 |
+
if fs is None:
|
| 561 |
+
fs = self.fs
|
| 562 |
+
|
| 563 |
+
ta.save(
|
| 564 |
+
os.path.join(output_dir, filename),
|
| 565 |
+
output["audio"][stem][b, ...].cpu(),
|
| 566 |
+
fs,
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
return metric_dict
|
| 570 |
+
|
| 571 |
+
def get_stems(
|
| 572 |
+
self,
|
| 573 |
+
batch: BatchedDataDict,
|
| 574 |
+
batch_idx: int = 0,
|
| 575 |
+
dataloader_idx: int = 0,
|
| 576 |
+
include_track_name: Optional[bool] = None,
|
| 577 |
+
get_no_vox_combinations: bool = True,
|
| 578 |
+
get_residual: bool = False,
|
| 579 |
+
treat_batch_as_channels: bool = False,
|
| 580 |
+
fs: Optional[int] = None,
|
| 581 |
+
) -> Any:
|
| 582 |
+
assert self.predict_output_path is not None
|
| 583 |
+
|
| 584 |
+
batch_size = batch["audio"]["mixture"].shape[0]
|
| 585 |
+
|
| 586 |
+
if include_track_name is None:
|
| 587 |
+
include_track_name = batch_size > 1
|
| 588 |
+
|
| 589 |
+
with torch.inference_mode():
|
| 590 |
+
batch, output = self.predtest_step(batch, batch_idx, dataloader_idx)
|
| 591 |
+
torch.cuda.empty_cache()
|
| 592 |
+
metric_dict = {}
|
| 593 |
+
|
| 594 |
+
if get_residual:
|
| 595 |
+
mixture = batch["audio"]["mixture"]
|
| 596 |
+
extracted = sum([output["audio"][stem] for stem in output["audio"]])
|
| 597 |
+
residual = mixture - extracted
|
| 598 |
+
# print(extracted.shape, mixture.shape, residual.shape)
|
| 599 |
+
|
| 600 |
+
output["audio"]["residual"] = residual
|
| 601 |
+
|
| 602 |
+
if get_no_vox_combinations:
|
| 603 |
+
no_vox_stems = [
|
| 604 |
+
stem for stem in output["audio"] if stem not in self._VOX_STEMS
|
| 605 |
+
]
|
| 606 |
+
no_vox_combinations = chain.from_iterable(
|
| 607 |
+
combinations(no_vox_stems, r) for r in range(2, len(no_vox_stems) + 1)
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
for combination in no_vox_combinations:
|
| 611 |
+
combination_ = list(combination)
|
| 612 |
+
output["audio"]["+".join(combination_)] = sum(
|
| 613 |
+
[output["audio"][stem] for stem in combination_]
|
| 614 |
+
)
|
| 615 |
+
|
| 616 |
+
if treat_batch_as_channels:
|
| 617 |
+
for stem in output["audio"]:
|
| 618 |
+
output["audio"][stem] = output["audio"][stem].reshape(
|
| 619 |
+
1, -1, output["audio"][stem].shape[-1]
|
| 620 |
+
)
|
| 621 |
+
batch_size = 1
|
| 622 |
+
|
| 623 |
+
result = {}
|
| 624 |
+
for b in range(batch_size):
|
| 625 |
+
for stem in output["audio"]:
|
| 626 |
+
track_name = batch["track"][b].split("/")[-1]
|
| 627 |
+
|
| 628 |
+
if batch.get("audio", {}).get(stem, None) is not None:
|
| 629 |
+
self.test_metrics[stem].reset()
|
| 630 |
+
metrics = self.test_metrics[stem](
|
| 631 |
+
batch["audio"][stem][[b], ...], output["audio"][stem][[b], ...]
|
| 632 |
+
)
|
| 633 |
+
snr = metrics["snr"]
|
| 634 |
+
sisnr = metrics["sisnr"]
|
| 635 |
+
sdr = metrics["sdr"]
|
| 636 |
+
metric_dict[stem] = metrics
|
| 637 |
+
print(
|
| 638 |
+
track_name,
|
| 639 |
+
f"snr={snr:2.2f} dB",
|
| 640 |
+
f"sisnr={sisnr:2.2f}",
|
| 641 |
+
f"sdr={sdr:2.2f} dB",
|
| 642 |
+
)
|
| 643 |
+
filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav"
|
| 644 |
+
else:
|
| 645 |
+
filename = f"{stem}.wav"
|
| 646 |
+
|
| 647 |
+
if include_track_name:
|
| 648 |
+
output_dir = os.path.join(self.predict_output_path, track_name)
|
| 649 |
+
else:
|
| 650 |
+
output_dir = self.predict_output_path
|
| 651 |
+
|
| 652 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 653 |
+
|
| 654 |
+
if fs is None:
|
| 655 |
+
fs = self.fs
|
| 656 |
+
|
| 657 |
+
result[stem] = output["audio"][stem][b, ...].cpu().numpy()
|
| 658 |
+
|
| 659 |
+
return result
|
| 660 |
+
|
| 661 |
+
def load_state_dict(
|
| 662 |
+
self, state_dict: Mapping[str, Any], strict: bool = False
|
| 663 |
+
) -> Any:
|
| 664 |
+
|
| 665 |
+
return super().load_state_dict(state_dict, strict=False)
|
| 666 |
+
|
| 667 |
+
def set_predict_output_path(self, path: str) -> None:
|
| 668 |
+
self.predict_output_path = path
|
| 669 |
+
os.makedirs(self.predict_output_path, exist_ok=True)
|
| 670 |
+
|
| 671 |
+
self.attach_fader()
|
| 672 |
+
|
| 673 |
+
def attach_fader(self, force_reattach=False) -> None:
|
| 674 |
+
if self.fader is None or force_reattach:
|
| 675 |
+
self.fader = parse_fader_config(self.fader_config)
|
| 676 |
+
self.fader.to(self.device)
|
| 677 |
+
|
| 678 |
+
def log_dict_with_prefix(
|
| 679 |
+
self,
|
| 680 |
+
dict_: Dict[str, torch.Tensor],
|
| 681 |
+
prefix: str,
|
| 682 |
+
batch_size: Optional[int] = None,
|
| 683 |
+
**kwargs: Any,
|
| 684 |
+
) -> None:
|
| 685 |
+
self.log_dict(
|
| 686 |
+
{f"{prefix}/{k}": v for k, v in dict_.items()},
|
| 687 |
+
batch_size=batch_size,
|
| 688 |
+
logger=True,
|
| 689 |
+
sync_dist=True,
|
| 690 |
+
**kwargs,
|
| 691 |
+
)
|
mvsepless/models/bandit/core/data/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .dnr.datamodule import DivideAndRemasterDataModule
|
| 2 |
+
from .musdb.datamodule import MUSDB18DataModule
|
mvsepless/models/bandit/core/data/_types.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Sequence, TypedDict
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
AudioDict = Dict[str, torch.Tensor]
|
| 6 |
+
|
| 7 |
+
DataDict = TypedDict("DataDict", {"audio": AudioDict, "track": str})
|
| 8 |
+
|
| 9 |
+
BatchedDataDict = TypedDict(
|
| 10 |
+
"BatchedDataDict", {"audio": AudioDict, "track": Sequence[str]}
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class DataDictWithLanguage(TypedDict):
|
| 15 |
+
audio: AudioDict
|
| 16 |
+
track: str
|
| 17 |
+
language: str
|
mvsepless/models/bandit/core/data/augmentation.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC
|
| 2 |
+
from typing import Any, Dict, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch_audiomentations as tam
|
| 6 |
+
from torch import nn
|
| 7 |
+
|
| 8 |
+
from ._types import BatchedDataDict, DataDict
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class BaseAugmentor(nn.Module, ABC):
|
| 12 |
+
def forward(
|
| 13 |
+
self, item: Union[DataDict, BatchedDataDict]
|
| 14 |
+
) -> Union[DataDict, BatchedDataDict]:
|
| 15 |
+
raise NotImplementedError
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class StemAugmentor(BaseAugmentor):
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
audiomentations: Dict[str, Dict[str, Any]],
|
| 22 |
+
fix_clipping: bool = True,
|
| 23 |
+
scaler_margin: float = 0.5,
|
| 24 |
+
apply_both_default_and_common: bool = False,
|
| 25 |
+
) -> None:
|
| 26 |
+
super().__init__()
|
| 27 |
+
|
| 28 |
+
augmentations = {}
|
| 29 |
+
|
| 30 |
+
self.has_default = "[default]" in audiomentations
|
| 31 |
+
self.has_common = "[common]" in audiomentations
|
| 32 |
+
self.apply_both_default_and_common = apply_both_default_and_common
|
| 33 |
+
|
| 34 |
+
for stem in audiomentations:
|
| 35 |
+
if audiomentations[stem]["name"] == "Compose":
|
| 36 |
+
augmentations[stem] = getattr(tam, audiomentations[stem]["name"])(
|
| 37 |
+
[
|
| 38 |
+
getattr(tam, aug["name"])(**aug["kwargs"])
|
| 39 |
+
for aug in audiomentations[stem]["kwargs"]["transforms"]
|
| 40 |
+
],
|
| 41 |
+
**audiomentations[stem]["kwargs"]["kwargs"],
|
| 42 |
+
)
|
| 43 |
+
else:
|
| 44 |
+
augmentations[stem] = getattr(tam, audiomentations[stem]["name"])(
|
| 45 |
+
**audiomentations[stem]["kwargs"]
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
self.augmentations = nn.ModuleDict(augmentations)
|
| 49 |
+
self.fix_clipping = fix_clipping
|
| 50 |
+
self.scaler_margin = scaler_margin
|
| 51 |
+
|
| 52 |
+
def check_and_fix_clipping(
|
| 53 |
+
self, item: Union[DataDict, BatchedDataDict]
|
| 54 |
+
) -> Union[DataDict, BatchedDataDict]:
|
| 55 |
+
max_abs = []
|
| 56 |
+
|
| 57 |
+
for stem in item["audio"]:
|
| 58 |
+
max_abs.append(item["audio"][stem].abs().max().item())
|
| 59 |
+
|
| 60 |
+
if max(max_abs) > 1.0:
|
| 61 |
+
scaler = 1.0 / (
|
| 62 |
+
max(max_abs)
|
| 63 |
+
+ torch.rand((1,), device=item["audio"]["mixture"].device)
|
| 64 |
+
* self.scaler_margin
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
for stem in item["audio"]:
|
| 68 |
+
item["audio"][stem] *= scaler
|
| 69 |
+
|
| 70 |
+
return item
|
| 71 |
+
|
| 72 |
+
def forward(
|
| 73 |
+
self, item: Union[DataDict, BatchedDataDict]
|
| 74 |
+
) -> Union[DataDict, BatchedDataDict]:
|
| 75 |
+
|
| 76 |
+
for stem in item["audio"]:
|
| 77 |
+
if stem == "mixture":
|
| 78 |
+
continue
|
| 79 |
+
|
| 80 |
+
if self.has_common:
|
| 81 |
+
item["audio"][stem] = self.augmentations["[common]"](
|
| 82 |
+
item["audio"][stem]
|
| 83 |
+
).samples
|
| 84 |
+
|
| 85 |
+
if stem in self.augmentations:
|
| 86 |
+
item["audio"][stem] = self.augmentations[stem](
|
| 87 |
+
item["audio"][stem]
|
| 88 |
+
).samples
|
| 89 |
+
elif self.has_default:
|
| 90 |
+
if not self.has_common or self.apply_both_default_and_common:
|
| 91 |
+
item["audio"][stem] = self.augmentations["[default]"](
|
| 92 |
+
item["audio"][stem]
|
| 93 |
+
).samples
|
| 94 |
+
|
| 95 |
+
item["audio"]["mixture"] = sum(
|
| 96 |
+
[item["audio"][stem] for stem in item["audio"] if stem != "mixture"]
|
| 97 |
+
) # type: ignore[call-overload, assignment]
|
| 98 |
+
|
| 99 |
+
if self.fix_clipping:
|
| 100 |
+
item = self.check_and_fix_clipping(item)
|
| 101 |
+
|
| 102 |
+
return item
|
mvsepless/models/bandit/core/data/augmented.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
from typing import Dict, Optional, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.utils import data
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class AugmentedDataset(data.Dataset):
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
dataset: data.Dataset,
|
| 13 |
+
augmentation: nn.Module = nn.Identity(),
|
| 14 |
+
target_length: Optional[int] = None,
|
| 15 |
+
) -> None:
|
| 16 |
+
warnings.warn(
|
| 17 |
+
"This class is no longer used. Attach augmentation to "
|
| 18 |
+
"the LightningSystem instead.",
|
| 19 |
+
DeprecationWarning,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
self.dataset = dataset
|
| 23 |
+
self.augmentation = augmentation
|
| 24 |
+
|
| 25 |
+
self.ds_length: int = len(dataset) # type: ignore[arg-type]
|
| 26 |
+
self.length = target_length if target_length is not None else self.ds_length
|
| 27 |
+
|
| 28 |
+
def __getitem__(self, index: int) -> Dict[str, Union[str, Dict[str, torch.Tensor]]]:
|
| 29 |
+
item = self.dataset[index % self.ds_length]
|
| 30 |
+
item = self.augmentation(item)
|
| 31 |
+
return item
|
| 32 |
+
|
| 33 |
+
def __len__(self) -> int:
|
| 34 |
+
return self.length
|
mvsepless/models/bandit/core/data/base.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from abc import ABC, abstractmethod
|
| 3 |
+
from typing import Any, Dict, List, Optional
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pedalboard as pb
|
| 7 |
+
import torch
|
| 8 |
+
import torchaudio as ta
|
| 9 |
+
from torch.utils import data
|
| 10 |
+
|
| 11 |
+
from ._types import AudioDict, DataDict
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class BaseSourceSeparationDataset(data.Dataset, ABC):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
split: str,
|
| 18 |
+
stems: List[str],
|
| 19 |
+
files: List[str],
|
| 20 |
+
data_path: str,
|
| 21 |
+
fs: int,
|
| 22 |
+
npy_memmap: bool,
|
| 23 |
+
recompute_mixture: bool,
|
| 24 |
+
):
|
| 25 |
+
self.split = split
|
| 26 |
+
self.stems = stems
|
| 27 |
+
self.stems_no_mixture = [s for s in stems if s != "mixture"]
|
| 28 |
+
self.files = files
|
| 29 |
+
self.data_path = data_path
|
| 30 |
+
self.fs = fs
|
| 31 |
+
self.npy_memmap = npy_memmap
|
| 32 |
+
self.recompute_mixture = recompute_mixture
|
| 33 |
+
|
| 34 |
+
@abstractmethod
|
| 35 |
+
def get_stem(self, *, stem: str, identifier: Dict[str, Any]) -> torch.Tensor:
|
| 36 |
+
raise NotImplementedError
|
| 37 |
+
|
| 38 |
+
def _get_audio(self, stems, identifier: Dict[str, Any]):
|
| 39 |
+
audio = {}
|
| 40 |
+
for stem in stems:
|
| 41 |
+
audio[stem] = self.get_stem(stem=stem, identifier=identifier)
|
| 42 |
+
|
| 43 |
+
return audio
|
| 44 |
+
|
| 45 |
+
def get_audio(self, identifier: Dict[str, Any]) -> AudioDict:
|
| 46 |
+
|
| 47 |
+
if self.recompute_mixture:
|
| 48 |
+
audio = self._get_audio(self.stems_no_mixture, identifier=identifier)
|
| 49 |
+
audio["mixture"] = self.compute_mixture(audio)
|
| 50 |
+
return audio
|
| 51 |
+
else:
|
| 52 |
+
return self._get_audio(self.stems, identifier=identifier)
|
| 53 |
+
|
| 54 |
+
@abstractmethod
|
| 55 |
+
def get_identifier(self, index: int) -> Dict[str, Any]:
|
| 56 |
+
pass
|
| 57 |
+
|
| 58 |
+
def compute_mixture(self, audio: AudioDict) -> torch.Tensor:
|
| 59 |
+
|
| 60 |
+
return sum(audio[stem] for stem in audio if stem != "mixture")
|
mvsepless/models/bandit/core/data/dnr/__init__.py
ADDED
|
File without changes
|
mvsepless/models/bandit/core/data/dnr/datamodule.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Mapping, Optional
|
| 3 |
+
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
|
| 6 |
+
from .dataset import (
|
| 7 |
+
DivideAndRemasterDataset,
|
| 8 |
+
DivideAndRemasterDeterministicChunkDataset,
|
| 9 |
+
DivideAndRemasterRandomChunkDataset,
|
| 10 |
+
DivideAndRemasterRandomChunkDatasetWithSpeechReverb,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def DivideAndRemasterDataModule(
|
| 15 |
+
data_root: str = "$DATA_ROOT/DnR/v2",
|
| 16 |
+
batch_size: int = 2,
|
| 17 |
+
num_workers: int = 8,
|
| 18 |
+
train_kwargs: Optional[Mapping] = None,
|
| 19 |
+
val_kwargs: Optional[Mapping] = None,
|
| 20 |
+
test_kwargs: Optional[Mapping] = None,
|
| 21 |
+
datamodule_kwargs: Optional[Mapping] = None,
|
| 22 |
+
use_speech_reverb: bool = False,
|
| 23 |
+
# augmentor=None
|
| 24 |
+
) -> pl.LightningDataModule:
|
| 25 |
+
if train_kwargs is None:
|
| 26 |
+
train_kwargs = {}
|
| 27 |
+
|
| 28 |
+
if val_kwargs is None:
|
| 29 |
+
val_kwargs = {}
|
| 30 |
+
|
| 31 |
+
if test_kwargs is None:
|
| 32 |
+
test_kwargs = {}
|
| 33 |
+
|
| 34 |
+
if datamodule_kwargs is None:
|
| 35 |
+
datamodule_kwargs = {}
|
| 36 |
+
|
| 37 |
+
if num_workers is None:
|
| 38 |
+
num_workers = os.cpu_count()
|
| 39 |
+
|
| 40 |
+
if num_workers is None:
|
| 41 |
+
num_workers = 32
|
| 42 |
+
|
| 43 |
+
num_workers = min(num_workers, 64)
|
| 44 |
+
|
| 45 |
+
if use_speech_reverb:
|
| 46 |
+
train_cls = DivideAndRemasterRandomChunkDatasetWithSpeechReverb
|
| 47 |
+
else:
|
| 48 |
+
train_cls = DivideAndRemasterRandomChunkDataset
|
| 49 |
+
|
| 50 |
+
train_dataset = train_cls(data_root, "train", **train_kwargs)
|
| 51 |
+
|
| 52 |
+
# if augmentor is not None:
|
| 53 |
+
# train_dataset = AugmentedDataset(train_dataset, augmentor)
|
| 54 |
+
|
| 55 |
+
datamodule = pl.LightningDataModule.from_datasets(
|
| 56 |
+
train_dataset=train_dataset,
|
| 57 |
+
val_dataset=DivideAndRemasterDeterministicChunkDataset(
|
| 58 |
+
data_root, "val", **val_kwargs
|
| 59 |
+
),
|
| 60 |
+
test_dataset=DivideAndRemasterDataset(data_root, "test", **test_kwargs),
|
| 61 |
+
batch_size=batch_size,
|
| 62 |
+
num_workers=num_workers,
|
| 63 |
+
**datamodule_kwargs
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
datamodule.predict_dataloader = datamodule.test_dataloader # type: ignore[method-assign]
|
| 67 |
+
|
| 68 |
+
return datamodule
|
mvsepless/models/bandit/core/data/dnr/dataset.py
ADDED
|
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from abc import ABC
|
| 3 |
+
from typing import Any, Dict, List, Optional
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pedalboard as pb
|
| 7 |
+
import torch
|
| 8 |
+
import torchaudio as ta
|
| 9 |
+
from torch.utils import data
|
| 10 |
+
|
| 11 |
+
from .._types import AudioDict, DataDict
|
| 12 |
+
from ..base import BaseSourceSeparationDataset
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class DivideAndRemasterBaseDataset(BaseSourceSeparationDataset, ABC):
|
| 16 |
+
ALLOWED_STEMS = ["mixture", "speech", "music", "effects", "mne"]
|
| 17 |
+
STEM_NAME_MAP = {
|
| 18 |
+
"mixture": "mix",
|
| 19 |
+
"speech": "speech",
|
| 20 |
+
"music": "music",
|
| 21 |
+
"effects": "sfx",
|
| 22 |
+
}
|
| 23 |
+
SPLIT_NAME_MAP = {"train": "tr", "val": "cv", "test": "tt"}
|
| 24 |
+
|
| 25 |
+
FULL_TRACK_LENGTH_SECOND = 60
|
| 26 |
+
FULL_TRACK_LENGTH_SAMPLES = FULL_TRACK_LENGTH_SECOND * 44100
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
split: str,
|
| 31 |
+
stems: List[str],
|
| 32 |
+
files: List[str],
|
| 33 |
+
data_path: str,
|
| 34 |
+
fs: int = 44100,
|
| 35 |
+
npy_memmap: bool = True,
|
| 36 |
+
recompute_mixture: bool = False,
|
| 37 |
+
) -> None:
|
| 38 |
+
super().__init__(
|
| 39 |
+
split=split,
|
| 40 |
+
stems=stems,
|
| 41 |
+
files=files,
|
| 42 |
+
data_path=data_path,
|
| 43 |
+
fs=fs,
|
| 44 |
+
npy_memmap=npy_memmap,
|
| 45 |
+
recompute_mixture=recompute_mixture,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
def get_stem(self, *, stem: str, identifier: Dict[str, Any]) -> torch.Tensor:
|
| 49 |
+
|
| 50 |
+
if stem == "mne":
|
| 51 |
+
return self.get_stem(stem="music", identifier=identifier) + self.get_stem(
|
| 52 |
+
stem="effects", identifier=identifier
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
track = identifier["track"]
|
| 56 |
+
path = os.path.join(self.data_path, track)
|
| 57 |
+
|
| 58 |
+
if self.npy_memmap:
|
| 59 |
+
audio = np.load(
|
| 60 |
+
os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.npy"), mmap_mode="r"
|
| 61 |
+
)
|
| 62 |
+
else:
|
| 63 |
+
# noinspection PyUnresolvedReferences
|
| 64 |
+
audio, _ = ta.load(os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.wav"))
|
| 65 |
+
|
| 66 |
+
return audio
|
| 67 |
+
|
| 68 |
+
def get_identifier(self, index):
|
| 69 |
+
return dict(track=self.files[index])
|
| 70 |
+
|
| 71 |
+
def __getitem__(self, index: int) -> DataDict:
|
| 72 |
+
identifier = self.get_identifier(index)
|
| 73 |
+
audio = self.get_audio(identifier)
|
| 74 |
+
|
| 75 |
+
return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class DivideAndRemasterDataset(DivideAndRemasterBaseDataset):
|
| 79 |
+
def __init__(
|
| 80 |
+
self,
|
| 81 |
+
data_root: str,
|
| 82 |
+
split: str,
|
| 83 |
+
stems: Optional[List[str]] = None,
|
| 84 |
+
fs: int = 44100,
|
| 85 |
+
npy_memmap: bool = True,
|
| 86 |
+
) -> None:
|
| 87 |
+
|
| 88 |
+
if stems is None:
|
| 89 |
+
stems = self.ALLOWED_STEMS
|
| 90 |
+
self.stems = stems
|
| 91 |
+
|
| 92 |
+
data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
|
| 93 |
+
|
| 94 |
+
files = sorted(os.listdir(data_path))
|
| 95 |
+
files = [
|
| 96 |
+
f
|
| 97 |
+
for f in files
|
| 98 |
+
if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f))
|
| 99 |
+
]
|
| 100 |
+
# pprint(list(enumerate(files)))
|
| 101 |
+
if split == "train":
|
| 102 |
+
assert len(files) == 3406, len(files)
|
| 103 |
+
elif split == "val":
|
| 104 |
+
assert len(files) == 487, len(files)
|
| 105 |
+
elif split == "test":
|
| 106 |
+
assert len(files) == 973, len(files)
|
| 107 |
+
|
| 108 |
+
self.n_tracks = len(files)
|
| 109 |
+
|
| 110 |
+
super().__init__(
|
| 111 |
+
data_path=data_path,
|
| 112 |
+
split=split,
|
| 113 |
+
stems=stems,
|
| 114 |
+
files=files,
|
| 115 |
+
fs=fs,
|
| 116 |
+
npy_memmap=npy_memmap,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
def __len__(self) -> int:
|
| 120 |
+
return self.n_tracks
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class DivideAndRemasterRandomChunkDataset(DivideAndRemasterBaseDataset):
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
data_root: str,
|
| 127 |
+
split: str,
|
| 128 |
+
target_length: int,
|
| 129 |
+
chunk_size_second: float,
|
| 130 |
+
stems: Optional[List[str]] = None,
|
| 131 |
+
fs: int = 44100,
|
| 132 |
+
npy_memmap: bool = True,
|
| 133 |
+
) -> None:
|
| 134 |
+
|
| 135 |
+
if stems is None:
|
| 136 |
+
stems = self.ALLOWED_STEMS
|
| 137 |
+
self.stems = stems
|
| 138 |
+
|
| 139 |
+
data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
|
| 140 |
+
|
| 141 |
+
files = sorted(os.listdir(data_path))
|
| 142 |
+
files = [
|
| 143 |
+
f
|
| 144 |
+
for f in files
|
| 145 |
+
if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f))
|
| 146 |
+
]
|
| 147 |
+
|
| 148 |
+
if split == "train":
|
| 149 |
+
assert len(files) == 3406, len(files)
|
| 150 |
+
elif split == "val":
|
| 151 |
+
assert len(files) == 487, len(files)
|
| 152 |
+
elif split == "test":
|
| 153 |
+
assert len(files) == 973, len(files)
|
| 154 |
+
|
| 155 |
+
self.n_tracks = len(files)
|
| 156 |
+
|
| 157 |
+
self.target_length = target_length
|
| 158 |
+
self.chunk_size = int(chunk_size_second * fs)
|
| 159 |
+
|
| 160 |
+
super().__init__(
|
| 161 |
+
data_path=data_path,
|
| 162 |
+
split=split,
|
| 163 |
+
stems=stems,
|
| 164 |
+
files=files,
|
| 165 |
+
fs=fs,
|
| 166 |
+
npy_memmap=npy_memmap,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
def __len__(self) -> int:
|
| 170 |
+
return self.target_length
|
| 171 |
+
|
| 172 |
+
def get_identifier(self, index):
|
| 173 |
+
return super().get_identifier(index % self.n_tracks)
|
| 174 |
+
|
| 175 |
+
def get_stem(
|
| 176 |
+
self,
|
| 177 |
+
*,
|
| 178 |
+
stem: str,
|
| 179 |
+
identifier: Dict[str, Any],
|
| 180 |
+
chunk_here: bool = False,
|
| 181 |
+
) -> torch.Tensor:
|
| 182 |
+
|
| 183 |
+
stem = super().get_stem(stem=stem, identifier=identifier)
|
| 184 |
+
|
| 185 |
+
if chunk_here:
|
| 186 |
+
start = np.random.randint(
|
| 187 |
+
0, self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size
|
| 188 |
+
)
|
| 189 |
+
end = start + self.chunk_size
|
| 190 |
+
|
| 191 |
+
stem = stem[:, start:end]
|
| 192 |
+
|
| 193 |
+
return stem
|
| 194 |
+
|
| 195 |
+
def __getitem__(self, index: int) -> DataDict:
|
| 196 |
+
identifier = self.get_identifier(index)
|
| 197 |
+
# self.index_lock = index
|
| 198 |
+
audio = self.get_audio(identifier)
|
| 199 |
+
# self.index_lock = None
|
| 200 |
+
|
| 201 |
+
start = np.random.randint(0, self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size)
|
| 202 |
+
end = start + self.chunk_size
|
| 203 |
+
|
| 204 |
+
audio = {k: v[:, start:end] for k, v in audio.items()}
|
| 205 |
+
|
| 206 |
+
return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class DivideAndRemasterDeterministicChunkDataset(DivideAndRemasterBaseDataset):
|
| 210 |
+
def __init__(
|
| 211 |
+
self,
|
| 212 |
+
data_root: str,
|
| 213 |
+
split: str,
|
| 214 |
+
chunk_size_second: float,
|
| 215 |
+
hop_size_second: float,
|
| 216 |
+
stems: Optional[List[str]] = None,
|
| 217 |
+
fs: int = 44100,
|
| 218 |
+
npy_memmap: bool = True,
|
| 219 |
+
) -> None:
|
| 220 |
+
|
| 221 |
+
if stems is None:
|
| 222 |
+
stems = self.ALLOWED_STEMS
|
| 223 |
+
self.stems = stems
|
| 224 |
+
|
| 225 |
+
data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
|
| 226 |
+
|
| 227 |
+
files = sorted(os.listdir(data_path))
|
| 228 |
+
files = [
|
| 229 |
+
f
|
| 230 |
+
for f in files
|
| 231 |
+
if (not f.startswith(".")) and os.path.isdir(os.path.join(data_path, f))
|
| 232 |
+
]
|
| 233 |
+
# pprint(list(enumerate(files)))
|
| 234 |
+
if split == "train":
|
| 235 |
+
assert len(files) == 3406, len(files)
|
| 236 |
+
elif split == "val":
|
| 237 |
+
assert len(files) == 487, len(files)
|
| 238 |
+
elif split == "test":
|
| 239 |
+
assert len(files) == 973, len(files)
|
| 240 |
+
|
| 241 |
+
self.n_tracks = len(files)
|
| 242 |
+
|
| 243 |
+
self.chunk_size = int(chunk_size_second * fs)
|
| 244 |
+
self.hop_size = int(hop_size_second * fs)
|
| 245 |
+
self.n_chunks_per_track = int(
|
| 246 |
+
(self.FULL_TRACK_LENGTH_SECOND - chunk_size_second) / hop_size_second
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
self.length = self.n_tracks * self.n_chunks_per_track
|
| 250 |
+
|
| 251 |
+
super().__init__(
|
| 252 |
+
data_path=data_path,
|
| 253 |
+
split=split,
|
| 254 |
+
stems=stems,
|
| 255 |
+
files=files,
|
| 256 |
+
fs=fs,
|
| 257 |
+
npy_memmap=npy_memmap,
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
def get_identifier(self, index):
|
| 261 |
+
return super().get_identifier(index % self.n_tracks)
|
| 262 |
+
|
| 263 |
+
def __len__(self) -> int:
|
| 264 |
+
return self.length
|
| 265 |
+
|
| 266 |
+
def __getitem__(self, item: int) -> DataDict:
|
| 267 |
+
|
| 268 |
+
index = item % self.n_tracks
|
| 269 |
+
chunk = item // self.n_tracks
|
| 270 |
+
|
| 271 |
+
data_ = super().__getitem__(index)
|
| 272 |
+
|
| 273 |
+
audio = data_["audio"]
|
| 274 |
+
|
| 275 |
+
start = chunk * self.hop_size
|
| 276 |
+
end = start + self.chunk_size
|
| 277 |
+
|
| 278 |
+
for stem in self.stems:
|
| 279 |
+
data_["audio"][stem] = audio[stem][:, start:end]
|
| 280 |
+
|
| 281 |
+
return data_
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
class DivideAndRemasterRandomChunkDatasetWithSpeechReverb(
|
| 285 |
+
DivideAndRemasterRandomChunkDataset
|
| 286 |
+
):
|
| 287 |
+
def __init__(
|
| 288 |
+
self,
|
| 289 |
+
data_root: str,
|
| 290 |
+
split: str,
|
| 291 |
+
target_length: int,
|
| 292 |
+
chunk_size_second: float,
|
| 293 |
+
stems: Optional[List[str]] = None,
|
| 294 |
+
fs: int = 44100,
|
| 295 |
+
npy_memmap: bool = True,
|
| 296 |
+
) -> None:
|
| 297 |
+
|
| 298 |
+
if stems is None:
|
| 299 |
+
stems = self.ALLOWED_STEMS
|
| 300 |
+
|
| 301 |
+
stems_no_mixture = [s for s in stems if s != "mixture"]
|
| 302 |
+
|
| 303 |
+
super().__init__(
|
| 304 |
+
data_root=data_root,
|
| 305 |
+
split=split,
|
| 306 |
+
target_length=target_length,
|
| 307 |
+
chunk_size_second=chunk_size_second,
|
| 308 |
+
stems=stems_no_mixture,
|
| 309 |
+
fs=fs,
|
| 310 |
+
npy_memmap=npy_memmap,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
self.stems = stems
|
| 314 |
+
self.stems_no_mixture = stems_no_mixture
|
| 315 |
+
|
| 316 |
+
def __getitem__(self, index: int) -> DataDict:
|
| 317 |
+
|
| 318 |
+
data_ = super().__getitem__(index)
|
| 319 |
+
|
| 320 |
+
dry = data_["audio"]["speech"][:]
|
| 321 |
+
n_samples = dry.shape[-1]
|
| 322 |
+
|
| 323 |
+
wet_level = np.random.rand()
|
| 324 |
+
|
| 325 |
+
speech = pb.Reverb(
|
| 326 |
+
room_size=np.random.rand(),
|
| 327 |
+
damping=np.random.rand(),
|
| 328 |
+
wet_level=wet_level,
|
| 329 |
+
dry_level=(1 - wet_level),
|
| 330 |
+
width=np.random.rand(),
|
| 331 |
+
).process(dry, self.fs, buffer_size=8192 * 4)[..., :n_samples]
|
| 332 |
+
|
| 333 |
+
data_["audio"]["speech"] = speech
|
| 334 |
+
|
| 335 |
+
data_["audio"]["mixture"] = sum(
|
| 336 |
+
[data_["audio"][s] for s in self.stems_no_mixture]
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
return data_
|
| 340 |
+
|
| 341 |
+
def __len__(self) -> int:
|
| 342 |
+
return super().__len__()
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
if __name__ == "__main__":
|
| 346 |
+
|
| 347 |
+
from pprint import pprint
|
| 348 |
+
from tqdm.auto import tqdm
|
| 349 |
+
|
| 350 |
+
for split_ in ["train", "val", "test"]:
|
| 351 |
+
ds = DivideAndRemasterRandomChunkDatasetWithSpeechReverb(
|
| 352 |
+
data_root="$DATA_ROOT/DnR/v2np",
|
| 353 |
+
split=split_,
|
| 354 |
+
target_length=100,
|
| 355 |
+
chunk_size_second=6.0,
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
print(split_, len(ds))
|
| 359 |
+
|
| 360 |
+
for track_ in tqdm(ds): # type: ignore
|
| 361 |
+
pprint(track_)
|
| 362 |
+
track_["audio"] = {k: v.shape for k, v in track_["audio"].items()}
|
| 363 |
+
pprint(track_)
|
| 364 |
+
# break
|
| 365 |
+
|
| 366 |
+
break
|
mvsepless/models/bandit/core/data/dnr/preprocess.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torchaudio as ta
|
| 7 |
+
from tqdm.contrib.concurrent import process_map
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def process_one(inputs: Tuple[str, str, int]) -> None:
|
| 11 |
+
infile, outfile, target_fs = inputs
|
| 12 |
+
|
| 13 |
+
dir = os.path.dirname(outfile)
|
| 14 |
+
os.makedirs(dir, exist_ok=True)
|
| 15 |
+
|
| 16 |
+
data, fs = ta.load(infile)
|
| 17 |
+
|
| 18 |
+
if fs != target_fs:
|
| 19 |
+
data = ta.functional.resample(
|
| 20 |
+
data, fs, target_fs, resampling_method="sinc_interp_kaiser"
|
| 21 |
+
)
|
| 22 |
+
fs = target_fs
|
| 23 |
+
|
| 24 |
+
data = data.numpy()
|
| 25 |
+
data = data.astype(np.float32)
|
| 26 |
+
|
| 27 |
+
if os.path.exists(outfile):
|
| 28 |
+
data_ = np.load(outfile)
|
| 29 |
+
if np.allclose(data, data_):
|
| 30 |
+
return
|
| 31 |
+
|
| 32 |
+
np.save(outfile, data)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def preprocess(data_path: str, output_path: str, fs: int) -> None:
|
| 36 |
+
files = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True)
|
| 37 |
+
print(files)
|
| 38 |
+
outfiles = [
|
| 39 |
+
f.replace(data_path, output_path).replace(".wav", ".npy") for f in files
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
os.makedirs(output_path, exist_ok=True)
|
| 43 |
+
inputs = list(zip(files, outfiles, [fs] * len(files)))
|
| 44 |
+
|
| 45 |
+
process_map(process_one, inputs, chunksize=32)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
if __name__ == "__main__":
|
| 49 |
+
import fire
|
| 50 |
+
|
| 51 |
+
fire.Fire()
|
mvsepless/models/bandit/core/data/musdb/__init__.py
ADDED
|
File without changes
|
mvsepless/models/bandit/core/data/musdb/datamodule.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path
|
| 2 |
+
from typing import Mapping, Optional
|
| 3 |
+
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
|
| 6 |
+
from .dataset import (
|
| 7 |
+
MUSDB18BaseDataset,
|
| 8 |
+
MUSDB18FullTrackDataset,
|
| 9 |
+
MUSDB18SadDataset,
|
| 10 |
+
MUSDB18SadOnTheFlyAugmentedDataset,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def MUSDB18DataModule(
|
| 15 |
+
data_root: str = "$DATA_ROOT/MUSDB18/HQ",
|
| 16 |
+
target_stem: str = "vocals",
|
| 17 |
+
batch_size: int = 2,
|
| 18 |
+
num_workers: int = 8,
|
| 19 |
+
train_kwargs: Optional[Mapping] = None,
|
| 20 |
+
val_kwargs: Optional[Mapping] = None,
|
| 21 |
+
test_kwargs: Optional[Mapping] = None,
|
| 22 |
+
datamodule_kwargs: Optional[Mapping] = None,
|
| 23 |
+
use_on_the_fly: bool = True,
|
| 24 |
+
npy_memmap: bool = True,
|
| 25 |
+
) -> pl.LightningDataModule:
|
| 26 |
+
if train_kwargs is None:
|
| 27 |
+
train_kwargs = {}
|
| 28 |
+
|
| 29 |
+
if val_kwargs is None:
|
| 30 |
+
val_kwargs = {}
|
| 31 |
+
|
| 32 |
+
if test_kwargs is None:
|
| 33 |
+
test_kwargs = {}
|
| 34 |
+
|
| 35 |
+
if datamodule_kwargs is None:
|
| 36 |
+
datamodule_kwargs = {}
|
| 37 |
+
|
| 38 |
+
train_dataset: MUSDB18BaseDataset
|
| 39 |
+
|
| 40 |
+
if use_on_the_fly:
|
| 41 |
+
train_dataset = MUSDB18SadOnTheFlyAugmentedDataset(
|
| 42 |
+
data_root=os.path.join(data_root, "saded-np"),
|
| 43 |
+
split="train",
|
| 44 |
+
target_stem=target_stem,
|
| 45 |
+
**train_kwargs
|
| 46 |
+
)
|
| 47 |
+
else:
|
| 48 |
+
train_dataset = MUSDB18SadDataset(
|
| 49 |
+
data_root=os.path.join(data_root, "saded-np"),
|
| 50 |
+
split="train",
|
| 51 |
+
target_stem=target_stem,
|
| 52 |
+
**train_kwargs
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
datamodule = pl.LightningDataModule.from_datasets(
|
| 56 |
+
train_dataset=train_dataset,
|
| 57 |
+
val_dataset=MUSDB18SadDataset(
|
| 58 |
+
data_root=os.path.join(data_root, "saded-np"),
|
| 59 |
+
split="val",
|
| 60 |
+
target_stem=target_stem,
|
| 61 |
+
**val_kwargs
|
| 62 |
+
),
|
| 63 |
+
test_dataset=MUSDB18FullTrackDataset(
|
| 64 |
+
data_root=os.path.join(data_root, "canonical"), split="test", **test_kwargs
|
| 65 |
+
),
|
| 66 |
+
batch_size=batch_size,
|
| 67 |
+
num_workers=num_workers,
|
| 68 |
+
**datamodule_kwargs
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
datamodule.predict_dataloader = ( # type: ignore[method-assign]
|
| 72 |
+
datamodule.test_dataloader
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
return datamodule
|
mvsepless/models/bandit/core/data/musdb/dataset.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from abc import ABC
|
| 3 |
+
from typing import List, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torchaudio as ta
|
| 8 |
+
from torch.utils import data
|
| 9 |
+
|
| 10 |
+
from .._types import AudioDict, DataDict
|
| 11 |
+
from ..base import BaseSourceSeparationDataset
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class MUSDB18BaseDataset(BaseSourceSeparationDataset, ABC):
|
| 15 |
+
|
| 16 |
+
ALLOWED_STEMS = ["mixture", "vocals", "bass", "drums", "other"]
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
split: str,
|
| 21 |
+
stems: List[str],
|
| 22 |
+
files: List[str],
|
| 23 |
+
data_path: str,
|
| 24 |
+
fs: int = 44100,
|
| 25 |
+
npy_memmap=False,
|
| 26 |
+
) -> None:
|
| 27 |
+
super().__init__(
|
| 28 |
+
split=split,
|
| 29 |
+
stems=stems,
|
| 30 |
+
files=files,
|
| 31 |
+
data_path=data_path,
|
| 32 |
+
fs=fs,
|
| 33 |
+
npy_memmap=npy_memmap,
|
| 34 |
+
recompute_mixture=False,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
def get_stem(self, *, stem: str, identifier) -> torch.Tensor:
|
| 38 |
+
track = identifier["track"]
|
| 39 |
+
path = os.path.join(self.data_path, track)
|
| 40 |
+
# noinspection PyUnresolvedReferences
|
| 41 |
+
|
| 42 |
+
if self.npy_memmap:
|
| 43 |
+
audio = np.load(os.path.join(path, f"{stem}.wav.npy"), mmap_mode="r")
|
| 44 |
+
else:
|
| 45 |
+
audio, _ = ta.load(os.path.join(path, f"{stem}.wav"))
|
| 46 |
+
|
| 47 |
+
return audio
|
| 48 |
+
|
| 49 |
+
def get_identifier(self, index):
|
| 50 |
+
return dict(track=self.files[index])
|
| 51 |
+
|
| 52 |
+
def __getitem__(self, index: int) -> DataDict:
|
| 53 |
+
identifier = self.get_identifier(index)
|
| 54 |
+
audio = self.get_audio(identifier)
|
| 55 |
+
|
| 56 |
+
return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class MUSDB18FullTrackDataset(MUSDB18BaseDataset):
|
| 60 |
+
|
| 61 |
+
N_TRAIN_TRACKS = 100
|
| 62 |
+
N_TEST_TRACKS = 50
|
| 63 |
+
VALIDATION_FILES = [
|
| 64 |
+
"Actions - One Minute Smile",
|
| 65 |
+
"Clara Berry And Wooldog - Waltz For My Victims",
|
| 66 |
+
"Johnny Lokke - Promises & Lies",
|
| 67 |
+
"Patrick Talbot - A Reason To Leave",
|
| 68 |
+
"Triviul - Angelsaint",
|
| 69 |
+
"Alexander Ross - Goodbye Bolero",
|
| 70 |
+
"Fergessen - Nos Palpitants",
|
| 71 |
+
"Leaf - Summerghost",
|
| 72 |
+
"Skelpolu - Human Mistakes",
|
| 73 |
+
"Young Griffo - Pennies",
|
| 74 |
+
"ANiMAL - Rockshow",
|
| 75 |
+
"James May - On The Line",
|
| 76 |
+
"Meaxic - Take A Step",
|
| 77 |
+
"Traffic Experiment - Sirens",
|
| 78 |
+
]
|
| 79 |
+
|
| 80 |
+
def __init__(
|
| 81 |
+
self, data_root: str, split: str, stems: Optional[List[str]] = None
|
| 82 |
+
) -> None:
|
| 83 |
+
|
| 84 |
+
if stems is None:
|
| 85 |
+
stems = self.ALLOWED_STEMS
|
| 86 |
+
self.stems = stems
|
| 87 |
+
|
| 88 |
+
if split == "test":
|
| 89 |
+
subset = "test"
|
| 90 |
+
elif split in ["train", "val"]:
|
| 91 |
+
subset = "train"
|
| 92 |
+
else:
|
| 93 |
+
raise NameError
|
| 94 |
+
|
| 95 |
+
data_path = os.path.join(data_root, subset)
|
| 96 |
+
|
| 97 |
+
files = sorted(os.listdir(data_path))
|
| 98 |
+
files = [f for f in files if not f.startswith(".")]
|
| 99 |
+
# pprint(list(enumerate(files)))
|
| 100 |
+
if subset == "train":
|
| 101 |
+
assert len(files) == 100, len(files)
|
| 102 |
+
if split == "train":
|
| 103 |
+
files = [f for f in files if f not in self.VALIDATION_FILES]
|
| 104 |
+
assert len(files) == 100 - len(self.VALIDATION_FILES)
|
| 105 |
+
else:
|
| 106 |
+
files = [f for f in files if f in self.VALIDATION_FILES]
|
| 107 |
+
assert len(files) == len(self.VALIDATION_FILES)
|
| 108 |
+
else:
|
| 109 |
+
split = "test"
|
| 110 |
+
assert len(files) == 50
|
| 111 |
+
|
| 112 |
+
self.n_tracks = len(files)
|
| 113 |
+
|
| 114 |
+
super().__init__(data_path=data_path, split=split, stems=stems, files=files)
|
| 115 |
+
|
| 116 |
+
def __len__(self) -> int:
|
| 117 |
+
return self.n_tracks
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class MUSDB18SadDataset(MUSDB18BaseDataset):
|
| 121 |
+
def __init__(
|
| 122 |
+
self,
|
| 123 |
+
data_root: str,
|
| 124 |
+
split: str,
|
| 125 |
+
target_stem: str,
|
| 126 |
+
stems: Optional[List[str]] = None,
|
| 127 |
+
target_length: Optional[int] = None,
|
| 128 |
+
npy_memmap=False,
|
| 129 |
+
) -> None:
|
| 130 |
+
|
| 131 |
+
if stems is None:
|
| 132 |
+
stems = self.ALLOWED_STEMS
|
| 133 |
+
|
| 134 |
+
data_path = os.path.join(data_root, target_stem, split)
|
| 135 |
+
|
| 136 |
+
files = sorted(os.listdir(data_path))
|
| 137 |
+
files = [f for f in files if not f.startswith(".")]
|
| 138 |
+
|
| 139 |
+
super().__init__(
|
| 140 |
+
data_path=data_path,
|
| 141 |
+
split=split,
|
| 142 |
+
stems=stems,
|
| 143 |
+
files=files,
|
| 144 |
+
npy_memmap=npy_memmap,
|
| 145 |
+
)
|
| 146 |
+
self.n_segments = len(files)
|
| 147 |
+
self.target_stem = target_stem
|
| 148 |
+
self.target_length = (
|
| 149 |
+
target_length if target_length is not None else self.n_segments
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
def __len__(self) -> int:
|
| 153 |
+
return self.target_length
|
| 154 |
+
|
| 155 |
+
def __getitem__(self, index: int) -> DataDict:
|
| 156 |
+
|
| 157 |
+
index = index % self.n_segments
|
| 158 |
+
|
| 159 |
+
return super().__getitem__(index)
|
| 160 |
+
|
| 161 |
+
def get_identifier(self, index):
|
| 162 |
+
return super().get_identifier(index % self.n_segments)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class MUSDB18SadOnTheFlyAugmentedDataset(MUSDB18SadDataset):
|
| 166 |
+
def __init__(
|
| 167 |
+
self,
|
| 168 |
+
data_root: str,
|
| 169 |
+
split: str,
|
| 170 |
+
target_stem: str,
|
| 171 |
+
stems: Optional[List[str]] = None,
|
| 172 |
+
target_length: int = 20000,
|
| 173 |
+
apply_probability: Optional[float] = None,
|
| 174 |
+
chunk_size_second: float = 3.0,
|
| 175 |
+
random_scale_range_db: Tuple[float, float] = (-10, 10),
|
| 176 |
+
drop_probability: float = 0.1,
|
| 177 |
+
rescale: bool = True,
|
| 178 |
+
) -> None:
|
| 179 |
+
super().__init__(data_root, split, target_stem, stems)
|
| 180 |
+
|
| 181 |
+
if apply_probability is None:
|
| 182 |
+
apply_probability = (target_length - self.n_segments) / target_length
|
| 183 |
+
|
| 184 |
+
self.apply_probability = apply_probability
|
| 185 |
+
self.drop_probability = drop_probability
|
| 186 |
+
self.chunk_size_second = chunk_size_second
|
| 187 |
+
self.random_scale_range_db = random_scale_range_db
|
| 188 |
+
self.rescale = rescale
|
| 189 |
+
|
| 190 |
+
self.chunk_size_sample = int(self.chunk_size_second * self.fs)
|
| 191 |
+
self.target_length = target_length
|
| 192 |
+
|
| 193 |
+
def __len__(self) -> int:
|
| 194 |
+
return self.target_length
|
| 195 |
+
|
| 196 |
+
def __getitem__(self, index: int) -> DataDict:
|
| 197 |
+
|
| 198 |
+
index = index % self.n_segments
|
| 199 |
+
|
| 200 |
+
# if np.random.rand() > self.apply_probability:
|
| 201 |
+
# return super().__getitem__(index)
|
| 202 |
+
|
| 203 |
+
audio = {}
|
| 204 |
+
identifier = self.get_identifier(index)
|
| 205 |
+
|
| 206 |
+
# assert self.target_stem in self.stems_no_mixture
|
| 207 |
+
for stem in self.stems_no_mixture:
|
| 208 |
+
if stem == self.target_stem:
|
| 209 |
+
identifier_ = identifier
|
| 210 |
+
else:
|
| 211 |
+
if np.random.rand() < self.apply_probability:
|
| 212 |
+
index_ = np.random.randint(self.n_segments)
|
| 213 |
+
identifier_ = self.get_identifier(index_)
|
| 214 |
+
else:
|
| 215 |
+
identifier_ = identifier
|
| 216 |
+
|
| 217 |
+
audio[stem] = self.get_stem(stem=stem, identifier=identifier_)
|
| 218 |
+
|
| 219 |
+
# if stem == self.target_stem:
|
| 220 |
+
|
| 221 |
+
if self.chunk_size_sample < audio[stem].shape[-1]:
|
| 222 |
+
chunk_start = np.random.randint(
|
| 223 |
+
audio[stem].shape[-1] - self.chunk_size_sample
|
| 224 |
+
)
|
| 225 |
+
else:
|
| 226 |
+
chunk_start = 0
|
| 227 |
+
|
| 228 |
+
if np.random.rand() < self.drop_probability:
|
| 229 |
+
# db_scale = "-inf"
|
| 230 |
+
linear_scale = 0.0
|
| 231 |
+
else:
|
| 232 |
+
db_scale = np.random.uniform(*self.random_scale_range_db)
|
| 233 |
+
linear_scale = np.power(10, db_scale / 20)
|
| 234 |
+
# db_scale = f"{db_scale:+2.1f}"
|
| 235 |
+
# print(linear_scale)
|
| 236 |
+
audio[stem][..., chunk_start : chunk_start + self.chunk_size_sample] = (
|
| 237 |
+
linear_scale
|
| 238 |
+
* audio[stem][..., chunk_start : chunk_start + self.chunk_size_sample]
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
audio["mixture"] = self.compute_mixture(audio)
|
| 242 |
+
|
| 243 |
+
if self.rescale:
|
| 244 |
+
max_abs_val = max(
|
| 245 |
+
[torch.max(torch.abs(audio[stem])) for stem in self.stems]
|
| 246 |
+
) # type: ignore[type-var]
|
| 247 |
+
if max_abs_val > 1:
|
| 248 |
+
audio = {k: v / max_abs_val for k, v in audio.items()}
|
| 249 |
+
|
| 250 |
+
track = identifier["track"]
|
| 251 |
+
|
| 252 |
+
return {"audio": audio, "track": f"{self.split}/{track}"}
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
# if __name__ == "__main__":
|
| 256 |
+
#
|
| 257 |
+
# from pprint import pprint
|
| 258 |
+
# from tqdm.auto import tqdm
|
| 259 |
+
#
|
| 260 |
+
# for split_ in ["train", "val", "test"]:
|
| 261 |
+
# ds = MUSDB18SadOnTheFlyAugmentedDataset(
|
| 262 |
+
# data_root="$DATA_ROOT/MUSDB18/HQ/saded",
|
| 263 |
+
# split=split_,
|
| 264 |
+
# target_stem="vocals"
|
| 265 |
+
# )
|
| 266 |
+
#
|
| 267 |
+
# print(split_, len(ds))
|
| 268 |
+
#
|
| 269 |
+
# for track_ in tqdm(ds):
|
| 270 |
+
# track_["audio"] = {
|
| 271 |
+
# k: v.shape for k, v in track_["audio"].items()
|
| 272 |
+
# }
|
| 273 |
+
# pprint(track_)
|
mvsepless/models/bandit/core/data/musdb/preprocess.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torchaudio as ta
|
| 7 |
+
from torch import nn
|
| 8 |
+
from torch.nn import functional as F
|
| 9 |
+
from tqdm.contrib.concurrent import process_map
|
| 10 |
+
|
| 11 |
+
from .._types import DataDict
|
| 12 |
+
from .dataset import MUSDB18FullTrackDataset
|
| 13 |
+
import pyloudnorm as pyln
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SourceActivityDetector(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
analysis_stem: str,
|
| 20 |
+
output_path: str,
|
| 21 |
+
fs: int = 44100,
|
| 22 |
+
segment_length_second: float = 6.0,
|
| 23 |
+
hop_length_second: float = 3.0,
|
| 24 |
+
n_chunks: int = 10,
|
| 25 |
+
chunk_epsilon: float = 1e-5,
|
| 26 |
+
energy_threshold_quantile: float = 0.15,
|
| 27 |
+
segment_epsilon: float = 1e-3,
|
| 28 |
+
salient_proportion_threshold: float = 0.5,
|
| 29 |
+
target_lufs: float = -24,
|
| 30 |
+
) -> None:
|
| 31 |
+
super().__init__()
|
| 32 |
+
|
| 33 |
+
self.fs = fs
|
| 34 |
+
self.segment_length = int(segment_length_second * self.fs)
|
| 35 |
+
self.hop_length = int(hop_length_second * self.fs)
|
| 36 |
+
self.n_chunks = n_chunks
|
| 37 |
+
assert self.segment_length % self.n_chunks == 0
|
| 38 |
+
self.chunk_size = self.segment_length // self.n_chunks
|
| 39 |
+
self.chunk_epsilon = chunk_epsilon
|
| 40 |
+
self.energy_threshold_quantile = energy_threshold_quantile
|
| 41 |
+
self.segment_epsilon = segment_epsilon
|
| 42 |
+
self.salient_proportion_threshold = salient_proportion_threshold
|
| 43 |
+
self.analysis_stem = analysis_stem
|
| 44 |
+
|
| 45 |
+
self.meter = pyln.Meter(self.fs)
|
| 46 |
+
self.target_lufs = target_lufs
|
| 47 |
+
|
| 48 |
+
self.output_path = output_path
|
| 49 |
+
|
| 50 |
+
def forward(self, data: DataDict) -> None:
|
| 51 |
+
|
| 52 |
+
stem_ = self.analysis_stem if (self.analysis_stem != "none") else "mixture"
|
| 53 |
+
|
| 54 |
+
x = data["audio"][stem_]
|
| 55 |
+
|
| 56 |
+
xnp = x.numpy()
|
| 57 |
+
loudness = self.meter.integrated_loudness(xnp.T)
|
| 58 |
+
|
| 59 |
+
for stem in data["audio"]:
|
| 60 |
+
s = data["audio"][stem]
|
| 61 |
+
s = pyln.normalize.loudness(s.numpy().T, loudness, self.target_lufs).T
|
| 62 |
+
s = torch.as_tensor(s)
|
| 63 |
+
data["audio"][stem] = s
|
| 64 |
+
|
| 65 |
+
if x.ndim == 3:
|
| 66 |
+
assert x.shape[0] == 1
|
| 67 |
+
x = x[0]
|
| 68 |
+
|
| 69 |
+
n_chan, n_samples = x.shape
|
| 70 |
+
|
| 71 |
+
n_segments = (
|
| 72 |
+
int(np.ceil((n_samples - self.segment_length) / self.hop_length)) + 1
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
segments = torch.zeros((n_segments, n_chan, self.segment_length))
|
| 76 |
+
for i in range(n_segments):
|
| 77 |
+
start = i * self.hop_length
|
| 78 |
+
end = start + self.segment_length
|
| 79 |
+
end = min(end, n_samples)
|
| 80 |
+
|
| 81 |
+
xseg = x[:, start:end]
|
| 82 |
+
|
| 83 |
+
if end - start < self.segment_length:
|
| 84 |
+
xseg = F.pad(
|
| 85 |
+
xseg, pad=(0, self.segment_length - (end - start)), value=torch.nan
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
segments[i, :, :] = xseg
|
| 89 |
+
|
| 90 |
+
chunks = segments.reshape((n_segments, n_chan, self.n_chunks, self.chunk_size))
|
| 91 |
+
|
| 92 |
+
if self.analysis_stem != "none":
|
| 93 |
+
chunk_energies = torch.mean(torch.square(chunks), dim=(1, 3))
|
| 94 |
+
chunk_energies = torch.nan_to_num(chunk_energies, nan=0)
|
| 95 |
+
chunk_energies[chunk_energies == 0] = self.chunk_epsilon
|
| 96 |
+
|
| 97 |
+
energy_threshold = torch.nanquantile(
|
| 98 |
+
chunk_energies, q=self.energy_threshold_quantile
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
if energy_threshold < self.segment_epsilon:
|
| 102 |
+
energy_threshold = self.segment_epsilon # type: ignore[assignment]
|
| 103 |
+
|
| 104 |
+
chunks_above_threshold = chunk_energies > energy_threshold
|
| 105 |
+
n_chunks_above_threshold = torch.mean(
|
| 106 |
+
chunks_above_threshold.to(torch.float), dim=-1
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
segment_above_threshold = (
|
| 110 |
+
n_chunks_above_threshold > self.salient_proportion_threshold
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
if torch.sum(segment_above_threshold) == 0:
|
| 114 |
+
return
|
| 115 |
+
|
| 116 |
+
else:
|
| 117 |
+
segment_above_threshold = torch.ones((n_segments,))
|
| 118 |
+
|
| 119 |
+
for i in range(n_segments):
|
| 120 |
+
if not segment_above_threshold[i]:
|
| 121 |
+
continue
|
| 122 |
+
|
| 123 |
+
outpath = os.path.join(
|
| 124 |
+
self.output_path,
|
| 125 |
+
self.analysis_stem,
|
| 126 |
+
f"{data['track']} - {self.analysis_stem}{i:03d}",
|
| 127 |
+
)
|
| 128 |
+
os.makedirs(outpath, exist_ok=True)
|
| 129 |
+
|
| 130 |
+
for stem in data["audio"]:
|
| 131 |
+
if stem == self.analysis_stem:
|
| 132 |
+
segment = torch.nan_to_num(segments[i, :, :], nan=0)
|
| 133 |
+
else:
|
| 134 |
+
start = i * self.hop_length
|
| 135 |
+
end = start + self.segment_length
|
| 136 |
+
end = min(n_samples, end)
|
| 137 |
+
|
| 138 |
+
segment = data["audio"][stem][:, start:end]
|
| 139 |
+
|
| 140 |
+
if end - start < self.segment_length:
|
| 141 |
+
segment = F.pad(
|
| 142 |
+
segment, (0, self.segment_length - (end - start))
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
assert segment.shape[-1] == self.segment_length, segment.shape
|
| 146 |
+
|
| 147 |
+
# ta.save(os.path.join(outpath, f"{stem}.wav"), segment, self.fs)
|
| 148 |
+
|
| 149 |
+
np.save(os.path.join(outpath, f"{stem}.wav"), segment)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def preprocess(
|
| 153 |
+
analysis_stem: str,
|
| 154 |
+
output_path: str = "/data/MUSDB18/HQ/saded-np",
|
| 155 |
+
fs: int = 44100,
|
| 156 |
+
segment_length_second: float = 6.0,
|
| 157 |
+
hop_length_second: float = 3.0,
|
| 158 |
+
n_chunks: int = 10,
|
| 159 |
+
chunk_epsilon: float = 1e-5,
|
| 160 |
+
energy_threshold_quantile: float = 0.15,
|
| 161 |
+
segment_epsilon: float = 1e-3,
|
| 162 |
+
salient_proportion_threshold: float = 0.5,
|
| 163 |
+
) -> None:
|
| 164 |
+
|
| 165 |
+
sad = SourceActivityDetector(
|
| 166 |
+
analysis_stem=analysis_stem,
|
| 167 |
+
output_path=output_path,
|
| 168 |
+
fs=fs,
|
| 169 |
+
segment_length_second=segment_length_second,
|
| 170 |
+
hop_length_second=hop_length_second,
|
| 171 |
+
n_chunks=n_chunks,
|
| 172 |
+
chunk_epsilon=chunk_epsilon,
|
| 173 |
+
energy_threshold_quantile=energy_threshold_quantile,
|
| 174 |
+
segment_epsilon=segment_epsilon,
|
| 175 |
+
salient_proportion_threshold=salient_proportion_threshold,
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
for split in ["train", "val", "test"]:
|
| 179 |
+
ds = MUSDB18FullTrackDataset(
|
| 180 |
+
data_root="/data/MUSDB18/HQ/canonical",
|
| 181 |
+
split=split,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
tracks = []
|
| 185 |
+
for i, track in enumerate(tqdm(ds, total=len(ds))):
|
| 186 |
+
if i % 32 == 0 and tracks:
|
| 187 |
+
process_map(sad, tracks, max_workers=8)
|
| 188 |
+
tracks = []
|
| 189 |
+
tracks.append(track)
|
| 190 |
+
process_map(sad, tracks, max_workers=8)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def loudness_norm_one(inputs):
|
| 194 |
+
infile, outfile, target_lufs = inputs
|
| 195 |
+
|
| 196 |
+
audio, fs = ta.load(infile)
|
| 197 |
+
audio = audio.mean(dim=0, keepdim=True).numpy().T
|
| 198 |
+
|
| 199 |
+
meter = pyln.Meter(fs)
|
| 200 |
+
loudness = meter.integrated_loudness(audio)
|
| 201 |
+
audio = pyln.normalize.loudness(audio, loudness, target_lufs)
|
| 202 |
+
|
| 203 |
+
os.makedirs(os.path.dirname(outfile), exist_ok=True)
|
| 204 |
+
np.save(outfile, audio.T)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def loudness_norm(
|
| 208 |
+
data_path: str,
|
| 209 |
+
# output_path: str,
|
| 210 |
+
target_lufs=-17.0,
|
| 211 |
+
):
|
| 212 |
+
files = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True)
|
| 213 |
+
|
| 214 |
+
outfiles = [f.replace(".wav", ".npy").replace("saded", "saded-np") for f in files]
|
| 215 |
+
|
| 216 |
+
files = [(f, o, target_lufs) for f, o in zip(files, outfiles)]
|
| 217 |
+
|
| 218 |
+
process_map(loudness_norm_one, files, chunksize=2)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
if __name__ == "__main__":
|
| 222 |
+
|
| 223 |
+
from tqdm.auto import tqdm
|
| 224 |
+
import fire
|
| 225 |
+
|
| 226 |
+
fire.Fire()
|
mvsepless/models/bandit/core/data/musdb/validation.yaml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
validation:
|
| 2 |
+
- 'Actions - One Minute Smile'
|
| 3 |
+
- 'Clara Berry And Wooldog - Waltz For My Victims'
|
| 4 |
+
- 'Johnny Lokke - Promises & Lies'
|
| 5 |
+
- 'Patrick Talbot - A Reason To Leave'
|
| 6 |
+
- 'Triviul - Angelsaint'
|
| 7 |
+
- 'Alexander Ross - Goodbye Bolero'
|
| 8 |
+
- 'Fergessen - Nos Palpitants'
|
| 9 |
+
- 'Leaf - Summerghost'
|
| 10 |
+
- 'Skelpolu - Human Mistakes'
|
| 11 |
+
- 'Young Griffo - Pennies'
|
| 12 |
+
- 'ANiMAL - Rockshow'
|
| 13 |
+
- 'James May - On The Line'
|
| 14 |
+
- 'Meaxic - Take A Step'
|
| 15 |
+
- 'Traffic Experiment - Sirens'
|
mvsepless/models/bandit/core/loss/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ._multistem import MultiStemWrapperFromConfig
|
| 2 |
+
from ._timefreq import (
|
| 3 |
+
ReImL1Loss,
|
| 4 |
+
ReImL2Loss,
|
| 5 |
+
TimeFreqL1Loss,
|
| 6 |
+
TimeFreqL2Loss,
|
| 7 |
+
TimeFreqSignalNoisePNormRatioLoss,
|
| 8 |
+
)
|
mvsepless/models/bandit/core/loss/_complex.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn.modules import loss as _loss
|
| 6 |
+
from torch.nn.modules.loss import _Loss
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ReImLossWrapper(_Loss):
|
| 10 |
+
def __init__(self, module: _Loss) -> None:
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.module = module
|
| 13 |
+
|
| 14 |
+
def forward(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 15 |
+
return self.module(torch.view_as_real(preds), torch.view_as_real(target))
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ReImL1Loss(ReImLossWrapper):
|
| 19 |
+
def __init__(self, **kwargs: Any) -> None:
|
| 20 |
+
l1_loss = _loss.L1Loss(**kwargs)
|
| 21 |
+
super().__init__(module=(l1_loss))
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ReImL2Loss(ReImLossWrapper):
|
| 25 |
+
def __init__(self, **kwargs: Any) -> None:
|
| 26 |
+
l2_loss = _loss.MSELoss(**kwargs)
|
| 27 |
+
super().__init__(module=(l2_loss))
|
mvsepless/models/bandit/core/loss/_multistem.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from asteroid import losses as asteroid_losses
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.nn.modules.loss import _Loss
|
| 7 |
+
|
| 8 |
+
from . import snr
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def parse_loss(name: str, kwargs: Dict[str, Any]) -> _Loss:
|
| 12 |
+
|
| 13 |
+
for module in [nn.modules.loss, snr, asteroid_losses, asteroid_losses.sdr]:
|
| 14 |
+
if name in module.__dict__:
|
| 15 |
+
return module.__dict__[name](**kwargs)
|
| 16 |
+
|
| 17 |
+
raise NameError
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class MultiStemWrapper(_Loss):
|
| 21 |
+
def __init__(self, module: _Loss, modality: str = "audio") -> None:
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.loss = module
|
| 24 |
+
self.modality = modality
|
| 25 |
+
|
| 26 |
+
def forward(
|
| 27 |
+
self,
|
| 28 |
+
preds: Dict[str, Dict[str, torch.Tensor]],
|
| 29 |
+
target: Dict[str, Dict[str, torch.Tensor]],
|
| 30 |
+
) -> torch.Tensor:
|
| 31 |
+
loss = {
|
| 32 |
+
stem: self.loss(preds[self.modality][stem], target[self.modality][stem])
|
| 33 |
+
for stem in preds[self.modality]
|
| 34 |
+
if stem in target[self.modality]
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
return sum(list(loss.values()))
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class MultiStemWrapperFromConfig(MultiStemWrapper):
|
| 41 |
+
def __init__(self, name: str, kwargs: Any, modality: str = "audio") -> None:
|
| 42 |
+
loss = parse_loss(name, kwargs)
|
| 43 |
+
super().__init__(module=loss, modality=modality)
|
mvsepless/models/bandit/core/loss/_timefreq.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn.modules.loss import _Loss
|
| 6 |
+
|
| 7 |
+
from ._multistem import MultiStemWrapper
|
| 8 |
+
from ._complex import ReImL1Loss, ReImL2Loss, ReImLossWrapper
|
| 9 |
+
from .snr import SignalNoisePNormRatio
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TimeFreqWrapper(_Loss):
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
time_module: _Loss,
|
| 16 |
+
freq_module: Optional[_Loss] = None,
|
| 17 |
+
time_weight: float = 1.0,
|
| 18 |
+
freq_weight: float = 1.0,
|
| 19 |
+
multistem: bool = True,
|
| 20 |
+
) -> None:
|
| 21 |
+
super().__init__()
|
| 22 |
+
|
| 23 |
+
if freq_module is None:
|
| 24 |
+
freq_module = time_module
|
| 25 |
+
|
| 26 |
+
if multistem:
|
| 27 |
+
time_module = MultiStemWrapper(time_module, modality="audio")
|
| 28 |
+
freq_module = MultiStemWrapper(freq_module, modality="spectrogram")
|
| 29 |
+
|
| 30 |
+
self.time_module = time_module
|
| 31 |
+
self.freq_module = freq_module
|
| 32 |
+
|
| 33 |
+
self.time_weight = time_weight
|
| 34 |
+
self.freq_weight = freq_weight
|
| 35 |
+
|
| 36 |
+
# TODO: add better type hints
|
| 37 |
+
def forward(self, preds: Any, target: Any) -> torch.Tensor:
|
| 38 |
+
|
| 39 |
+
return self.time_weight * self.time_module(
|
| 40 |
+
preds, target
|
| 41 |
+
) + self.freq_weight * self.freq_module(preds, target)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class TimeFreqL1Loss(TimeFreqWrapper):
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
time_weight: float = 1.0,
|
| 48 |
+
freq_weight: float = 1.0,
|
| 49 |
+
tkwargs: Optional[Dict[str, Any]] = None,
|
| 50 |
+
fkwargs: Optional[Dict[str, Any]] = None,
|
| 51 |
+
multistem: bool = True,
|
| 52 |
+
) -> None:
|
| 53 |
+
if tkwargs is None:
|
| 54 |
+
tkwargs = {}
|
| 55 |
+
if fkwargs is None:
|
| 56 |
+
fkwargs = {}
|
| 57 |
+
time_module = nn.L1Loss(**tkwargs)
|
| 58 |
+
freq_module = ReImL1Loss(**fkwargs)
|
| 59 |
+
super().__init__(time_module, freq_module, time_weight, freq_weight, multistem)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class TimeFreqL2Loss(TimeFreqWrapper):
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
time_weight: float = 1.0,
|
| 66 |
+
freq_weight: float = 1.0,
|
| 67 |
+
tkwargs: Optional[Dict[str, Any]] = None,
|
| 68 |
+
fkwargs: Optional[Dict[str, Any]] = None,
|
| 69 |
+
multistem: bool = True,
|
| 70 |
+
) -> None:
|
| 71 |
+
if tkwargs is None:
|
| 72 |
+
tkwargs = {}
|
| 73 |
+
if fkwargs is None:
|
| 74 |
+
fkwargs = {}
|
| 75 |
+
time_module = nn.MSELoss(**tkwargs)
|
| 76 |
+
freq_module = ReImL2Loss(**fkwargs)
|
| 77 |
+
super().__init__(time_module, freq_module, time_weight, freq_weight, multistem)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class TimeFreqSignalNoisePNormRatioLoss(TimeFreqWrapper):
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
time_weight: float = 1.0,
|
| 84 |
+
freq_weight: float = 1.0,
|
| 85 |
+
tkwargs: Optional[Dict[str, Any]] = None,
|
| 86 |
+
fkwargs: Optional[Dict[str, Any]] = None,
|
| 87 |
+
multistem: bool = True,
|
| 88 |
+
) -> None:
|
| 89 |
+
if tkwargs is None:
|
| 90 |
+
tkwargs = {}
|
| 91 |
+
if fkwargs is None:
|
| 92 |
+
fkwargs = {}
|
| 93 |
+
time_module = SignalNoisePNormRatio(**tkwargs)
|
| 94 |
+
freq_module = SignalNoisePNormRatio(**fkwargs)
|
| 95 |
+
super().__init__(time_module, freq_module, time_weight, freq_weight, multistem)
|
mvsepless/models/bandit/core/loss/snr.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.nn.modules.loss import _Loss
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class SignalNoisePNormRatio(_Loss):
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
p: float = 1.0,
|
| 10 |
+
scale_invariant: bool = False,
|
| 11 |
+
zero_mean: bool = False,
|
| 12 |
+
take_log: bool = True,
|
| 13 |
+
reduction: str = "mean",
|
| 14 |
+
EPS: float = 1e-3,
|
| 15 |
+
) -> None:
|
| 16 |
+
assert reduction != "sum", NotImplementedError
|
| 17 |
+
super().__init__(reduction=reduction)
|
| 18 |
+
assert not zero_mean
|
| 19 |
+
|
| 20 |
+
self.p = p
|
| 21 |
+
|
| 22 |
+
self.EPS = EPS
|
| 23 |
+
self.take_log = take_log
|
| 24 |
+
|
| 25 |
+
self.scale_invariant = scale_invariant
|
| 26 |
+
|
| 27 |
+
def forward(self, est_target: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 28 |
+
|
| 29 |
+
target_ = target
|
| 30 |
+
if self.scale_invariant:
|
| 31 |
+
ndim = target.ndim
|
| 32 |
+
dot = torch.sum(est_target * torch.conj(target), dim=-1, keepdim=True)
|
| 33 |
+
s_target_energy = torch.sum(
|
| 34 |
+
target * torch.conj(target), dim=-1, keepdim=True
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
if ndim > 2:
|
| 38 |
+
dot = torch.sum(dot, dim=list(range(1, ndim)), keepdim=True)
|
| 39 |
+
s_target_energy = torch.sum(
|
| 40 |
+
s_target_energy, dim=list(range(1, ndim)), keepdim=True
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
target_scaler = (dot + 1e-8) / (s_target_energy + 1e-8)
|
| 44 |
+
target = target_ * target_scaler
|
| 45 |
+
|
| 46 |
+
if torch.is_complex(est_target):
|
| 47 |
+
est_target = torch.view_as_real(est_target)
|
| 48 |
+
target = torch.view_as_real(target)
|
| 49 |
+
|
| 50 |
+
batch_size = est_target.shape[0]
|
| 51 |
+
est_target = est_target.reshape(batch_size, -1)
|
| 52 |
+
target = target.reshape(batch_size, -1)
|
| 53 |
+
# target_ = target_.reshape(batch_size, -1)
|
| 54 |
+
|
| 55 |
+
if self.p == 1:
|
| 56 |
+
e_error = torch.abs(est_target - target).mean(dim=-1)
|
| 57 |
+
e_target = torch.abs(target).mean(dim=-1)
|
| 58 |
+
elif self.p == 2:
|
| 59 |
+
e_error = torch.square(est_target - target).mean(dim=-1)
|
| 60 |
+
e_target = torch.square(target).mean(dim=-1)
|
| 61 |
+
else:
|
| 62 |
+
raise NotImplementedError
|
| 63 |
+
|
| 64 |
+
if self.take_log:
|
| 65 |
+
loss = 10 * (
|
| 66 |
+
torch.log10(e_error + self.EPS) - torch.log10(e_target + self.EPS)
|
| 67 |
+
)
|
| 68 |
+
else:
|
| 69 |
+
loss = (e_error + self.EPS) / (e_target + self.EPS)
|
| 70 |
+
|
| 71 |
+
if self.reduction == "mean":
|
| 72 |
+
loss = loss.mean()
|
| 73 |
+
elif self.reduction == "sum":
|
| 74 |
+
loss = loss.sum()
|
| 75 |
+
|
| 76 |
+
return loss
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class MultichannelSingleSrcNegSDR(_Loss):
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
sdr_type: str,
|
| 83 |
+
p: float = 2.0,
|
| 84 |
+
zero_mean: bool = True,
|
| 85 |
+
take_log: bool = True,
|
| 86 |
+
reduction: str = "mean",
|
| 87 |
+
EPS: float = 1e-8,
|
| 88 |
+
) -> None:
|
| 89 |
+
assert reduction != "sum", NotImplementedError
|
| 90 |
+
super().__init__(reduction=reduction)
|
| 91 |
+
|
| 92 |
+
assert sdr_type in ["snr", "sisdr", "sdsdr"]
|
| 93 |
+
self.sdr_type = sdr_type
|
| 94 |
+
self.zero_mean = zero_mean
|
| 95 |
+
self.take_log = take_log
|
| 96 |
+
self.EPS = 1e-8
|
| 97 |
+
|
| 98 |
+
self.p = p
|
| 99 |
+
|
| 100 |
+
def forward(self, est_target: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 101 |
+
if target.size() != est_target.size() or target.ndim != 3:
|
| 102 |
+
raise TypeError(
|
| 103 |
+
f"Inputs must be of shape [batch, time], got {target.size()} and {est_target.size()} instead"
|
| 104 |
+
)
|
| 105 |
+
# Step 1. Zero-mean norm
|
| 106 |
+
if self.zero_mean:
|
| 107 |
+
mean_source = torch.mean(target, dim=[1, 2], keepdim=True)
|
| 108 |
+
mean_estimate = torch.mean(est_target, dim=[1, 2], keepdim=True)
|
| 109 |
+
target = target - mean_source
|
| 110 |
+
est_target = est_target - mean_estimate
|
| 111 |
+
# Step 2. Pair-wise SI-SDR.
|
| 112 |
+
if self.sdr_type in ["sisdr", "sdsdr"]:
|
| 113 |
+
# [batch, 1]
|
| 114 |
+
dot = torch.sum(est_target * target, dim=[1, 2], keepdim=True)
|
| 115 |
+
# [batch, 1]
|
| 116 |
+
s_target_energy = torch.sum(target**2, dim=[1, 2], keepdim=True) + self.EPS
|
| 117 |
+
# [batch, time]
|
| 118 |
+
scaled_target = dot * target / s_target_energy
|
| 119 |
+
else:
|
| 120 |
+
# [batch, time]
|
| 121 |
+
scaled_target = target
|
| 122 |
+
if self.sdr_type in ["sdsdr", "snr"]:
|
| 123 |
+
e_noise = est_target - target
|
| 124 |
+
else:
|
| 125 |
+
e_noise = est_target - scaled_target
|
| 126 |
+
# [batch]
|
| 127 |
+
|
| 128 |
+
if self.p == 2.0:
|
| 129 |
+
losses = torch.sum(scaled_target**2, dim=[1, 2]) / (
|
| 130 |
+
torch.sum(e_noise**2, dim=[1, 2]) + self.EPS
|
| 131 |
+
)
|
| 132 |
+
else:
|
| 133 |
+
losses = torch.norm(scaled_target, p=self.p, dim=[1, 2]) / (
|
| 134 |
+
torch.linalg.vector_norm(e_noise, p=self.p, dim=[1, 2]) + self.EPS
|
| 135 |
+
)
|
| 136 |
+
if self.take_log:
|
| 137 |
+
losses = 10 * torch.log10(losses + self.EPS)
|
| 138 |
+
losses = losses.mean() if self.reduction == "mean" else losses
|
| 139 |
+
return -losses
|
mvsepless/models/bandit/core/metrics/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .snr import (
|
| 2 |
+
ChunkMedianScaleInvariantSignalDistortionRatio,
|
| 3 |
+
ChunkMedianScaleInvariantSignalNoiseRatio,
|
| 4 |
+
ChunkMedianSignalDistortionRatio,
|
| 5 |
+
ChunkMedianSignalNoiseRatio,
|
| 6 |
+
SafeSignalDistortionRatio,
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
# from .mushra import EstimatedMushraScore
|
mvsepless/models/bandit/core/metrics/_squim.py
ADDED
|
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
from torchaudio._internal import load_state_dict_from_url
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
from typing import List, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def transform_wb_pesq_range(x: float) -> float:
|
| 14 |
+
"""The metric defined by ITU-T P.862 is often called 'PESQ score', which is defined
|
| 15 |
+
for narrow-band signals and has a value range of [-0.5, 4.5] exactly. Here, we use the metric
|
| 16 |
+
defined by ITU-T P.862.2, commonly known as 'wide-band PESQ' and will be referred to as "PESQ score".
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
x (float): Narrow-band PESQ score.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
(float): Wide-band PESQ score.
|
| 23 |
+
"""
|
| 24 |
+
return 0.999 + (4.999 - 0.999) / (1 + math.exp(-1.3669 * x + 3.8224))
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
PESQRange: Tuple[float, float] = (
|
| 28 |
+
1.0, # P.862.2 uses a different input filter than P.862, and the lower bound of
|
| 29 |
+
# the raw score is not -0.5 anymore. It's hard to figure out the true lower bound.
|
| 30 |
+
# We are using 1.0 as a reasonable approximation.
|
| 31 |
+
transform_wb_pesq_range(4.5),
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class RangeSigmoid(nn.Module):
|
| 36 |
+
def __init__(self, val_range: Tuple[float, float] = (0.0, 1.0)) -> None:
|
| 37 |
+
super(RangeSigmoid, self).__init__()
|
| 38 |
+
assert isinstance(val_range, tuple) and len(val_range) == 2
|
| 39 |
+
self.val_range: Tuple[float, float] = val_range
|
| 40 |
+
self.sigmoid: nn.modules.Module = nn.Sigmoid()
|
| 41 |
+
|
| 42 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 43 |
+
out = (
|
| 44 |
+
self.sigmoid(x) * (self.val_range[1] - self.val_range[0])
|
| 45 |
+
+ self.val_range[0]
|
| 46 |
+
)
|
| 47 |
+
return out
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class Encoder(nn.Module):
|
| 51 |
+
"""Encoder module that transform 1D waveform to 2D representations.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
feat_dim (int, optional): The feature dimension after Encoder module. (Default: 512)
|
| 55 |
+
win_len (int, optional): kernel size in the Conv1D layer. (Default: 32)
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
def __init__(self, feat_dim: int = 512, win_len: int = 32) -> None:
|
| 59 |
+
super(Encoder, self).__init__()
|
| 60 |
+
|
| 61 |
+
self.conv1d = nn.Conv1d(1, feat_dim, win_len, stride=win_len // 2, bias=False)
|
| 62 |
+
|
| 63 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 64 |
+
"""Apply waveforms to convolutional layer and ReLU layer.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
(torch,Tensor): Feature Tensor with dimensions `(batch, channel, frame)`.
|
| 71 |
+
"""
|
| 72 |
+
out = x.unsqueeze(dim=1)
|
| 73 |
+
out = F.relu(self.conv1d(out))
|
| 74 |
+
return out
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class SingleRNN(nn.Module):
|
| 78 |
+
def __init__(
|
| 79 |
+
self, rnn_type: str, input_size: int, hidden_size: int, dropout: float = 0.0
|
| 80 |
+
) -> None:
|
| 81 |
+
super(SingleRNN, self).__init__()
|
| 82 |
+
|
| 83 |
+
self.rnn_type = rnn_type
|
| 84 |
+
self.input_size = input_size
|
| 85 |
+
self.hidden_size = hidden_size
|
| 86 |
+
|
| 87 |
+
self.rnn: nn.modules.Module = getattr(nn, rnn_type)(
|
| 88 |
+
input_size,
|
| 89 |
+
hidden_size,
|
| 90 |
+
1,
|
| 91 |
+
dropout=dropout,
|
| 92 |
+
batch_first=True,
|
| 93 |
+
bidirectional=True,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
self.proj = nn.Linear(hidden_size * 2, input_size)
|
| 97 |
+
|
| 98 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 99 |
+
# input shape: batch, seq, dim
|
| 100 |
+
out, _ = self.rnn(x)
|
| 101 |
+
out = self.proj(out)
|
| 102 |
+
return out
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class DPRNN(nn.Module):
|
| 106 |
+
"""*Dual-path recurrent neural networks (DPRNN)* :cite:`luo2020dual`.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
feat_dim (int, optional): The feature dimension after Encoder module. (Default: 64)
|
| 110 |
+
hidden_dim (int, optional): Hidden dimension in the RNN layer of DPRNN. (Default: 128)
|
| 111 |
+
num_blocks (int, optional): Number of DPRNN layers. (Default: 6)
|
| 112 |
+
rnn_type (str, optional): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"]. (Default: "LSTM")
|
| 113 |
+
d_model (int, optional): The number of expected features in the input. (Default: 256)
|
| 114 |
+
chunk_size (int, optional): Chunk size of input for DPRNN. (Default: 100)
|
| 115 |
+
chunk_stride (int, optional): Stride of chunk input for DPRNN. (Default: 50)
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
def __init__(
|
| 119 |
+
self,
|
| 120 |
+
feat_dim: int = 64,
|
| 121 |
+
hidden_dim: int = 128,
|
| 122 |
+
num_blocks: int = 6,
|
| 123 |
+
rnn_type: str = "LSTM",
|
| 124 |
+
d_model: int = 256,
|
| 125 |
+
chunk_size: int = 100,
|
| 126 |
+
chunk_stride: int = 50,
|
| 127 |
+
) -> None:
|
| 128 |
+
super(DPRNN, self).__init__()
|
| 129 |
+
|
| 130 |
+
self.num_blocks = num_blocks
|
| 131 |
+
|
| 132 |
+
self.row_rnn = nn.ModuleList([])
|
| 133 |
+
self.col_rnn = nn.ModuleList([])
|
| 134 |
+
self.row_norm = nn.ModuleList([])
|
| 135 |
+
self.col_norm = nn.ModuleList([])
|
| 136 |
+
for _ in range(num_blocks):
|
| 137 |
+
self.row_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
|
| 138 |
+
self.col_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
|
| 139 |
+
self.row_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
|
| 140 |
+
self.col_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
|
| 141 |
+
self.conv = nn.Sequential(
|
| 142 |
+
nn.Conv2d(feat_dim, d_model, 1),
|
| 143 |
+
nn.PReLU(),
|
| 144 |
+
)
|
| 145 |
+
self.chunk_size = chunk_size
|
| 146 |
+
self.chunk_stride = chunk_stride
|
| 147 |
+
|
| 148 |
+
def pad_chunk(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
|
| 149 |
+
# input shape: (B, N, T)
|
| 150 |
+
seq_len = x.shape[-1]
|
| 151 |
+
|
| 152 |
+
rest = (
|
| 153 |
+
self.chunk_size
|
| 154 |
+
- (self.chunk_stride + seq_len % self.chunk_size) % self.chunk_size
|
| 155 |
+
)
|
| 156 |
+
out = F.pad(x, [self.chunk_stride, rest + self.chunk_stride])
|
| 157 |
+
|
| 158 |
+
return out, rest
|
| 159 |
+
|
| 160 |
+
def chunking(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
|
| 161 |
+
out, rest = self.pad_chunk(x)
|
| 162 |
+
batch_size, feat_dim, seq_len = out.shape
|
| 163 |
+
|
| 164 |
+
segments1 = (
|
| 165 |
+
out[:, :, : -self.chunk_stride]
|
| 166 |
+
.contiguous()
|
| 167 |
+
.view(batch_size, feat_dim, -1, self.chunk_size)
|
| 168 |
+
)
|
| 169 |
+
segments2 = (
|
| 170 |
+
out[:, :, self.chunk_stride :]
|
| 171 |
+
.contiguous()
|
| 172 |
+
.view(batch_size, feat_dim, -1, self.chunk_size)
|
| 173 |
+
)
|
| 174 |
+
out = torch.cat([segments1, segments2], dim=3)
|
| 175 |
+
out = (
|
| 176 |
+
out.view(batch_size, feat_dim, -1, self.chunk_size)
|
| 177 |
+
.transpose(2, 3)
|
| 178 |
+
.contiguous()
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
return out, rest
|
| 182 |
+
|
| 183 |
+
def merging(self, x: torch.Tensor, rest: int) -> torch.Tensor:
|
| 184 |
+
batch_size, dim, _, _ = x.shape
|
| 185 |
+
out = (
|
| 186 |
+
x.transpose(2, 3)
|
| 187 |
+
.contiguous()
|
| 188 |
+
.view(batch_size, dim, -1, self.chunk_size * 2)
|
| 189 |
+
)
|
| 190 |
+
out1 = (
|
| 191 |
+
out[:, :, :, : self.chunk_size]
|
| 192 |
+
.contiguous()
|
| 193 |
+
.view(batch_size, dim, -1)[:, :, self.chunk_stride :]
|
| 194 |
+
)
|
| 195 |
+
out2 = (
|
| 196 |
+
out[:, :, :, self.chunk_size :]
|
| 197 |
+
.contiguous()
|
| 198 |
+
.view(batch_size, dim, -1)[:, :, : -self.chunk_stride]
|
| 199 |
+
)
|
| 200 |
+
out = out1 + out2
|
| 201 |
+
if rest > 0:
|
| 202 |
+
out = out[:, :, :-rest]
|
| 203 |
+
out = out.contiguous()
|
| 204 |
+
return out
|
| 205 |
+
|
| 206 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 207 |
+
x, rest = self.chunking(x)
|
| 208 |
+
batch_size, _, dim1, dim2 = x.shape
|
| 209 |
+
out = x
|
| 210 |
+
for row_rnn, row_norm, col_rnn, col_norm in zip(
|
| 211 |
+
self.row_rnn, self.row_norm, self.col_rnn, self.col_norm
|
| 212 |
+
):
|
| 213 |
+
row_in = (
|
| 214 |
+
out.permute(0, 3, 2, 1)
|
| 215 |
+
.contiguous()
|
| 216 |
+
.view(batch_size * dim2, dim1, -1)
|
| 217 |
+
.contiguous()
|
| 218 |
+
)
|
| 219 |
+
row_out = row_rnn(row_in)
|
| 220 |
+
row_out = (
|
| 221 |
+
row_out.view(batch_size, dim2, dim1, -1)
|
| 222 |
+
.permute(0, 3, 2, 1)
|
| 223 |
+
.contiguous()
|
| 224 |
+
)
|
| 225 |
+
row_out = row_norm(row_out)
|
| 226 |
+
out = out + row_out
|
| 227 |
+
|
| 228 |
+
col_in = (
|
| 229 |
+
out.permute(0, 2, 3, 1)
|
| 230 |
+
.contiguous()
|
| 231 |
+
.view(batch_size * dim1, dim2, -1)
|
| 232 |
+
.contiguous()
|
| 233 |
+
)
|
| 234 |
+
col_out = col_rnn(col_in)
|
| 235 |
+
col_out = (
|
| 236 |
+
col_out.view(batch_size, dim1, dim2, -1)
|
| 237 |
+
.permute(0, 3, 1, 2)
|
| 238 |
+
.contiguous()
|
| 239 |
+
)
|
| 240 |
+
col_out = col_norm(col_out)
|
| 241 |
+
out = out + col_out
|
| 242 |
+
out = self.conv(out)
|
| 243 |
+
out = self.merging(out, rest)
|
| 244 |
+
out = out.transpose(1, 2).contiguous()
|
| 245 |
+
return out
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
class AutoPool(nn.Module):
|
| 249 |
+
def __init__(self, pool_dim: int = 1) -> None:
|
| 250 |
+
super(AutoPool, self).__init__()
|
| 251 |
+
self.pool_dim: int = pool_dim
|
| 252 |
+
self.softmax: nn.modules.Module = nn.Softmax(dim=pool_dim)
|
| 253 |
+
self.register_parameter("alpha", nn.Parameter(torch.ones(1)))
|
| 254 |
+
|
| 255 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 256 |
+
weight = self.softmax(torch.mul(x, self.alpha))
|
| 257 |
+
out = torch.sum(torch.mul(x, weight), dim=self.pool_dim)
|
| 258 |
+
return out
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
class SquimObjective(nn.Module):
|
| 262 |
+
"""Speech Quality and Intelligibility Measures (SQUIM) model that predicts **objective** metric scores
|
| 263 |
+
for speech enhancement (e.g., STOI, PESQ, and SI-SDR).
|
| 264 |
+
|
| 265 |
+
Args:
|
| 266 |
+
encoder (torch.nn.Module): Encoder module to transform 1D waveform to 2D feature representation.
|
| 267 |
+
dprnn (torch.nn.Module): DPRNN module to model sequential feature.
|
| 268 |
+
branches (torch.nn.ModuleList): Transformer branches in which each branch estimate one objective metirc score.
|
| 269 |
+
"""
|
| 270 |
+
|
| 271 |
+
def __init__(
|
| 272 |
+
self,
|
| 273 |
+
encoder: nn.Module,
|
| 274 |
+
dprnn: nn.Module,
|
| 275 |
+
branches: nn.ModuleList,
|
| 276 |
+
):
|
| 277 |
+
super(SquimObjective, self).__init__()
|
| 278 |
+
self.encoder = encoder
|
| 279 |
+
self.dprnn = dprnn
|
| 280 |
+
self.branches = branches
|
| 281 |
+
|
| 282 |
+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
| 283 |
+
"""
|
| 284 |
+
Args:
|
| 285 |
+
x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`.
|
| 286 |
+
|
| 287 |
+
Returns:
|
| 288 |
+
List(torch.Tensor): List of score Tenosrs. Each Tensor is with dimension `(batch,)`.
|
| 289 |
+
"""
|
| 290 |
+
if x.ndim != 2:
|
| 291 |
+
raise ValueError(
|
| 292 |
+
f"The input must be a 2D Tensor. Found dimension {x.ndim}."
|
| 293 |
+
)
|
| 294 |
+
x = x / (torch.mean(x**2, dim=1, keepdim=True) ** 0.5 * 20)
|
| 295 |
+
out = self.encoder(x)
|
| 296 |
+
out = self.dprnn(out)
|
| 297 |
+
scores = []
|
| 298 |
+
for branch in self.branches:
|
| 299 |
+
scores.append(branch(out).squeeze(dim=1))
|
| 300 |
+
return scores
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def _create_branch(d_model: int, nhead: int, metric: str) -> nn.modules.Module:
|
| 304 |
+
"""Create branch module after DPRNN model for predicting metric score.
|
| 305 |
+
|
| 306 |
+
Args:
|
| 307 |
+
d_model (int): The number of expected features in the input.
|
| 308 |
+
nhead (int): Number of heads in the multi-head attention model.
|
| 309 |
+
metric (str): The metric name to predict.
|
| 310 |
+
|
| 311 |
+
Returns:
|
| 312 |
+
(nn.Module): Returned module to predict corresponding metric score.
|
| 313 |
+
"""
|
| 314 |
+
layer1 = nn.TransformerEncoderLayer(
|
| 315 |
+
d_model, nhead, d_model * 4, dropout=0.0, batch_first=True
|
| 316 |
+
)
|
| 317 |
+
layer2 = AutoPool()
|
| 318 |
+
if metric == "stoi":
|
| 319 |
+
layer3 = nn.Sequential(
|
| 320 |
+
nn.Linear(d_model, d_model),
|
| 321 |
+
nn.PReLU(),
|
| 322 |
+
nn.Linear(d_model, 1),
|
| 323 |
+
RangeSigmoid(),
|
| 324 |
+
)
|
| 325 |
+
elif metric == "pesq":
|
| 326 |
+
layer3 = nn.Sequential(
|
| 327 |
+
nn.Linear(d_model, d_model),
|
| 328 |
+
nn.PReLU(),
|
| 329 |
+
nn.Linear(d_model, 1),
|
| 330 |
+
RangeSigmoid(val_range=PESQRange),
|
| 331 |
+
)
|
| 332 |
+
else:
|
| 333 |
+
layer3: nn.modules.Module = nn.Sequential(
|
| 334 |
+
nn.Linear(d_model, d_model), nn.PReLU(), nn.Linear(d_model, 1)
|
| 335 |
+
)
|
| 336 |
+
return nn.Sequential(layer1, layer2, layer3)
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def squim_objective_model(
|
| 340 |
+
feat_dim: int,
|
| 341 |
+
win_len: int,
|
| 342 |
+
d_model: int,
|
| 343 |
+
nhead: int,
|
| 344 |
+
hidden_dim: int,
|
| 345 |
+
num_blocks: int,
|
| 346 |
+
rnn_type: str,
|
| 347 |
+
chunk_size: int,
|
| 348 |
+
chunk_stride: Optional[int] = None,
|
| 349 |
+
) -> SquimObjective:
|
| 350 |
+
"""Build a custome :class:`torchaudio.prototype.models.SquimObjective` model.
|
| 351 |
+
|
| 352 |
+
Args:
|
| 353 |
+
feat_dim (int, optional): The feature dimension after Encoder module.
|
| 354 |
+
win_len (int): Kernel size in the Encoder module.
|
| 355 |
+
d_model (int): The number of expected features in the input.
|
| 356 |
+
nhead (int): Number of heads in the multi-head attention model.
|
| 357 |
+
hidden_dim (int): Hidden dimension in the RNN layer of DPRNN.
|
| 358 |
+
num_blocks (int): Number of DPRNN layers.
|
| 359 |
+
rnn_type (str): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"].
|
| 360 |
+
chunk_size (int): Chunk size of input for DPRNN.
|
| 361 |
+
chunk_stride (int or None, optional): Stride of chunk input for DPRNN.
|
| 362 |
+
"""
|
| 363 |
+
if chunk_stride is None:
|
| 364 |
+
chunk_stride = chunk_size // 2
|
| 365 |
+
encoder = Encoder(feat_dim, win_len)
|
| 366 |
+
dprnn = DPRNN(
|
| 367 |
+
feat_dim, hidden_dim, num_blocks, rnn_type, d_model, chunk_size, chunk_stride
|
| 368 |
+
)
|
| 369 |
+
branches = nn.ModuleList(
|
| 370 |
+
[
|
| 371 |
+
_create_branch(d_model, nhead, "stoi"),
|
| 372 |
+
_create_branch(d_model, nhead, "pesq"),
|
| 373 |
+
_create_branch(d_model, nhead, "sisdr"),
|
| 374 |
+
]
|
| 375 |
+
)
|
| 376 |
+
return SquimObjective(encoder, dprnn, branches)
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
def squim_objective_base() -> SquimObjective:
|
| 380 |
+
"""Build :class:`torchaudio.prototype.models.SquimObjective` model with default arguments."""
|
| 381 |
+
return squim_objective_model(
|
| 382 |
+
feat_dim=256,
|
| 383 |
+
win_len=64,
|
| 384 |
+
d_model=256,
|
| 385 |
+
nhead=4,
|
| 386 |
+
hidden_dim=256,
|
| 387 |
+
num_blocks=2,
|
| 388 |
+
rnn_type="LSTM",
|
| 389 |
+
chunk_size=71,
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
@dataclass
|
| 394 |
+
class SquimObjectiveBundle:
|
| 395 |
+
|
| 396 |
+
_path: str
|
| 397 |
+
_sample_rate: float
|
| 398 |
+
|
| 399 |
+
def _get_state_dict(self, dl_kwargs):
|
| 400 |
+
url = f"https://download.pytorch.org/torchaudio/models/{self._path}"
|
| 401 |
+
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
|
| 402 |
+
state_dict = load_state_dict_from_url(url, **dl_kwargs)
|
| 403 |
+
return state_dict
|
| 404 |
+
|
| 405 |
+
def get_model(self, *, dl_kwargs=None) -> SquimObjective:
|
| 406 |
+
"""Construct the SquimObjective model, and load the pretrained weight.
|
| 407 |
+
|
| 408 |
+
The weight file is downloaded from the internet and cached with
|
| 409 |
+
:func:`torch.hub.load_state_dict_from_url`
|
| 410 |
+
|
| 411 |
+
Args:
|
| 412 |
+
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
|
| 413 |
+
|
| 414 |
+
Returns:
|
| 415 |
+
Variation of :py:class:`~torchaudio.models.SquimObjective`.
|
| 416 |
+
"""
|
| 417 |
+
model = squim_objective_base()
|
| 418 |
+
model.load_state_dict(self._get_state_dict(dl_kwargs))
|
| 419 |
+
model.eval()
|
| 420 |
+
return model
|
| 421 |
+
|
| 422 |
+
@property
|
| 423 |
+
def sample_rate(self):
|
| 424 |
+
"""Sample rate of the audio that the model is trained on.
|
| 425 |
+
|
| 426 |
+
:type: float
|
| 427 |
+
"""
|
| 428 |
+
return self._sample_rate
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
SQUIM_OBJECTIVE = SquimObjectiveBundle(
|
| 432 |
+
"squim_objective_dns2020.pth",
|
| 433 |
+
_sample_rate=16000,
|
| 434 |
+
)
|
| 435 |
+
SQUIM_OBJECTIVE.__doc__ = """SquimObjective pipeline trained using approach described in
|
| 436 |
+
:cite:`kumar2023torchaudio` on the *DNS 2020 Dataset* :cite:`reddy2020interspeech`.
|
| 437 |
+
|
| 438 |
+
The underlying model is constructed by :py:func:`torchaudio.models.squim_objective_base`.
|
| 439 |
+
The weights are under `Creative Commons Attribution 4.0 International License
|
| 440 |
+
<https://github.com/microsoft/DNS-Challenge/blob/interspeech2020/master/LICENSE>`__.
|
| 441 |
+
|
| 442 |
+
Please refer to :py:class:`SquimObjectiveBundle` for usage instructions.
|
| 443 |
+
"""
|
mvsepless/models/bandit/core/metrics/snr.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Callable
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torchmetrics as tm
|
| 6 |
+
from torch._C import _LinAlgError
|
| 7 |
+
from torchmetrics import functional as tmF
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SafeSignalDistortionRatio(tm.SignalDistortionRatio):
|
| 11 |
+
def __init__(self, **kwargs) -> None:
|
| 12 |
+
super().__init__(**kwargs)
|
| 13 |
+
|
| 14 |
+
def update(self, *args, **kwargs) -> Any:
|
| 15 |
+
try:
|
| 16 |
+
super().update(*args, **kwargs)
|
| 17 |
+
except:
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
def compute(self) -> Any:
|
| 21 |
+
if self.total == 0:
|
| 22 |
+
return torch.tensor(torch.nan)
|
| 23 |
+
return super().compute()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class BaseChunkMedianSignalRatio(tm.Metric):
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
func: Callable,
|
| 30 |
+
window_size: int,
|
| 31 |
+
hop_size: int = None,
|
| 32 |
+
zero_mean: bool = False,
|
| 33 |
+
) -> None:
|
| 34 |
+
super().__init__()
|
| 35 |
+
|
| 36 |
+
# self.zero_mean = zero_mean
|
| 37 |
+
self.func = func
|
| 38 |
+
self.window_size = window_size
|
| 39 |
+
if hop_size is None:
|
| 40 |
+
hop_size = window_size
|
| 41 |
+
self.hop_size = hop_size
|
| 42 |
+
|
| 43 |
+
self.add_state("sum_snr", default=torch.tensor(0.0), dist_reduce_fx="sum")
|
| 44 |
+
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
|
| 45 |
+
|
| 46 |
+
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
|
| 47 |
+
|
| 48 |
+
n_samples = target.shape[-1]
|
| 49 |
+
|
| 50 |
+
n_chunks = int(np.ceil((n_samples - self.window_size) / self.hop_size) + 1)
|
| 51 |
+
|
| 52 |
+
snr_chunk = []
|
| 53 |
+
|
| 54 |
+
for i in range(n_chunks):
|
| 55 |
+
start = i * self.hop_size
|
| 56 |
+
|
| 57 |
+
if n_samples - start < self.window_size:
|
| 58 |
+
continue
|
| 59 |
+
|
| 60 |
+
end = start + self.window_size
|
| 61 |
+
|
| 62 |
+
try:
|
| 63 |
+
chunk_snr = self.func(preds[..., start:end], target[..., start:end])
|
| 64 |
+
|
| 65 |
+
# print(preds.shape, chunk_snr.shape)
|
| 66 |
+
|
| 67 |
+
if torch.all(torch.isfinite(chunk_snr)):
|
| 68 |
+
snr_chunk.append(chunk_snr)
|
| 69 |
+
except _LinAlgError:
|
| 70 |
+
pass
|
| 71 |
+
|
| 72 |
+
snr_chunk = torch.stack(snr_chunk, dim=-1)
|
| 73 |
+
snr_batch, _ = torch.nanmedian(snr_chunk, dim=-1)
|
| 74 |
+
|
| 75 |
+
self.sum_snr += snr_batch.sum()
|
| 76 |
+
self.total += snr_batch.numel()
|
| 77 |
+
|
| 78 |
+
def compute(self) -> Any:
|
| 79 |
+
return self.sum_snr / self.total
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class ChunkMedianSignalNoiseRatio(BaseChunkMedianSignalRatio):
|
| 83 |
+
def __init__(
|
| 84 |
+
self, window_size: int, hop_size: int = None, zero_mean: bool = False
|
| 85 |
+
) -> None:
|
| 86 |
+
super().__init__(
|
| 87 |
+
func=tmF.signal_noise_ratio,
|
| 88 |
+
window_size=window_size,
|
| 89 |
+
hop_size=hop_size,
|
| 90 |
+
zero_mean=zero_mean,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class ChunkMedianScaleInvariantSignalNoiseRatio(BaseChunkMedianSignalRatio):
|
| 95 |
+
def __init__(
|
| 96 |
+
self, window_size: int, hop_size: int = None, zero_mean: bool = False
|
| 97 |
+
) -> None:
|
| 98 |
+
super().__init__(
|
| 99 |
+
func=tmF.scale_invariant_signal_noise_ratio,
|
| 100 |
+
window_size=window_size,
|
| 101 |
+
hop_size=hop_size,
|
| 102 |
+
zero_mean=zero_mean,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class ChunkMedianSignalDistortionRatio(BaseChunkMedianSignalRatio):
|
| 107 |
+
def __init__(
|
| 108 |
+
self, window_size: int, hop_size: int = None, zero_mean: bool = False
|
| 109 |
+
) -> None:
|
| 110 |
+
super().__init__(
|
| 111 |
+
func=tmF.signal_distortion_ratio,
|
| 112 |
+
window_size=window_size,
|
| 113 |
+
hop_size=hop_size,
|
| 114 |
+
zero_mean=zero_mean,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class ChunkMedianScaleInvariantSignalDistortionRatio(BaseChunkMedianSignalRatio):
|
| 119 |
+
def __init__(
|
| 120 |
+
self, window_size: int, hop_size: int = None, zero_mean: bool = False
|
| 121 |
+
) -> None:
|
| 122 |
+
super().__init__(
|
| 123 |
+
func=tmF.scale_invariant_signal_distortion_ratio,
|
| 124 |
+
window_size=window_size,
|
| 125 |
+
hop_size=hop_size,
|
| 126 |
+
zero_mean=zero_mean,
|
| 127 |
+
)
|
mvsepless/models/bandit/core/model/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .bsrnn.wrapper import (
|
| 2 |
+
MultiMaskMultiSourceBandSplitRNNSimple,
|
| 3 |
+
)
|
mvsepless/models/bandit/core/model/_spectral.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torchaudio as ta
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class _SpectralComponent(nn.Module):
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
n_fft: int = 2048,
|
| 12 |
+
win_length: Optional[int] = 2048,
|
| 13 |
+
hop_length: int = 512,
|
| 14 |
+
window_fn: str = "hann_window",
|
| 15 |
+
wkwargs: Optional[Dict] = None,
|
| 16 |
+
power: Optional[int] = None,
|
| 17 |
+
center: bool = True,
|
| 18 |
+
normalized: bool = True,
|
| 19 |
+
pad_mode: str = "constant",
|
| 20 |
+
onesided: bool = True,
|
| 21 |
+
**kwargs,
|
| 22 |
+
) -> None:
|
| 23 |
+
super().__init__()
|
| 24 |
+
|
| 25 |
+
assert power is None
|
| 26 |
+
|
| 27 |
+
window_fn = torch.__dict__[window_fn]
|
| 28 |
+
|
| 29 |
+
self.stft = ta.transforms.Spectrogram(
|
| 30 |
+
n_fft=n_fft,
|
| 31 |
+
win_length=win_length,
|
| 32 |
+
hop_length=hop_length,
|
| 33 |
+
pad_mode=pad_mode,
|
| 34 |
+
pad=0,
|
| 35 |
+
window_fn=window_fn,
|
| 36 |
+
wkwargs=wkwargs,
|
| 37 |
+
power=power,
|
| 38 |
+
normalized=normalized,
|
| 39 |
+
center=center,
|
| 40 |
+
onesided=onesided,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
self.istft = ta.transforms.InverseSpectrogram(
|
| 44 |
+
n_fft=n_fft,
|
| 45 |
+
win_length=win_length,
|
| 46 |
+
hop_length=hop_length,
|
| 47 |
+
pad_mode=pad_mode,
|
| 48 |
+
pad=0,
|
| 49 |
+
window_fn=window_fn,
|
| 50 |
+
wkwargs=wkwargs,
|
| 51 |
+
normalized=normalized,
|
| 52 |
+
center=center,
|
| 53 |
+
onesided=onesided,
|
| 54 |
+
)
|
mvsepless/models/bandit/core/model/bsrnn/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC
|
| 2 |
+
from typing import Iterable, Mapping, Union
|
| 3 |
+
|
| 4 |
+
from torch import nn
|
| 5 |
+
|
| 6 |
+
from .bandsplit import BandSplitModule
|
| 7 |
+
from .tfmodel import (
|
| 8 |
+
SeqBandModellingModule,
|
| 9 |
+
TransformerTimeFreqModule,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class BandsplitCoreBase(nn.Module, ABC):
|
| 14 |
+
band_split: nn.Module
|
| 15 |
+
tf_model: nn.Module
|
| 16 |
+
mask_estim: Union[nn.Module, Mapping[str, nn.Module], Iterable[nn.Module]]
|
| 17 |
+
|
| 18 |
+
def __init__(self) -> None:
|
| 19 |
+
super().__init__()
|
| 20 |
+
|
| 21 |
+
@staticmethod
|
| 22 |
+
def mask(x, m):
|
| 23 |
+
return x * m
|
mvsepless/models/bandit/core/model/bsrnn/bandsplit.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
|
| 6 |
+
from .utils import (
|
| 7 |
+
band_widths_from_specs,
|
| 8 |
+
check_no_gap,
|
| 9 |
+
check_no_overlap,
|
| 10 |
+
check_nonzero_bandwidth,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class NormFC(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
emb_dim: int,
|
| 18 |
+
bandwidth: int,
|
| 19 |
+
in_channel: int,
|
| 20 |
+
normalize_channel_independently: bool = False,
|
| 21 |
+
treat_channel_as_feature: bool = True,
|
| 22 |
+
) -> None:
|
| 23 |
+
super().__init__()
|
| 24 |
+
|
| 25 |
+
self.treat_channel_as_feature = treat_channel_as_feature
|
| 26 |
+
|
| 27 |
+
if normalize_channel_independently:
|
| 28 |
+
raise NotImplementedError
|
| 29 |
+
|
| 30 |
+
reim = 2
|
| 31 |
+
|
| 32 |
+
self.norm = nn.LayerNorm(in_channel * bandwidth * reim)
|
| 33 |
+
|
| 34 |
+
fc_in = bandwidth * reim
|
| 35 |
+
|
| 36 |
+
if treat_channel_as_feature:
|
| 37 |
+
fc_in *= in_channel
|
| 38 |
+
else:
|
| 39 |
+
assert emb_dim % in_channel == 0
|
| 40 |
+
emb_dim = emb_dim // in_channel
|
| 41 |
+
|
| 42 |
+
self.fc = nn.Linear(fc_in, emb_dim)
|
| 43 |
+
|
| 44 |
+
def forward(self, xb):
|
| 45 |
+
# xb = (batch, n_time, in_chan, reim * band_width)
|
| 46 |
+
|
| 47 |
+
batch, n_time, in_chan, ribw = xb.shape
|
| 48 |
+
xb = self.norm(xb.reshape(batch, n_time, in_chan * ribw))
|
| 49 |
+
# (batch, n_time, in_chan * reim * band_width)
|
| 50 |
+
|
| 51 |
+
if not self.treat_channel_as_feature:
|
| 52 |
+
xb = xb.reshape(batch, n_time, in_chan, ribw)
|
| 53 |
+
# (batch, n_time, in_chan, reim * band_width)
|
| 54 |
+
|
| 55 |
+
zb = self.fc(xb)
|
| 56 |
+
# (batch, n_time, emb_dim)
|
| 57 |
+
# OR
|
| 58 |
+
# (batch, n_time, in_chan, emb_dim_per_chan)
|
| 59 |
+
|
| 60 |
+
if not self.treat_channel_as_feature:
|
| 61 |
+
batch, n_time, in_chan, emb_dim_per_chan = zb.shape
|
| 62 |
+
# (batch, n_time, in_chan, emb_dim_per_chan)
|
| 63 |
+
zb = zb.reshape((batch, n_time, in_chan * emb_dim_per_chan))
|
| 64 |
+
|
| 65 |
+
return zb # (batch, n_time, emb_dim)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class BandSplitModule(nn.Module):
|
| 69 |
+
def __init__(
|
| 70 |
+
self,
|
| 71 |
+
band_specs: List[Tuple[float, float]],
|
| 72 |
+
emb_dim: int,
|
| 73 |
+
in_channel: int,
|
| 74 |
+
require_no_overlap: bool = False,
|
| 75 |
+
require_no_gap: bool = True,
|
| 76 |
+
normalize_channel_independently: bool = False,
|
| 77 |
+
treat_channel_as_feature: bool = True,
|
| 78 |
+
) -> None:
|
| 79 |
+
super().__init__()
|
| 80 |
+
|
| 81 |
+
check_nonzero_bandwidth(band_specs)
|
| 82 |
+
|
| 83 |
+
if require_no_gap:
|
| 84 |
+
check_no_gap(band_specs)
|
| 85 |
+
|
| 86 |
+
if require_no_overlap:
|
| 87 |
+
check_no_overlap(band_specs)
|
| 88 |
+
|
| 89 |
+
self.band_specs = band_specs
|
| 90 |
+
# list of [fstart, fend) in index.
|
| 91 |
+
# Note that fend is exclusive.
|
| 92 |
+
self.band_widths = band_widths_from_specs(band_specs)
|
| 93 |
+
self.n_bands = len(band_specs)
|
| 94 |
+
self.emb_dim = emb_dim
|
| 95 |
+
|
| 96 |
+
self.norm_fc_modules = nn.ModuleList(
|
| 97 |
+
[ # type: ignore
|
| 98 |
+
(
|
| 99 |
+
NormFC(
|
| 100 |
+
emb_dim=emb_dim,
|
| 101 |
+
bandwidth=bw,
|
| 102 |
+
in_channel=in_channel,
|
| 103 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 104 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 105 |
+
)
|
| 106 |
+
)
|
| 107 |
+
for bw in self.band_widths
|
| 108 |
+
]
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
def forward(self, x: torch.Tensor):
|
| 112 |
+
# x = complex spectrogram (batch, in_chan, n_freq, n_time)
|
| 113 |
+
|
| 114 |
+
batch, in_chan, _, n_time = x.shape
|
| 115 |
+
|
| 116 |
+
z = torch.zeros(
|
| 117 |
+
size=(batch, self.n_bands, n_time, self.emb_dim), device=x.device
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
xr = torch.view_as_real(x) # batch, in_chan, n_freq, n_time, 2
|
| 121 |
+
xr = torch.permute(xr, (0, 3, 1, 4, 2)) # batch, n_time, in_chan, 2, n_freq
|
| 122 |
+
batch, n_time, in_chan, reim, band_width = xr.shape
|
| 123 |
+
for i, nfm in enumerate(self.norm_fc_modules):
|
| 124 |
+
# print(f"bandsplit/band{i:02d}")
|
| 125 |
+
fstart, fend = self.band_specs[i]
|
| 126 |
+
xb = xr[..., fstart:fend]
|
| 127 |
+
# (batch, n_time, in_chan, reim, band_width)
|
| 128 |
+
xb = torch.reshape(xb, (batch, n_time, in_chan, -1))
|
| 129 |
+
# (batch, n_time, in_chan, reim * band_width)
|
| 130 |
+
# z.append(nfm(xb)) # (batch, n_time, emb_dim)
|
| 131 |
+
z[:, i, :, :] = nfm(xb.contiguous())
|
| 132 |
+
|
| 133 |
+
# z = torch.stack(z, dim=1)
|
| 134 |
+
|
| 135 |
+
return z
|
mvsepless/models/bandit/core/model/bsrnn/core.py
ADDED
|
@@ -0,0 +1,651 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Optional, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
|
| 7 |
+
from . import BandsplitCoreBase
|
| 8 |
+
from .bandsplit import BandSplitModule
|
| 9 |
+
from .maskestim import (
|
| 10 |
+
MaskEstimationModule,
|
| 11 |
+
OverlappingMaskEstimationModule,
|
| 12 |
+
)
|
| 13 |
+
from .tfmodel import (
|
| 14 |
+
ConvolutionalTimeFreqModule,
|
| 15 |
+
SeqBandModellingModule,
|
| 16 |
+
TransformerTimeFreqModule,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class MultiMaskBandSplitCoreBase(BandsplitCoreBase):
|
| 21 |
+
def __init__(self) -> None:
|
| 22 |
+
super().__init__()
|
| 23 |
+
|
| 24 |
+
def forward(self, x, cond=None, compute_residual: bool = True):
|
| 25 |
+
# x = complex spectrogram (batch, in_chan, n_freq, n_time)
|
| 26 |
+
# print(x.shape)
|
| 27 |
+
batch, in_chan, n_freq, n_time = x.shape
|
| 28 |
+
x = torch.reshape(x, (-1, 1, n_freq, n_time))
|
| 29 |
+
|
| 30 |
+
z = self.band_split(x) # (batch, emb_dim, n_band, n_time)
|
| 31 |
+
|
| 32 |
+
# if torch.any(torch.isnan(z)):
|
| 33 |
+
# raise ValueError("z nan")
|
| 34 |
+
|
| 35 |
+
# print(z)
|
| 36 |
+
q = self.tf_model(z) # (batch, emb_dim, n_band, n_time)
|
| 37 |
+
# print(q)
|
| 38 |
+
|
| 39 |
+
# if torch.any(torch.isnan(q)):
|
| 40 |
+
# raise ValueError("q nan")
|
| 41 |
+
|
| 42 |
+
out = {}
|
| 43 |
+
|
| 44 |
+
for stem, mem in self.mask_estim.items():
|
| 45 |
+
m = mem(q, cond=cond)
|
| 46 |
+
|
| 47 |
+
# if torch.any(torch.isnan(m)):
|
| 48 |
+
# raise ValueError("m nan", stem)
|
| 49 |
+
|
| 50 |
+
s = self.mask(x, m)
|
| 51 |
+
s = torch.reshape(s, (batch, in_chan, n_freq, n_time))
|
| 52 |
+
out[stem] = s
|
| 53 |
+
|
| 54 |
+
return {"spectrogram": out}
|
| 55 |
+
|
| 56 |
+
def instantiate_mask_estim(
|
| 57 |
+
self,
|
| 58 |
+
in_channel: int,
|
| 59 |
+
stems: List[str],
|
| 60 |
+
band_specs: List[Tuple[float, float]],
|
| 61 |
+
emb_dim: int,
|
| 62 |
+
mlp_dim: int,
|
| 63 |
+
cond_dim: int,
|
| 64 |
+
hidden_activation: str,
|
| 65 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 66 |
+
complex_mask: bool = True,
|
| 67 |
+
overlapping_band: bool = False,
|
| 68 |
+
freq_weights: Optional[List[torch.Tensor]] = None,
|
| 69 |
+
n_freq: Optional[int] = None,
|
| 70 |
+
use_freq_weights: bool = True,
|
| 71 |
+
mult_add_mask: bool = False,
|
| 72 |
+
):
|
| 73 |
+
if hidden_activation_kwargs is None:
|
| 74 |
+
hidden_activation_kwargs = {}
|
| 75 |
+
|
| 76 |
+
if "mne:+" in stems:
|
| 77 |
+
stems = [s for s in stems if s != "mne:+"]
|
| 78 |
+
|
| 79 |
+
if overlapping_band:
|
| 80 |
+
assert freq_weights is not None
|
| 81 |
+
assert n_freq is not None
|
| 82 |
+
|
| 83 |
+
if mult_add_mask:
|
| 84 |
+
|
| 85 |
+
self.mask_estim = nn.ModuleDict(
|
| 86 |
+
{
|
| 87 |
+
stem: MultAddMaskEstimationModule(
|
| 88 |
+
band_specs=band_specs,
|
| 89 |
+
freq_weights=freq_weights,
|
| 90 |
+
n_freq=n_freq,
|
| 91 |
+
emb_dim=emb_dim,
|
| 92 |
+
mlp_dim=mlp_dim,
|
| 93 |
+
in_channel=in_channel,
|
| 94 |
+
hidden_activation=hidden_activation,
|
| 95 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 96 |
+
complex_mask=complex_mask,
|
| 97 |
+
use_freq_weights=use_freq_weights,
|
| 98 |
+
)
|
| 99 |
+
for stem in stems
|
| 100 |
+
}
|
| 101 |
+
)
|
| 102 |
+
else:
|
| 103 |
+
self.mask_estim = nn.ModuleDict(
|
| 104 |
+
{
|
| 105 |
+
stem: OverlappingMaskEstimationModule(
|
| 106 |
+
band_specs=band_specs,
|
| 107 |
+
freq_weights=freq_weights,
|
| 108 |
+
n_freq=n_freq,
|
| 109 |
+
emb_dim=emb_dim,
|
| 110 |
+
mlp_dim=mlp_dim,
|
| 111 |
+
in_channel=in_channel,
|
| 112 |
+
hidden_activation=hidden_activation,
|
| 113 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 114 |
+
complex_mask=complex_mask,
|
| 115 |
+
use_freq_weights=use_freq_weights,
|
| 116 |
+
)
|
| 117 |
+
for stem in stems
|
| 118 |
+
}
|
| 119 |
+
)
|
| 120 |
+
else:
|
| 121 |
+
self.mask_estim = nn.ModuleDict(
|
| 122 |
+
{
|
| 123 |
+
stem: MaskEstimationModule(
|
| 124 |
+
band_specs=band_specs,
|
| 125 |
+
emb_dim=emb_dim,
|
| 126 |
+
mlp_dim=mlp_dim,
|
| 127 |
+
cond_dim=cond_dim,
|
| 128 |
+
in_channel=in_channel,
|
| 129 |
+
hidden_activation=hidden_activation,
|
| 130 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 131 |
+
complex_mask=complex_mask,
|
| 132 |
+
)
|
| 133 |
+
for stem in stems
|
| 134 |
+
}
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
def instantiate_bandsplit(
|
| 138 |
+
self,
|
| 139 |
+
in_channel: int,
|
| 140 |
+
band_specs: List[Tuple[float, float]],
|
| 141 |
+
require_no_overlap: bool = False,
|
| 142 |
+
require_no_gap: bool = True,
|
| 143 |
+
normalize_channel_independently: bool = False,
|
| 144 |
+
treat_channel_as_feature: bool = True,
|
| 145 |
+
emb_dim: int = 128,
|
| 146 |
+
):
|
| 147 |
+
self.band_split = BandSplitModule(
|
| 148 |
+
in_channel=in_channel,
|
| 149 |
+
band_specs=band_specs,
|
| 150 |
+
require_no_overlap=require_no_overlap,
|
| 151 |
+
require_no_gap=require_no_gap,
|
| 152 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 153 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 154 |
+
emb_dim=emb_dim,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class SingleMaskBandsplitCoreBase(BandsplitCoreBase):
|
| 159 |
+
def __init__(self, **kwargs) -> None:
|
| 160 |
+
super().__init__()
|
| 161 |
+
|
| 162 |
+
def forward(self, x):
|
| 163 |
+
# x = complex spectrogram (batch, in_chan, n_freq, n_time)
|
| 164 |
+
z = self.band_split(x) # (batch, emb_dim, n_band, n_time)
|
| 165 |
+
q = self.tf_model(z) # (batch, emb_dim, n_band, n_time)
|
| 166 |
+
m = self.mask_estim(q) # (batch, in_chan, n_freq, n_time)
|
| 167 |
+
|
| 168 |
+
s = self.mask(x, m)
|
| 169 |
+
|
| 170 |
+
return s
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class SingleMaskBandsplitCoreRNN(
|
| 174 |
+
SingleMaskBandsplitCoreBase,
|
| 175 |
+
):
|
| 176 |
+
def __init__(
|
| 177 |
+
self,
|
| 178 |
+
in_channel: int,
|
| 179 |
+
band_specs: List[Tuple[float, float]],
|
| 180 |
+
require_no_overlap: bool = False,
|
| 181 |
+
require_no_gap: bool = True,
|
| 182 |
+
normalize_channel_independently: bool = False,
|
| 183 |
+
treat_channel_as_feature: bool = True,
|
| 184 |
+
n_sqm_modules: int = 12,
|
| 185 |
+
emb_dim: int = 128,
|
| 186 |
+
rnn_dim: int = 256,
|
| 187 |
+
bidirectional: bool = True,
|
| 188 |
+
rnn_type: str = "LSTM",
|
| 189 |
+
mlp_dim: int = 512,
|
| 190 |
+
hidden_activation: str = "Tanh",
|
| 191 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 192 |
+
complex_mask: bool = True,
|
| 193 |
+
) -> None:
|
| 194 |
+
super().__init__()
|
| 195 |
+
self.band_split = BandSplitModule(
|
| 196 |
+
in_channel=in_channel,
|
| 197 |
+
band_specs=band_specs,
|
| 198 |
+
require_no_overlap=require_no_overlap,
|
| 199 |
+
require_no_gap=require_no_gap,
|
| 200 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 201 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 202 |
+
emb_dim=emb_dim,
|
| 203 |
+
)
|
| 204 |
+
self.tf_model = SeqBandModellingModule(
|
| 205 |
+
n_modules=n_sqm_modules,
|
| 206 |
+
emb_dim=emb_dim,
|
| 207 |
+
rnn_dim=rnn_dim,
|
| 208 |
+
bidirectional=bidirectional,
|
| 209 |
+
rnn_type=rnn_type,
|
| 210 |
+
)
|
| 211 |
+
self.mask_estim = MaskEstimationModule(
|
| 212 |
+
in_channel=in_channel,
|
| 213 |
+
band_specs=band_specs,
|
| 214 |
+
emb_dim=emb_dim,
|
| 215 |
+
mlp_dim=mlp_dim,
|
| 216 |
+
hidden_activation=hidden_activation,
|
| 217 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 218 |
+
complex_mask=complex_mask,
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class SingleMaskBandsplitCoreTransformer(
|
| 223 |
+
SingleMaskBandsplitCoreBase,
|
| 224 |
+
):
|
| 225 |
+
def __init__(
|
| 226 |
+
self,
|
| 227 |
+
in_channel: int,
|
| 228 |
+
band_specs: List[Tuple[float, float]],
|
| 229 |
+
require_no_overlap: bool = False,
|
| 230 |
+
require_no_gap: bool = True,
|
| 231 |
+
normalize_channel_independently: bool = False,
|
| 232 |
+
treat_channel_as_feature: bool = True,
|
| 233 |
+
n_sqm_modules: int = 12,
|
| 234 |
+
emb_dim: int = 128,
|
| 235 |
+
rnn_dim: int = 256,
|
| 236 |
+
bidirectional: bool = True,
|
| 237 |
+
tf_dropout: float = 0.0,
|
| 238 |
+
mlp_dim: int = 512,
|
| 239 |
+
hidden_activation: str = "Tanh",
|
| 240 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 241 |
+
complex_mask: bool = True,
|
| 242 |
+
) -> None:
|
| 243 |
+
super().__init__()
|
| 244 |
+
self.band_split = BandSplitModule(
|
| 245 |
+
in_channel=in_channel,
|
| 246 |
+
band_specs=band_specs,
|
| 247 |
+
require_no_overlap=require_no_overlap,
|
| 248 |
+
require_no_gap=require_no_gap,
|
| 249 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 250 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 251 |
+
emb_dim=emb_dim,
|
| 252 |
+
)
|
| 253 |
+
self.tf_model = TransformerTimeFreqModule(
|
| 254 |
+
n_modules=n_sqm_modules,
|
| 255 |
+
emb_dim=emb_dim,
|
| 256 |
+
rnn_dim=rnn_dim,
|
| 257 |
+
bidirectional=bidirectional,
|
| 258 |
+
dropout=tf_dropout,
|
| 259 |
+
)
|
| 260 |
+
self.mask_estim = MaskEstimationModule(
|
| 261 |
+
in_channel=in_channel,
|
| 262 |
+
band_specs=band_specs,
|
| 263 |
+
emb_dim=emb_dim,
|
| 264 |
+
mlp_dim=mlp_dim,
|
| 265 |
+
hidden_activation=hidden_activation,
|
| 266 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 267 |
+
complex_mask=complex_mask,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class MultiSourceMultiMaskBandSplitCoreRNN(MultiMaskBandSplitCoreBase):
|
| 272 |
+
def __init__(
|
| 273 |
+
self,
|
| 274 |
+
in_channel: int,
|
| 275 |
+
stems: List[str],
|
| 276 |
+
band_specs: List[Tuple[float, float]],
|
| 277 |
+
require_no_overlap: bool = False,
|
| 278 |
+
require_no_gap: bool = True,
|
| 279 |
+
normalize_channel_independently: bool = False,
|
| 280 |
+
treat_channel_as_feature: bool = True,
|
| 281 |
+
n_sqm_modules: int = 12,
|
| 282 |
+
emb_dim: int = 128,
|
| 283 |
+
rnn_dim: int = 256,
|
| 284 |
+
bidirectional: bool = True,
|
| 285 |
+
rnn_type: str = "LSTM",
|
| 286 |
+
mlp_dim: int = 512,
|
| 287 |
+
cond_dim: int = 0,
|
| 288 |
+
hidden_activation: str = "Tanh",
|
| 289 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 290 |
+
complex_mask: bool = True,
|
| 291 |
+
overlapping_band: bool = False,
|
| 292 |
+
freq_weights: Optional[List[torch.Tensor]] = None,
|
| 293 |
+
n_freq: Optional[int] = None,
|
| 294 |
+
use_freq_weights: bool = True,
|
| 295 |
+
mult_add_mask: bool = False,
|
| 296 |
+
) -> None:
|
| 297 |
+
|
| 298 |
+
super().__init__()
|
| 299 |
+
self.instantiate_bandsplit(
|
| 300 |
+
in_channel=in_channel,
|
| 301 |
+
band_specs=band_specs,
|
| 302 |
+
require_no_overlap=require_no_overlap,
|
| 303 |
+
require_no_gap=require_no_gap,
|
| 304 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 305 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 306 |
+
emb_dim=emb_dim,
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
self.tf_model = SeqBandModellingModule(
|
| 310 |
+
n_modules=n_sqm_modules,
|
| 311 |
+
emb_dim=emb_dim,
|
| 312 |
+
rnn_dim=rnn_dim,
|
| 313 |
+
bidirectional=bidirectional,
|
| 314 |
+
rnn_type=rnn_type,
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
self.mult_add_mask = mult_add_mask
|
| 318 |
+
|
| 319 |
+
self.instantiate_mask_estim(
|
| 320 |
+
in_channel=in_channel,
|
| 321 |
+
stems=stems,
|
| 322 |
+
band_specs=band_specs,
|
| 323 |
+
emb_dim=emb_dim,
|
| 324 |
+
mlp_dim=mlp_dim,
|
| 325 |
+
cond_dim=cond_dim,
|
| 326 |
+
hidden_activation=hidden_activation,
|
| 327 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 328 |
+
complex_mask=complex_mask,
|
| 329 |
+
overlapping_band=overlapping_band,
|
| 330 |
+
freq_weights=freq_weights,
|
| 331 |
+
n_freq=n_freq,
|
| 332 |
+
use_freq_weights=use_freq_weights,
|
| 333 |
+
mult_add_mask=mult_add_mask,
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
@staticmethod
|
| 337 |
+
def _mult_add_mask(x, m):
|
| 338 |
+
|
| 339 |
+
assert m.ndim == 5
|
| 340 |
+
|
| 341 |
+
mm = m[..., 0]
|
| 342 |
+
am = m[..., 1]
|
| 343 |
+
|
| 344 |
+
# print(mm.shape, am.shape, x.shape, m.shape)
|
| 345 |
+
|
| 346 |
+
return x * mm + am
|
| 347 |
+
|
| 348 |
+
def mask(self, x, m):
|
| 349 |
+
if self.mult_add_mask:
|
| 350 |
+
|
| 351 |
+
return self._mult_add_mask(x, m)
|
| 352 |
+
else:
|
| 353 |
+
return super().mask(x, m)
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
class MultiSourceMultiMaskBandSplitCoreTransformer(
|
| 357 |
+
MultiMaskBandSplitCoreBase,
|
| 358 |
+
):
|
| 359 |
+
def __init__(
|
| 360 |
+
self,
|
| 361 |
+
in_channel: int,
|
| 362 |
+
stems: List[str],
|
| 363 |
+
band_specs: List[Tuple[float, float]],
|
| 364 |
+
require_no_overlap: bool = False,
|
| 365 |
+
require_no_gap: bool = True,
|
| 366 |
+
normalize_channel_independently: bool = False,
|
| 367 |
+
treat_channel_as_feature: bool = True,
|
| 368 |
+
n_sqm_modules: int = 12,
|
| 369 |
+
emb_dim: int = 128,
|
| 370 |
+
rnn_dim: int = 256,
|
| 371 |
+
bidirectional: bool = True,
|
| 372 |
+
tf_dropout: float = 0.0,
|
| 373 |
+
mlp_dim: int = 512,
|
| 374 |
+
hidden_activation: str = "Tanh",
|
| 375 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 376 |
+
complex_mask: bool = True,
|
| 377 |
+
overlapping_band: bool = False,
|
| 378 |
+
freq_weights: Optional[List[torch.Tensor]] = None,
|
| 379 |
+
n_freq: Optional[int] = None,
|
| 380 |
+
use_freq_weights: bool = True,
|
| 381 |
+
rnn_type: str = "LSTM",
|
| 382 |
+
cond_dim: int = 0,
|
| 383 |
+
mult_add_mask: bool = False,
|
| 384 |
+
) -> None:
|
| 385 |
+
super().__init__()
|
| 386 |
+
self.instantiate_bandsplit(
|
| 387 |
+
in_channel=in_channel,
|
| 388 |
+
band_specs=band_specs,
|
| 389 |
+
require_no_overlap=require_no_overlap,
|
| 390 |
+
require_no_gap=require_no_gap,
|
| 391 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 392 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 393 |
+
emb_dim=emb_dim,
|
| 394 |
+
)
|
| 395 |
+
self.tf_model = TransformerTimeFreqModule(
|
| 396 |
+
n_modules=n_sqm_modules,
|
| 397 |
+
emb_dim=emb_dim,
|
| 398 |
+
rnn_dim=rnn_dim,
|
| 399 |
+
bidirectional=bidirectional,
|
| 400 |
+
dropout=tf_dropout,
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
self.instantiate_mask_estim(
|
| 404 |
+
in_channel=in_channel,
|
| 405 |
+
stems=stems,
|
| 406 |
+
band_specs=band_specs,
|
| 407 |
+
emb_dim=emb_dim,
|
| 408 |
+
mlp_dim=mlp_dim,
|
| 409 |
+
cond_dim=cond_dim,
|
| 410 |
+
hidden_activation=hidden_activation,
|
| 411 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 412 |
+
complex_mask=complex_mask,
|
| 413 |
+
overlapping_band=overlapping_band,
|
| 414 |
+
freq_weights=freq_weights,
|
| 415 |
+
n_freq=n_freq,
|
| 416 |
+
use_freq_weights=use_freq_weights,
|
| 417 |
+
mult_add_mask=mult_add_mask,
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
class MultiSourceMultiMaskBandSplitCoreConv(
|
| 422 |
+
MultiMaskBandSplitCoreBase,
|
| 423 |
+
):
|
| 424 |
+
def __init__(
|
| 425 |
+
self,
|
| 426 |
+
in_channel: int,
|
| 427 |
+
stems: List[str],
|
| 428 |
+
band_specs: List[Tuple[float, float]],
|
| 429 |
+
require_no_overlap: bool = False,
|
| 430 |
+
require_no_gap: bool = True,
|
| 431 |
+
normalize_channel_independently: bool = False,
|
| 432 |
+
treat_channel_as_feature: bool = True,
|
| 433 |
+
n_sqm_modules: int = 12,
|
| 434 |
+
emb_dim: int = 128,
|
| 435 |
+
rnn_dim: int = 256,
|
| 436 |
+
bidirectional: bool = True,
|
| 437 |
+
tf_dropout: float = 0.0,
|
| 438 |
+
mlp_dim: int = 512,
|
| 439 |
+
hidden_activation: str = "Tanh",
|
| 440 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 441 |
+
complex_mask: bool = True,
|
| 442 |
+
overlapping_band: bool = False,
|
| 443 |
+
freq_weights: Optional[List[torch.Tensor]] = None,
|
| 444 |
+
n_freq: Optional[int] = None,
|
| 445 |
+
use_freq_weights: bool = True,
|
| 446 |
+
rnn_type: str = "LSTM",
|
| 447 |
+
cond_dim: int = 0,
|
| 448 |
+
mult_add_mask: bool = False,
|
| 449 |
+
) -> None:
|
| 450 |
+
super().__init__()
|
| 451 |
+
self.instantiate_bandsplit(
|
| 452 |
+
in_channel=in_channel,
|
| 453 |
+
band_specs=band_specs,
|
| 454 |
+
require_no_overlap=require_no_overlap,
|
| 455 |
+
require_no_gap=require_no_gap,
|
| 456 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 457 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 458 |
+
emb_dim=emb_dim,
|
| 459 |
+
)
|
| 460 |
+
self.tf_model = ConvolutionalTimeFreqModule(
|
| 461 |
+
n_modules=n_sqm_modules,
|
| 462 |
+
emb_dim=emb_dim,
|
| 463 |
+
rnn_dim=rnn_dim,
|
| 464 |
+
bidirectional=bidirectional,
|
| 465 |
+
dropout=tf_dropout,
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
self.instantiate_mask_estim(
|
| 469 |
+
in_channel=in_channel,
|
| 470 |
+
stems=stems,
|
| 471 |
+
band_specs=band_specs,
|
| 472 |
+
emb_dim=emb_dim,
|
| 473 |
+
mlp_dim=mlp_dim,
|
| 474 |
+
cond_dim=cond_dim,
|
| 475 |
+
hidden_activation=hidden_activation,
|
| 476 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 477 |
+
complex_mask=complex_mask,
|
| 478 |
+
overlapping_band=overlapping_band,
|
| 479 |
+
freq_weights=freq_weights,
|
| 480 |
+
n_freq=n_freq,
|
| 481 |
+
use_freq_weights=use_freq_weights,
|
| 482 |
+
mult_add_mask=mult_add_mask,
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
class PatchingMaskBandsplitCoreBase(MultiMaskBandSplitCoreBase):
|
| 487 |
+
def __init__(self) -> None:
|
| 488 |
+
super().__init__()
|
| 489 |
+
|
| 490 |
+
def mask(self, x, m):
|
| 491 |
+
# x.shape = (batch, n_channel, n_freq, n_time)
|
| 492 |
+
# m.shape = (kernel_freq, kernel_time, batch, n_channel, n_freq, n_time)
|
| 493 |
+
|
| 494 |
+
_, n_channel, kernel_freq, kernel_time, n_freq, n_time = m.shape
|
| 495 |
+
padding = ((kernel_freq - 1) // 2, (kernel_time - 1) // 2)
|
| 496 |
+
|
| 497 |
+
xf = F.unfold(
|
| 498 |
+
x,
|
| 499 |
+
kernel_size=(kernel_freq, kernel_time),
|
| 500 |
+
padding=padding,
|
| 501 |
+
stride=(1, 1),
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
xf = xf.view(
|
| 505 |
+
-1,
|
| 506 |
+
n_channel,
|
| 507 |
+
kernel_freq,
|
| 508 |
+
kernel_time,
|
| 509 |
+
n_freq,
|
| 510 |
+
n_time,
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
sf = xf * m
|
| 514 |
+
|
| 515 |
+
sf = sf.view(
|
| 516 |
+
-1,
|
| 517 |
+
n_channel * kernel_freq * kernel_time,
|
| 518 |
+
n_freq * n_time,
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
s = F.fold(
|
| 522 |
+
sf,
|
| 523 |
+
output_size=(n_freq, n_time),
|
| 524 |
+
kernel_size=(kernel_freq, kernel_time),
|
| 525 |
+
padding=padding,
|
| 526 |
+
stride=(1, 1),
|
| 527 |
+
).view(
|
| 528 |
+
-1,
|
| 529 |
+
n_channel,
|
| 530 |
+
n_freq,
|
| 531 |
+
n_time,
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
return s
|
| 535 |
+
|
| 536 |
+
def old_mask(self, x, m):
|
| 537 |
+
# x.shape = (batch, n_channel, n_freq, n_time)
|
| 538 |
+
# m.shape = (kernel_freq, kernel_time, batch, n_channel, n_freq, n_time)
|
| 539 |
+
|
| 540 |
+
s = torch.zeros_like(x)
|
| 541 |
+
|
| 542 |
+
_, n_channel, n_freq, n_time = x.shape
|
| 543 |
+
kernel_freq, kernel_time, _, _, _, _ = m.shape
|
| 544 |
+
|
| 545 |
+
# print(x.shape, m.shape)
|
| 546 |
+
|
| 547 |
+
kernel_freq_half = (kernel_freq - 1) // 2
|
| 548 |
+
kernel_time_half = (kernel_time - 1) // 2
|
| 549 |
+
|
| 550 |
+
for ifreq in range(kernel_freq):
|
| 551 |
+
for itime in range(kernel_time):
|
| 552 |
+
df, dt = kernel_freq_half - ifreq, kernel_time_half - itime
|
| 553 |
+
x = x.roll(shifts=(df, dt), dims=(2, 3))
|
| 554 |
+
|
| 555 |
+
# if `df` > 0:
|
| 556 |
+
# x[:, :, :df, :] = 0
|
| 557 |
+
# elif `df` < 0:
|
| 558 |
+
# x[:, :, df:, :] = 0
|
| 559 |
+
|
| 560 |
+
# if `dt` > 0:
|
| 561 |
+
# x[:, :, :, :dt] = 0
|
| 562 |
+
# elif `dt` < 0:
|
| 563 |
+
# x[:, :, :, dt:] = 0
|
| 564 |
+
|
| 565 |
+
fslice = slice(max(0, df), min(n_freq, n_freq + df))
|
| 566 |
+
tslice = slice(max(0, dt), min(n_time, n_time + dt))
|
| 567 |
+
|
| 568 |
+
s[:, :, fslice, tslice] += (
|
| 569 |
+
x[:, :, fslice, tslice] * m[ifreq, itime, :, :, fslice, tslice]
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
return s
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
class MultiSourceMultiPatchingMaskBandSplitCoreRNN(PatchingMaskBandsplitCoreBase):
|
| 576 |
+
def __init__(
|
| 577 |
+
self,
|
| 578 |
+
in_channel: int,
|
| 579 |
+
stems: List[str],
|
| 580 |
+
band_specs: List[Tuple[float, float]],
|
| 581 |
+
mask_kernel_freq: int,
|
| 582 |
+
mask_kernel_time: int,
|
| 583 |
+
conv_kernel_freq: int,
|
| 584 |
+
conv_kernel_time: int,
|
| 585 |
+
kernel_norm_mlp_version: int,
|
| 586 |
+
require_no_overlap: bool = False,
|
| 587 |
+
require_no_gap: bool = True,
|
| 588 |
+
normalize_channel_independently: bool = False,
|
| 589 |
+
treat_channel_as_feature: bool = True,
|
| 590 |
+
n_sqm_modules: int = 12,
|
| 591 |
+
emb_dim: int = 128,
|
| 592 |
+
rnn_dim: int = 256,
|
| 593 |
+
bidirectional: bool = True,
|
| 594 |
+
rnn_type: str = "LSTM",
|
| 595 |
+
mlp_dim: int = 512,
|
| 596 |
+
hidden_activation: str = "Tanh",
|
| 597 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 598 |
+
complex_mask: bool = True,
|
| 599 |
+
overlapping_band: bool = False,
|
| 600 |
+
freq_weights: Optional[List[torch.Tensor]] = None,
|
| 601 |
+
n_freq: Optional[int] = None,
|
| 602 |
+
) -> None:
|
| 603 |
+
|
| 604 |
+
super().__init__()
|
| 605 |
+
self.band_split = BandSplitModule(
|
| 606 |
+
in_channel=in_channel,
|
| 607 |
+
band_specs=band_specs,
|
| 608 |
+
require_no_overlap=require_no_overlap,
|
| 609 |
+
require_no_gap=require_no_gap,
|
| 610 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 611 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 612 |
+
emb_dim=emb_dim,
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
self.tf_model = SeqBandModellingModule(
|
| 616 |
+
n_modules=n_sqm_modules,
|
| 617 |
+
emb_dim=emb_dim,
|
| 618 |
+
rnn_dim=rnn_dim,
|
| 619 |
+
bidirectional=bidirectional,
|
| 620 |
+
rnn_type=rnn_type,
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
if hidden_activation_kwargs is None:
|
| 624 |
+
hidden_activation_kwargs = {}
|
| 625 |
+
|
| 626 |
+
if overlapping_band:
|
| 627 |
+
assert freq_weights is not None
|
| 628 |
+
assert n_freq is not None
|
| 629 |
+
self.mask_estim = nn.ModuleDict(
|
| 630 |
+
{
|
| 631 |
+
stem: PatchingMaskEstimationModule(
|
| 632 |
+
band_specs=band_specs,
|
| 633 |
+
freq_weights=freq_weights,
|
| 634 |
+
n_freq=n_freq,
|
| 635 |
+
emb_dim=emb_dim,
|
| 636 |
+
mlp_dim=mlp_dim,
|
| 637 |
+
in_channel=in_channel,
|
| 638 |
+
hidden_activation=hidden_activation,
|
| 639 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 640 |
+
complex_mask=complex_mask,
|
| 641 |
+
mask_kernel_freq=mask_kernel_freq,
|
| 642 |
+
mask_kernel_time=mask_kernel_time,
|
| 643 |
+
conv_kernel_freq=conv_kernel_freq,
|
| 644 |
+
conv_kernel_time=conv_kernel_time,
|
| 645 |
+
kernel_norm_mlp_version=kernel_norm_mlp_version,
|
| 646 |
+
)
|
| 647 |
+
for stem in stems
|
| 648 |
+
}
|
| 649 |
+
)
|
| 650 |
+
else:
|
| 651 |
+
raise NotImplementedError
|
mvsepless/models/bandit/core/model/bsrnn/maskestim.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
from typing import Dict, List, Optional, Tuple, Type
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.nn.modules import activation
|
| 7 |
+
|
| 8 |
+
from .utils import (
|
| 9 |
+
band_widths_from_specs,
|
| 10 |
+
check_no_gap,
|
| 11 |
+
check_no_overlap,
|
| 12 |
+
check_nonzero_bandwidth,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class BaseNormMLP(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
emb_dim: int,
|
| 20 |
+
mlp_dim: int,
|
| 21 |
+
bandwidth: int,
|
| 22 |
+
in_channel: Optional[int],
|
| 23 |
+
hidden_activation: str = "Tanh",
|
| 24 |
+
hidden_activation_kwargs=None,
|
| 25 |
+
complex_mask: bool = True,
|
| 26 |
+
):
|
| 27 |
+
|
| 28 |
+
super().__init__()
|
| 29 |
+
if hidden_activation_kwargs is None:
|
| 30 |
+
hidden_activation_kwargs = {}
|
| 31 |
+
self.hidden_activation_kwargs = hidden_activation_kwargs
|
| 32 |
+
self.norm = nn.LayerNorm(emb_dim)
|
| 33 |
+
self.hidden = torch.jit.script(
|
| 34 |
+
nn.Sequential(
|
| 35 |
+
nn.Linear(in_features=emb_dim, out_features=mlp_dim),
|
| 36 |
+
activation.__dict__[hidden_activation](**self.hidden_activation_kwargs),
|
| 37 |
+
)
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
self.bandwidth = bandwidth
|
| 41 |
+
self.in_channel = in_channel
|
| 42 |
+
|
| 43 |
+
self.complex_mask = complex_mask
|
| 44 |
+
self.reim = 2 if complex_mask else 1
|
| 45 |
+
self.glu_mult = 2
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class NormMLP(BaseNormMLP):
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
emb_dim: int,
|
| 52 |
+
mlp_dim: int,
|
| 53 |
+
bandwidth: int,
|
| 54 |
+
in_channel: Optional[int],
|
| 55 |
+
hidden_activation: str = "Tanh",
|
| 56 |
+
hidden_activation_kwargs=None,
|
| 57 |
+
complex_mask: bool = True,
|
| 58 |
+
) -> None:
|
| 59 |
+
super().__init__(
|
| 60 |
+
emb_dim=emb_dim,
|
| 61 |
+
mlp_dim=mlp_dim,
|
| 62 |
+
bandwidth=bandwidth,
|
| 63 |
+
in_channel=in_channel,
|
| 64 |
+
hidden_activation=hidden_activation,
|
| 65 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 66 |
+
complex_mask=complex_mask,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
self.output = torch.jit.script(
|
| 70 |
+
nn.Sequential(
|
| 71 |
+
nn.Linear(
|
| 72 |
+
in_features=mlp_dim,
|
| 73 |
+
out_features=bandwidth * in_channel * self.reim * 2,
|
| 74 |
+
),
|
| 75 |
+
nn.GLU(dim=-1),
|
| 76 |
+
)
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
def reshape_output(self, mb):
|
| 80 |
+
# print(mb.shape)
|
| 81 |
+
batch, n_time, _ = mb.shape
|
| 82 |
+
if self.complex_mask:
|
| 83 |
+
mb = mb.reshape(
|
| 84 |
+
batch, n_time, self.in_channel, self.bandwidth, self.reim
|
| 85 |
+
).contiguous()
|
| 86 |
+
# print(mb.shape)
|
| 87 |
+
mb = torch.view_as_complex(mb) # (batch, n_time, in_channel, bandwidth)
|
| 88 |
+
else:
|
| 89 |
+
mb = mb.reshape(batch, n_time, self.in_channel, self.bandwidth)
|
| 90 |
+
|
| 91 |
+
mb = torch.permute(mb, (0, 2, 3, 1)) # (batch, in_channel, bandwidth, n_time)
|
| 92 |
+
|
| 93 |
+
return mb
|
| 94 |
+
|
| 95 |
+
def forward(self, qb):
|
| 96 |
+
# qb = (batch, n_time, emb_dim)
|
| 97 |
+
|
| 98 |
+
# if torch.any(torch.isnan(qb)):
|
| 99 |
+
# raise ValueError("qb0")
|
| 100 |
+
|
| 101 |
+
qb = self.norm(qb) # (batch, n_time, emb_dim)
|
| 102 |
+
|
| 103 |
+
# if torch.any(torch.isnan(qb)):
|
| 104 |
+
# raise ValueError("qb1")
|
| 105 |
+
|
| 106 |
+
qb = self.hidden(qb) # (batch, n_time, mlp_dim)
|
| 107 |
+
# if torch.any(torch.isnan(qb)):
|
| 108 |
+
# raise ValueError("qb2")
|
| 109 |
+
mb = self.output(qb) # (batch, n_time, bandwidth * in_channel * reim)
|
| 110 |
+
# if torch.any(torch.isnan(qb)):
|
| 111 |
+
# raise ValueError("mb")
|
| 112 |
+
mb = self.reshape_output(mb) # (batch, in_channel, bandwidth, n_time)
|
| 113 |
+
|
| 114 |
+
return mb
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class MultAddNormMLP(NormMLP):
|
| 118 |
+
def __init__(
|
| 119 |
+
self,
|
| 120 |
+
emb_dim: int,
|
| 121 |
+
mlp_dim: int,
|
| 122 |
+
bandwidth: int,
|
| 123 |
+
in_channel: "int | None",
|
| 124 |
+
hidden_activation: str = "Tanh",
|
| 125 |
+
hidden_activation_kwargs=None,
|
| 126 |
+
complex_mask: bool = True,
|
| 127 |
+
) -> None:
|
| 128 |
+
super().__init__(
|
| 129 |
+
emb_dim,
|
| 130 |
+
mlp_dim,
|
| 131 |
+
bandwidth,
|
| 132 |
+
in_channel,
|
| 133 |
+
hidden_activation,
|
| 134 |
+
hidden_activation_kwargs,
|
| 135 |
+
complex_mask,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
self.output2 = torch.jit.script(
|
| 139 |
+
nn.Sequential(
|
| 140 |
+
nn.Linear(
|
| 141 |
+
in_features=mlp_dim,
|
| 142 |
+
out_features=bandwidth * in_channel * self.reim * 2,
|
| 143 |
+
),
|
| 144 |
+
nn.GLU(dim=-1),
|
| 145 |
+
)
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
def forward(self, qb):
|
| 149 |
+
|
| 150 |
+
qb = self.norm(qb) # (batch, n_time, emb_dim)
|
| 151 |
+
qb = self.hidden(qb) # (batch, n_time, mlp_dim)
|
| 152 |
+
mmb = self.output(qb) # (batch, n_time, bandwidth * in_channel * reim)
|
| 153 |
+
mmb = self.reshape_output(mmb) # (batch, in_channel, bandwidth, n_time)
|
| 154 |
+
amb = self.output2(qb) # (batch, n_time, bandwidth * in_channel * reim)
|
| 155 |
+
amb = self.reshape_output(amb) # (batch, in_channel, bandwidth, n_time)
|
| 156 |
+
|
| 157 |
+
return mmb, amb
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class MaskEstimationModuleSuperBase(nn.Module):
|
| 161 |
+
pass
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class MaskEstimationModuleBase(MaskEstimationModuleSuperBase):
|
| 165 |
+
def __init__(
|
| 166 |
+
self,
|
| 167 |
+
band_specs: List[Tuple[float, float]],
|
| 168 |
+
emb_dim: int,
|
| 169 |
+
mlp_dim: int,
|
| 170 |
+
in_channel: Optional[int],
|
| 171 |
+
hidden_activation: str = "Tanh",
|
| 172 |
+
hidden_activation_kwargs: Dict = None,
|
| 173 |
+
complex_mask: bool = True,
|
| 174 |
+
norm_mlp_cls: Type[nn.Module] = NormMLP,
|
| 175 |
+
norm_mlp_kwargs: Dict = None,
|
| 176 |
+
) -> None:
|
| 177 |
+
super().__init__()
|
| 178 |
+
|
| 179 |
+
self.band_widths = band_widths_from_specs(band_specs)
|
| 180 |
+
self.n_bands = len(band_specs)
|
| 181 |
+
|
| 182 |
+
if hidden_activation_kwargs is None:
|
| 183 |
+
hidden_activation_kwargs = {}
|
| 184 |
+
|
| 185 |
+
if norm_mlp_kwargs is None:
|
| 186 |
+
norm_mlp_kwargs = {}
|
| 187 |
+
|
| 188 |
+
self.norm_mlp = nn.ModuleList(
|
| 189 |
+
[
|
| 190 |
+
(
|
| 191 |
+
norm_mlp_cls(
|
| 192 |
+
bandwidth=self.band_widths[b],
|
| 193 |
+
emb_dim=emb_dim,
|
| 194 |
+
mlp_dim=mlp_dim,
|
| 195 |
+
in_channel=in_channel,
|
| 196 |
+
hidden_activation=hidden_activation,
|
| 197 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 198 |
+
complex_mask=complex_mask,
|
| 199 |
+
**norm_mlp_kwargs,
|
| 200 |
+
)
|
| 201 |
+
)
|
| 202 |
+
for b in range(self.n_bands)
|
| 203 |
+
]
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
def compute_masks(self, q):
|
| 207 |
+
batch, n_bands, n_time, emb_dim = q.shape
|
| 208 |
+
|
| 209 |
+
masks = []
|
| 210 |
+
|
| 211 |
+
for b, nmlp in enumerate(self.norm_mlp):
|
| 212 |
+
# print(f"maskestim/{b:02d}")
|
| 213 |
+
qb = q[:, b, :, :]
|
| 214 |
+
mb = nmlp(qb)
|
| 215 |
+
masks.append(mb)
|
| 216 |
+
|
| 217 |
+
return masks
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class OverlappingMaskEstimationModule(MaskEstimationModuleBase):
|
| 221 |
+
def __init__(
|
| 222 |
+
self,
|
| 223 |
+
in_channel: int,
|
| 224 |
+
band_specs: List[Tuple[float, float]],
|
| 225 |
+
freq_weights: List[torch.Tensor],
|
| 226 |
+
n_freq: int,
|
| 227 |
+
emb_dim: int,
|
| 228 |
+
mlp_dim: int,
|
| 229 |
+
cond_dim: int = 0,
|
| 230 |
+
hidden_activation: str = "Tanh",
|
| 231 |
+
hidden_activation_kwargs: Dict = None,
|
| 232 |
+
complex_mask: bool = True,
|
| 233 |
+
norm_mlp_cls: Type[nn.Module] = NormMLP,
|
| 234 |
+
norm_mlp_kwargs: Dict = None,
|
| 235 |
+
use_freq_weights: bool = True,
|
| 236 |
+
) -> None:
|
| 237 |
+
check_nonzero_bandwidth(band_specs)
|
| 238 |
+
check_no_gap(band_specs)
|
| 239 |
+
|
| 240 |
+
# if cond_dim > 0:
|
| 241 |
+
# raise NotImplementedError
|
| 242 |
+
|
| 243 |
+
super().__init__(
|
| 244 |
+
band_specs=band_specs,
|
| 245 |
+
emb_dim=emb_dim + cond_dim,
|
| 246 |
+
mlp_dim=mlp_dim,
|
| 247 |
+
in_channel=in_channel,
|
| 248 |
+
hidden_activation=hidden_activation,
|
| 249 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 250 |
+
complex_mask=complex_mask,
|
| 251 |
+
norm_mlp_cls=norm_mlp_cls,
|
| 252 |
+
norm_mlp_kwargs=norm_mlp_kwargs,
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
self.n_freq = n_freq
|
| 256 |
+
self.band_specs = band_specs
|
| 257 |
+
self.in_channel = in_channel
|
| 258 |
+
|
| 259 |
+
if freq_weights is not None:
|
| 260 |
+
for i, fw in enumerate(freq_weights):
|
| 261 |
+
self.register_buffer(f"freq_weights/{i}", fw)
|
| 262 |
+
|
| 263 |
+
self.use_freq_weights = use_freq_weights
|
| 264 |
+
else:
|
| 265 |
+
self.use_freq_weights = False
|
| 266 |
+
|
| 267 |
+
self.cond_dim = cond_dim
|
| 268 |
+
|
| 269 |
+
def forward(self, q, cond=None):
|
| 270 |
+
# q = (batch, n_bands, n_time, emb_dim)
|
| 271 |
+
|
| 272 |
+
batch, n_bands, n_time, emb_dim = q.shape
|
| 273 |
+
|
| 274 |
+
if cond is not None:
|
| 275 |
+
print(cond)
|
| 276 |
+
if cond.ndim == 2:
|
| 277 |
+
cond = cond[:, None, None, :].expand(-1, n_bands, n_time, -1)
|
| 278 |
+
elif cond.ndim == 3:
|
| 279 |
+
assert cond.shape[1] == n_time
|
| 280 |
+
else:
|
| 281 |
+
raise ValueError(f"Invalid cond shape: {cond.shape}")
|
| 282 |
+
|
| 283 |
+
q = torch.cat([q, cond], dim=-1)
|
| 284 |
+
elif self.cond_dim > 0:
|
| 285 |
+
cond = torch.ones(
|
| 286 |
+
(batch, n_bands, n_time, self.cond_dim),
|
| 287 |
+
device=q.device,
|
| 288 |
+
dtype=q.dtype,
|
| 289 |
+
)
|
| 290 |
+
q = torch.cat([q, cond], dim=-1)
|
| 291 |
+
else:
|
| 292 |
+
pass
|
| 293 |
+
|
| 294 |
+
mask_list = self.compute_masks(
|
| 295 |
+
q
|
| 296 |
+
) # [n_bands * (batch, in_channel, bandwidth, n_time)]
|
| 297 |
+
|
| 298 |
+
masks = torch.zeros(
|
| 299 |
+
(batch, self.in_channel, self.n_freq, n_time),
|
| 300 |
+
device=q.device,
|
| 301 |
+
dtype=mask_list[0].dtype,
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
for im, mask in enumerate(mask_list):
|
| 305 |
+
fstart, fend = self.band_specs[im]
|
| 306 |
+
if self.use_freq_weights:
|
| 307 |
+
fw = self.get_buffer(f"freq_weights/{im}")[:, None]
|
| 308 |
+
mask = mask * fw
|
| 309 |
+
masks[:, :, fstart:fend, :] += mask
|
| 310 |
+
|
| 311 |
+
return masks
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class MaskEstimationModule(OverlappingMaskEstimationModule):
|
| 315 |
+
def __init__(
|
| 316 |
+
self,
|
| 317 |
+
band_specs: List[Tuple[float, float]],
|
| 318 |
+
emb_dim: int,
|
| 319 |
+
mlp_dim: int,
|
| 320 |
+
in_channel: Optional[int],
|
| 321 |
+
hidden_activation: str = "Tanh",
|
| 322 |
+
hidden_activation_kwargs: Dict = None,
|
| 323 |
+
complex_mask: bool = True,
|
| 324 |
+
**kwargs,
|
| 325 |
+
) -> None:
|
| 326 |
+
check_nonzero_bandwidth(band_specs)
|
| 327 |
+
check_no_gap(band_specs)
|
| 328 |
+
check_no_overlap(band_specs)
|
| 329 |
+
super().__init__(
|
| 330 |
+
in_channel=in_channel,
|
| 331 |
+
band_specs=band_specs,
|
| 332 |
+
freq_weights=None,
|
| 333 |
+
n_freq=None,
|
| 334 |
+
emb_dim=emb_dim,
|
| 335 |
+
mlp_dim=mlp_dim,
|
| 336 |
+
hidden_activation=hidden_activation,
|
| 337 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 338 |
+
complex_mask=complex_mask,
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
def forward(self, q, cond=None):
|
| 342 |
+
# q = (batch, n_bands, n_time, emb_dim)
|
| 343 |
+
|
| 344 |
+
masks = self.compute_masks(
|
| 345 |
+
q
|
| 346 |
+
) # [n_bands * (batch, in_channel, bandwidth, n_time)]
|
| 347 |
+
|
| 348 |
+
# TODO: currently this requires band specs to have no gap and no overlap
|
| 349 |
+
masks = torch.concat(masks, dim=2) # (batch, in_channel, n_freq, n_time)
|
| 350 |
+
|
| 351 |
+
return masks
|
mvsepless/models/bandit/core/model/bsrnn/tfmodel.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
from torch.nn.modules import rnn
|
| 7 |
+
|
| 8 |
+
import torch.backends.cuda
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TimeFrequencyModellingModule(nn.Module):
|
| 12 |
+
def __init__(self) -> None:
|
| 13 |
+
super().__init__()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ResidualRNN(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
emb_dim: int,
|
| 20 |
+
rnn_dim: int,
|
| 21 |
+
bidirectional: bool = True,
|
| 22 |
+
rnn_type: str = "LSTM",
|
| 23 |
+
use_batch_trick: bool = True,
|
| 24 |
+
use_layer_norm: bool = True,
|
| 25 |
+
) -> None:
|
| 26 |
+
# n_group is the size of the 2nd dim
|
| 27 |
+
super().__init__()
|
| 28 |
+
|
| 29 |
+
self.use_layer_norm = use_layer_norm
|
| 30 |
+
if use_layer_norm:
|
| 31 |
+
self.norm = nn.LayerNorm(emb_dim)
|
| 32 |
+
else:
|
| 33 |
+
self.norm = nn.GroupNorm(num_groups=emb_dim, num_channels=emb_dim)
|
| 34 |
+
|
| 35 |
+
self.rnn = rnn.__dict__[rnn_type](
|
| 36 |
+
input_size=emb_dim,
|
| 37 |
+
hidden_size=rnn_dim,
|
| 38 |
+
num_layers=1,
|
| 39 |
+
batch_first=True,
|
| 40 |
+
bidirectional=bidirectional,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
self.fc = nn.Linear(
|
| 44 |
+
in_features=rnn_dim * (2 if bidirectional else 1), out_features=emb_dim
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
self.use_batch_trick = use_batch_trick
|
| 48 |
+
if not self.use_batch_trick:
|
| 49 |
+
warnings.warn("NOT USING BATCH TRICK IS EXTREMELY SLOW!!")
|
| 50 |
+
|
| 51 |
+
def forward(self, z):
|
| 52 |
+
# z = (batch, n_uncrossed, n_across, emb_dim)
|
| 53 |
+
|
| 54 |
+
z0 = torch.clone(z)
|
| 55 |
+
|
| 56 |
+
# print(z.device)
|
| 57 |
+
|
| 58 |
+
if self.use_layer_norm:
|
| 59 |
+
z = self.norm(z) # (batch, n_uncrossed, n_across, emb_dim)
|
| 60 |
+
else:
|
| 61 |
+
z = torch.permute(
|
| 62 |
+
z, (0, 3, 1, 2)
|
| 63 |
+
) # (batch, emb_dim, n_uncrossed, n_across)
|
| 64 |
+
|
| 65 |
+
z = self.norm(z) # (batch, emb_dim, n_uncrossed, n_across)
|
| 66 |
+
|
| 67 |
+
z = torch.permute(
|
| 68 |
+
z, (0, 2, 3, 1)
|
| 69 |
+
) # (batch, n_uncrossed, n_across, emb_dim)
|
| 70 |
+
|
| 71 |
+
batch, n_uncrossed, n_across, emb_dim = z.shape
|
| 72 |
+
|
| 73 |
+
if self.use_batch_trick:
|
| 74 |
+
z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
|
| 75 |
+
|
| 76 |
+
z = self.rnn(z.contiguous())[
|
| 77 |
+
0
|
| 78 |
+
] # (batch * n_uncrossed, n_across, dir_rnn_dim)
|
| 79 |
+
|
| 80 |
+
z = torch.reshape(z, (batch, n_uncrossed, n_across, -1))
|
| 81 |
+
# (batch, n_uncrossed, n_across, dir_rnn_dim)
|
| 82 |
+
else:
|
| 83 |
+
# Note: this is EXTREMELY SLOW
|
| 84 |
+
zlist = []
|
| 85 |
+
for i in range(n_uncrossed):
|
| 86 |
+
zi = self.rnn(z[:, i, :, :])[0] # (batch, n_across, emb_dim)
|
| 87 |
+
zlist.append(zi)
|
| 88 |
+
|
| 89 |
+
z = torch.stack(zlist, dim=1) # (batch, n_uncrossed, n_across, dir_rnn_dim)
|
| 90 |
+
|
| 91 |
+
z = self.fc(z) # (batch, n_uncrossed, n_across, emb_dim)
|
| 92 |
+
|
| 93 |
+
z = z + z0
|
| 94 |
+
|
| 95 |
+
return z
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class SeqBandModellingModule(TimeFrequencyModellingModule):
|
| 99 |
+
def __init__(
|
| 100 |
+
self,
|
| 101 |
+
n_modules: int = 12,
|
| 102 |
+
emb_dim: int = 128,
|
| 103 |
+
rnn_dim: int = 256,
|
| 104 |
+
bidirectional: bool = True,
|
| 105 |
+
rnn_type: str = "LSTM",
|
| 106 |
+
parallel_mode=False,
|
| 107 |
+
) -> None:
|
| 108 |
+
super().__init__()
|
| 109 |
+
self.seqband = nn.ModuleList([])
|
| 110 |
+
|
| 111 |
+
if parallel_mode:
|
| 112 |
+
for _ in range(n_modules):
|
| 113 |
+
self.seqband.append(
|
| 114 |
+
nn.ModuleList(
|
| 115 |
+
[
|
| 116 |
+
ResidualRNN(
|
| 117 |
+
emb_dim=emb_dim,
|
| 118 |
+
rnn_dim=rnn_dim,
|
| 119 |
+
bidirectional=bidirectional,
|
| 120 |
+
rnn_type=rnn_type,
|
| 121 |
+
),
|
| 122 |
+
ResidualRNN(
|
| 123 |
+
emb_dim=emb_dim,
|
| 124 |
+
rnn_dim=rnn_dim,
|
| 125 |
+
bidirectional=bidirectional,
|
| 126 |
+
rnn_type=rnn_type,
|
| 127 |
+
),
|
| 128 |
+
]
|
| 129 |
+
)
|
| 130 |
+
)
|
| 131 |
+
else:
|
| 132 |
+
|
| 133 |
+
for _ in range(2 * n_modules):
|
| 134 |
+
self.seqband.append(
|
| 135 |
+
ResidualRNN(
|
| 136 |
+
emb_dim=emb_dim,
|
| 137 |
+
rnn_dim=rnn_dim,
|
| 138 |
+
bidirectional=bidirectional,
|
| 139 |
+
rnn_type=rnn_type,
|
| 140 |
+
)
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
self.parallel_mode = parallel_mode
|
| 144 |
+
|
| 145 |
+
def forward(self, z):
|
| 146 |
+
# z = (batch, n_bands, n_time, emb_dim)
|
| 147 |
+
|
| 148 |
+
if self.parallel_mode:
|
| 149 |
+
for sbm_pair in self.seqband:
|
| 150 |
+
# z: (batch, n_bands, n_time, emb_dim)
|
| 151 |
+
sbm_t, sbm_f = sbm_pair[0], sbm_pair[1]
|
| 152 |
+
zt = sbm_t(z) # (batch, n_bands, n_time, emb_dim)
|
| 153 |
+
zf = sbm_f(z.transpose(1, 2)) # (batch, n_time, n_bands, emb_dim)
|
| 154 |
+
z = zt + zf.transpose(1, 2)
|
| 155 |
+
else:
|
| 156 |
+
for sbm in self.seqband:
|
| 157 |
+
z = sbm(z)
|
| 158 |
+
z = z.transpose(1, 2)
|
| 159 |
+
|
| 160 |
+
# (batch, n_bands, n_time, emb_dim)
|
| 161 |
+
# --> (batch, n_time, n_bands, emb_dim)
|
| 162 |
+
# OR
|
| 163 |
+
# (batch, n_time, n_bands, emb_dim)
|
| 164 |
+
# --> (batch, n_bands, n_time, emb_dim)
|
| 165 |
+
|
| 166 |
+
q = z
|
| 167 |
+
return q # (batch, n_bands, n_time, emb_dim)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class ResidualTransformer(nn.Module):
|
| 171 |
+
def __init__(
|
| 172 |
+
self,
|
| 173 |
+
emb_dim: int = 128,
|
| 174 |
+
rnn_dim: int = 256,
|
| 175 |
+
bidirectional: bool = True,
|
| 176 |
+
dropout: float = 0.0,
|
| 177 |
+
) -> None:
|
| 178 |
+
# n_group is the size of the 2nd dim
|
| 179 |
+
super().__init__()
|
| 180 |
+
|
| 181 |
+
self.tf = nn.TransformerEncoderLayer(
|
| 182 |
+
d_model=emb_dim, nhead=4, dim_feedforward=rnn_dim, batch_first=True
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
self.is_causal = not bidirectional
|
| 186 |
+
self.dropout = dropout
|
| 187 |
+
|
| 188 |
+
def forward(self, z):
|
| 189 |
+
batch, n_uncrossed, n_across, emb_dim = z.shape
|
| 190 |
+
z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
|
| 191 |
+
z = self.tf(
|
| 192 |
+
z, is_causal=self.is_causal
|
| 193 |
+
) # (batch, n_uncrossed, n_across, emb_dim)
|
| 194 |
+
z = torch.reshape(z, (batch, n_uncrossed, n_across, emb_dim))
|
| 195 |
+
|
| 196 |
+
return z
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class TransformerTimeFreqModule(TimeFrequencyModellingModule):
|
| 200 |
+
def __init__(
|
| 201 |
+
self,
|
| 202 |
+
n_modules: int = 12,
|
| 203 |
+
emb_dim: int = 128,
|
| 204 |
+
rnn_dim: int = 256,
|
| 205 |
+
bidirectional: bool = True,
|
| 206 |
+
dropout: float = 0.0,
|
| 207 |
+
) -> None:
|
| 208 |
+
super().__init__()
|
| 209 |
+
self.norm = nn.LayerNorm(emb_dim)
|
| 210 |
+
self.seqband = nn.ModuleList([])
|
| 211 |
+
|
| 212 |
+
for _ in range(2 * n_modules):
|
| 213 |
+
self.seqband.append(
|
| 214 |
+
ResidualTransformer(
|
| 215 |
+
emb_dim=emb_dim,
|
| 216 |
+
rnn_dim=rnn_dim,
|
| 217 |
+
bidirectional=bidirectional,
|
| 218 |
+
dropout=dropout,
|
| 219 |
+
)
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
def forward(self, z):
|
| 223 |
+
# z = (batch, n_bands, n_time, emb_dim)
|
| 224 |
+
z = self.norm(z) # (batch, n_bands, n_time, emb_dim)
|
| 225 |
+
|
| 226 |
+
for sbm in self.seqband:
|
| 227 |
+
z = sbm(z)
|
| 228 |
+
z = z.transpose(1, 2)
|
| 229 |
+
|
| 230 |
+
# (batch, n_bands, n_time, emb_dim)
|
| 231 |
+
# --> (batch, n_time, n_bands, emb_dim)
|
| 232 |
+
# OR
|
| 233 |
+
# (batch, n_time, n_bands, emb_dim)
|
| 234 |
+
# --> (batch, n_bands, n_time, emb_dim)
|
| 235 |
+
|
| 236 |
+
q = z
|
| 237 |
+
return q # (batch, n_bands, n_time, emb_dim)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
class ResidualConvolution(nn.Module):
|
| 241 |
+
def __init__(
|
| 242 |
+
self,
|
| 243 |
+
emb_dim: int = 128,
|
| 244 |
+
rnn_dim: int = 256,
|
| 245 |
+
bidirectional: bool = True,
|
| 246 |
+
dropout: float = 0.0,
|
| 247 |
+
) -> None:
|
| 248 |
+
# n_group is the size of the 2nd dim
|
| 249 |
+
super().__init__()
|
| 250 |
+
self.norm = nn.InstanceNorm2d(emb_dim, affine=True)
|
| 251 |
+
|
| 252 |
+
self.conv = nn.Sequential(
|
| 253 |
+
nn.Conv2d(
|
| 254 |
+
in_channels=emb_dim,
|
| 255 |
+
out_channels=rnn_dim,
|
| 256 |
+
kernel_size=(3, 3),
|
| 257 |
+
padding="same",
|
| 258 |
+
stride=(1, 1),
|
| 259 |
+
),
|
| 260 |
+
nn.Tanhshrink(),
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
self.is_causal = not bidirectional
|
| 264 |
+
self.dropout = dropout
|
| 265 |
+
|
| 266 |
+
self.fc = nn.Conv2d(
|
| 267 |
+
in_channels=rnn_dim,
|
| 268 |
+
out_channels=emb_dim,
|
| 269 |
+
kernel_size=(1, 1),
|
| 270 |
+
padding="same",
|
| 271 |
+
stride=(1, 1),
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
def forward(self, z):
|
| 275 |
+
# z = (batch, n_uncrossed, n_across, emb_dim)
|
| 276 |
+
|
| 277 |
+
z0 = torch.clone(z)
|
| 278 |
+
|
| 279 |
+
z = self.norm(z) # (batch, n_uncrossed, n_across, emb_dim)
|
| 280 |
+
z = self.conv(z) # (batch, n_uncrossed, n_across, emb_dim)
|
| 281 |
+
z = self.fc(z) # (batch, n_uncrossed, n_across, emb_dim)
|
| 282 |
+
z = z + z0
|
| 283 |
+
|
| 284 |
+
return z
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class ConvolutionalTimeFreqModule(TimeFrequencyModellingModule):
|
| 288 |
+
def __init__(
|
| 289 |
+
self,
|
| 290 |
+
n_modules: int = 12,
|
| 291 |
+
emb_dim: int = 128,
|
| 292 |
+
rnn_dim: int = 256,
|
| 293 |
+
bidirectional: bool = True,
|
| 294 |
+
dropout: float = 0.0,
|
| 295 |
+
) -> None:
|
| 296 |
+
super().__init__()
|
| 297 |
+
self.seqband = torch.jit.script(
|
| 298 |
+
nn.Sequential(
|
| 299 |
+
*[
|
| 300 |
+
ResidualConvolution(
|
| 301 |
+
emb_dim=emb_dim,
|
| 302 |
+
rnn_dim=rnn_dim,
|
| 303 |
+
bidirectional=bidirectional,
|
| 304 |
+
dropout=dropout,
|
| 305 |
+
)
|
| 306 |
+
for _ in range(2 * n_modules)
|
| 307 |
+
]
|
| 308 |
+
)
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
def forward(self, z):
|
| 312 |
+
# z = (batch, n_bands, n_time, emb_dim)
|
| 313 |
+
|
| 314 |
+
z = torch.permute(z, (0, 3, 1, 2)) # (batch, emb_dim, n_bands, n_time)
|
| 315 |
+
|
| 316 |
+
z = self.seqband(z) # (batch, emb_dim, n_bands, n_time)
|
| 317 |
+
|
| 318 |
+
z = torch.permute(z, (0, 2, 3, 1)) # (batch, n_bands, n_time, emb_dim)
|
| 319 |
+
|
| 320 |
+
return z
|
mvsepless/models/bandit/core/model/bsrnn/utils.py
ADDED
|
@@ -0,0 +1,525 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from abc import abstractmethod
|
| 3 |
+
from typing import Any, Callable
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from librosa import hz_to_midi, midi_to_hz
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
from torchaudio import functional as taF
|
| 10 |
+
from spafe.fbanks import bark_fbanks
|
| 11 |
+
from spafe.utils.converters import erb2hz, hz2bark, hz2erb
|
| 12 |
+
from torchaudio.functional.functional import _create_triangular_filterbank
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def band_widths_from_specs(band_specs):
|
| 16 |
+
return [e - i for i, e in band_specs]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def check_nonzero_bandwidth(band_specs):
|
| 20 |
+
# pprint(band_specs)
|
| 21 |
+
for fstart, fend in band_specs:
|
| 22 |
+
if fend - fstart <= 0:
|
| 23 |
+
raise ValueError("Bands cannot be zero-width")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def check_no_overlap(band_specs):
|
| 27 |
+
fend_prev = -1
|
| 28 |
+
for fstart_curr, fend_curr in band_specs:
|
| 29 |
+
if fstart_curr <= fend_prev:
|
| 30 |
+
raise ValueError("Bands cannot overlap")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def check_no_gap(band_specs):
|
| 34 |
+
fstart, _ = band_specs[0]
|
| 35 |
+
assert fstart == 0
|
| 36 |
+
|
| 37 |
+
fend_prev = -1
|
| 38 |
+
for fstart_curr, fend_curr in band_specs:
|
| 39 |
+
if fstart_curr - fend_prev > 1:
|
| 40 |
+
raise ValueError("Bands cannot leave gap")
|
| 41 |
+
fend_prev = fend_curr
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class BandsplitSpecification:
|
| 45 |
+
def __init__(self, nfft: int, fs: int) -> None:
|
| 46 |
+
self.fs = fs
|
| 47 |
+
self.nfft = nfft
|
| 48 |
+
self.nyquist = fs / 2
|
| 49 |
+
self.max_index = nfft // 2 + 1
|
| 50 |
+
|
| 51 |
+
self.split500 = self.hertz_to_index(500)
|
| 52 |
+
self.split1k = self.hertz_to_index(1000)
|
| 53 |
+
self.split2k = self.hertz_to_index(2000)
|
| 54 |
+
self.split4k = self.hertz_to_index(4000)
|
| 55 |
+
self.split8k = self.hertz_to_index(8000)
|
| 56 |
+
self.split16k = self.hertz_to_index(16000)
|
| 57 |
+
self.split20k = self.hertz_to_index(20000)
|
| 58 |
+
|
| 59 |
+
self.above20k = [(self.split20k, self.max_index)]
|
| 60 |
+
self.above16k = [(self.split16k, self.split20k)] + self.above20k
|
| 61 |
+
|
| 62 |
+
def index_to_hertz(self, index: int):
|
| 63 |
+
return index * self.fs / self.nfft
|
| 64 |
+
|
| 65 |
+
def hertz_to_index(self, hz: float, round: bool = True):
|
| 66 |
+
index = hz * self.nfft / self.fs
|
| 67 |
+
|
| 68 |
+
if round:
|
| 69 |
+
index = int(np.round(index))
|
| 70 |
+
|
| 71 |
+
return index
|
| 72 |
+
|
| 73 |
+
def get_band_specs_with_bandwidth(self, start_index, end_index, bandwidth_hz):
|
| 74 |
+
band_specs = []
|
| 75 |
+
lower = start_index
|
| 76 |
+
|
| 77 |
+
while lower < end_index:
|
| 78 |
+
upper = int(np.floor(lower + self.hertz_to_index(bandwidth_hz)))
|
| 79 |
+
upper = min(upper, end_index)
|
| 80 |
+
|
| 81 |
+
band_specs.append((lower, upper))
|
| 82 |
+
lower = upper
|
| 83 |
+
|
| 84 |
+
return band_specs
|
| 85 |
+
|
| 86 |
+
@abstractmethod
|
| 87 |
+
def get_band_specs(self):
|
| 88 |
+
raise NotImplementedError
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class VocalBandsplitSpecification(BandsplitSpecification):
|
| 92 |
+
def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
|
| 93 |
+
super().__init__(nfft=nfft, fs=fs)
|
| 94 |
+
|
| 95 |
+
self.version = version
|
| 96 |
+
|
| 97 |
+
def get_band_specs(self):
|
| 98 |
+
return getattr(self, f"version{self.version}")()
|
| 99 |
+
|
| 100 |
+
@property
|
| 101 |
+
def version1(self):
|
| 102 |
+
return self.get_band_specs_with_bandwidth(
|
| 103 |
+
start_index=0, end_index=self.max_index, bandwidth_hz=1000
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
def version2(self):
|
| 107 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 108 |
+
start_index=0, end_index=self.split16k, bandwidth_hz=1000
|
| 109 |
+
)
|
| 110 |
+
below20k = self.get_band_specs_with_bandwidth(
|
| 111 |
+
start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
return below16k + below20k + self.above20k
|
| 115 |
+
|
| 116 |
+
def version3(self):
|
| 117 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 118 |
+
start_index=0, end_index=self.split8k, bandwidth_hz=1000
|
| 119 |
+
)
|
| 120 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 121 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
return below8k + below16k + self.above16k
|
| 125 |
+
|
| 126 |
+
def version4(self):
|
| 127 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 128 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 129 |
+
)
|
| 130 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 131 |
+
start_index=self.split1k, end_index=self.split8k, bandwidth_hz=1000
|
| 132 |
+
)
|
| 133 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 134 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
return below1k + below8k + below16k + self.above16k
|
| 138 |
+
|
| 139 |
+
def version5(self):
|
| 140 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 141 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 142 |
+
)
|
| 143 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 144 |
+
start_index=self.split1k, end_index=self.split16k, bandwidth_hz=1000
|
| 145 |
+
)
|
| 146 |
+
below20k = self.get_band_specs_with_bandwidth(
|
| 147 |
+
start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
|
| 148 |
+
)
|
| 149 |
+
return below1k + below16k + below20k + self.above20k
|
| 150 |
+
|
| 151 |
+
def version6(self):
|
| 152 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 153 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 154 |
+
)
|
| 155 |
+
below4k = self.get_band_specs_with_bandwidth(
|
| 156 |
+
start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
|
| 157 |
+
)
|
| 158 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 159 |
+
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
|
| 160 |
+
)
|
| 161 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 162 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
|
| 163 |
+
)
|
| 164 |
+
return below1k + below4k + below8k + below16k + self.above16k
|
| 165 |
+
|
| 166 |
+
def version7(self):
|
| 167 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 168 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 169 |
+
)
|
| 170 |
+
below4k = self.get_band_specs_with_bandwidth(
|
| 171 |
+
start_index=self.split1k, end_index=self.split4k, bandwidth_hz=250
|
| 172 |
+
)
|
| 173 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 174 |
+
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
|
| 175 |
+
)
|
| 176 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 177 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
|
| 178 |
+
)
|
| 179 |
+
below20k = self.get_band_specs_with_bandwidth(
|
| 180 |
+
start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
|
| 181 |
+
)
|
| 182 |
+
return below1k + below4k + below8k + below16k + below20k + self.above20k
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class OtherBandsplitSpecification(VocalBandsplitSpecification):
|
| 186 |
+
def __init__(self, nfft: int, fs: int) -> None:
|
| 187 |
+
super().__init__(nfft=nfft, fs=fs, version="7")
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class BassBandsplitSpecification(BandsplitSpecification):
|
| 191 |
+
def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
|
| 192 |
+
super().__init__(nfft=nfft, fs=fs)
|
| 193 |
+
|
| 194 |
+
def get_band_specs(self):
|
| 195 |
+
below500 = self.get_band_specs_with_bandwidth(
|
| 196 |
+
start_index=0, end_index=self.split500, bandwidth_hz=50
|
| 197 |
+
)
|
| 198 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 199 |
+
start_index=self.split500, end_index=self.split1k, bandwidth_hz=100
|
| 200 |
+
)
|
| 201 |
+
below4k = self.get_band_specs_with_bandwidth(
|
| 202 |
+
start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
|
| 203 |
+
)
|
| 204 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 205 |
+
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
|
| 206 |
+
)
|
| 207 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 208 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
|
| 209 |
+
)
|
| 210 |
+
above16k = [(self.split16k, self.max_index)]
|
| 211 |
+
|
| 212 |
+
return below500 + below1k + below4k + below8k + below16k + above16k
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class DrumBandsplitSpecification(BandsplitSpecification):
|
| 216 |
+
def __init__(self, nfft: int, fs: int) -> None:
|
| 217 |
+
super().__init__(nfft=nfft, fs=fs)
|
| 218 |
+
|
| 219 |
+
def get_band_specs(self):
|
| 220 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 221 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=50
|
| 222 |
+
)
|
| 223 |
+
below2k = self.get_band_specs_with_bandwidth(
|
| 224 |
+
start_index=self.split1k, end_index=self.split2k, bandwidth_hz=100
|
| 225 |
+
)
|
| 226 |
+
below4k = self.get_band_specs_with_bandwidth(
|
| 227 |
+
start_index=self.split2k, end_index=self.split4k, bandwidth_hz=250
|
| 228 |
+
)
|
| 229 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 230 |
+
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
|
| 231 |
+
)
|
| 232 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 233 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
|
| 234 |
+
)
|
| 235 |
+
above16k = [(self.split16k, self.max_index)]
|
| 236 |
+
|
| 237 |
+
return below1k + below2k + below4k + below8k + below16k + above16k
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
class PerceptualBandsplitSpecification(BandsplitSpecification):
|
| 241 |
+
def __init__(
|
| 242 |
+
self,
|
| 243 |
+
nfft: int,
|
| 244 |
+
fs: int,
|
| 245 |
+
fbank_fn: Callable[[int, int, float, float, int], torch.Tensor],
|
| 246 |
+
n_bands: int,
|
| 247 |
+
f_min: float = 0.0,
|
| 248 |
+
f_max: float = None,
|
| 249 |
+
) -> None:
|
| 250 |
+
super().__init__(nfft=nfft, fs=fs)
|
| 251 |
+
self.n_bands = n_bands
|
| 252 |
+
if f_max is None:
|
| 253 |
+
f_max = fs / 2
|
| 254 |
+
|
| 255 |
+
self.filterbank = fbank_fn(n_bands, fs, f_min, f_max, self.max_index)
|
| 256 |
+
|
| 257 |
+
weight_per_bin = torch.sum(self.filterbank, dim=0, keepdim=True) # (1, n_freqs)
|
| 258 |
+
normalized_mel_fb = self.filterbank / weight_per_bin # (n_mels, n_freqs)
|
| 259 |
+
|
| 260 |
+
freq_weights = []
|
| 261 |
+
band_specs = []
|
| 262 |
+
for i in range(self.n_bands):
|
| 263 |
+
active_bins = torch.nonzero(self.filterbank[i, :]).squeeze().tolist()
|
| 264 |
+
if isinstance(active_bins, int):
|
| 265 |
+
active_bins = (active_bins, active_bins)
|
| 266 |
+
if len(active_bins) == 0:
|
| 267 |
+
continue
|
| 268 |
+
start_index = active_bins[0]
|
| 269 |
+
end_index = active_bins[-1] + 1
|
| 270 |
+
band_specs.append((start_index, end_index))
|
| 271 |
+
freq_weights.append(normalized_mel_fb[i, start_index:end_index])
|
| 272 |
+
|
| 273 |
+
self.freq_weights = freq_weights
|
| 274 |
+
self.band_specs = band_specs
|
| 275 |
+
|
| 276 |
+
def get_band_specs(self):
|
| 277 |
+
return self.band_specs
|
| 278 |
+
|
| 279 |
+
def get_freq_weights(self):
|
| 280 |
+
return self.freq_weights
|
| 281 |
+
|
| 282 |
+
def save_to_file(self, dir_path: str) -> None:
|
| 283 |
+
|
| 284 |
+
os.makedirs(dir_path, exist_ok=True)
|
| 285 |
+
|
| 286 |
+
import pickle
|
| 287 |
+
|
| 288 |
+
with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f:
|
| 289 |
+
pickle.dump(
|
| 290 |
+
{
|
| 291 |
+
"band_specs": self.band_specs,
|
| 292 |
+
"freq_weights": self.freq_weights,
|
| 293 |
+
"filterbank": self.filterbank,
|
| 294 |
+
},
|
| 295 |
+
f,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs):
|
| 300 |
+
fb = taF.melscale_fbanks(
|
| 301 |
+
n_mels=n_bands,
|
| 302 |
+
sample_rate=fs,
|
| 303 |
+
f_min=f_min,
|
| 304 |
+
f_max=f_max,
|
| 305 |
+
n_freqs=n_freqs,
|
| 306 |
+
).T
|
| 307 |
+
|
| 308 |
+
fb[0, 0] = 1.0
|
| 309 |
+
|
| 310 |
+
return fb
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
class MelBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 314 |
+
def __init__(
|
| 315 |
+
self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
|
| 316 |
+
) -> None:
|
| 317 |
+
super().__init__(
|
| 318 |
+
fbank_fn=mel_filterbank,
|
| 319 |
+
nfft=nfft,
|
| 320 |
+
fs=fs,
|
| 321 |
+
n_bands=n_bands,
|
| 322 |
+
f_min=f_min,
|
| 323 |
+
f_max=f_max,
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs, scale="constant"):
|
| 328 |
+
|
| 329 |
+
nfft = 2 * (n_freqs - 1)
|
| 330 |
+
df = fs / nfft
|
| 331 |
+
# init freqs
|
| 332 |
+
f_max = f_max or fs / 2
|
| 333 |
+
f_min = f_min or 0
|
| 334 |
+
f_min = fs / nfft
|
| 335 |
+
|
| 336 |
+
n_octaves = np.log2(f_max / f_min)
|
| 337 |
+
n_octaves_per_band = n_octaves / n_bands
|
| 338 |
+
bandwidth_mult = np.power(2.0, n_octaves_per_band)
|
| 339 |
+
|
| 340 |
+
low_midi = max(0, hz_to_midi(f_min))
|
| 341 |
+
high_midi = hz_to_midi(f_max)
|
| 342 |
+
midi_points = np.linspace(low_midi, high_midi, n_bands)
|
| 343 |
+
hz_pts = midi_to_hz(midi_points)
|
| 344 |
+
|
| 345 |
+
low_pts = hz_pts / bandwidth_mult
|
| 346 |
+
high_pts = hz_pts * bandwidth_mult
|
| 347 |
+
|
| 348 |
+
low_bins = np.floor(low_pts / df).astype(int)
|
| 349 |
+
high_bins = np.ceil(high_pts / df).astype(int)
|
| 350 |
+
|
| 351 |
+
fb = np.zeros((n_bands, n_freqs))
|
| 352 |
+
|
| 353 |
+
for i in range(n_bands):
|
| 354 |
+
fb[i, low_bins[i] : high_bins[i] + 1] = 1.0
|
| 355 |
+
|
| 356 |
+
fb[0, : low_bins[0]] = 1.0
|
| 357 |
+
fb[-1, high_bins[-1] + 1 :] = 1.0
|
| 358 |
+
|
| 359 |
+
return torch.as_tensor(fb)
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
class MusicalBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 363 |
+
def __init__(
|
| 364 |
+
self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
|
| 365 |
+
) -> None:
|
| 366 |
+
super().__init__(
|
| 367 |
+
fbank_fn=musical_filterbank,
|
| 368 |
+
nfft=nfft,
|
| 369 |
+
fs=fs,
|
| 370 |
+
n_bands=n_bands,
|
| 371 |
+
f_min=f_min,
|
| 372 |
+
f_max=f_max,
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def bark_filterbank(n_bands, fs, f_min, f_max, n_freqs):
|
| 377 |
+
nfft = 2 * (n_freqs - 1)
|
| 378 |
+
fb, _ = bark_fbanks.bark_filter_banks(
|
| 379 |
+
nfilts=n_bands,
|
| 380 |
+
nfft=nfft,
|
| 381 |
+
fs=fs,
|
| 382 |
+
low_freq=f_min,
|
| 383 |
+
high_freq=f_max,
|
| 384 |
+
scale="constant",
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
return torch.as_tensor(fb)
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
class BarkBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 391 |
+
def __init__(
|
| 392 |
+
self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
|
| 393 |
+
) -> None:
|
| 394 |
+
super().__init__(
|
| 395 |
+
fbank_fn=bark_filterbank,
|
| 396 |
+
nfft=nfft,
|
| 397 |
+
fs=fs,
|
| 398 |
+
n_bands=n_bands,
|
| 399 |
+
f_min=f_min,
|
| 400 |
+
f_max=f_max,
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def triangular_bark_filterbank(n_bands, fs, f_min, f_max, n_freqs):
|
| 405 |
+
|
| 406 |
+
all_freqs = torch.linspace(0, fs // 2, n_freqs)
|
| 407 |
+
|
| 408 |
+
# calculate mel freq bins
|
| 409 |
+
m_min = hz2bark(f_min)
|
| 410 |
+
m_max = hz2bark(f_max)
|
| 411 |
+
|
| 412 |
+
m_pts = torch.linspace(m_min, m_max, n_bands + 2)
|
| 413 |
+
f_pts = 600 * torch.sinh(m_pts / 6)
|
| 414 |
+
|
| 415 |
+
# create filterbank
|
| 416 |
+
fb = _create_triangular_filterbank(all_freqs, f_pts)
|
| 417 |
+
|
| 418 |
+
fb = fb.T
|
| 419 |
+
|
| 420 |
+
first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
|
| 421 |
+
first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
|
| 422 |
+
|
| 423 |
+
fb[first_active_band, :first_active_bin] = 1.0
|
| 424 |
+
|
| 425 |
+
return fb
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
class TriangularBarkBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 429 |
+
def __init__(
|
| 430 |
+
self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
|
| 431 |
+
) -> None:
|
| 432 |
+
super().__init__(
|
| 433 |
+
fbank_fn=triangular_bark_filterbank,
|
| 434 |
+
nfft=nfft,
|
| 435 |
+
fs=fs,
|
| 436 |
+
n_bands=n_bands,
|
| 437 |
+
f_min=f_min,
|
| 438 |
+
f_max=f_max,
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
def minibark_filterbank(n_bands, fs, f_min, f_max, n_freqs):
|
| 443 |
+
fb = bark_filterbank(n_bands, fs, f_min, f_max, n_freqs)
|
| 444 |
+
|
| 445 |
+
fb[fb < np.sqrt(0.5)] = 0.0
|
| 446 |
+
|
| 447 |
+
return fb
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
class MiniBarkBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 451 |
+
def __init__(
|
| 452 |
+
self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
|
| 453 |
+
) -> None:
|
| 454 |
+
super().__init__(
|
| 455 |
+
fbank_fn=minibark_filterbank,
|
| 456 |
+
nfft=nfft,
|
| 457 |
+
fs=fs,
|
| 458 |
+
n_bands=n_bands,
|
| 459 |
+
f_min=f_min,
|
| 460 |
+
f_max=f_max,
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
def erb_filterbank(
|
| 465 |
+
n_bands: int,
|
| 466 |
+
fs: int,
|
| 467 |
+
f_min: float,
|
| 468 |
+
f_max: float,
|
| 469 |
+
n_freqs: int,
|
| 470 |
+
) -> Tensor:
|
| 471 |
+
# freq bins
|
| 472 |
+
A = (1000 * np.log(10)) / (24.7 * 4.37)
|
| 473 |
+
all_freqs = torch.linspace(0, fs // 2, n_freqs)
|
| 474 |
+
|
| 475 |
+
# calculate mel freq bins
|
| 476 |
+
m_min = hz2erb(f_min)
|
| 477 |
+
m_max = hz2erb(f_max)
|
| 478 |
+
|
| 479 |
+
m_pts = torch.linspace(m_min, m_max, n_bands + 2)
|
| 480 |
+
f_pts = (torch.pow(10, (m_pts / A)) - 1) / 0.00437
|
| 481 |
+
|
| 482 |
+
# create filterbank
|
| 483 |
+
fb = _create_triangular_filterbank(all_freqs, f_pts)
|
| 484 |
+
|
| 485 |
+
fb = fb.T
|
| 486 |
+
|
| 487 |
+
first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
|
| 488 |
+
first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
|
| 489 |
+
|
| 490 |
+
fb[first_active_band, :first_active_bin] = 1.0
|
| 491 |
+
|
| 492 |
+
return fb
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
class EquivalentRectangularBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 496 |
+
def __init__(
|
| 497 |
+
self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
|
| 498 |
+
) -> None:
|
| 499 |
+
super().__init__(
|
| 500 |
+
fbank_fn=erb_filterbank,
|
| 501 |
+
nfft=nfft,
|
| 502 |
+
fs=fs,
|
| 503 |
+
n_bands=n_bands,
|
| 504 |
+
f_min=f_min,
|
| 505 |
+
f_max=f_max,
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
if __name__ == "__main__":
|
| 510 |
+
import pandas as pd
|
| 511 |
+
|
| 512 |
+
band_defs = []
|
| 513 |
+
|
| 514 |
+
for bands in [VocalBandsplitSpecification]:
|
| 515 |
+
band_name = bands.__name__.replace("BandsplitSpecification", "")
|
| 516 |
+
|
| 517 |
+
mbs = bands(nfft=2048, fs=44100).get_band_specs()
|
| 518 |
+
|
| 519 |
+
for i, (f_min, f_max) in enumerate(mbs):
|
| 520 |
+
band_defs.append(
|
| 521 |
+
{"band": band_name, "band_index": i, "f_min": f_min, "f_max": f_max}
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
df = pd.DataFrame(band_defs)
|
| 525 |
+
df.to_csv("vox7bands.csv", index=False)
|
mvsepless/models/bandit/core/model/bsrnn/wrapper.py
ADDED
|
@@ -0,0 +1,829 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pprint import pprint
|
| 2 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
from .._spectral import _SpectralComponent
|
| 8 |
+
from .utils import (
|
| 9 |
+
BarkBandsplitSpecification,
|
| 10 |
+
BassBandsplitSpecification,
|
| 11 |
+
DrumBandsplitSpecification,
|
| 12 |
+
EquivalentRectangularBandsplitSpecification,
|
| 13 |
+
MelBandsplitSpecification,
|
| 14 |
+
MusicalBandsplitSpecification,
|
| 15 |
+
OtherBandsplitSpecification,
|
| 16 |
+
TriangularBarkBandsplitSpecification,
|
| 17 |
+
VocalBandsplitSpecification,
|
| 18 |
+
)
|
| 19 |
+
from .core import (
|
| 20 |
+
MultiSourceMultiMaskBandSplitCoreConv,
|
| 21 |
+
MultiSourceMultiMaskBandSplitCoreRNN,
|
| 22 |
+
MultiSourceMultiMaskBandSplitCoreTransformer,
|
| 23 |
+
MultiSourceMultiPatchingMaskBandSplitCoreRNN,
|
| 24 |
+
SingleMaskBandsplitCoreRNN,
|
| 25 |
+
SingleMaskBandsplitCoreTransformer,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
import pytorch_lightning as pl
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_band_specs(band_specs, n_fft, fs, n_bands=None):
|
| 32 |
+
if band_specs in ["dnr:speech", "dnr:vox7", "musdb:vocals", "musdb:vox7"]:
|
| 33 |
+
bsm = VocalBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs()
|
| 34 |
+
freq_weights = None
|
| 35 |
+
overlapping_band = False
|
| 36 |
+
elif "tribark" in band_specs:
|
| 37 |
+
assert n_bands is not None
|
| 38 |
+
specs = TriangularBarkBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
|
| 39 |
+
bsm = specs.get_band_specs()
|
| 40 |
+
freq_weights = specs.get_freq_weights()
|
| 41 |
+
overlapping_band = True
|
| 42 |
+
elif "bark" in band_specs:
|
| 43 |
+
assert n_bands is not None
|
| 44 |
+
specs = BarkBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
|
| 45 |
+
bsm = specs.get_band_specs()
|
| 46 |
+
freq_weights = specs.get_freq_weights()
|
| 47 |
+
overlapping_band = True
|
| 48 |
+
elif "erb" in band_specs:
|
| 49 |
+
assert n_bands is not None
|
| 50 |
+
specs = EquivalentRectangularBandsplitSpecification(
|
| 51 |
+
nfft=n_fft, fs=fs, n_bands=n_bands
|
| 52 |
+
)
|
| 53 |
+
bsm = specs.get_band_specs()
|
| 54 |
+
freq_weights = specs.get_freq_weights()
|
| 55 |
+
overlapping_band = True
|
| 56 |
+
elif "musical" in band_specs:
|
| 57 |
+
assert n_bands is not None
|
| 58 |
+
specs = MusicalBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
|
| 59 |
+
bsm = specs.get_band_specs()
|
| 60 |
+
freq_weights = specs.get_freq_weights()
|
| 61 |
+
overlapping_band = True
|
| 62 |
+
elif band_specs == "dnr:mel" or "mel" in band_specs:
|
| 63 |
+
assert n_bands is not None
|
| 64 |
+
specs = MelBandsplitSpecification(nfft=n_fft, fs=fs, n_bands=n_bands)
|
| 65 |
+
bsm = specs.get_band_specs()
|
| 66 |
+
freq_weights = specs.get_freq_weights()
|
| 67 |
+
overlapping_band = True
|
| 68 |
+
else:
|
| 69 |
+
raise NameError
|
| 70 |
+
|
| 71 |
+
return bsm, freq_weights, overlapping_band
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def get_band_specs_map(band_specs_map, n_fft, fs, n_bands=None):
|
| 75 |
+
if band_specs_map == "musdb:all":
|
| 76 |
+
bsm = {
|
| 77 |
+
"vocals": VocalBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
|
| 78 |
+
"drums": DrumBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
|
| 79 |
+
"bass": BassBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
|
| 80 |
+
"other": OtherBandsplitSpecification(nfft=n_fft, fs=fs).get_band_specs(),
|
| 81 |
+
}
|
| 82 |
+
freq_weights = None
|
| 83 |
+
overlapping_band = False
|
| 84 |
+
elif band_specs_map == "dnr:vox7":
|
| 85 |
+
bsm_, freq_weights, overlapping_band = get_band_specs(
|
| 86 |
+
"dnr:speech", n_fft, fs, n_bands
|
| 87 |
+
)
|
| 88 |
+
bsm = {"speech": bsm_, "music": bsm_, "effects": bsm_}
|
| 89 |
+
elif "dnr:vox7:" in band_specs_map:
|
| 90 |
+
stem = band_specs_map.split(":")[-1]
|
| 91 |
+
bsm_, freq_weights, overlapping_band = get_band_specs(
|
| 92 |
+
"dnr:speech", n_fft, fs, n_bands
|
| 93 |
+
)
|
| 94 |
+
bsm = {stem: bsm_}
|
| 95 |
+
else:
|
| 96 |
+
raise NameError
|
| 97 |
+
|
| 98 |
+
return bsm, freq_weights, overlapping_band
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class BandSplitWrapperBase(pl.LightningModule):
|
| 102 |
+
bsrnn: nn.Module
|
| 103 |
+
|
| 104 |
+
def __init__(self, **kwargs):
|
| 105 |
+
super().__init__()
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class SingleMaskMultiSourceBandSplitBase(BandSplitWrapperBase, _SpectralComponent):
|
| 109 |
+
def __init__(
|
| 110 |
+
self,
|
| 111 |
+
band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
|
| 112 |
+
fs: int = 44100,
|
| 113 |
+
n_fft: int = 2048,
|
| 114 |
+
win_length: Optional[int] = 2048,
|
| 115 |
+
hop_length: int = 512,
|
| 116 |
+
window_fn: str = "hann_window",
|
| 117 |
+
wkwargs: Optional[Dict] = None,
|
| 118 |
+
power: Optional[int] = None,
|
| 119 |
+
center: bool = True,
|
| 120 |
+
normalized: bool = True,
|
| 121 |
+
pad_mode: str = "constant",
|
| 122 |
+
onesided: bool = True,
|
| 123 |
+
n_bands: int = None,
|
| 124 |
+
) -> None:
|
| 125 |
+
super().__init__(
|
| 126 |
+
n_fft=n_fft,
|
| 127 |
+
win_length=win_length,
|
| 128 |
+
hop_length=hop_length,
|
| 129 |
+
window_fn=window_fn,
|
| 130 |
+
wkwargs=wkwargs,
|
| 131 |
+
power=power,
|
| 132 |
+
center=center,
|
| 133 |
+
normalized=normalized,
|
| 134 |
+
pad_mode=pad_mode,
|
| 135 |
+
onesided=onesided,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
if isinstance(band_specs_map, str):
|
| 139 |
+
self.band_specs_map, self.freq_weights, self.overlapping_band = (
|
| 140 |
+
get_band_specs_map(band_specs_map, n_fft, fs, n_bands=n_bands)
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
self.stems = list(self.band_specs_map.keys())
|
| 144 |
+
|
| 145 |
+
def forward(self, batch):
|
| 146 |
+
audio = batch["audio"]
|
| 147 |
+
|
| 148 |
+
with torch.no_grad():
|
| 149 |
+
batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in audio}
|
| 150 |
+
|
| 151 |
+
X = batch["spectrogram"]["mixture"]
|
| 152 |
+
length = batch["audio"]["mixture"].shape[-1]
|
| 153 |
+
|
| 154 |
+
output = {"spectrogram": {}, "audio": {}}
|
| 155 |
+
|
| 156 |
+
for stem, bsrnn in self.bsrnn.items():
|
| 157 |
+
S = bsrnn(X)
|
| 158 |
+
s = self.istft(S, length)
|
| 159 |
+
output["spectrogram"][stem] = S
|
| 160 |
+
output["audio"][stem] = s
|
| 161 |
+
|
| 162 |
+
return batch, output
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class MultiMaskMultiSourceBandSplitBase(BandSplitWrapperBase, _SpectralComponent):
|
| 166 |
+
def __init__(
|
| 167 |
+
self,
|
| 168 |
+
stems: List[str],
|
| 169 |
+
band_specs: Union[str, List[Tuple[float, float]]],
|
| 170 |
+
fs: int = 44100,
|
| 171 |
+
n_fft: int = 2048,
|
| 172 |
+
win_length: Optional[int] = 2048,
|
| 173 |
+
hop_length: int = 512,
|
| 174 |
+
window_fn: str = "hann_window",
|
| 175 |
+
wkwargs: Optional[Dict] = None,
|
| 176 |
+
power: Optional[int] = None,
|
| 177 |
+
center: bool = True,
|
| 178 |
+
normalized: bool = True,
|
| 179 |
+
pad_mode: str = "constant",
|
| 180 |
+
onesided: bool = True,
|
| 181 |
+
n_bands: int = None,
|
| 182 |
+
) -> None:
|
| 183 |
+
super().__init__(
|
| 184 |
+
n_fft=n_fft,
|
| 185 |
+
win_length=win_length,
|
| 186 |
+
hop_length=hop_length,
|
| 187 |
+
window_fn=window_fn,
|
| 188 |
+
wkwargs=wkwargs,
|
| 189 |
+
power=power,
|
| 190 |
+
center=center,
|
| 191 |
+
normalized=normalized,
|
| 192 |
+
pad_mode=pad_mode,
|
| 193 |
+
onesided=onesided,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
if isinstance(band_specs, str):
|
| 197 |
+
self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs(
|
| 198 |
+
band_specs, n_fft, fs, n_bands
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
self.stems = stems
|
| 202 |
+
|
| 203 |
+
def forward(self, batch):
|
| 204 |
+
# with torch.no_grad():
|
| 205 |
+
audio = batch["audio"]
|
| 206 |
+
cond = batch.get("condition", None)
|
| 207 |
+
with torch.no_grad():
|
| 208 |
+
batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in audio}
|
| 209 |
+
|
| 210 |
+
X = batch["spectrogram"]["mixture"]
|
| 211 |
+
length = batch["audio"]["mixture"].shape[-1]
|
| 212 |
+
|
| 213 |
+
output = self.bsrnn(X, cond=cond)
|
| 214 |
+
output["audio"] = {}
|
| 215 |
+
|
| 216 |
+
for stem, S in output["spectrogram"].items():
|
| 217 |
+
s = self.istft(S, length)
|
| 218 |
+
output["audio"][stem] = s
|
| 219 |
+
|
| 220 |
+
return batch, output
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class MultiMaskMultiSourceBandSplitBaseSimple(BandSplitWrapperBase, _SpectralComponent):
|
| 224 |
+
def __init__(
|
| 225 |
+
self,
|
| 226 |
+
stems: List[str],
|
| 227 |
+
band_specs: Union[str, List[Tuple[float, float]]],
|
| 228 |
+
fs: int = 44100,
|
| 229 |
+
n_fft: int = 2048,
|
| 230 |
+
win_length: Optional[int] = 2048,
|
| 231 |
+
hop_length: int = 512,
|
| 232 |
+
window_fn: str = "hann_window",
|
| 233 |
+
wkwargs: Optional[Dict] = None,
|
| 234 |
+
power: Optional[int] = None,
|
| 235 |
+
center: bool = True,
|
| 236 |
+
normalized: bool = True,
|
| 237 |
+
pad_mode: str = "constant",
|
| 238 |
+
onesided: bool = True,
|
| 239 |
+
n_bands: int = None,
|
| 240 |
+
) -> None:
|
| 241 |
+
super().__init__(
|
| 242 |
+
n_fft=n_fft,
|
| 243 |
+
win_length=win_length,
|
| 244 |
+
hop_length=hop_length,
|
| 245 |
+
window_fn=window_fn,
|
| 246 |
+
wkwargs=wkwargs,
|
| 247 |
+
power=power,
|
| 248 |
+
center=center,
|
| 249 |
+
normalized=normalized,
|
| 250 |
+
pad_mode=pad_mode,
|
| 251 |
+
onesided=onesided,
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
if isinstance(band_specs, str):
|
| 255 |
+
self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs(
|
| 256 |
+
band_specs, n_fft, fs, n_bands
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
self.stems = stems
|
| 260 |
+
|
| 261 |
+
def forward(self, batch):
|
| 262 |
+
with torch.no_grad():
|
| 263 |
+
X = self.stft(batch)
|
| 264 |
+
length = batch.shape[-1]
|
| 265 |
+
output = self.bsrnn(X, cond=None)
|
| 266 |
+
res = []
|
| 267 |
+
for stem, S in output["spectrogram"].items():
|
| 268 |
+
s = self.istft(S, length)
|
| 269 |
+
res.append(s)
|
| 270 |
+
res = torch.stack(res, dim=1)
|
| 271 |
+
return res
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
class SingleMaskMultiSourceBandSplitRNN(SingleMaskMultiSourceBandSplitBase):
|
| 275 |
+
def __init__(
|
| 276 |
+
self,
|
| 277 |
+
in_channel: int,
|
| 278 |
+
band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
|
| 279 |
+
fs: int = 44100,
|
| 280 |
+
require_no_overlap: bool = False,
|
| 281 |
+
require_no_gap: bool = True,
|
| 282 |
+
normalize_channel_independently: bool = False,
|
| 283 |
+
treat_channel_as_feature: bool = True,
|
| 284 |
+
n_sqm_modules: int = 12,
|
| 285 |
+
emb_dim: int = 128,
|
| 286 |
+
rnn_dim: int = 256,
|
| 287 |
+
bidirectional: bool = True,
|
| 288 |
+
rnn_type: str = "LSTM",
|
| 289 |
+
mlp_dim: int = 512,
|
| 290 |
+
hidden_activation: str = "Tanh",
|
| 291 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 292 |
+
complex_mask: bool = True,
|
| 293 |
+
n_fft: int = 2048,
|
| 294 |
+
win_length: Optional[int] = 2048,
|
| 295 |
+
hop_length: int = 512,
|
| 296 |
+
window_fn: str = "hann_window",
|
| 297 |
+
wkwargs: Optional[Dict] = None,
|
| 298 |
+
power: Optional[int] = None,
|
| 299 |
+
center: bool = True,
|
| 300 |
+
normalized: bool = True,
|
| 301 |
+
pad_mode: str = "constant",
|
| 302 |
+
onesided: bool = True,
|
| 303 |
+
) -> None:
|
| 304 |
+
super().__init__(
|
| 305 |
+
band_specs_map=band_specs_map,
|
| 306 |
+
fs=fs,
|
| 307 |
+
n_fft=n_fft,
|
| 308 |
+
win_length=win_length,
|
| 309 |
+
hop_length=hop_length,
|
| 310 |
+
window_fn=window_fn,
|
| 311 |
+
wkwargs=wkwargs,
|
| 312 |
+
power=power,
|
| 313 |
+
center=center,
|
| 314 |
+
normalized=normalized,
|
| 315 |
+
pad_mode=pad_mode,
|
| 316 |
+
onesided=onesided,
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
self.bsrnn = nn.ModuleDict(
|
| 320 |
+
{
|
| 321 |
+
src: SingleMaskBandsplitCoreRNN(
|
| 322 |
+
band_specs=specs,
|
| 323 |
+
in_channel=in_channel,
|
| 324 |
+
require_no_overlap=require_no_overlap,
|
| 325 |
+
require_no_gap=require_no_gap,
|
| 326 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 327 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 328 |
+
n_sqm_modules=n_sqm_modules,
|
| 329 |
+
emb_dim=emb_dim,
|
| 330 |
+
rnn_dim=rnn_dim,
|
| 331 |
+
bidirectional=bidirectional,
|
| 332 |
+
rnn_type=rnn_type,
|
| 333 |
+
mlp_dim=mlp_dim,
|
| 334 |
+
hidden_activation=hidden_activation,
|
| 335 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 336 |
+
complex_mask=complex_mask,
|
| 337 |
+
)
|
| 338 |
+
for src, specs in self.band_specs_map.items()
|
| 339 |
+
}
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
class SingleMaskMultiSourceBandSplitTransformer(SingleMaskMultiSourceBandSplitBase):
|
| 344 |
+
def __init__(
|
| 345 |
+
self,
|
| 346 |
+
in_channel: int,
|
| 347 |
+
band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
|
| 348 |
+
fs: int = 44100,
|
| 349 |
+
require_no_overlap: bool = False,
|
| 350 |
+
require_no_gap: bool = True,
|
| 351 |
+
normalize_channel_independently: bool = False,
|
| 352 |
+
treat_channel_as_feature: bool = True,
|
| 353 |
+
n_sqm_modules: int = 12,
|
| 354 |
+
emb_dim: int = 128,
|
| 355 |
+
rnn_dim: int = 256,
|
| 356 |
+
bidirectional: bool = True,
|
| 357 |
+
tf_dropout: float = 0.0,
|
| 358 |
+
mlp_dim: int = 512,
|
| 359 |
+
hidden_activation: str = "Tanh",
|
| 360 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 361 |
+
complex_mask: bool = True,
|
| 362 |
+
n_fft: int = 2048,
|
| 363 |
+
win_length: Optional[int] = 2048,
|
| 364 |
+
hop_length: int = 512,
|
| 365 |
+
window_fn: str = "hann_window",
|
| 366 |
+
wkwargs: Optional[Dict] = None,
|
| 367 |
+
power: Optional[int] = None,
|
| 368 |
+
center: bool = True,
|
| 369 |
+
normalized: bool = True,
|
| 370 |
+
pad_mode: str = "constant",
|
| 371 |
+
onesided: bool = True,
|
| 372 |
+
) -> None:
|
| 373 |
+
super().__init__(
|
| 374 |
+
band_specs_map=band_specs_map,
|
| 375 |
+
fs=fs,
|
| 376 |
+
n_fft=n_fft,
|
| 377 |
+
win_length=win_length,
|
| 378 |
+
hop_length=hop_length,
|
| 379 |
+
window_fn=window_fn,
|
| 380 |
+
wkwargs=wkwargs,
|
| 381 |
+
power=power,
|
| 382 |
+
center=center,
|
| 383 |
+
normalized=normalized,
|
| 384 |
+
pad_mode=pad_mode,
|
| 385 |
+
onesided=onesided,
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
self.bsrnn = nn.ModuleDict(
|
| 389 |
+
{
|
| 390 |
+
src: SingleMaskBandsplitCoreTransformer(
|
| 391 |
+
band_specs=specs,
|
| 392 |
+
in_channel=in_channel,
|
| 393 |
+
require_no_overlap=require_no_overlap,
|
| 394 |
+
require_no_gap=require_no_gap,
|
| 395 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 396 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 397 |
+
n_sqm_modules=n_sqm_modules,
|
| 398 |
+
emb_dim=emb_dim,
|
| 399 |
+
rnn_dim=rnn_dim,
|
| 400 |
+
bidirectional=bidirectional,
|
| 401 |
+
tf_dropout=tf_dropout,
|
| 402 |
+
mlp_dim=mlp_dim,
|
| 403 |
+
hidden_activation=hidden_activation,
|
| 404 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 405 |
+
complex_mask=complex_mask,
|
| 406 |
+
)
|
| 407 |
+
for src, specs in self.band_specs_map.items()
|
| 408 |
+
}
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
class MultiMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase):
|
| 413 |
+
def __init__(
|
| 414 |
+
self,
|
| 415 |
+
in_channel: int,
|
| 416 |
+
stems: List[str],
|
| 417 |
+
band_specs: Union[str, List[Tuple[float, float]]],
|
| 418 |
+
fs: int = 44100,
|
| 419 |
+
require_no_overlap: bool = False,
|
| 420 |
+
require_no_gap: bool = True,
|
| 421 |
+
normalize_channel_independently: bool = False,
|
| 422 |
+
treat_channel_as_feature: bool = True,
|
| 423 |
+
n_sqm_modules: int = 12,
|
| 424 |
+
emb_dim: int = 128,
|
| 425 |
+
rnn_dim: int = 256,
|
| 426 |
+
cond_dim: int = 0,
|
| 427 |
+
bidirectional: bool = True,
|
| 428 |
+
rnn_type: str = "LSTM",
|
| 429 |
+
mlp_dim: int = 512,
|
| 430 |
+
hidden_activation: str = "Tanh",
|
| 431 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 432 |
+
complex_mask: bool = True,
|
| 433 |
+
n_fft: int = 2048,
|
| 434 |
+
win_length: Optional[int] = 2048,
|
| 435 |
+
hop_length: int = 512,
|
| 436 |
+
window_fn: str = "hann_window",
|
| 437 |
+
wkwargs: Optional[Dict] = None,
|
| 438 |
+
power: Optional[int] = None,
|
| 439 |
+
center: bool = True,
|
| 440 |
+
normalized: bool = True,
|
| 441 |
+
pad_mode: str = "constant",
|
| 442 |
+
onesided: bool = True,
|
| 443 |
+
n_bands: int = None,
|
| 444 |
+
use_freq_weights: bool = True,
|
| 445 |
+
normalize_input: bool = False,
|
| 446 |
+
mult_add_mask: bool = False,
|
| 447 |
+
freeze_encoder: bool = False,
|
| 448 |
+
) -> None:
|
| 449 |
+
super().__init__(
|
| 450 |
+
stems=stems,
|
| 451 |
+
band_specs=band_specs,
|
| 452 |
+
fs=fs,
|
| 453 |
+
n_fft=n_fft,
|
| 454 |
+
win_length=win_length,
|
| 455 |
+
hop_length=hop_length,
|
| 456 |
+
window_fn=window_fn,
|
| 457 |
+
wkwargs=wkwargs,
|
| 458 |
+
power=power,
|
| 459 |
+
center=center,
|
| 460 |
+
normalized=normalized,
|
| 461 |
+
pad_mode=pad_mode,
|
| 462 |
+
onesided=onesided,
|
| 463 |
+
n_bands=n_bands,
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN(
|
| 467 |
+
stems=stems,
|
| 468 |
+
band_specs=self.band_specs,
|
| 469 |
+
in_channel=in_channel,
|
| 470 |
+
require_no_overlap=require_no_overlap,
|
| 471 |
+
require_no_gap=require_no_gap,
|
| 472 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 473 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 474 |
+
n_sqm_modules=n_sqm_modules,
|
| 475 |
+
emb_dim=emb_dim,
|
| 476 |
+
rnn_dim=rnn_dim,
|
| 477 |
+
bidirectional=bidirectional,
|
| 478 |
+
rnn_type=rnn_type,
|
| 479 |
+
mlp_dim=mlp_dim,
|
| 480 |
+
cond_dim=cond_dim,
|
| 481 |
+
hidden_activation=hidden_activation,
|
| 482 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 483 |
+
complex_mask=complex_mask,
|
| 484 |
+
overlapping_band=self.overlapping_band,
|
| 485 |
+
freq_weights=self.freq_weights,
|
| 486 |
+
n_freq=n_fft // 2 + 1,
|
| 487 |
+
use_freq_weights=use_freq_weights,
|
| 488 |
+
mult_add_mask=mult_add_mask,
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
self.normalize_input = normalize_input
|
| 492 |
+
self.cond_dim = cond_dim
|
| 493 |
+
|
| 494 |
+
if freeze_encoder:
|
| 495 |
+
for param in self.bsrnn.band_split.parameters():
|
| 496 |
+
param.requires_grad = False
|
| 497 |
+
|
| 498 |
+
for param in self.bsrnn.tf_model.parameters():
|
| 499 |
+
param.requires_grad = False
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
class MultiMaskMultiSourceBandSplitRNNSimple(MultiMaskMultiSourceBandSplitBaseSimple):
|
| 503 |
+
def __init__(
|
| 504 |
+
self,
|
| 505 |
+
in_channel: int,
|
| 506 |
+
stems: List[str],
|
| 507 |
+
band_specs: Union[str, List[Tuple[float, float]]],
|
| 508 |
+
fs: int = 44100,
|
| 509 |
+
require_no_overlap: bool = False,
|
| 510 |
+
require_no_gap: bool = True,
|
| 511 |
+
normalize_channel_independently: bool = False,
|
| 512 |
+
treat_channel_as_feature: bool = True,
|
| 513 |
+
n_sqm_modules: int = 12,
|
| 514 |
+
emb_dim: int = 128,
|
| 515 |
+
rnn_dim: int = 256,
|
| 516 |
+
cond_dim: int = 0,
|
| 517 |
+
bidirectional: bool = True,
|
| 518 |
+
rnn_type: str = "LSTM",
|
| 519 |
+
mlp_dim: int = 512,
|
| 520 |
+
hidden_activation: str = "Tanh",
|
| 521 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 522 |
+
complex_mask: bool = True,
|
| 523 |
+
n_fft: int = 2048,
|
| 524 |
+
win_length: Optional[int] = 2048,
|
| 525 |
+
hop_length: int = 512,
|
| 526 |
+
window_fn: str = "hann_window",
|
| 527 |
+
wkwargs: Optional[Dict] = None,
|
| 528 |
+
power: Optional[int] = None,
|
| 529 |
+
center: bool = True,
|
| 530 |
+
normalized: bool = True,
|
| 531 |
+
pad_mode: str = "constant",
|
| 532 |
+
onesided: bool = True,
|
| 533 |
+
n_bands: int = None,
|
| 534 |
+
use_freq_weights: bool = True,
|
| 535 |
+
normalize_input: bool = False,
|
| 536 |
+
mult_add_mask: bool = False,
|
| 537 |
+
freeze_encoder: bool = False,
|
| 538 |
+
) -> None:
|
| 539 |
+
super().__init__(
|
| 540 |
+
stems=stems,
|
| 541 |
+
band_specs=band_specs,
|
| 542 |
+
fs=fs,
|
| 543 |
+
n_fft=n_fft,
|
| 544 |
+
win_length=win_length,
|
| 545 |
+
hop_length=hop_length,
|
| 546 |
+
window_fn=window_fn,
|
| 547 |
+
wkwargs=wkwargs,
|
| 548 |
+
power=power,
|
| 549 |
+
center=center,
|
| 550 |
+
normalized=normalized,
|
| 551 |
+
pad_mode=pad_mode,
|
| 552 |
+
onesided=onesided,
|
| 553 |
+
n_bands=n_bands,
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN(
|
| 557 |
+
stems=stems,
|
| 558 |
+
band_specs=self.band_specs,
|
| 559 |
+
in_channel=in_channel,
|
| 560 |
+
require_no_overlap=require_no_overlap,
|
| 561 |
+
require_no_gap=require_no_gap,
|
| 562 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 563 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 564 |
+
n_sqm_modules=n_sqm_modules,
|
| 565 |
+
emb_dim=emb_dim,
|
| 566 |
+
rnn_dim=rnn_dim,
|
| 567 |
+
bidirectional=bidirectional,
|
| 568 |
+
rnn_type=rnn_type,
|
| 569 |
+
mlp_dim=mlp_dim,
|
| 570 |
+
cond_dim=cond_dim,
|
| 571 |
+
hidden_activation=hidden_activation,
|
| 572 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 573 |
+
complex_mask=complex_mask,
|
| 574 |
+
overlapping_band=self.overlapping_band,
|
| 575 |
+
freq_weights=self.freq_weights,
|
| 576 |
+
n_freq=n_fft // 2 + 1,
|
| 577 |
+
use_freq_weights=use_freq_weights,
|
| 578 |
+
mult_add_mask=mult_add_mask,
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
self.normalize_input = normalize_input
|
| 582 |
+
self.cond_dim = cond_dim
|
| 583 |
+
|
| 584 |
+
if freeze_encoder:
|
| 585 |
+
for param in self.bsrnn.band_split.parameters():
|
| 586 |
+
param.requires_grad = False
|
| 587 |
+
|
| 588 |
+
for param in self.bsrnn.tf_model.parameters():
|
| 589 |
+
param.requires_grad = False
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
class MultiMaskMultiSourceBandSplitTransformer(MultiMaskMultiSourceBandSplitBase):
|
| 593 |
+
def __init__(
|
| 594 |
+
self,
|
| 595 |
+
in_channel: int,
|
| 596 |
+
stems: List[str],
|
| 597 |
+
band_specs: Union[str, List[Tuple[float, float]]],
|
| 598 |
+
fs: int = 44100,
|
| 599 |
+
require_no_overlap: bool = False,
|
| 600 |
+
require_no_gap: bool = True,
|
| 601 |
+
normalize_channel_independently: bool = False,
|
| 602 |
+
treat_channel_as_feature: bool = True,
|
| 603 |
+
n_sqm_modules: int = 12,
|
| 604 |
+
emb_dim: int = 128,
|
| 605 |
+
rnn_dim: int = 256,
|
| 606 |
+
cond_dim: int = 0,
|
| 607 |
+
bidirectional: bool = True,
|
| 608 |
+
rnn_type: str = "LSTM",
|
| 609 |
+
mlp_dim: int = 512,
|
| 610 |
+
hidden_activation: str = "Tanh",
|
| 611 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 612 |
+
complex_mask: bool = True,
|
| 613 |
+
n_fft: int = 2048,
|
| 614 |
+
win_length: Optional[int] = 2048,
|
| 615 |
+
hop_length: int = 512,
|
| 616 |
+
window_fn: str = "hann_window",
|
| 617 |
+
wkwargs: Optional[Dict] = None,
|
| 618 |
+
power: Optional[int] = None,
|
| 619 |
+
center: bool = True,
|
| 620 |
+
normalized: bool = True,
|
| 621 |
+
pad_mode: str = "constant",
|
| 622 |
+
onesided: bool = True,
|
| 623 |
+
n_bands: int = None,
|
| 624 |
+
use_freq_weights: bool = True,
|
| 625 |
+
normalize_input: bool = False,
|
| 626 |
+
mult_add_mask: bool = False,
|
| 627 |
+
) -> None:
|
| 628 |
+
super().__init__(
|
| 629 |
+
stems=stems,
|
| 630 |
+
band_specs=band_specs,
|
| 631 |
+
fs=fs,
|
| 632 |
+
n_fft=n_fft,
|
| 633 |
+
win_length=win_length,
|
| 634 |
+
hop_length=hop_length,
|
| 635 |
+
window_fn=window_fn,
|
| 636 |
+
wkwargs=wkwargs,
|
| 637 |
+
power=power,
|
| 638 |
+
center=center,
|
| 639 |
+
normalized=normalized,
|
| 640 |
+
pad_mode=pad_mode,
|
| 641 |
+
onesided=onesided,
|
| 642 |
+
n_bands=n_bands,
|
| 643 |
+
)
|
| 644 |
+
|
| 645 |
+
self.bsrnn = MultiSourceMultiMaskBandSplitCoreTransformer(
|
| 646 |
+
stems=stems,
|
| 647 |
+
band_specs=self.band_specs,
|
| 648 |
+
in_channel=in_channel,
|
| 649 |
+
require_no_overlap=require_no_overlap,
|
| 650 |
+
require_no_gap=require_no_gap,
|
| 651 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 652 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 653 |
+
n_sqm_modules=n_sqm_modules,
|
| 654 |
+
emb_dim=emb_dim,
|
| 655 |
+
rnn_dim=rnn_dim,
|
| 656 |
+
bidirectional=bidirectional,
|
| 657 |
+
rnn_type=rnn_type,
|
| 658 |
+
mlp_dim=mlp_dim,
|
| 659 |
+
cond_dim=cond_dim,
|
| 660 |
+
hidden_activation=hidden_activation,
|
| 661 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 662 |
+
complex_mask=complex_mask,
|
| 663 |
+
overlapping_band=self.overlapping_band,
|
| 664 |
+
freq_weights=self.freq_weights,
|
| 665 |
+
n_freq=n_fft // 2 + 1,
|
| 666 |
+
use_freq_weights=use_freq_weights,
|
| 667 |
+
mult_add_mask=mult_add_mask,
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
class MultiMaskMultiSourceBandSplitConv(MultiMaskMultiSourceBandSplitBase):
|
| 672 |
+
def __init__(
|
| 673 |
+
self,
|
| 674 |
+
in_channel: int,
|
| 675 |
+
stems: List[str],
|
| 676 |
+
band_specs: Union[str, List[Tuple[float, float]]],
|
| 677 |
+
fs: int = 44100,
|
| 678 |
+
require_no_overlap: bool = False,
|
| 679 |
+
require_no_gap: bool = True,
|
| 680 |
+
normalize_channel_independently: bool = False,
|
| 681 |
+
treat_channel_as_feature: bool = True,
|
| 682 |
+
n_sqm_modules: int = 12,
|
| 683 |
+
emb_dim: int = 128,
|
| 684 |
+
rnn_dim: int = 256,
|
| 685 |
+
cond_dim: int = 0,
|
| 686 |
+
bidirectional: bool = True,
|
| 687 |
+
rnn_type: str = "LSTM",
|
| 688 |
+
mlp_dim: int = 512,
|
| 689 |
+
hidden_activation: str = "Tanh",
|
| 690 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 691 |
+
complex_mask: bool = True,
|
| 692 |
+
n_fft: int = 2048,
|
| 693 |
+
win_length: Optional[int] = 2048,
|
| 694 |
+
hop_length: int = 512,
|
| 695 |
+
window_fn: str = "hann_window",
|
| 696 |
+
wkwargs: Optional[Dict] = None,
|
| 697 |
+
power: Optional[int] = None,
|
| 698 |
+
center: bool = True,
|
| 699 |
+
normalized: bool = True,
|
| 700 |
+
pad_mode: str = "constant",
|
| 701 |
+
onesided: bool = True,
|
| 702 |
+
n_bands: int = None,
|
| 703 |
+
use_freq_weights: bool = True,
|
| 704 |
+
normalize_input: bool = False,
|
| 705 |
+
mult_add_mask: bool = False,
|
| 706 |
+
) -> None:
|
| 707 |
+
super().__init__(
|
| 708 |
+
stems=stems,
|
| 709 |
+
band_specs=band_specs,
|
| 710 |
+
fs=fs,
|
| 711 |
+
n_fft=n_fft,
|
| 712 |
+
win_length=win_length,
|
| 713 |
+
hop_length=hop_length,
|
| 714 |
+
window_fn=window_fn,
|
| 715 |
+
wkwargs=wkwargs,
|
| 716 |
+
power=power,
|
| 717 |
+
center=center,
|
| 718 |
+
normalized=normalized,
|
| 719 |
+
pad_mode=pad_mode,
|
| 720 |
+
onesided=onesided,
|
| 721 |
+
n_bands=n_bands,
|
| 722 |
+
)
|
| 723 |
+
|
| 724 |
+
self.bsrnn = MultiSourceMultiMaskBandSplitCoreConv(
|
| 725 |
+
stems=stems,
|
| 726 |
+
band_specs=self.band_specs,
|
| 727 |
+
in_channel=in_channel,
|
| 728 |
+
require_no_overlap=require_no_overlap,
|
| 729 |
+
require_no_gap=require_no_gap,
|
| 730 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 731 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 732 |
+
n_sqm_modules=n_sqm_modules,
|
| 733 |
+
emb_dim=emb_dim,
|
| 734 |
+
rnn_dim=rnn_dim,
|
| 735 |
+
bidirectional=bidirectional,
|
| 736 |
+
rnn_type=rnn_type,
|
| 737 |
+
mlp_dim=mlp_dim,
|
| 738 |
+
cond_dim=cond_dim,
|
| 739 |
+
hidden_activation=hidden_activation,
|
| 740 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 741 |
+
complex_mask=complex_mask,
|
| 742 |
+
overlapping_band=self.overlapping_band,
|
| 743 |
+
freq_weights=self.freq_weights,
|
| 744 |
+
n_freq=n_fft // 2 + 1,
|
| 745 |
+
use_freq_weights=use_freq_weights,
|
| 746 |
+
mult_add_mask=mult_add_mask,
|
| 747 |
+
)
|
| 748 |
+
|
| 749 |
+
|
| 750 |
+
class PatchingMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase):
|
| 751 |
+
def __init__(
|
| 752 |
+
self,
|
| 753 |
+
in_channel: int,
|
| 754 |
+
stems: List[str],
|
| 755 |
+
band_specs: Union[str, List[Tuple[float, float]]],
|
| 756 |
+
kernel_norm_mlp_version: int = 1,
|
| 757 |
+
mask_kernel_freq: int = 3,
|
| 758 |
+
mask_kernel_time: int = 3,
|
| 759 |
+
conv_kernel_freq: int = 1,
|
| 760 |
+
conv_kernel_time: int = 1,
|
| 761 |
+
fs: int = 44100,
|
| 762 |
+
require_no_overlap: bool = False,
|
| 763 |
+
require_no_gap: bool = True,
|
| 764 |
+
normalize_channel_independently: bool = False,
|
| 765 |
+
treat_channel_as_feature: bool = True,
|
| 766 |
+
n_sqm_modules: int = 12,
|
| 767 |
+
emb_dim: int = 128,
|
| 768 |
+
rnn_dim: int = 256,
|
| 769 |
+
bidirectional: bool = True,
|
| 770 |
+
rnn_type: str = "LSTM",
|
| 771 |
+
mlp_dim: int = 512,
|
| 772 |
+
hidden_activation: str = "Tanh",
|
| 773 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 774 |
+
complex_mask: bool = True,
|
| 775 |
+
n_fft: int = 2048,
|
| 776 |
+
win_length: Optional[int] = 2048,
|
| 777 |
+
hop_length: int = 512,
|
| 778 |
+
window_fn: str = "hann_window",
|
| 779 |
+
wkwargs: Optional[Dict] = None,
|
| 780 |
+
power: Optional[int] = None,
|
| 781 |
+
center: bool = True,
|
| 782 |
+
normalized: bool = True,
|
| 783 |
+
pad_mode: str = "constant",
|
| 784 |
+
onesided: bool = True,
|
| 785 |
+
n_bands: int = None,
|
| 786 |
+
) -> None:
|
| 787 |
+
super().__init__(
|
| 788 |
+
stems=stems,
|
| 789 |
+
band_specs=band_specs,
|
| 790 |
+
fs=fs,
|
| 791 |
+
n_fft=n_fft,
|
| 792 |
+
win_length=win_length,
|
| 793 |
+
hop_length=hop_length,
|
| 794 |
+
window_fn=window_fn,
|
| 795 |
+
wkwargs=wkwargs,
|
| 796 |
+
power=power,
|
| 797 |
+
center=center,
|
| 798 |
+
normalized=normalized,
|
| 799 |
+
pad_mode=pad_mode,
|
| 800 |
+
onesided=onesided,
|
| 801 |
+
n_bands=n_bands,
|
| 802 |
+
)
|
| 803 |
+
|
| 804 |
+
self.bsrnn = MultiSourceMultiPatchingMaskBandSplitCoreRNN(
|
| 805 |
+
stems=stems,
|
| 806 |
+
band_specs=self.band_specs,
|
| 807 |
+
in_channel=in_channel,
|
| 808 |
+
require_no_overlap=require_no_overlap,
|
| 809 |
+
require_no_gap=require_no_gap,
|
| 810 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 811 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 812 |
+
n_sqm_modules=n_sqm_modules,
|
| 813 |
+
emb_dim=emb_dim,
|
| 814 |
+
rnn_dim=rnn_dim,
|
| 815 |
+
bidirectional=bidirectional,
|
| 816 |
+
rnn_type=rnn_type,
|
| 817 |
+
mlp_dim=mlp_dim,
|
| 818 |
+
hidden_activation=hidden_activation,
|
| 819 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 820 |
+
complex_mask=complex_mask,
|
| 821 |
+
overlapping_band=self.overlapping_band,
|
| 822 |
+
freq_weights=self.freq_weights,
|
| 823 |
+
n_freq=n_fft // 2 + 1,
|
| 824 |
+
mask_kernel_freq=mask_kernel_freq,
|
| 825 |
+
mask_kernel_time=mask_kernel_time,
|
| 826 |
+
conv_kernel_freq=conv_kernel_freq,
|
| 827 |
+
conv_kernel_time=conv_kernel_time,
|
| 828 |
+
kernel_norm_mlp_version=kernel_norm_mlp_version,
|
| 829 |
+
)
|
mvsepless/models/bandit/core/utils/__init__.py
ADDED
|
File without changes
|
mvsepless/models/bandit/core/utils/audio.py
ADDED
|
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
|
| 3 |
+
from tqdm.auto import tqdm
|
| 4 |
+
from typing import Callable, Dict, List, Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@torch.jit.script
|
| 13 |
+
def merge(
|
| 14 |
+
combined: torch.Tensor,
|
| 15 |
+
original_batch_size: int,
|
| 16 |
+
n_channel: int,
|
| 17 |
+
n_chunks: int,
|
| 18 |
+
chunk_size: int,
|
| 19 |
+
):
|
| 20 |
+
combined = torch.reshape(
|
| 21 |
+
combined, (original_batch_size, n_chunks, n_channel, chunk_size)
|
| 22 |
+
)
|
| 23 |
+
combined = torch.permute(combined, (0, 2, 3, 1)).reshape(
|
| 24 |
+
original_batch_size * n_channel, chunk_size, n_chunks
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
return combined
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@torch.jit.script
|
| 31 |
+
def unfold(
|
| 32 |
+
padded_audio: torch.Tensor,
|
| 33 |
+
original_batch_size: int,
|
| 34 |
+
n_channel: int,
|
| 35 |
+
chunk_size: int,
|
| 36 |
+
hop_size: int,
|
| 37 |
+
) -> torch.Tensor:
|
| 38 |
+
|
| 39 |
+
unfolded_input = F.unfold(
|
| 40 |
+
padded_audio[:, :, None, :], kernel_size=(1, chunk_size), stride=(1, hop_size)
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
_, _, n_chunks = unfolded_input.shape
|
| 44 |
+
unfolded_input = unfolded_input.view(
|
| 45 |
+
original_batch_size, n_channel, chunk_size, n_chunks
|
| 46 |
+
)
|
| 47 |
+
unfolded_input = torch.permute(unfolded_input, (0, 3, 1, 2)).reshape(
|
| 48 |
+
original_batch_size * n_chunks, n_channel, chunk_size
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
return unfolded_input
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@torch.jit.script
|
| 55 |
+
# @torch.compile
|
| 56 |
+
def merge_chunks_all(
|
| 57 |
+
combined: torch.Tensor,
|
| 58 |
+
original_batch_size: int,
|
| 59 |
+
n_channel: int,
|
| 60 |
+
n_samples: int,
|
| 61 |
+
n_padded_samples: int,
|
| 62 |
+
n_chunks: int,
|
| 63 |
+
chunk_size: int,
|
| 64 |
+
hop_size: int,
|
| 65 |
+
edge_frame_pad_sizes: Tuple[int, int],
|
| 66 |
+
standard_window: torch.Tensor,
|
| 67 |
+
first_window: torch.Tensor,
|
| 68 |
+
last_window: torch.Tensor,
|
| 69 |
+
):
|
| 70 |
+
combined = merge(combined, original_batch_size, n_channel, n_chunks, chunk_size)
|
| 71 |
+
|
| 72 |
+
combined = combined * standard_window[:, None].to(combined.device)
|
| 73 |
+
|
| 74 |
+
combined = F.fold(
|
| 75 |
+
combined.to(torch.float32),
|
| 76 |
+
output_size=(1, n_padded_samples),
|
| 77 |
+
kernel_size=(1, chunk_size),
|
| 78 |
+
stride=(1, hop_size),
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
combined = combined.view(original_batch_size, n_channel, n_padded_samples)
|
| 82 |
+
|
| 83 |
+
pad_front, pad_back = edge_frame_pad_sizes
|
| 84 |
+
combined = combined[..., pad_front:-pad_back]
|
| 85 |
+
|
| 86 |
+
combined = combined[..., :n_samples]
|
| 87 |
+
|
| 88 |
+
return combined
|
| 89 |
+
|
| 90 |
+
# @torch.jit.script
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def merge_chunks_edge(
|
| 94 |
+
combined: torch.Tensor,
|
| 95 |
+
original_batch_size: int,
|
| 96 |
+
n_channel: int,
|
| 97 |
+
n_samples: int,
|
| 98 |
+
n_padded_samples: int,
|
| 99 |
+
n_chunks: int,
|
| 100 |
+
chunk_size: int,
|
| 101 |
+
hop_size: int,
|
| 102 |
+
edge_frame_pad_sizes: Tuple[int, int],
|
| 103 |
+
standard_window: torch.Tensor,
|
| 104 |
+
first_window: torch.Tensor,
|
| 105 |
+
last_window: torch.Tensor,
|
| 106 |
+
):
|
| 107 |
+
combined = merge(combined, original_batch_size, n_channel, n_chunks, chunk_size)
|
| 108 |
+
|
| 109 |
+
combined[..., 0] = combined[..., 0] * first_window
|
| 110 |
+
combined[..., -1] = combined[..., -1] * last_window
|
| 111 |
+
combined[..., 1:-1] = combined[..., 1:-1] * standard_window[:, None]
|
| 112 |
+
|
| 113 |
+
combined = F.fold(
|
| 114 |
+
combined,
|
| 115 |
+
output_size=(1, n_padded_samples),
|
| 116 |
+
kernel_size=(1, chunk_size),
|
| 117 |
+
stride=(1, hop_size),
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
combined = combined.view(original_batch_size, n_channel, n_padded_samples)
|
| 121 |
+
|
| 122 |
+
combined = combined[..., :n_samples]
|
| 123 |
+
|
| 124 |
+
return combined
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class BaseFader(nn.Module):
|
| 128 |
+
def __init__(
|
| 129 |
+
self,
|
| 130 |
+
chunk_size_second: float,
|
| 131 |
+
hop_size_second: float,
|
| 132 |
+
fs: int,
|
| 133 |
+
fade_edge_frames: bool,
|
| 134 |
+
batch_size: int,
|
| 135 |
+
) -> None:
|
| 136 |
+
super().__init__()
|
| 137 |
+
|
| 138 |
+
self.chunk_size = int(chunk_size_second * fs)
|
| 139 |
+
self.hop_size = int(hop_size_second * fs)
|
| 140 |
+
self.overlap_size = self.chunk_size - self.hop_size
|
| 141 |
+
self.fade_edge_frames = fade_edge_frames
|
| 142 |
+
self.batch_size = batch_size
|
| 143 |
+
|
| 144 |
+
# @torch.jit.script
|
| 145 |
+
def prepare(self, audio):
|
| 146 |
+
|
| 147 |
+
if self.fade_edge_frames:
|
| 148 |
+
audio = F.pad(audio, self.edge_frame_pad_sizes, mode="reflect")
|
| 149 |
+
|
| 150 |
+
n_samples = audio.shape[-1]
|
| 151 |
+
n_chunks = int(np.ceil((n_samples - self.chunk_size) / self.hop_size) + 1)
|
| 152 |
+
|
| 153 |
+
padded_size = (n_chunks - 1) * self.hop_size + self.chunk_size
|
| 154 |
+
pad_size = padded_size - n_samples
|
| 155 |
+
|
| 156 |
+
padded_audio = F.pad(audio, (0, pad_size))
|
| 157 |
+
|
| 158 |
+
return padded_audio, n_chunks
|
| 159 |
+
|
| 160 |
+
def forward(
|
| 161 |
+
self,
|
| 162 |
+
audio: torch.Tensor,
|
| 163 |
+
model_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]],
|
| 164 |
+
):
|
| 165 |
+
|
| 166 |
+
original_dtype = audio.dtype
|
| 167 |
+
original_device = audio.device
|
| 168 |
+
|
| 169 |
+
audio = audio.to("cpu")
|
| 170 |
+
|
| 171 |
+
original_batch_size, n_channel, n_samples = audio.shape
|
| 172 |
+
padded_audio, n_chunks = self.prepare(audio)
|
| 173 |
+
del audio
|
| 174 |
+
n_padded_samples = padded_audio.shape[-1]
|
| 175 |
+
|
| 176 |
+
if n_channel > 1:
|
| 177 |
+
padded_audio = padded_audio.view(
|
| 178 |
+
original_batch_size * n_channel, 1, n_padded_samples
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
unfolded_input = unfold(
|
| 182 |
+
padded_audio, original_batch_size, n_channel, self.chunk_size, self.hop_size
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
n_total_chunks, n_channel, chunk_size = unfolded_input.shape
|
| 186 |
+
|
| 187 |
+
n_batch = np.ceil(n_total_chunks / self.batch_size).astype(int)
|
| 188 |
+
|
| 189 |
+
chunks_in = [
|
| 190 |
+
unfolded_input[b * self.batch_size : (b + 1) * self.batch_size, ...].clone()
|
| 191 |
+
for b in range(n_batch)
|
| 192 |
+
]
|
| 193 |
+
|
| 194 |
+
all_chunks_out = defaultdict(
|
| 195 |
+
lambda: torch.zeros_like(unfolded_input, device="cpu")
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
# for b, cin in enumerate(tqdm(chunks_in)):
|
| 199 |
+
for b, cin in enumerate(chunks_in):
|
| 200 |
+
if torch.allclose(cin, torch.tensor(0.0)):
|
| 201 |
+
del cin
|
| 202 |
+
continue
|
| 203 |
+
|
| 204 |
+
chunks_out = model_fn(cin.to(original_device))
|
| 205 |
+
del cin
|
| 206 |
+
for s, c in chunks_out.items():
|
| 207 |
+
all_chunks_out[s][
|
| 208 |
+
b * self.batch_size : (b + 1) * self.batch_size, ...
|
| 209 |
+
] = c.cpu()
|
| 210 |
+
del chunks_out
|
| 211 |
+
|
| 212 |
+
del unfolded_input
|
| 213 |
+
del padded_audio
|
| 214 |
+
|
| 215 |
+
if self.fade_edge_frames:
|
| 216 |
+
fn = merge_chunks_all
|
| 217 |
+
else:
|
| 218 |
+
fn = merge_chunks_edge
|
| 219 |
+
outputs = {}
|
| 220 |
+
|
| 221 |
+
torch.cuda.empty_cache()
|
| 222 |
+
|
| 223 |
+
for s, c in all_chunks_out.items():
|
| 224 |
+
combined: torch.Tensor = fn(
|
| 225 |
+
c,
|
| 226 |
+
original_batch_size,
|
| 227 |
+
n_channel,
|
| 228 |
+
n_samples,
|
| 229 |
+
n_padded_samples,
|
| 230 |
+
n_chunks,
|
| 231 |
+
self.chunk_size,
|
| 232 |
+
self.hop_size,
|
| 233 |
+
self.edge_frame_pad_sizes,
|
| 234 |
+
self.standard_window,
|
| 235 |
+
self.__dict__.get("first_window", self.standard_window),
|
| 236 |
+
self.__dict__.get("last_window", self.standard_window),
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
outputs[s] = combined.to(dtype=original_dtype, device=original_device)
|
| 240 |
+
|
| 241 |
+
return {"audio": outputs}
|
| 242 |
+
|
| 243 |
+
#
|
| 244 |
+
# def old_forward(
|
| 245 |
+
# self,
|
| 246 |
+
# audio: torch.Tensor,
|
| 247 |
+
# model_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]],
|
| 248 |
+
# ):
|
| 249 |
+
#
|
| 250 |
+
# n_samples = audio.shape[-1]
|
| 251 |
+
# original_batch_size = audio.shape[0]
|
| 252 |
+
#
|
| 253 |
+
# padded_audio, n_chunks = self.prepare(audio)
|
| 254 |
+
#
|
| 255 |
+
# ndim = padded_audio.ndim
|
| 256 |
+
# broadcaster = [1 for _ in range(ndim - 1)] + [self.chunk_size]
|
| 257 |
+
#
|
| 258 |
+
# outputs = defaultdict(
|
| 259 |
+
# lambda: torch.zeros_like(
|
| 260 |
+
# padded_audio, device=audio.device, dtype=torch.float64
|
| 261 |
+
# )
|
| 262 |
+
# )
|
| 263 |
+
#
|
| 264 |
+
# all_chunks_out = []
|
| 265 |
+
# len_chunks_in = []
|
| 266 |
+
#
|
| 267 |
+
# batch_size_ = int(self.batch_size // original_batch_size)
|
| 268 |
+
# for b in range(int(np.ceil(n_chunks / batch_size_))):
|
| 269 |
+
# chunks_in = []
|
| 270 |
+
# for j in range(batch_size_):
|
| 271 |
+
# i = b * batch_size_ + j
|
| 272 |
+
# if i == n_chunks:
|
| 273 |
+
# break
|
| 274 |
+
#
|
| 275 |
+
# start = i * hop_size
|
| 276 |
+
# end = start + self.chunk_size
|
| 277 |
+
# chunk_in = padded_audio[..., start:end]
|
| 278 |
+
# chunks_in.append(chunk_in)
|
| 279 |
+
#
|
| 280 |
+
# chunks_in = torch.concat(chunks_in, dim=0)
|
| 281 |
+
# chunks_out = model_fn(chunks_in)
|
| 282 |
+
# all_chunks_out.append(chunks_out)
|
| 283 |
+
# len_chunks_in.append(len(chunks_in))
|
| 284 |
+
#
|
| 285 |
+
# for b, (chunks_out, lci) in enumerate(
|
| 286 |
+
# zip(all_chunks_out, len_chunks_in)
|
| 287 |
+
# ):
|
| 288 |
+
# for stem in chunks_out:
|
| 289 |
+
# for j in range(lci // original_batch_size):
|
| 290 |
+
# i = b * batch_size_ + j
|
| 291 |
+
#
|
| 292 |
+
# if self.fade_edge_frames:
|
| 293 |
+
# window = self.standard_window
|
| 294 |
+
# else:
|
| 295 |
+
# if i == 0:
|
| 296 |
+
# window = self.first_window
|
| 297 |
+
# elif i == n_chunks - 1:
|
| 298 |
+
# window = self.last_window
|
| 299 |
+
# else:
|
| 300 |
+
# window = self.standard_window
|
| 301 |
+
#
|
| 302 |
+
# start = i * hop_size
|
| 303 |
+
# end = start + self.chunk_size
|
| 304 |
+
#
|
| 305 |
+
# chunk_out = chunks_out[stem][j * original_batch_size: (j + 1) * original_batch_size,
|
| 306 |
+
# ...]
|
| 307 |
+
# contrib = window.view(*broadcaster) * chunk_out
|
| 308 |
+
# outputs[stem][..., start:end] = (
|
| 309 |
+
# outputs[stem][..., start:end] + contrib
|
| 310 |
+
# )
|
| 311 |
+
#
|
| 312 |
+
# if self.fade_edge_frames:
|
| 313 |
+
# pad_front, pad_back = self.edge_frame_pad_sizes
|
| 314 |
+
# outputs = {k: v[..., pad_front:-pad_back] for k, v in
|
| 315 |
+
# outputs.items()}
|
| 316 |
+
#
|
| 317 |
+
# outputs = {k: v[..., :n_samples].to(audio.dtype) for k, v in
|
| 318 |
+
# outputs.items()}
|
| 319 |
+
#
|
| 320 |
+
# return {
|
| 321 |
+
# "audio": outputs
|
| 322 |
+
# }
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
class LinearFader(BaseFader):
|
| 326 |
+
def __init__(
|
| 327 |
+
self,
|
| 328 |
+
chunk_size_second: float,
|
| 329 |
+
hop_size_second: float,
|
| 330 |
+
fs: int,
|
| 331 |
+
fade_edge_frames: bool = False,
|
| 332 |
+
batch_size: int = 1,
|
| 333 |
+
) -> None:
|
| 334 |
+
|
| 335 |
+
assert hop_size_second >= chunk_size_second / 2
|
| 336 |
+
|
| 337 |
+
super().__init__(
|
| 338 |
+
chunk_size_second=chunk_size_second,
|
| 339 |
+
hop_size_second=hop_size_second,
|
| 340 |
+
fs=fs,
|
| 341 |
+
fade_edge_frames=fade_edge_frames,
|
| 342 |
+
batch_size=batch_size,
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
in_fade = torch.linspace(0.0, 1.0, self.overlap_size + 1)[:-1]
|
| 346 |
+
out_fade = torch.linspace(1.0, 0.0, self.overlap_size + 1)[1:]
|
| 347 |
+
center_ones = torch.ones(self.chunk_size - 2 * self.overlap_size)
|
| 348 |
+
inout_ones = torch.ones(self.overlap_size)
|
| 349 |
+
|
| 350 |
+
# using nn.Parameters allows lightning to take care of devices for us
|
| 351 |
+
self.register_buffer(
|
| 352 |
+
"standard_window", torch.concat([in_fade, center_ones, out_fade])
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
self.fade_edge_frames = fade_edge_frames
|
| 356 |
+
self.edge_frame_pad_size = (self.overlap_size, self.overlap_size)
|
| 357 |
+
|
| 358 |
+
if not self.fade_edge_frames:
|
| 359 |
+
self.first_window = nn.Parameter(
|
| 360 |
+
torch.concat([inout_ones, center_ones, out_fade]), requires_grad=False
|
| 361 |
+
)
|
| 362 |
+
self.last_window = nn.Parameter(
|
| 363 |
+
torch.concat([in_fade, center_ones, inout_ones]), requires_grad=False
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
class OverlapAddFader(BaseFader):
|
| 368 |
+
def __init__(
|
| 369 |
+
self,
|
| 370 |
+
window_type: str,
|
| 371 |
+
chunk_size_second: float,
|
| 372 |
+
hop_size_second: float,
|
| 373 |
+
fs: int,
|
| 374 |
+
batch_size: int = 1,
|
| 375 |
+
) -> None:
|
| 376 |
+
assert (chunk_size_second / hop_size_second) % 2 == 0
|
| 377 |
+
assert int(chunk_size_second * fs) % 2 == 0
|
| 378 |
+
|
| 379 |
+
super().__init__(
|
| 380 |
+
chunk_size_second=chunk_size_second,
|
| 381 |
+
hop_size_second=hop_size_second,
|
| 382 |
+
fs=fs,
|
| 383 |
+
fade_edge_frames=True,
|
| 384 |
+
batch_size=batch_size,
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
self.hop_multiplier = self.chunk_size / (2 * self.hop_size)
|
| 388 |
+
# print(f"hop multiplier: {self.hop_multiplier}")
|
| 389 |
+
|
| 390 |
+
self.edge_frame_pad_sizes = (2 * self.overlap_size, 2 * self.overlap_size)
|
| 391 |
+
|
| 392 |
+
self.register_buffer(
|
| 393 |
+
"standard_window",
|
| 394 |
+
torch.windows.__dict__[window_type](
|
| 395 |
+
self.chunk_size,
|
| 396 |
+
sym=False, # dtype=torch.float64
|
| 397 |
+
)
|
| 398 |
+
/ self.hop_multiplier,
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
if __name__ == "__main__":
|
| 403 |
+
import torchaudio as ta
|
| 404 |
+
|
| 405 |
+
fs = 44100
|
| 406 |
+
ola = OverlapAddFader("hann", 6.0, 1.0, fs, batch_size=16)
|
| 407 |
+
audio_, _ = ta.load(
|
| 408 |
+
"$DATA_ROOT/MUSDB18/HQ/canonical/test/BKS - Too " "Much/vocals.wav"
|
| 409 |
+
)
|
| 410 |
+
audio_ = audio_[None, ...]
|
| 411 |
+
out = ola(audio_, lambda x: {"stem": x})["audio"]["stem"]
|
| 412 |
+
print(torch.allclose(out, audio_))
|
mvsepless/models/bandit/model_from_config.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os.path
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
import yaml
|
| 6 |
+
from ml_collections import ConfigDict
|
| 7 |
+
|
| 8 |
+
torch.set_float32_matmul_precision("medium")
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_model(
|
| 12 |
+
config_path,
|
| 13 |
+
weights_path,
|
| 14 |
+
device,
|
| 15 |
+
):
|
| 16 |
+
from .core.model import MultiMaskMultiSourceBandSplitRNNSimple
|
| 17 |
+
|
| 18 |
+
f = open(config_path)
|
| 19 |
+
config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
|
| 20 |
+
f.close()
|
| 21 |
+
|
| 22 |
+
model = MultiMaskMultiSourceBandSplitRNNSimple(**config.model)
|
| 23 |
+
d = torch.load(code_path + "model_bandit_plus_dnr_sdr_11.47.chpt")
|
| 24 |
+
model.load_state_dict(d)
|
| 25 |
+
model.to(device)
|
| 26 |
+
return model, config
|
mvsepless/models/bandit_v2/bandit.py
ADDED
|
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torchaudio as ta
|
| 5 |
+
from torch import nn
|
| 6 |
+
import pytorch_lightning as pl
|
| 7 |
+
|
| 8 |
+
from .bandsplit import BandSplitModule
|
| 9 |
+
from .maskestim import OverlappingMaskEstimationModule
|
| 10 |
+
from .tfmodel import SeqBandModellingModule
|
| 11 |
+
from .utils import MusicalBandsplitSpecification
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class BaseEndToEndModule(pl.LightningModule):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
) -> None:
|
| 18 |
+
super().__init__()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class BaseBandit(BaseEndToEndModule):
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
in_channels: int,
|
| 25 |
+
fs: int,
|
| 26 |
+
band_type: str = "musical",
|
| 27 |
+
n_bands: int = 64,
|
| 28 |
+
require_no_overlap: bool = False,
|
| 29 |
+
require_no_gap: bool = True,
|
| 30 |
+
normalize_channel_independently: bool = False,
|
| 31 |
+
treat_channel_as_feature: bool = True,
|
| 32 |
+
n_sqm_modules: int = 12,
|
| 33 |
+
emb_dim: int = 128,
|
| 34 |
+
rnn_dim: int = 256,
|
| 35 |
+
bidirectional: bool = True,
|
| 36 |
+
rnn_type: str = "LSTM",
|
| 37 |
+
n_fft: int = 2048,
|
| 38 |
+
win_length: Optional[int] = 2048,
|
| 39 |
+
hop_length: int = 512,
|
| 40 |
+
window_fn: str = "hann_window",
|
| 41 |
+
wkwargs: Optional[Dict] = None,
|
| 42 |
+
power: Optional[int] = None,
|
| 43 |
+
center: bool = True,
|
| 44 |
+
normalized: bool = True,
|
| 45 |
+
pad_mode: str = "constant",
|
| 46 |
+
onesided: bool = True,
|
| 47 |
+
):
|
| 48 |
+
super().__init__()
|
| 49 |
+
|
| 50 |
+
self.in_channels = in_channels
|
| 51 |
+
|
| 52 |
+
self.instantitate_spectral(
|
| 53 |
+
n_fft=n_fft,
|
| 54 |
+
win_length=win_length,
|
| 55 |
+
hop_length=hop_length,
|
| 56 |
+
window_fn=window_fn,
|
| 57 |
+
wkwargs=wkwargs,
|
| 58 |
+
power=power,
|
| 59 |
+
normalized=normalized,
|
| 60 |
+
center=center,
|
| 61 |
+
pad_mode=pad_mode,
|
| 62 |
+
onesided=onesided,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
self.instantiate_bandsplit(
|
| 66 |
+
in_channels=in_channels,
|
| 67 |
+
band_type=band_type,
|
| 68 |
+
n_bands=n_bands,
|
| 69 |
+
require_no_overlap=require_no_overlap,
|
| 70 |
+
require_no_gap=require_no_gap,
|
| 71 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 72 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 73 |
+
emb_dim=emb_dim,
|
| 74 |
+
n_fft=n_fft,
|
| 75 |
+
fs=fs,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
self.instantiate_tf_modelling(
|
| 79 |
+
n_sqm_modules=n_sqm_modules,
|
| 80 |
+
emb_dim=emb_dim,
|
| 81 |
+
rnn_dim=rnn_dim,
|
| 82 |
+
bidirectional=bidirectional,
|
| 83 |
+
rnn_type=rnn_type,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
def instantitate_spectral(
|
| 87 |
+
self,
|
| 88 |
+
n_fft: int = 2048,
|
| 89 |
+
win_length: Optional[int] = 2048,
|
| 90 |
+
hop_length: int = 512,
|
| 91 |
+
window_fn: str = "hann_window",
|
| 92 |
+
wkwargs: Optional[Dict] = None,
|
| 93 |
+
power: Optional[int] = None,
|
| 94 |
+
normalized: bool = True,
|
| 95 |
+
center: bool = True,
|
| 96 |
+
pad_mode: str = "constant",
|
| 97 |
+
onesided: bool = True,
|
| 98 |
+
):
|
| 99 |
+
assert power is None
|
| 100 |
+
|
| 101 |
+
window_fn = torch.__dict__[window_fn]
|
| 102 |
+
|
| 103 |
+
self.stft = ta.transforms.Spectrogram(
|
| 104 |
+
n_fft=n_fft,
|
| 105 |
+
win_length=win_length,
|
| 106 |
+
hop_length=hop_length,
|
| 107 |
+
pad_mode=pad_mode,
|
| 108 |
+
pad=0,
|
| 109 |
+
window_fn=window_fn,
|
| 110 |
+
wkwargs=wkwargs,
|
| 111 |
+
power=power,
|
| 112 |
+
normalized=normalized,
|
| 113 |
+
center=center,
|
| 114 |
+
onesided=onesided,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
self.istft = ta.transforms.InverseSpectrogram(
|
| 118 |
+
n_fft=n_fft,
|
| 119 |
+
win_length=win_length,
|
| 120 |
+
hop_length=hop_length,
|
| 121 |
+
pad_mode=pad_mode,
|
| 122 |
+
pad=0,
|
| 123 |
+
window_fn=window_fn,
|
| 124 |
+
wkwargs=wkwargs,
|
| 125 |
+
normalized=normalized,
|
| 126 |
+
center=center,
|
| 127 |
+
onesided=onesided,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
def instantiate_bandsplit(
|
| 131 |
+
self,
|
| 132 |
+
in_channels: int,
|
| 133 |
+
band_type: str = "musical",
|
| 134 |
+
n_bands: int = 64,
|
| 135 |
+
require_no_overlap: bool = False,
|
| 136 |
+
require_no_gap: bool = True,
|
| 137 |
+
normalize_channel_independently: bool = False,
|
| 138 |
+
treat_channel_as_feature: bool = True,
|
| 139 |
+
emb_dim: int = 128,
|
| 140 |
+
n_fft: int = 2048,
|
| 141 |
+
fs: int = 44100,
|
| 142 |
+
):
|
| 143 |
+
assert band_type == "musical"
|
| 144 |
+
|
| 145 |
+
self.band_specs = MusicalBandsplitSpecification(
|
| 146 |
+
nfft=n_fft, fs=fs, n_bands=n_bands
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
self.band_split = BandSplitModule(
|
| 150 |
+
in_channels=in_channels,
|
| 151 |
+
band_specs=self.band_specs.get_band_specs(),
|
| 152 |
+
require_no_overlap=require_no_overlap,
|
| 153 |
+
require_no_gap=require_no_gap,
|
| 154 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 155 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 156 |
+
emb_dim=emb_dim,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
def instantiate_tf_modelling(
|
| 160 |
+
self,
|
| 161 |
+
n_sqm_modules: int = 12,
|
| 162 |
+
emb_dim: int = 128,
|
| 163 |
+
rnn_dim: int = 256,
|
| 164 |
+
bidirectional: bool = True,
|
| 165 |
+
rnn_type: str = "LSTM",
|
| 166 |
+
):
|
| 167 |
+
try:
|
| 168 |
+
self.tf_model = torch.compile(
|
| 169 |
+
SeqBandModellingModule(
|
| 170 |
+
n_modules=n_sqm_modules,
|
| 171 |
+
emb_dim=emb_dim,
|
| 172 |
+
rnn_dim=rnn_dim,
|
| 173 |
+
bidirectional=bidirectional,
|
| 174 |
+
rnn_type=rnn_type,
|
| 175 |
+
),
|
| 176 |
+
disable=True,
|
| 177 |
+
)
|
| 178 |
+
except Exception as e:
|
| 179 |
+
self.tf_model = SeqBandModellingModule(
|
| 180 |
+
n_modules=n_sqm_modules,
|
| 181 |
+
emb_dim=emb_dim,
|
| 182 |
+
rnn_dim=rnn_dim,
|
| 183 |
+
bidirectional=bidirectional,
|
| 184 |
+
rnn_type=rnn_type,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
def mask(self, x, m):
|
| 188 |
+
return x * m
|
| 189 |
+
|
| 190 |
+
def forward(self, batch, mode="train"):
|
| 191 |
+
# Model takes mono as input we give stereo, so we do process of each channel independently
|
| 192 |
+
init_shape = batch.shape
|
| 193 |
+
if not isinstance(batch, dict):
|
| 194 |
+
mono = batch.view(-1, 1, batch.shape[-1])
|
| 195 |
+
batch = {"mixture": {"audio": mono}}
|
| 196 |
+
|
| 197 |
+
with torch.no_grad():
|
| 198 |
+
mixture = batch["mixture"]["audio"]
|
| 199 |
+
|
| 200 |
+
x = self.stft(mixture)
|
| 201 |
+
batch["mixture"]["spectrogram"] = x
|
| 202 |
+
|
| 203 |
+
if "sources" in batch.keys():
|
| 204 |
+
for stem in batch["sources"].keys():
|
| 205 |
+
s = batch["sources"][stem]["audio"]
|
| 206 |
+
s = self.stft(s)
|
| 207 |
+
batch["sources"][stem]["spectrogram"] = s
|
| 208 |
+
|
| 209 |
+
batch = self.separate(batch)
|
| 210 |
+
|
| 211 |
+
if 1:
|
| 212 |
+
b = []
|
| 213 |
+
for s in self.stems:
|
| 214 |
+
# We need to obtain stereo again
|
| 215 |
+
r = batch["estimates"][s]["audio"].view(
|
| 216 |
+
-1, init_shape[1], init_shape[2]
|
| 217 |
+
)
|
| 218 |
+
b.append(r)
|
| 219 |
+
# And we need to return back tensor and not independent stems
|
| 220 |
+
batch = torch.stack(b, dim=1)
|
| 221 |
+
return batch
|
| 222 |
+
|
| 223 |
+
def encode(self, batch):
|
| 224 |
+
x = batch["mixture"]["spectrogram"]
|
| 225 |
+
length = batch["mixture"]["audio"].shape[-1]
|
| 226 |
+
|
| 227 |
+
z = self.band_split(x) # (batch, emb_dim, n_band, n_time)
|
| 228 |
+
q = self.tf_model(z) # (batch, emb_dim, n_band, n_time)
|
| 229 |
+
|
| 230 |
+
return x, q, length
|
| 231 |
+
|
| 232 |
+
def separate(self, batch):
|
| 233 |
+
raise NotImplementedError
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class Bandit(BaseBandit):
|
| 237 |
+
def __init__(
|
| 238 |
+
self,
|
| 239 |
+
in_channels: int,
|
| 240 |
+
stems: List[str],
|
| 241 |
+
band_type: str = "musical",
|
| 242 |
+
n_bands: int = 64,
|
| 243 |
+
require_no_overlap: bool = False,
|
| 244 |
+
require_no_gap: bool = True,
|
| 245 |
+
normalize_channel_independently: bool = False,
|
| 246 |
+
treat_channel_as_feature: bool = True,
|
| 247 |
+
n_sqm_modules: int = 12,
|
| 248 |
+
emb_dim: int = 128,
|
| 249 |
+
rnn_dim: int = 256,
|
| 250 |
+
bidirectional: bool = True,
|
| 251 |
+
rnn_type: str = "LSTM",
|
| 252 |
+
mlp_dim: int = 512,
|
| 253 |
+
hidden_activation: str = "Tanh",
|
| 254 |
+
hidden_activation_kwargs: Dict | None = None,
|
| 255 |
+
complex_mask: bool = True,
|
| 256 |
+
use_freq_weights: bool = True,
|
| 257 |
+
n_fft: int = 2048,
|
| 258 |
+
win_length: int | None = 2048,
|
| 259 |
+
hop_length: int = 512,
|
| 260 |
+
window_fn: str = "hann_window",
|
| 261 |
+
wkwargs: Dict | None = None,
|
| 262 |
+
power: int | None = None,
|
| 263 |
+
center: bool = True,
|
| 264 |
+
normalized: bool = True,
|
| 265 |
+
pad_mode: str = "constant",
|
| 266 |
+
onesided: bool = True,
|
| 267 |
+
fs: int = 44100,
|
| 268 |
+
stft_precisions="32",
|
| 269 |
+
bandsplit_precisions="bf16",
|
| 270 |
+
tf_model_precisions="bf16",
|
| 271 |
+
mask_estim_precisions="bf16",
|
| 272 |
+
):
|
| 273 |
+
super().__init__(
|
| 274 |
+
in_channels=in_channels,
|
| 275 |
+
band_type=band_type,
|
| 276 |
+
n_bands=n_bands,
|
| 277 |
+
require_no_overlap=require_no_overlap,
|
| 278 |
+
require_no_gap=require_no_gap,
|
| 279 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 280 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 281 |
+
n_sqm_modules=n_sqm_modules,
|
| 282 |
+
emb_dim=emb_dim,
|
| 283 |
+
rnn_dim=rnn_dim,
|
| 284 |
+
bidirectional=bidirectional,
|
| 285 |
+
rnn_type=rnn_type,
|
| 286 |
+
n_fft=n_fft,
|
| 287 |
+
win_length=win_length,
|
| 288 |
+
hop_length=hop_length,
|
| 289 |
+
window_fn=window_fn,
|
| 290 |
+
wkwargs=wkwargs,
|
| 291 |
+
power=power,
|
| 292 |
+
center=center,
|
| 293 |
+
normalized=normalized,
|
| 294 |
+
pad_mode=pad_mode,
|
| 295 |
+
onesided=onesided,
|
| 296 |
+
fs=fs,
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
self.stems = stems
|
| 300 |
+
|
| 301 |
+
self.instantiate_mask_estim(
|
| 302 |
+
in_channels=in_channels,
|
| 303 |
+
stems=stems,
|
| 304 |
+
emb_dim=emb_dim,
|
| 305 |
+
mlp_dim=mlp_dim,
|
| 306 |
+
hidden_activation=hidden_activation,
|
| 307 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 308 |
+
complex_mask=complex_mask,
|
| 309 |
+
n_freq=n_fft // 2 + 1,
|
| 310 |
+
use_freq_weights=use_freq_weights,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
def instantiate_mask_estim(
|
| 314 |
+
self,
|
| 315 |
+
in_channels: int,
|
| 316 |
+
stems: List[str],
|
| 317 |
+
emb_dim: int,
|
| 318 |
+
mlp_dim: int,
|
| 319 |
+
hidden_activation: str,
|
| 320 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 321 |
+
complex_mask: bool = True,
|
| 322 |
+
n_freq: Optional[int] = None,
|
| 323 |
+
use_freq_weights: bool = False,
|
| 324 |
+
):
|
| 325 |
+
if hidden_activation_kwargs is None:
|
| 326 |
+
hidden_activation_kwargs = {}
|
| 327 |
+
|
| 328 |
+
assert n_freq is not None
|
| 329 |
+
|
| 330 |
+
self.mask_estim = nn.ModuleDict(
|
| 331 |
+
{
|
| 332 |
+
stem: OverlappingMaskEstimationModule(
|
| 333 |
+
band_specs=self.band_specs.get_band_specs(),
|
| 334 |
+
freq_weights=self.band_specs.get_freq_weights(),
|
| 335 |
+
n_freq=n_freq,
|
| 336 |
+
emb_dim=emb_dim,
|
| 337 |
+
mlp_dim=mlp_dim,
|
| 338 |
+
in_channels=in_channels,
|
| 339 |
+
hidden_activation=hidden_activation,
|
| 340 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 341 |
+
complex_mask=complex_mask,
|
| 342 |
+
use_freq_weights=use_freq_weights,
|
| 343 |
+
)
|
| 344 |
+
for stem in stems
|
| 345 |
+
}
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
def separate(self, batch):
|
| 349 |
+
batch["estimates"] = {}
|
| 350 |
+
|
| 351 |
+
x, q, length = self.encode(batch)
|
| 352 |
+
|
| 353 |
+
for stem, mem in self.mask_estim.items():
|
| 354 |
+
m = mem(q)
|
| 355 |
+
|
| 356 |
+
s = self.mask(x, m.to(x.dtype))
|
| 357 |
+
s = torch.reshape(s, x.shape)
|
| 358 |
+
batch["estimates"][stem] = {
|
| 359 |
+
"audio": self.istft(s, length),
|
| 360 |
+
"spectrogram": s,
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
return batch
|
mvsepless/models/bandit_v2/bandsplit.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.utils.checkpoint import checkpoint_sequential
|
| 6 |
+
|
| 7 |
+
from .utils import (
|
| 8 |
+
band_widths_from_specs,
|
| 9 |
+
check_no_gap,
|
| 10 |
+
check_no_overlap,
|
| 11 |
+
check_nonzero_bandwidth,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class NormFC(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
emb_dim: int,
|
| 19 |
+
bandwidth: int,
|
| 20 |
+
in_channels: int,
|
| 21 |
+
normalize_channel_independently: bool = False,
|
| 22 |
+
treat_channel_as_feature: bool = True,
|
| 23 |
+
) -> None:
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
if not treat_channel_as_feature:
|
| 27 |
+
raise NotImplementedError
|
| 28 |
+
|
| 29 |
+
self.treat_channel_as_feature = treat_channel_as_feature
|
| 30 |
+
|
| 31 |
+
if normalize_channel_independently:
|
| 32 |
+
raise NotImplementedError
|
| 33 |
+
|
| 34 |
+
reim = 2
|
| 35 |
+
|
| 36 |
+
norm = nn.LayerNorm(in_channels * bandwidth * reim)
|
| 37 |
+
|
| 38 |
+
fc_in = bandwidth * reim
|
| 39 |
+
|
| 40 |
+
if treat_channel_as_feature:
|
| 41 |
+
fc_in *= in_channels
|
| 42 |
+
else:
|
| 43 |
+
assert emb_dim % in_channels == 0
|
| 44 |
+
emb_dim = emb_dim // in_channels
|
| 45 |
+
|
| 46 |
+
fc = nn.Linear(fc_in, emb_dim)
|
| 47 |
+
|
| 48 |
+
self.combined = nn.Sequential(norm, fc)
|
| 49 |
+
|
| 50 |
+
def forward(self, xb):
|
| 51 |
+
return checkpoint_sequential(self.combined, 1, xb, use_reentrant=False)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class BandSplitModule(nn.Module):
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
band_specs: List[Tuple[float, float]],
|
| 58 |
+
emb_dim: int,
|
| 59 |
+
in_channels: int,
|
| 60 |
+
require_no_overlap: bool = False,
|
| 61 |
+
require_no_gap: bool = True,
|
| 62 |
+
normalize_channel_independently: bool = False,
|
| 63 |
+
treat_channel_as_feature: bool = True,
|
| 64 |
+
) -> None:
|
| 65 |
+
super().__init__()
|
| 66 |
+
|
| 67 |
+
check_nonzero_bandwidth(band_specs)
|
| 68 |
+
|
| 69 |
+
if require_no_gap:
|
| 70 |
+
check_no_gap(band_specs)
|
| 71 |
+
|
| 72 |
+
if require_no_overlap:
|
| 73 |
+
check_no_overlap(band_specs)
|
| 74 |
+
|
| 75 |
+
self.band_specs = band_specs
|
| 76 |
+
# list of [fstart, fend) in index.
|
| 77 |
+
# Note that fend is exclusive.
|
| 78 |
+
self.band_widths = band_widths_from_specs(band_specs)
|
| 79 |
+
self.n_bands = len(band_specs)
|
| 80 |
+
self.emb_dim = emb_dim
|
| 81 |
+
|
| 82 |
+
try:
|
| 83 |
+
self.norm_fc_modules = nn.ModuleList(
|
| 84 |
+
[ # type: ignore
|
| 85 |
+
torch.compile(
|
| 86 |
+
NormFC(
|
| 87 |
+
emb_dim=emb_dim,
|
| 88 |
+
bandwidth=bw,
|
| 89 |
+
in_channels=in_channels,
|
| 90 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 91 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 92 |
+
),
|
| 93 |
+
disable=True,
|
| 94 |
+
)
|
| 95 |
+
for bw in self.band_widths
|
| 96 |
+
]
|
| 97 |
+
)
|
| 98 |
+
except Exception as e:
|
| 99 |
+
self.norm_fc_modules = nn.ModuleList(
|
| 100 |
+
[ # type: ignore
|
| 101 |
+
NormFC(
|
| 102 |
+
emb_dim=emb_dim,
|
| 103 |
+
bandwidth=bw,
|
| 104 |
+
in_channels=in_channels,
|
| 105 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 106 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 107 |
+
)
|
| 108 |
+
for bw in self.band_widths
|
| 109 |
+
]
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
def forward(self, x: torch.Tensor):
|
| 113 |
+
# x = complex spectrogram (batch, in_chan, n_freq, n_time)
|
| 114 |
+
|
| 115 |
+
batch, in_chan, band_width, n_time = x.shape
|
| 116 |
+
|
| 117 |
+
z = torch.zeros(
|
| 118 |
+
size=(batch, self.n_bands, n_time, self.emb_dim), device=x.device
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
x = torch.permute(x, (0, 3, 1, 2)).contiguous()
|
| 122 |
+
|
| 123 |
+
for i, nfm in enumerate(self.norm_fc_modules):
|
| 124 |
+
fstart, fend = self.band_specs[i]
|
| 125 |
+
xb = x[:, :, :, fstart:fend]
|
| 126 |
+
xb = torch.view_as_real(xb)
|
| 127 |
+
xb = torch.reshape(xb, (batch, n_time, -1))
|
| 128 |
+
z[:, i, :, :] = nfm(xb)
|
| 129 |
+
|
| 130 |
+
return z
|
mvsepless/models/bandit_v2/film.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class FiLM(nn.Module):
|
| 6 |
+
def __init__(self):
|
| 7 |
+
super().__init__()
|
| 8 |
+
|
| 9 |
+
def forward(self, x, gamma, beta):
|
| 10 |
+
return gamma * x + beta
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class BTFBroadcastedFiLM(nn.Module):
|
| 14 |
+
def __init__(self):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.film = FiLM()
|
| 17 |
+
|
| 18 |
+
def forward(self, x, gamma, beta):
|
| 19 |
+
|
| 20 |
+
gamma = gamma[None, None, None, :]
|
| 21 |
+
beta = beta[None, None, None, :]
|
| 22 |
+
|
| 23 |
+
return self.film(x, gamma, beta)
|
mvsepless/models/bandit_v2/maskestim.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Optional, Tuple, Type
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn.modules import activation
|
| 6 |
+
from torch.utils.checkpoint import checkpoint_sequential
|
| 7 |
+
|
| 8 |
+
from .utils import (
|
| 9 |
+
band_widths_from_specs,
|
| 10 |
+
check_no_gap,
|
| 11 |
+
check_no_overlap,
|
| 12 |
+
check_nonzero_bandwidth,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class BaseNormMLP(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
emb_dim: int,
|
| 20 |
+
mlp_dim: int,
|
| 21 |
+
bandwidth: int,
|
| 22 |
+
in_channels: Optional[int],
|
| 23 |
+
hidden_activation: str = "Tanh",
|
| 24 |
+
hidden_activation_kwargs=None,
|
| 25 |
+
complex_mask: bool = True,
|
| 26 |
+
):
|
| 27 |
+
super().__init__()
|
| 28 |
+
if hidden_activation_kwargs is None:
|
| 29 |
+
hidden_activation_kwargs = {}
|
| 30 |
+
self.hidden_activation_kwargs = hidden_activation_kwargs
|
| 31 |
+
self.norm = nn.LayerNorm(emb_dim)
|
| 32 |
+
self.hidden = nn.Sequential(
|
| 33 |
+
nn.Linear(in_features=emb_dim, out_features=mlp_dim),
|
| 34 |
+
activation.__dict__[hidden_activation](**self.hidden_activation_kwargs),
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
self.bandwidth = bandwidth
|
| 38 |
+
self.in_channels = in_channels
|
| 39 |
+
|
| 40 |
+
self.complex_mask = complex_mask
|
| 41 |
+
self.reim = 2 if complex_mask else 1
|
| 42 |
+
self.glu_mult = 2
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class NormMLP(BaseNormMLP):
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
emb_dim: int,
|
| 49 |
+
mlp_dim: int,
|
| 50 |
+
bandwidth: int,
|
| 51 |
+
in_channels: Optional[int],
|
| 52 |
+
hidden_activation: str = "Tanh",
|
| 53 |
+
hidden_activation_kwargs=None,
|
| 54 |
+
complex_mask: bool = True,
|
| 55 |
+
) -> None:
|
| 56 |
+
super().__init__(
|
| 57 |
+
emb_dim=emb_dim,
|
| 58 |
+
mlp_dim=mlp_dim,
|
| 59 |
+
bandwidth=bandwidth,
|
| 60 |
+
in_channels=in_channels,
|
| 61 |
+
hidden_activation=hidden_activation,
|
| 62 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 63 |
+
complex_mask=complex_mask,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
self.output = nn.Sequential(
|
| 67 |
+
nn.Linear(
|
| 68 |
+
in_features=mlp_dim,
|
| 69 |
+
out_features=bandwidth * in_channels * self.reim * 2,
|
| 70 |
+
),
|
| 71 |
+
nn.GLU(dim=-1),
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
try:
|
| 75 |
+
self.combined = torch.compile(
|
| 76 |
+
nn.Sequential(self.norm, self.hidden, self.output), disable=True
|
| 77 |
+
)
|
| 78 |
+
except Exception as e:
|
| 79 |
+
self.combined = nn.Sequential(self.norm, self.hidden, self.output)
|
| 80 |
+
|
| 81 |
+
def reshape_output(self, mb):
|
| 82 |
+
# print(mb.shape)
|
| 83 |
+
batch, n_time, _ = mb.shape
|
| 84 |
+
if self.complex_mask:
|
| 85 |
+
mb = mb.reshape(
|
| 86 |
+
batch, n_time, self.in_channels, self.bandwidth, self.reim
|
| 87 |
+
).contiguous()
|
| 88 |
+
# print(mb.shape)
|
| 89 |
+
mb = torch.view_as_complex(mb) # (batch, n_time, in_channels, bandwidth)
|
| 90 |
+
else:
|
| 91 |
+
mb = mb.reshape(batch, n_time, self.in_channels, self.bandwidth)
|
| 92 |
+
|
| 93 |
+
mb = torch.permute(mb, (0, 2, 3, 1)) # (batch, in_channels, bandwidth, n_time)
|
| 94 |
+
|
| 95 |
+
return mb
|
| 96 |
+
|
| 97 |
+
def forward(self, qb):
|
| 98 |
+
# qb = (batch, n_time, emb_dim)
|
| 99 |
+
# qb = self.norm(qb) # (batch, n_time, emb_dim)
|
| 100 |
+
# qb = self.hidden(qb) # (batch, n_time, mlp_dim)
|
| 101 |
+
# mb = self.output(qb) # (batch, n_time, bandwidth * in_channels * reim)
|
| 102 |
+
|
| 103 |
+
mb = checkpoint_sequential(self.combined, 2, qb, use_reentrant=False)
|
| 104 |
+
mb = self.reshape_output(mb) # (batch, in_channels, bandwidth, n_time)
|
| 105 |
+
|
| 106 |
+
return mb
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class MaskEstimationModuleSuperBase(nn.Module):
|
| 110 |
+
pass
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class MaskEstimationModuleBase(MaskEstimationModuleSuperBase):
|
| 114 |
+
def __init__(
|
| 115 |
+
self,
|
| 116 |
+
band_specs: List[Tuple[float, float]],
|
| 117 |
+
emb_dim: int,
|
| 118 |
+
mlp_dim: int,
|
| 119 |
+
in_channels: Optional[int],
|
| 120 |
+
hidden_activation: str = "Tanh",
|
| 121 |
+
hidden_activation_kwargs: Dict = None,
|
| 122 |
+
complex_mask: bool = True,
|
| 123 |
+
norm_mlp_cls: Type[nn.Module] = NormMLP,
|
| 124 |
+
norm_mlp_kwargs: Dict = None,
|
| 125 |
+
) -> None:
|
| 126 |
+
super().__init__()
|
| 127 |
+
|
| 128 |
+
self.band_widths = band_widths_from_specs(band_specs)
|
| 129 |
+
self.n_bands = len(band_specs)
|
| 130 |
+
|
| 131 |
+
if hidden_activation_kwargs is None:
|
| 132 |
+
hidden_activation_kwargs = {}
|
| 133 |
+
|
| 134 |
+
if norm_mlp_kwargs is None:
|
| 135 |
+
norm_mlp_kwargs = {}
|
| 136 |
+
|
| 137 |
+
self.norm_mlp = nn.ModuleList(
|
| 138 |
+
[
|
| 139 |
+
norm_mlp_cls(
|
| 140 |
+
bandwidth=self.band_widths[b],
|
| 141 |
+
emb_dim=emb_dim,
|
| 142 |
+
mlp_dim=mlp_dim,
|
| 143 |
+
in_channels=in_channels,
|
| 144 |
+
hidden_activation=hidden_activation,
|
| 145 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 146 |
+
complex_mask=complex_mask,
|
| 147 |
+
**norm_mlp_kwargs,
|
| 148 |
+
)
|
| 149 |
+
for b in range(self.n_bands)
|
| 150 |
+
]
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
def compute_masks(self, q):
|
| 154 |
+
batch, n_bands, n_time, emb_dim = q.shape
|
| 155 |
+
|
| 156 |
+
masks = []
|
| 157 |
+
|
| 158 |
+
for b, nmlp in enumerate(self.norm_mlp):
|
| 159 |
+
# print(f"maskestim/{b:02d}")
|
| 160 |
+
qb = q[:, b, :, :]
|
| 161 |
+
mb = nmlp(qb)
|
| 162 |
+
masks.append(mb)
|
| 163 |
+
|
| 164 |
+
return masks
|
| 165 |
+
|
| 166 |
+
def compute_mask(self, q, b):
|
| 167 |
+
batch, n_bands, n_time, emb_dim = q.shape
|
| 168 |
+
qb = q[:, b, :, :]
|
| 169 |
+
mb = self.norm_mlp[b](qb)
|
| 170 |
+
return mb
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class OverlappingMaskEstimationModule(MaskEstimationModuleBase):
|
| 174 |
+
def __init__(
|
| 175 |
+
self,
|
| 176 |
+
in_channels: int,
|
| 177 |
+
band_specs: List[Tuple[float, float]],
|
| 178 |
+
freq_weights: List[torch.Tensor],
|
| 179 |
+
n_freq: int,
|
| 180 |
+
emb_dim: int,
|
| 181 |
+
mlp_dim: int,
|
| 182 |
+
cond_dim: int = 0,
|
| 183 |
+
hidden_activation: str = "Tanh",
|
| 184 |
+
hidden_activation_kwargs: Dict = None,
|
| 185 |
+
complex_mask: bool = True,
|
| 186 |
+
norm_mlp_cls: Type[nn.Module] = NormMLP,
|
| 187 |
+
norm_mlp_kwargs: Dict = None,
|
| 188 |
+
use_freq_weights: bool = False,
|
| 189 |
+
) -> None:
|
| 190 |
+
check_nonzero_bandwidth(band_specs)
|
| 191 |
+
check_no_gap(band_specs)
|
| 192 |
+
|
| 193 |
+
if cond_dim > 0:
|
| 194 |
+
raise NotImplementedError
|
| 195 |
+
|
| 196 |
+
super().__init__(
|
| 197 |
+
band_specs=band_specs,
|
| 198 |
+
emb_dim=emb_dim + cond_dim,
|
| 199 |
+
mlp_dim=mlp_dim,
|
| 200 |
+
in_channels=in_channels,
|
| 201 |
+
hidden_activation=hidden_activation,
|
| 202 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 203 |
+
complex_mask=complex_mask,
|
| 204 |
+
norm_mlp_cls=norm_mlp_cls,
|
| 205 |
+
norm_mlp_kwargs=norm_mlp_kwargs,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
self.n_freq = n_freq
|
| 209 |
+
self.band_specs = band_specs
|
| 210 |
+
self.in_channels = in_channels
|
| 211 |
+
|
| 212 |
+
if freq_weights is not None and use_freq_weights:
|
| 213 |
+
for i, fw in enumerate(freq_weights):
|
| 214 |
+
self.register_buffer(f"freq_weights/{i}", fw)
|
| 215 |
+
|
| 216 |
+
self.use_freq_weights = use_freq_weights
|
| 217 |
+
else:
|
| 218 |
+
self.use_freq_weights = False
|
| 219 |
+
|
| 220 |
+
def forward(self, q):
|
| 221 |
+
# q = (batch, n_bands, n_time, emb_dim)
|
| 222 |
+
|
| 223 |
+
batch, n_bands, n_time, emb_dim = q.shape
|
| 224 |
+
|
| 225 |
+
masks = torch.zeros(
|
| 226 |
+
(batch, self.in_channels, self.n_freq, n_time),
|
| 227 |
+
device=q.device,
|
| 228 |
+
dtype=torch.complex64,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
for im in range(n_bands):
|
| 232 |
+
fstart, fend = self.band_specs[im]
|
| 233 |
+
|
| 234 |
+
mask = self.compute_mask(q, im)
|
| 235 |
+
|
| 236 |
+
if self.use_freq_weights:
|
| 237 |
+
fw = self.get_buffer(f"freq_weights/{im}")[:, None]
|
| 238 |
+
mask = mask * fw
|
| 239 |
+
masks[:, :, fstart:fend, :] += mask
|
| 240 |
+
|
| 241 |
+
return masks
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class MaskEstimationModule(OverlappingMaskEstimationModule):
|
| 245 |
+
def __init__(
|
| 246 |
+
self,
|
| 247 |
+
band_specs: List[Tuple[float, float]],
|
| 248 |
+
emb_dim: int,
|
| 249 |
+
mlp_dim: int,
|
| 250 |
+
in_channels: Optional[int],
|
| 251 |
+
hidden_activation: str = "Tanh",
|
| 252 |
+
hidden_activation_kwargs: Dict = None,
|
| 253 |
+
complex_mask: bool = True,
|
| 254 |
+
**kwargs,
|
| 255 |
+
) -> None:
|
| 256 |
+
check_nonzero_bandwidth(band_specs)
|
| 257 |
+
check_no_gap(band_specs)
|
| 258 |
+
check_no_overlap(band_specs)
|
| 259 |
+
super().__init__(
|
| 260 |
+
in_channels=in_channels,
|
| 261 |
+
band_specs=band_specs,
|
| 262 |
+
freq_weights=None,
|
| 263 |
+
n_freq=None,
|
| 264 |
+
emb_dim=emb_dim,
|
| 265 |
+
mlp_dim=mlp_dim,
|
| 266 |
+
hidden_activation=hidden_activation,
|
| 267 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 268 |
+
complex_mask=complex_mask,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
def forward(self, q, cond=None):
|
| 272 |
+
# q = (batch, n_bands, n_time, emb_dim)
|
| 273 |
+
|
| 274 |
+
masks = self.compute_masks(
|
| 275 |
+
q
|
| 276 |
+
) # [n_bands * (batch, in_channels, bandwidth, n_time)]
|
| 277 |
+
|
| 278 |
+
# TODO: currently this requires band specs to have no gap and no overlap
|
| 279 |
+
masks = torch.concat(masks, dim=2) # (batch, in_channels, n_freq, n_time)
|
| 280 |
+
|
| 281 |
+
return masks
|
mvsepless/models/bandit_v2/tfmodel.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.backends.cuda
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.nn.modules import rnn
|
| 7 |
+
from torch.utils.checkpoint import checkpoint_sequential
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TimeFrequencyModellingModule(nn.Module):
|
| 11 |
+
def __init__(self) -> None:
|
| 12 |
+
super().__init__()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ResidualRNN(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
emb_dim: int,
|
| 19 |
+
rnn_dim: int,
|
| 20 |
+
bidirectional: bool = True,
|
| 21 |
+
rnn_type: str = "LSTM",
|
| 22 |
+
use_batch_trick: bool = True,
|
| 23 |
+
use_layer_norm: bool = True,
|
| 24 |
+
) -> None:
|
| 25 |
+
# n_group is the size of the 2nd dim
|
| 26 |
+
super().__init__()
|
| 27 |
+
|
| 28 |
+
assert use_layer_norm
|
| 29 |
+
assert use_batch_trick
|
| 30 |
+
|
| 31 |
+
self.use_layer_norm = use_layer_norm
|
| 32 |
+
self.norm = nn.LayerNorm(emb_dim)
|
| 33 |
+
self.rnn = rnn.__dict__[rnn_type](
|
| 34 |
+
input_size=emb_dim,
|
| 35 |
+
hidden_size=rnn_dim,
|
| 36 |
+
num_layers=1,
|
| 37 |
+
batch_first=True,
|
| 38 |
+
bidirectional=bidirectional,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
self.fc = nn.Linear(
|
| 42 |
+
in_features=rnn_dim * (2 if bidirectional else 1), out_features=emb_dim
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
self.use_batch_trick = use_batch_trick
|
| 46 |
+
if not self.use_batch_trick:
|
| 47 |
+
warnings.warn("NOT USING BATCH TRICK IS EXTREMELY SLOW!!")
|
| 48 |
+
|
| 49 |
+
def forward(self, z):
|
| 50 |
+
# z = (batch, n_uncrossed, n_across, emb_dim)
|
| 51 |
+
|
| 52 |
+
z0 = torch.clone(z)
|
| 53 |
+
z = self.norm(z)
|
| 54 |
+
|
| 55 |
+
batch, n_uncrossed, n_across, emb_dim = z.shape
|
| 56 |
+
z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
|
| 57 |
+
z = self.rnn(z)[0]
|
| 58 |
+
z = torch.reshape(z, (batch, n_uncrossed, n_across, -1))
|
| 59 |
+
|
| 60 |
+
z = self.fc(z) # (batch, n_uncrossed, n_across, emb_dim)
|
| 61 |
+
|
| 62 |
+
z = z + z0
|
| 63 |
+
|
| 64 |
+
return z
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class Transpose(nn.Module):
|
| 68 |
+
def __init__(self, dim0: int, dim1: int) -> None:
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.dim0 = dim0
|
| 71 |
+
self.dim1 = dim1
|
| 72 |
+
|
| 73 |
+
def forward(self, z):
|
| 74 |
+
return z.transpose(self.dim0, self.dim1)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class SeqBandModellingModule(TimeFrequencyModellingModule):
|
| 78 |
+
def __init__(
|
| 79 |
+
self,
|
| 80 |
+
n_modules: int = 12,
|
| 81 |
+
emb_dim: int = 128,
|
| 82 |
+
rnn_dim: int = 256,
|
| 83 |
+
bidirectional: bool = True,
|
| 84 |
+
rnn_type: str = "LSTM",
|
| 85 |
+
parallel_mode=False,
|
| 86 |
+
) -> None:
|
| 87 |
+
super().__init__()
|
| 88 |
+
|
| 89 |
+
self.n_modules = n_modules
|
| 90 |
+
|
| 91 |
+
if parallel_mode:
|
| 92 |
+
self.seqband = nn.ModuleList([])
|
| 93 |
+
for _ in range(n_modules):
|
| 94 |
+
self.seqband.append(
|
| 95 |
+
nn.ModuleList(
|
| 96 |
+
[
|
| 97 |
+
ResidualRNN(
|
| 98 |
+
emb_dim=emb_dim,
|
| 99 |
+
rnn_dim=rnn_dim,
|
| 100 |
+
bidirectional=bidirectional,
|
| 101 |
+
rnn_type=rnn_type,
|
| 102 |
+
),
|
| 103 |
+
ResidualRNN(
|
| 104 |
+
emb_dim=emb_dim,
|
| 105 |
+
rnn_dim=rnn_dim,
|
| 106 |
+
bidirectional=bidirectional,
|
| 107 |
+
rnn_type=rnn_type,
|
| 108 |
+
),
|
| 109 |
+
]
|
| 110 |
+
)
|
| 111 |
+
)
|
| 112 |
+
else:
|
| 113 |
+
seqband = []
|
| 114 |
+
for _ in range(2 * n_modules):
|
| 115 |
+
seqband += [
|
| 116 |
+
ResidualRNN(
|
| 117 |
+
emb_dim=emb_dim,
|
| 118 |
+
rnn_dim=rnn_dim,
|
| 119 |
+
bidirectional=bidirectional,
|
| 120 |
+
rnn_type=rnn_type,
|
| 121 |
+
),
|
| 122 |
+
Transpose(1, 2),
|
| 123 |
+
]
|
| 124 |
+
|
| 125 |
+
self.seqband = nn.Sequential(*seqband)
|
| 126 |
+
|
| 127 |
+
self.parallel_mode = parallel_mode
|
| 128 |
+
|
| 129 |
+
def forward(self, z):
|
| 130 |
+
# z = (batch, n_bands, n_time, emb_dim)
|
| 131 |
+
|
| 132 |
+
if self.parallel_mode:
|
| 133 |
+
for sbm_pair in self.seqband:
|
| 134 |
+
# z: (batch, n_bands, n_time, emb_dim)
|
| 135 |
+
sbm_t, sbm_f = sbm_pair[0], sbm_pair[1]
|
| 136 |
+
zt = sbm_t(z) # (batch, n_bands, n_time, emb_dim)
|
| 137 |
+
zf = sbm_f(z.transpose(1, 2)) # (batch, n_time, n_bands, emb_dim)
|
| 138 |
+
z = zt + zf.transpose(1, 2)
|
| 139 |
+
else:
|
| 140 |
+
z = checkpoint_sequential(
|
| 141 |
+
self.seqband, self.n_modules, z, use_reentrant=False
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
q = z
|
| 145 |
+
return q # (batch, n_bands, n_time, emb_dim)
|
mvsepless/models/bandit_v2/utils.py
ADDED
|
@@ -0,0 +1,523 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from abc import abstractmethod
|
| 3 |
+
from typing import Callable
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from librosa import hz_to_midi, midi_to_hz
|
| 8 |
+
from torchaudio import functional as taF
|
| 9 |
+
|
| 10 |
+
# from spafe.fbanks import bark_fbanks
|
| 11 |
+
# from spafe.utils.converters import erb2hz, hz2bark, hz2erb
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def band_widths_from_specs(band_specs):
|
| 15 |
+
return [e - i for i, e in band_specs]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def check_nonzero_bandwidth(band_specs):
|
| 19 |
+
# pprint(band_specs)
|
| 20 |
+
for fstart, fend in band_specs:
|
| 21 |
+
if fend - fstart <= 0:
|
| 22 |
+
raise ValueError("Bands cannot be zero-width")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def check_no_overlap(band_specs):
|
| 26 |
+
fend_prev = -1
|
| 27 |
+
for fstart_curr, fend_curr in band_specs:
|
| 28 |
+
if fstart_curr <= fend_prev:
|
| 29 |
+
raise ValueError("Bands cannot overlap")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def check_no_gap(band_specs):
|
| 33 |
+
fstart, _ = band_specs[0]
|
| 34 |
+
assert fstart == 0
|
| 35 |
+
|
| 36 |
+
fend_prev = -1
|
| 37 |
+
for fstart_curr, fend_curr in band_specs:
|
| 38 |
+
if fstart_curr - fend_prev > 1:
|
| 39 |
+
raise ValueError("Bands cannot leave gap")
|
| 40 |
+
fend_prev = fend_curr
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class BandsplitSpecification:
|
| 44 |
+
def __init__(self, nfft: int, fs: int) -> None:
|
| 45 |
+
self.fs = fs
|
| 46 |
+
self.nfft = nfft
|
| 47 |
+
self.nyquist = fs / 2
|
| 48 |
+
self.max_index = nfft // 2 + 1
|
| 49 |
+
|
| 50 |
+
self.split500 = self.hertz_to_index(500)
|
| 51 |
+
self.split1k = self.hertz_to_index(1000)
|
| 52 |
+
self.split2k = self.hertz_to_index(2000)
|
| 53 |
+
self.split4k = self.hertz_to_index(4000)
|
| 54 |
+
self.split8k = self.hertz_to_index(8000)
|
| 55 |
+
self.split16k = self.hertz_to_index(16000)
|
| 56 |
+
self.split20k = self.hertz_to_index(20000)
|
| 57 |
+
|
| 58 |
+
self.above20k = [(self.split20k, self.max_index)]
|
| 59 |
+
self.above16k = [(self.split16k, self.split20k)] + self.above20k
|
| 60 |
+
|
| 61 |
+
def index_to_hertz(self, index: int):
|
| 62 |
+
return index * self.fs / self.nfft
|
| 63 |
+
|
| 64 |
+
def hertz_to_index(self, hz: float, round: bool = True):
|
| 65 |
+
index = hz * self.nfft / self.fs
|
| 66 |
+
|
| 67 |
+
if round:
|
| 68 |
+
index = int(np.round(index))
|
| 69 |
+
|
| 70 |
+
return index
|
| 71 |
+
|
| 72 |
+
def get_band_specs_with_bandwidth(self, start_index, end_index, bandwidth_hz):
|
| 73 |
+
band_specs = []
|
| 74 |
+
lower = start_index
|
| 75 |
+
|
| 76 |
+
while lower < end_index:
|
| 77 |
+
upper = int(np.floor(lower + self.hertz_to_index(bandwidth_hz)))
|
| 78 |
+
upper = min(upper, end_index)
|
| 79 |
+
|
| 80 |
+
band_specs.append((lower, upper))
|
| 81 |
+
lower = upper
|
| 82 |
+
|
| 83 |
+
return band_specs
|
| 84 |
+
|
| 85 |
+
@abstractmethod
|
| 86 |
+
def get_band_specs(self):
|
| 87 |
+
raise NotImplementedError
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class VocalBandsplitSpecification(BandsplitSpecification):
|
| 91 |
+
def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
|
| 92 |
+
super().__init__(nfft=nfft, fs=fs)
|
| 93 |
+
|
| 94 |
+
self.version = version
|
| 95 |
+
|
| 96 |
+
def get_band_specs(self):
|
| 97 |
+
return getattr(self, f"version{self.version}")()
|
| 98 |
+
|
| 99 |
+
@property
|
| 100 |
+
def version1(self):
|
| 101 |
+
return self.get_band_specs_with_bandwidth(
|
| 102 |
+
start_index=0, end_index=self.max_index, bandwidth_hz=1000
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
def version2(self):
|
| 106 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 107 |
+
start_index=0, end_index=self.split16k, bandwidth_hz=1000
|
| 108 |
+
)
|
| 109 |
+
below20k = self.get_band_specs_with_bandwidth(
|
| 110 |
+
start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
return below16k + below20k + self.above20k
|
| 114 |
+
|
| 115 |
+
def version3(self):
|
| 116 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 117 |
+
start_index=0, end_index=self.split8k, bandwidth_hz=1000
|
| 118 |
+
)
|
| 119 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 120 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
return below8k + below16k + self.above16k
|
| 124 |
+
|
| 125 |
+
def version4(self):
|
| 126 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 127 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 128 |
+
)
|
| 129 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 130 |
+
start_index=self.split1k, end_index=self.split8k, bandwidth_hz=1000
|
| 131 |
+
)
|
| 132 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 133 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
return below1k + below8k + below16k + self.above16k
|
| 137 |
+
|
| 138 |
+
def version5(self):
|
| 139 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 140 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 141 |
+
)
|
| 142 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 143 |
+
start_index=self.split1k, end_index=self.split16k, bandwidth_hz=1000
|
| 144 |
+
)
|
| 145 |
+
below20k = self.get_band_specs_with_bandwidth(
|
| 146 |
+
start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
|
| 147 |
+
)
|
| 148 |
+
return below1k + below16k + below20k + self.above20k
|
| 149 |
+
|
| 150 |
+
def version6(self):
|
| 151 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 152 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 153 |
+
)
|
| 154 |
+
below4k = self.get_band_specs_with_bandwidth(
|
| 155 |
+
start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
|
| 156 |
+
)
|
| 157 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 158 |
+
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
|
| 159 |
+
)
|
| 160 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 161 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
|
| 162 |
+
)
|
| 163 |
+
return below1k + below4k + below8k + below16k + self.above16k
|
| 164 |
+
|
| 165 |
+
def version7(self):
|
| 166 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 167 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 168 |
+
)
|
| 169 |
+
below4k = self.get_band_specs_with_bandwidth(
|
| 170 |
+
start_index=self.split1k, end_index=self.split4k, bandwidth_hz=250
|
| 171 |
+
)
|
| 172 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 173 |
+
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
|
| 174 |
+
)
|
| 175 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 176 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
|
| 177 |
+
)
|
| 178 |
+
below20k = self.get_band_specs_with_bandwidth(
|
| 179 |
+
start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
|
| 180 |
+
)
|
| 181 |
+
return below1k + below4k + below8k + below16k + below20k + self.above20k
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class OtherBandsplitSpecification(VocalBandsplitSpecification):
|
| 185 |
+
def __init__(self, nfft: int, fs: int) -> None:
|
| 186 |
+
super().__init__(nfft=nfft, fs=fs, version="7")
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class BassBandsplitSpecification(BandsplitSpecification):
|
| 190 |
+
def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
|
| 191 |
+
super().__init__(nfft=nfft, fs=fs)
|
| 192 |
+
|
| 193 |
+
def get_band_specs(self):
|
| 194 |
+
below500 = self.get_band_specs_with_bandwidth(
|
| 195 |
+
start_index=0, end_index=self.split500, bandwidth_hz=50
|
| 196 |
+
)
|
| 197 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 198 |
+
start_index=self.split500, end_index=self.split1k, bandwidth_hz=100
|
| 199 |
+
)
|
| 200 |
+
below4k = self.get_band_specs_with_bandwidth(
|
| 201 |
+
start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
|
| 202 |
+
)
|
| 203 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 204 |
+
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
|
| 205 |
+
)
|
| 206 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 207 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
|
| 208 |
+
)
|
| 209 |
+
above16k = [(self.split16k, self.max_index)]
|
| 210 |
+
|
| 211 |
+
return below500 + below1k + below4k + below8k + below16k + above16k
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class DrumBandsplitSpecification(BandsplitSpecification):
|
| 215 |
+
def __init__(self, nfft: int, fs: int) -> None:
|
| 216 |
+
super().__init__(nfft=nfft, fs=fs)
|
| 217 |
+
|
| 218 |
+
def get_band_specs(self):
|
| 219 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 220 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=50
|
| 221 |
+
)
|
| 222 |
+
below2k = self.get_band_specs_with_bandwidth(
|
| 223 |
+
start_index=self.split1k, end_index=self.split2k, bandwidth_hz=100
|
| 224 |
+
)
|
| 225 |
+
below4k = self.get_band_specs_with_bandwidth(
|
| 226 |
+
start_index=self.split2k, end_index=self.split4k, bandwidth_hz=250
|
| 227 |
+
)
|
| 228 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 229 |
+
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
|
| 230 |
+
)
|
| 231 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 232 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
|
| 233 |
+
)
|
| 234 |
+
above16k = [(self.split16k, self.max_index)]
|
| 235 |
+
|
| 236 |
+
return below1k + below2k + below4k + below8k + below16k + above16k
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class PerceptualBandsplitSpecification(BandsplitSpecification):
|
| 240 |
+
def __init__(
|
| 241 |
+
self,
|
| 242 |
+
nfft: int,
|
| 243 |
+
fs: int,
|
| 244 |
+
fbank_fn: Callable[[int, int, float, float, int], torch.Tensor],
|
| 245 |
+
n_bands: int,
|
| 246 |
+
f_min: float = 0.0,
|
| 247 |
+
f_max: float = None,
|
| 248 |
+
) -> None:
|
| 249 |
+
super().__init__(nfft=nfft, fs=fs)
|
| 250 |
+
self.n_bands = n_bands
|
| 251 |
+
if f_max is None:
|
| 252 |
+
f_max = fs / 2
|
| 253 |
+
|
| 254 |
+
self.filterbank = fbank_fn(n_bands, fs, f_min, f_max, self.max_index)
|
| 255 |
+
|
| 256 |
+
weight_per_bin = torch.sum(self.filterbank, dim=0, keepdim=True) # (1, n_freqs)
|
| 257 |
+
normalized_mel_fb = self.filterbank / weight_per_bin # (n_mels, n_freqs)
|
| 258 |
+
|
| 259 |
+
freq_weights = []
|
| 260 |
+
band_specs = []
|
| 261 |
+
for i in range(self.n_bands):
|
| 262 |
+
active_bins = torch.nonzero(self.filterbank[i, :]).squeeze().tolist()
|
| 263 |
+
if isinstance(active_bins, int):
|
| 264 |
+
active_bins = (active_bins, active_bins)
|
| 265 |
+
if len(active_bins) == 0:
|
| 266 |
+
continue
|
| 267 |
+
start_index = active_bins[0]
|
| 268 |
+
end_index = active_bins[-1] + 1
|
| 269 |
+
band_specs.append((start_index, end_index))
|
| 270 |
+
freq_weights.append(normalized_mel_fb[i, start_index:end_index])
|
| 271 |
+
|
| 272 |
+
self.freq_weights = freq_weights
|
| 273 |
+
self.band_specs = band_specs
|
| 274 |
+
|
| 275 |
+
def get_band_specs(self):
|
| 276 |
+
return self.band_specs
|
| 277 |
+
|
| 278 |
+
def get_freq_weights(self):
|
| 279 |
+
return self.freq_weights
|
| 280 |
+
|
| 281 |
+
def save_to_file(self, dir_path: str) -> None:
|
| 282 |
+
os.makedirs(dir_path, exist_ok=True)
|
| 283 |
+
|
| 284 |
+
import pickle
|
| 285 |
+
|
| 286 |
+
with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f:
|
| 287 |
+
pickle.dump(
|
| 288 |
+
{
|
| 289 |
+
"band_specs": self.band_specs,
|
| 290 |
+
"freq_weights": self.freq_weights,
|
| 291 |
+
"filterbank": self.filterbank,
|
| 292 |
+
},
|
| 293 |
+
f,
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs):
|
| 298 |
+
fb = taF.melscale_fbanks(
|
| 299 |
+
n_mels=n_bands,
|
| 300 |
+
sample_rate=fs,
|
| 301 |
+
f_min=f_min,
|
| 302 |
+
f_max=f_max,
|
| 303 |
+
n_freqs=n_freqs,
|
| 304 |
+
).T
|
| 305 |
+
|
| 306 |
+
fb[0, 0] = 1.0
|
| 307 |
+
|
| 308 |
+
return fb
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
class MelBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 312 |
+
def __init__(
|
| 313 |
+
self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
|
| 314 |
+
) -> None:
|
| 315 |
+
super().__init__(
|
| 316 |
+
fbank_fn=mel_filterbank,
|
| 317 |
+
nfft=nfft,
|
| 318 |
+
fs=fs,
|
| 319 |
+
n_bands=n_bands,
|
| 320 |
+
f_min=f_min,
|
| 321 |
+
f_max=f_max,
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs, scale="constant"):
|
| 326 |
+
nfft = 2 * (n_freqs - 1)
|
| 327 |
+
df = fs / nfft
|
| 328 |
+
# init freqs
|
| 329 |
+
f_max = f_max or fs / 2
|
| 330 |
+
f_min = f_min or 0
|
| 331 |
+
f_min = fs / nfft
|
| 332 |
+
|
| 333 |
+
n_octaves = np.log2(f_max / f_min)
|
| 334 |
+
n_octaves_per_band = n_octaves / n_bands
|
| 335 |
+
bandwidth_mult = np.power(2.0, n_octaves_per_band)
|
| 336 |
+
|
| 337 |
+
low_midi = max(0, hz_to_midi(f_min))
|
| 338 |
+
high_midi = hz_to_midi(f_max)
|
| 339 |
+
midi_points = np.linspace(low_midi, high_midi, n_bands)
|
| 340 |
+
hz_pts = midi_to_hz(midi_points)
|
| 341 |
+
|
| 342 |
+
low_pts = hz_pts / bandwidth_mult
|
| 343 |
+
high_pts = hz_pts * bandwidth_mult
|
| 344 |
+
|
| 345 |
+
low_bins = np.floor(low_pts / df).astype(int)
|
| 346 |
+
high_bins = np.ceil(high_pts / df).astype(int)
|
| 347 |
+
|
| 348 |
+
fb = np.zeros((n_bands, n_freqs))
|
| 349 |
+
|
| 350 |
+
for i in range(n_bands):
|
| 351 |
+
fb[i, low_bins[i] : high_bins[i] + 1] = 1.0
|
| 352 |
+
|
| 353 |
+
fb[0, : low_bins[0]] = 1.0
|
| 354 |
+
fb[-1, high_bins[-1] + 1 :] = 1.0
|
| 355 |
+
|
| 356 |
+
return torch.as_tensor(fb)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
class MusicalBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 360 |
+
def __init__(
|
| 361 |
+
self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
|
| 362 |
+
) -> None:
|
| 363 |
+
super().__init__(
|
| 364 |
+
fbank_fn=musical_filterbank,
|
| 365 |
+
nfft=nfft,
|
| 366 |
+
fs=fs,
|
| 367 |
+
n_bands=n_bands,
|
| 368 |
+
f_min=f_min,
|
| 369 |
+
f_max=f_max,
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
# def bark_filterbank(
|
| 374 |
+
# n_bands, fs, f_min, f_max, n_freqs
|
| 375 |
+
# ):
|
| 376 |
+
# nfft = 2 * (n_freqs -1)
|
| 377 |
+
# fb, _ = bark_fbanks.bark_filter_banks(
|
| 378 |
+
# nfilts=n_bands,
|
| 379 |
+
# nfft=nfft,
|
| 380 |
+
# fs=fs,
|
| 381 |
+
# low_freq=f_min,
|
| 382 |
+
# high_freq=f_max,
|
| 383 |
+
# scale="constant"
|
| 384 |
+
# )
|
| 385 |
+
|
| 386 |
+
# return torch.as_tensor(fb)
|
| 387 |
+
|
| 388 |
+
# class BarkBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 389 |
+
# def __init__(
|
| 390 |
+
# self,
|
| 391 |
+
# nfft: int,
|
| 392 |
+
# fs: int,
|
| 393 |
+
# n_bands: int,
|
| 394 |
+
# f_min: float = 0.0,
|
| 395 |
+
# f_max: float = None
|
| 396 |
+
# ) -> None:
|
| 397 |
+
# super().__init__(fbank_fn=bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
# def triangular_bark_filterbank(
|
| 401 |
+
# n_bands, fs, f_min, f_max, n_freqs
|
| 402 |
+
# ):
|
| 403 |
+
|
| 404 |
+
# all_freqs = torch.linspace(0, fs // 2, n_freqs)
|
| 405 |
+
|
| 406 |
+
# # calculate mel freq bins
|
| 407 |
+
# m_min = hz2bark(f_min)
|
| 408 |
+
# m_max = hz2bark(f_max)
|
| 409 |
+
|
| 410 |
+
# m_pts = torch.linspace(m_min, m_max, n_bands + 2)
|
| 411 |
+
# f_pts = 600 * torch.sinh(m_pts / 6)
|
| 412 |
+
|
| 413 |
+
# # create filterbank
|
| 414 |
+
# fb = _create_triangular_filterbank(all_freqs, f_pts)
|
| 415 |
+
|
| 416 |
+
# fb = fb.T
|
| 417 |
+
|
| 418 |
+
# first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
|
| 419 |
+
# first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
|
| 420 |
+
|
| 421 |
+
# fb[first_active_band, :first_active_bin] = 1.0
|
| 422 |
+
|
| 423 |
+
# return fb
|
| 424 |
+
|
| 425 |
+
# class TriangularBarkBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 426 |
+
# def __init__(
|
| 427 |
+
# self,
|
| 428 |
+
# nfft: int,
|
| 429 |
+
# fs: int,
|
| 430 |
+
# n_bands: int,
|
| 431 |
+
# f_min: float = 0.0,
|
| 432 |
+
# f_max: float = None
|
| 433 |
+
# ) -> None:
|
| 434 |
+
# super().__init__(fbank_fn=triangular_bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
# def minibark_filterbank(
|
| 438 |
+
# n_bands, fs, f_min, f_max, n_freqs
|
| 439 |
+
# ):
|
| 440 |
+
# fb = bark_filterbank(
|
| 441 |
+
# n_bands,
|
| 442 |
+
# fs,
|
| 443 |
+
# f_min,
|
| 444 |
+
# f_max,
|
| 445 |
+
# n_freqs
|
| 446 |
+
# )
|
| 447 |
+
|
| 448 |
+
# fb[fb < np.sqrt(0.5)] = 0.0
|
| 449 |
+
|
| 450 |
+
# return fb
|
| 451 |
+
|
| 452 |
+
# class MiniBarkBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 453 |
+
# def __init__(
|
| 454 |
+
# self,
|
| 455 |
+
# nfft: int,
|
| 456 |
+
# fs: int,
|
| 457 |
+
# n_bands: int,
|
| 458 |
+
# f_min: float = 0.0,
|
| 459 |
+
# f_max: float = None
|
| 460 |
+
# ) -> None:
|
| 461 |
+
# super().__init__(fbank_fn=minibark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
# def erb_filterbank(
|
| 465 |
+
# n_bands: int,
|
| 466 |
+
# fs: int,
|
| 467 |
+
# f_min: float,
|
| 468 |
+
# f_max: float,
|
| 469 |
+
# n_freqs: int,
|
| 470 |
+
# ) -> Tensor:
|
| 471 |
+
# # freq bins
|
| 472 |
+
# A = (1000 * np.log(10)) / (24.7 * 4.37)
|
| 473 |
+
# all_freqs = torch.linspace(0, fs // 2, n_freqs)
|
| 474 |
+
|
| 475 |
+
# # calculate mel freq bins
|
| 476 |
+
# m_min = hz2erb(f_min)
|
| 477 |
+
# m_max = hz2erb(f_max)
|
| 478 |
+
|
| 479 |
+
# m_pts = torch.linspace(m_min, m_max, n_bands + 2)
|
| 480 |
+
# f_pts = (torch.pow(10, (m_pts / A)) - 1)/ 0.00437
|
| 481 |
+
|
| 482 |
+
# # create filterbank
|
| 483 |
+
# fb = _create_triangular_filterbank(all_freqs, f_pts)
|
| 484 |
+
|
| 485 |
+
# fb = fb.T
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
# first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
|
| 489 |
+
# first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
|
| 490 |
+
|
| 491 |
+
# fb[first_active_band, :first_active_bin] = 1.0
|
| 492 |
+
|
| 493 |
+
# return fb
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
# class EquivalentRectangularBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 497 |
+
# def __init__(
|
| 498 |
+
# self,
|
| 499 |
+
# nfft: int,
|
| 500 |
+
# fs: int,
|
| 501 |
+
# n_bands: int,
|
| 502 |
+
# f_min: float = 0.0,
|
| 503 |
+
# f_max: float = None
|
| 504 |
+
# ) -> None:
|
| 505 |
+
# super().__init__(fbank_fn=erb_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
|
| 506 |
+
|
| 507 |
+
if __name__ == "__main__":
|
| 508 |
+
import pandas as pd
|
| 509 |
+
|
| 510 |
+
band_defs = []
|
| 511 |
+
|
| 512 |
+
for bands in [VocalBandsplitSpecification]:
|
| 513 |
+
band_name = bands.__name__.replace("BandsplitSpecification", "")
|
| 514 |
+
|
| 515 |
+
mbs = bands(nfft=2048, fs=44100).get_band_specs()
|
| 516 |
+
|
| 517 |
+
for i, (f_min, f_max) in enumerate(mbs):
|
| 518 |
+
band_defs.append(
|
| 519 |
+
{"band": band_name, "band_index": i, "f_min": f_min, "f_max": f_max}
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
df = pd.DataFrame(band_defs)
|
| 523 |
+
df.to_csv("vox7bands.csv", index=False)
|