musicgowdam commited on
Commit
b45d54e
·
verified ·
1 Parent(s): bd7b928

Upload 96 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ demucs.png filter=lfs diff=lfs merge=lfs -text
37
+ test.mp3 filter=lfs diff=lfs merge=lfs -text
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to make participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies within all project spaces, and it also applies when
49
+ an individual is representing the project or its community in public spaces.
50
+ Examples of representing a project or community include using an official
51
+ project e-mail address, posting via an official social media account, or acting
52
+ as an appointed representative at an online or offline event. Representation of
53
+ a project may be further defined and clarified by project maintainers.
54
+
55
+ ## Enforcement
56
+
57
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
58
+ reported by contacting the project team at <opensource-conduct@fb.com>. All
59
+ complaints will be reviewed and investigated and will result in a response that
60
+ is deemed necessary and appropriate to the circumstances. The project team is
61
+ obligated to maintain confidentiality with regard to the reporter of an incident.
62
+ Further details of specific enforcement policies may be posted separately.
63
+
64
+ Project maintainers who do not follow or enforce the Code of Conduct in good
65
+ faith may face temporary or permanent repercussions as determined by other
66
+ members of the project's leadership.
67
+
68
+ ## Attribution
69
+
70
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
71
+ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
72
+
73
+ [homepage]: https://www.contributor-covenant.org
74
+
75
+ For answers to common questions about this code of conduct, see
76
+ https://www.contributor-covenant.org/faq
CONTRIBUTING.md ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to Demucs
2
+
3
+ ## Pull Requests
4
+
5
+ In order to accept your pull request, we need you to submit a CLA. You only need
6
+ to do this once to work on any of Facebook's open source projects.
7
+
8
+ Complete your CLA here: <https://code.facebook.com/cla>
9
+
10
+ Demucs is the implementation of a research paper.
11
+ Therefore, we do not plan on accepting many pull requests for new features.
12
+ We certainly welcome them for bug fixes.
13
+
14
+
15
+ ## Issues
16
+
17
+ We use GitHub issues to track public bugs. Please ensure your description is
18
+ clear and has sufficient instructions to be able to reproduce the issue.
19
+
20
+
21
+ ## License
22
+ By contributing to this repository, you agree that your contributions will be licensed
23
+ under the LICENSE file in the root directory of this source tree.
Demucs.ipynb ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "colab_type": "text",
7
+ "id": "Be9yoh-ILfRr"
8
+ },
9
+ "source": [
10
+ "# Hybrid Demucs\n",
11
+ "\n",
12
+ "Feel free to use the Colab version:\n",
13
+ "https://colab.research.google.com/drive/1dC9nVxk3V_VPjUADsnFu8EiT-xnU1tGH?usp=sharing"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": null,
19
+ "metadata": {
20
+ "colab": {
21
+ "base_uri": "https://localhost:8080/",
22
+ "height": 139
23
+ },
24
+ "colab_type": "code",
25
+ "executionInfo": {
26
+ "elapsed": 12277,
27
+ "status": "ok",
28
+ "timestamp": 1583778134659,
29
+ "user": {
30
+ "displayName": "Marllus Lustosa",
31
+ "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgLl2RbW64ZyWz3Y8IBku0zhHCMnt7fz7fEl0LTdA=s64",
32
+ "userId": "14811735256675200480"
33
+ },
34
+ "user_tz": 180
35
+ },
36
+ "id": "kOjIPLlzhPfn",
37
+ "outputId": "c75f17ec-b576-4105-bc5b-c2ac9c1018a3"
38
+ },
39
+ "outputs": [],
40
+ "source": [
41
+ "!pip install -U demucs\n",
42
+ "# or for local development, if you have a clone of Demucs\n",
43
+ "# pip install -e ."
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": null,
49
+ "metadata": {
50
+ "colab": {},
51
+ "colab_type": "code",
52
+ "id": "5lYOzKKCKAbJ"
53
+ },
54
+ "outputs": [],
55
+ "source": [
56
+ "# You can use the `demucs` command line to separate tracks\n",
57
+ "!demucs test.mp3"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": null,
63
+ "metadata": {},
64
+ "outputs": [],
65
+ "source": [
66
+ "# You can also load directly the pretrained models,\n",
67
+ "# for instance for the MDX 2021 winning model of Track A:\n",
68
+ "from demucs import pretrained\n",
69
+ "model = pretrained.get_model('mdx')"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "execution_count": null,
75
+ "metadata": {},
76
+ "outputs": [],
77
+ "source": [
78
+ "# Because `model` is a bag of 4 models, you cannot directly call it on your data,\n",
79
+ "# but the `apply_model` will know what to do of it.\n",
80
+ "import torch\n",
81
+ "from demucs.apply import apply_model\n",
82
+ "x = torch.randn(1, 2, 44100 * 10) # ten seconds of white noise for the demo\n",
83
+ "out = apply_model(model, x)[0] # shape is [S, C, T] with S the number of sources\n",
84
+ "\n",
85
+ "# So let see, where is all the white noise content is going ?\n",
86
+ "for name, source in zip(model.sources, out):\n",
87
+ " print(name, source.std() / x.std())\n",
88
+ "# The outputs are quite weird to be fair, not what I would have expected."
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": null,
94
+ "metadata": {},
95
+ "outputs": [],
96
+ "source": [
97
+ "# now let's take a single model from the bag, and let's test it on a pure cosine\n",
98
+ "freq = 440 # in Hz\n",
99
+ "sr = model.samplerate\n",
100
+ "t = torch.arange(10 * sr).float() / sr\n",
101
+ "x = torch.cos(2 * 3.1416 * freq * t).expand(1, 2, -1)\n",
102
+ "sub_model = model.models[3]\n",
103
+ "out = sub_model(x)[0]\n",
104
+ "\n",
105
+ "# Same question where does it go?\n",
106
+ "for name, source in zip(model.sources, out):\n",
107
+ " print(name, source.std() / x.std())\n",
108
+ " \n",
109
+ "# Well now it makes much more sense, all the energy is going\n",
110
+ "# in the `other` source.\n",
111
+ "# Feel free to try lower pitch (try 80 Hz) to see what happens !"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": null,
117
+ "metadata": {},
118
+ "outputs": [],
119
+ "source": [
120
+ "# For training or more fun, refer to the Demucs README on our repo\n",
121
+ "# https://github.com/facebookresearch/demucs/tree/main/demucs"
122
+ ]
123
+ }
124
+ ],
125
+ "metadata": {
126
+ "accelerator": "GPU",
127
+ "colab": {
128
+ "authorship_tag": "ABX9TyM9xpVr1M86NRcjtQ7g9tCx",
129
+ "collapsed_sections": [],
130
+ "name": "Demucs.ipynb",
131
+ "provenance": []
132
+ },
133
+ "kernelspec": {
134
+ "display_name": "Python 3",
135
+ "language": "python",
136
+ "name": "python3"
137
+ },
138
+ "language_info": {
139
+ "codemirror_mode": {
140
+ "name": "ipython",
141
+ "version": 3
142
+ },
143
+ "file_extension": ".py",
144
+ "mimetype": "text/x-python",
145
+ "name": "python",
146
+ "nbconvert_exporter": "python",
147
+ "pygments_lexer": "ipython3",
148
+ "version": "3.8.8"
149
+ }
150
+ },
151
+ "nbformat": 4,
152
+ "nbformat_minor": 1
153
+ }
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Meta Platforms, Inc. and affiliates.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
MANIFEST.in ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ recursive-exclude env *
2
+ recursive-include conf *.yaml
3
+ include Makefile
4
+ include LICENSE
5
+ include demucs.png
6
+ include outputs.tar.gz
7
+ include test.mp3
8
+ include requirements.txt
9
+ include requirements_minimal.txt
10
+ include mypy.ini
11
+ include demucs/py.typed
12
+ include demucs/remote/*.txt
13
+ include demucs/remote/*.yaml
Makefile ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ all: linter tests
2
+
3
+ linter:
4
+ flake8 demucs
5
+ mypy demucs
6
+
7
+ tests: test_train test_eval
8
+
9
+ test_train: tests/musdb
10
+ _DORA_TEST_PATH=/tmp/demucs python3 -m dora run --clear \
11
+ dset.musdb=./tests/musdb dset.segment=4 dset.shift=2 epochs=2 model=demucs \
12
+ demucs.depth=2 demucs.channels=4 test.sdr=false misc.num_workers=0 test.workers=0 \
13
+ test.shifts=0
14
+
15
+ test_eval:
16
+ python3 -m demucs -n demucs_unittest test.mp3
17
+ python3 -m demucs -n demucs_unittest --two-stems=vocals test.mp3
18
+ python3 -m demucs -n demucs_unittest --mp3 test.mp3
19
+ python3 -m demucs -n demucs_unittest --flac --int24 test.mp3
20
+ python3 -m demucs -n demucs_unittest --int24 --clip-mode clamp test.mp3
21
+ python3 -m demucs -n demucs_unittest --segment 8 test.mp3
22
+ python3 -m demucs.api -n demucs_unittest --segment 8 test.mp3
23
+ python3 -m demucs --list-models
24
+
25
+ tests/musdb:
26
+ test -e tests || mkdir tests
27
+ python3 -c 'import musdb; musdb.DB("tests/tmp", download=True)'
28
+ musdbconvert tests/tmp tests/musdb
29
+
30
+ dist:
31
+ python3 setup.py sdist
32
+
33
+ clean:
34
+ rm -r dist build *.egg-info
35
+
36
+ .PHONY: linter dist test_train test_eval
README.md CHANGED
@@ -1,14 +1,319 @@
1
- ---
2
- title: MUSIC SEPERATION
3
- emoji: 🐢
4
- colorFrom: indigo
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 5.49.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: Demucs Music Source Separation (v4)
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Demucs Music Source Separation
2
+
3
+ [![Support Ukraine](https://img.shields.io/badge/Support-Ukraine-FFD500?style=flat&labelColor=005BBB)](https://opensource.fb.com/support-ukraine)
4
+ ![tests badge](https://github.com/facebookresearch/demucs/workflows/tests/badge.svg)
5
+ ![linter badge](https://github.com/facebookresearch/demucs/workflows/linter/badge.svg)
6
+
7
+
8
+ **Important:** As I am no longer working at Meta, **this repository is not maintained anymore**.
9
+ I've created a fork at [github.com/adefossez/demucs](https://github.com/adefossez/demucs). Note that this project is not actively maintained anymore
10
+ and only important bug fixes will be processed on the new repo. Please do not open issues for feature request or if Demucs doesn't work perfectly for your use case :)
11
+
12
+ This is the 4th release of Demucs (v4), featuring Hybrid Transformer based source separation.
13
+ **For the classic Hybrid Demucs (v3):** [Go this commit][demucs_v3].
14
+ If you are experiencing issues and want the old Demucs back, please file an issue, and then you can get back to Demucs v3 with
15
+ `git checkout v3`. You can also go [Demucs v2][demucs_v2].
16
+
17
+
18
+ Demucs is a state-of-the-art music source separation model, currently capable of separating
19
+ drums, bass, and vocals from the rest of the accompaniment.
20
+ Demucs is based on a U-Net convolutional architecture inspired by [Wave-U-Net][waveunet].
21
+ The v4 version features [Hybrid Transformer Demucs][htdemucs], a hybrid spectrogram/waveform separation model using Transformers.
22
+ It is based on [Hybrid Demucs][hybrid_paper] (also provided in this repo), with the innermost layers
23
+ replaced by a cross-domain Transformer Encoder. This Transformer uses self-attention within each domain,
24
+ and cross-attention across domains.
25
+ The model achieves a SDR of 9.00 dB on the MUSDB HQ test set. Moreover, when using sparse attention
26
+ kernels to extend its receptive field and per source fine-tuning, we achieve state-of-the-art 9.20 dB of SDR.
27
+
28
+ Samples are available [on our sample page](https://ai.honu.io/papers/htdemucs/index.html).
29
+ Checkout [our paper][htdemucs] for more information.
30
+ It has been trained on the [MUSDB HQ][musdb] dataset + an extra training dataset of 800 songs.
31
+ This model separates drums, bass and vocals and other stems for any song.
32
+
33
+
34
+ As Hybrid Transformer Demucs is brand new, it is not activated by default, you can activate it in the usual
35
+ commands described hereafter with `-n htdemucs_ft`.
36
+ The single, non fine-tuned model is provided as `-n htdemucs`, and the retrained baseline
37
+ as `-n hdemucs_mmi`. The Sparse Hybrid Transformer model decribed in our paper is not provided as its
38
+ requires custom CUDA code that is not ready for release yet.
39
+ We are also releasing an experimental 6 sources model, that adds a `guitar` and `piano` source.
40
+ Quick testing seems to show okay quality for `guitar`, but a lot of bleeding and artifacts for the `piano` source.
41
+
42
+
43
+ <p align="center">
44
+ <img src="./demucs.png" alt="Schema representing the structure of Hybrid Transformer Demucs,
45
+ with a dual U-Net structure, one branch for the temporal domain,
46
+ and one branch for the spectral domain. There is a cross-domain Transformer between the Encoders and Decoders."
47
+ width="800px"></p>
48
+
49
+
50
+
51
+ ## Important news if you are already using Demucs
52
+
53
+ See the [release notes](./docs/release.md) for more details.
54
+
55
+ - 22/02/2023: added support for the [SDX 2023 Challenge](https://www.aicrowd.com/challenges/sound-demixing-challenge-2023),
56
+ see the dedicated [doc page](./docs/sdx23.md)
57
+ - 07/12/2022: Demucs v4 now on PyPI. **htdemucs** model now used by default. Also releasing
58
+ a 6 sources models (adding `guitar` and `piano`, although the latter doesn't work so well at the moment).
59
+ - 16/11/2022: Added the new **Hybrid Transformer Demucs v4** models.
60
+ Adding support for the [torchaudio implementation of HDemucs](https://pytorch.org/audio/stable/tutorials/hybrid_demucs_tutorial.html).
61
+ - 30/08/2022: added reproducibility and ablation grids, along with an updated version of the paper.
62
+ - 17/08/2022: Releasing v3.0.5: Set split segment length to reduce memory. Compatible with pyTorch 1.12.
63
+ - 24/02/2022: Releasing v3.0.4: split into two stems (i.e. karaoke mode).
64
+ Export as float32 or int24.
65
+ - 17/12/2021: Releasing v3.0.3: bug fixes (thanks @keunwoochoi), memory drastically
66
+ reduced on GPU (thanks @famzah) and new multi-core evaluation on CPU (`-j` flag).
67
+ - 12/11/2021: Releasing **Demucs v3** with hybrid domain separation. Strong improvements
68
+ on all sources. This is the model that won Sony MDX challenge.
69
+ - 11/05/2021: Adding support for MusDB-HQ and arbitrary wav set, for the MDX challenge. For more information
70
+ on joining the challenge with Demucs see [the Demucs MDX instructions](docs/mdx.md)
71
+
72
+
73
+ ## Comparison with other models
74
+
75
+ We provide hereafter a summary of the different metrics presented in the paper.
76
+ You can also compare Hybrid Demucs (v3), [KUIELAB-MDX-Net][kuielab], [Spleeter][spleeter], Open-Unmix, Demucs (v1), and Conv-Tasnet on one of my favorite
77
+ songs on my [soundcloud playlist][soundcloud].
78
+
79
+ ### Comparison of accuracy
80
+
81
+ `Overall SDR` is the mean of the SDR for each of the 4 sources, `MOS Quality` is a rating from 1 to 5
82
+ of the naturalness and absence of artifacts given by human listeners (5 = no artifacts), `MOS Contamination`
83
+ is a rating from 1 to 5 with 5 being zero contamination by other sources. We refer the reader to our [paper][hybrid_paper],
84
+ for more details.
85
+
86
+ | Model | Domain | Extra data? | Overall SDR | MOS Quality | MOS Contamination |
87
+ |------------------------------|-------------|-------------------|-------------|-------------|-------------------|
88
+ | [Wave-U-Net][waveunet] | waveform | no | 3.2 | - | - |
89
+ | [Open-Unmix][openunmix] | spectrogram | no | 5.3 | - | - |
90
+ | [D3Net][d3net] | spectrogram | no | 6.0 | - | - |
91
+ | [Conv-Tasnet][demucs_v2] | waveform | no | 5.7 | - | |
92
+ | [Demucs (v2)][demucs_v2] | waveform | no | 6.3 | 2.37 | 2.36 |
93
+ | [ResUNetDecouple+][decouple] | spectrogram | no | 6.7 | - | - |
94
+ | [KUIELAB-MDX-Net][kuielab] | hybrid | no | 7.5 | **2.86** | 2.55 |
95
+ | [Band-Spit RNN][bandsplit] | spectrogram | no | **8.2** | - | - |
96
+ | **Hybrid Demucs (v3)** | hybrid | no | 7.7 | **2.83** | **3.04** |
97
+ | [MMDenseLSTM][mmdenselstm] | spectrogram | 804 songs | 6.0 | - | - |
98
+ | [D3Net][d3net] | spectrogram | 1.5k songs | 6.7 | - | - |
99
+ | [Spleeter][spleeter] | spectrogram | 25k songs | 5.9 | - | - |
100
+ | [Band-Spit RNN][bandsplit] | spectrogram | 1.7k (mixes only) | **9.0** | - | - |
101
+ | **HT Demucs f.t. (v4)** | hybrid | 800 songs | **9.0** | - | - |
102
+
103
+
104
+
105
+ ## Requirements
106
+
107
+ You will need at least Python 3.8. See `requirements_minimal.txt` for requirements for separation only,
108
+ and `environment-[cpu|cuda].yml` (or `requirements.txt`) if you want to train a new model.
109
+
110
+ ### For Windows users
111
+
112
+ Everytime you see `python3`, replace it with `python.exe`. You should always run commands from the
113
+ Anaconda console.
114
+
115
+ ### For musicians
116
+
117
+ If you just want to use Demucs to separate tracks, you can install it with
118
+
119
+ ```bash
120
+ python3 -m pip install -U demucs
121
+ ```
122
+
123
+ For bleeding edge versions, you can install directly from this repo using
124
+ ```bash
125
+ python3 -m pip install -U git+https://github.com/facebookresearch/demucs#egg=demucs
126
+ ```
127
+
128
+ Advanced OS support are provided on the following page, **you must read the page for your OS before posting an issues**:
129
+ - **If you are using Windows:** [Windows support](docs/windows.md).
130
+ - **If you are using macOS:** [macOS support](docs/mac.md).
131
+ - **If you are using Linux:** [Linux support](docs/linux.md).
132
+
133
+ ### For machine learning scientists
134
+
135
+ If you have anaconda installed, you can run from the root of this repository:
136
+
137
+ ```bash
138
+ conda env update -f environment-cpu.yml # if you don't have GPUs
139
+ conda env update -f environment-cuda.yml # if you have GPUs
140
+ conda activate demucs
141
+ pip install -e .
142
+ ```
143
+
144
+ This will create a `demucs` environment with all the dependencies installed.
145
+
146
+ You will also need to install [soundstretch/soundtouch](https://www.surina.net/soundtouch/soundstretch.html): on macOS you can do `brew install sound-touch`,
147
+ and on Ubuntu `sudo apt-get install soundstretch`. This is used for the
148
+ pitch/tempo augmentation.
149
+
150
+
151
+ ### Running in Docker
152
+
153
+ Thanks to @xserrat, there is now a Docker image definition ready for using Demucs. This can ensure all libraries are correctly installed without interfering with the host OS. See his repo [Docker Facebook Demucs](https://github.com/xserrat/docker-facebook-demucs) for more information.
154
+
155
+
156
+ ### Running from Colab
157
+
158
+ I made a Colab to easily separate track with Demucs. Note that
159
+ transfer speeds with Colab are a bit slow for large media files,
160
+ but it will allow you to use Demucs without installing anything.
161
+
162
+ [Demucs on Google Colab](https://colab.research.google.com/drive/1dC9nVxk3V_VPjUADsnFu8EiT-xnU1tGH?usp=sharing)
163
+
164
+ ### Web Demo
165
+
166
+ Integrated to [Hugging Face Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See demo: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/akhaliq/demucs)
167
+
168
+ ### Graphical Interface
169
+
170
+ @CarlGao4 has released a GUI for Demucs: [CarlGao4/Demucs-Gui](https://github.com/CarlGao4/Demucs-Gui). Downloads for Windows and macOS is available [here](https://github.com/CarlGao4/Demucs-Gui/releases). Use [FossHub mirror](https://fosshub.com/Demucs-GUI.html) to speed up your download.
171
+
172
+ @Anjok07 is providing a self contained GUI in [UVR (Ultimate Vocal Remover)](https://github.com/facebookresearch/demucs/issues/334) that supports Demucs.
173
+
174
+ ### Other providers
175
+
176
+ Audiostrip is providing free online separation with Demucs on their website [https://audiostrip.co.uk/](https://audiostrip.co.uk/).
177
+
178
+ [MVSep](https://mvsep.com/) also provides free online separation, select `Demucs3 model B` for the best quality.
179
+
180
+ [Neutone](https://neutone.space/) provides a realtime Demucs model in their free VST/AU plugin that can be used in your favorite DAW.
181
+
182
+
183
+ ## Separating tracks
184
+
185
+ In order to try Demucs, you can just run from any folder (as long as you properly installed it)
186
+
187
+ ```bash
188
+ demucs PATH_TO_AUDIO_FILE_1 [PATH_TO_AUDIO_FILE_2 ...] # for Demucs
189
+ # If you used `pip install --user` you might need to replace demucs with python3 -m demucs
190
+ python3 -m demucs --mp3 --mp3-bitrate BITRATE PATH_TO_AUDIO_FILE_1 # output files saved as MP3
191
+ # use --mp3-preset to change encoder preset, 2 for best quality, 7 for fastest
192
+ # If your filename contain spaces don't forget to quote it !!!
193
+ demucs "my music/my favorite track.mp3"
194
+ # You can select different models with `-n` mdx_q is the quantized model, smaller but maybe a bit less accurate.
195
+ demucs -n mdx_q myfile.mp3
196
+ # If you only want to separate vocals out of an audio, use `--two-stems=vocals` (You can also set to drums or bass)
197
+ demucs --two-stems=vocals myfile.mp3
198
+ ```
199
+
200
+
201
+ If you have a GPU, but you run out of memory, please use `--segment SEGMENT` to reduce length of each split. `SEGMENT` should be changed to a integer describing the length of each segment in seconds.
202
+ A segment length of at least 10 is recommended (the bigger the number is, the more memory is required, but quality may increase). Note that the Hybrid Transformer models only support a maximum segment length of 7.8 seconds.
203
+ Creating an environment variable `PYTORCH_NO_CUDA_MEMORY_CACHING=1` is also helpful. If this still does not help, please add `-d cpu` to the command line. See the section hereafter for more details on the memory requirements for GPU acceleration.
204
+
205
+ Separated tracks are stored in the `separated/MODEL_NAME/TRACK_NAME` folder. There you will find four stereo wav files sampled at 44.1 kHz: `drums.wav`, `bass.wav`,
206
+ `other.wav`, `vocals.wav` (or `.mp3` if you used the `--mp3` option).
207
+
208
+ All audio formats supported by `torchaudio` can be processed (i.e. wav, mp3, flac, ogg/vorbis on Linux/macOS, etc.). On Windows, `torchaudio` has limited support, so we rely on `ffmpeg`, which should support pretty much anything.
209
+ Audio is resampled on the fly if necessary.
210
+ The output will be a wav file encoded as int16.
211
+ You can save as float32 wav files with `--float32`, or 24 bits integer wav with `--int24`.
212
+ You can pass `--mp3` to save as mp3 instead, and set the bitrate (in kbps) with `--mp3-bitrate` (default is 320).
213
+
214
+ It can happen that the output would need clipping, in particular due to some separation artifacts.
215
+ Demucs will automatically rescale each output stem so as to avoid clipping. This can however break
216
+ the relative volume between stems. If instead you prefer hard clipping, pass `--clip-mode clamp`.
217
+ You can also try to reduce the volume of the input mixture before feeding it to Demucs.
218
+
219
+
220
+ Other pre-trained models can be selected with the `-n` flag.
221
+ The list of pre-trained models is:
222
+ - `htdemucs`: first version of Hybrid Transformer Demucs. Trained on MusDB + 800 songs. Default model.
223
+ - `htdemucs_ft`: fine-tuned version of `htdemucs`, separation will take 4 times more time
224
+ but might be a bit better. Same training set as `htdemucs`.
225
+ - `htdemucs_6s`: 6 sources version of `htdemucs`, with `piano` and `guitar` being added as sources.
226
+ Note that the `piano` source is not working great at the moment.
227
+ - `hdemucs_mmi`: Hybrid Demucs v3, retrained on MusDB + 800 songs.
228
+ - `mdx`: trained only on MusDB HQ, winning model on track A at the [MDX][mdx] challenge.
229
+ - `mdx_extra`: trained with extra training data (**including MusDB test set**), ranked 2nd on the track B
230
+ of the [MDX][mdx] challenge.
231
+ - `mdx_q`, `mdx_extra_q`: quantized version of the previous models. Smaller download and storage
232
+ but quality can be slightly worse.
233
+ - `SIG`: where `SIG` is a single model from the [model zoo](docs/training.md#model-zoo).
234
+
235
+ The `--two-stems=vocals` option allows separating vocals from the rest of the accompaniment (i.e., karaoke mode).
236
+ `vocals` can be changed to any source in the selected model.
237
+ This will mix the files after separating the mix fully, so this won't be faster or use less memory.
238
+
239
+ The `--shifts=SHIFTS` performs multiple predictions with random shifts (a.k.a the *shift trick*) of the input and average them. This makes prediction `SHIFTS` times
240
+ slower. Don't use it unless you have a GPU.
241
+
242
+ The `--overlap` option controls the amount of overlap between prediction windows. Default is 0.25 (i.e. 25%) which is probably fine.
243
+ It can probably be reduced to 0.1 to improve a bit speed.
244
+
245
+
246
+ The `-j` flag allow to specify a number of parallel jobs (e.g. `demucs -j 2 myfile.mp3`).
247
+ This will multiply by the same amount the RAM used so be careful!
248
+
249
+ ### Memory requirements for GPU acceleration
250
+
251
+ If you want to use GPU acceleration, you will need at least 3GB of RAM on your GPU for `demucs`. However, about 7GB of RAM will be required if you use the default arguments. Add `--segment SEGMENT` to change size of each split. If you only have 3GB memory, set SEGMENT to 8 (though quality may be worse if this argument is too small). Creating an environment variable `PYTORCH_NO_CUDA_MEMORY_CACHING=1` can help users with even smaller RAM such as 2GB (I separated a track that is 4 minutes but only 1.5GB is used), but this would make the separation slower.
252
+
253
+ If you do not have enough memory on your GPU, simply add `-d cpu` to the command line to use the CPU. With Demucs, processing time should be roughly equal to 1.5 times the duration of the track.
254
+
255
+ ## Calling from another Python program
256
+
257
+ The main function provides an `opt` parameter as a simple API. You can just pass the parsed command line as this parameter:
258
+ ```python
259
+ # Assume that your command is `demucs --mp3 --two-stems vocals -n mdx_extra "track with space.mp3"`
260
+ # The following codes are same as the command above:
261
+ import demucs.separate
262
+ demucs.separate.main(["--mp3", "--two-stems", "vocals", "-n", "mdx_extra", "track with space.mp3"])
263
+
264
+ # Or like this
265
+ import demucs.separate
266
+ import shlex
267
+ demucs.separate.main(shlex.split('--mp3 --two-stems vocals -n mdx_extra "track with space.mp3"'))
268
+ ```
269
+
270
+ To use more complicated APIs, see [API docs](docs/api.md)
271
+
272
+ ## Training Demucs
273
+
274
+ If you want to train (Hybrid) Demucs, please follow the [training doc](docs/training.md).
275
+
276
+ ## MDX Challenge reproduction
277
+
278
+ In order to reproduce the results from the Track A and Track B submissions, checkout the [MDX Hybrid Demucs submission repo][mdx_submission].
279
+
280
+
281
+
282
+ ## How to cite
283
+
284
+ ```
285
+ @inproceedings{rouard2022hybrid,
286
+ title={Hybrid Transformers for Music Source Separation},
287
+ author={Rouard, Simon and Massa, Francisco and D{\'e}fossez, Alexandre},
288
+ booktitle={ICASSP 23},
289
+ year={2023}
290
+ }
291
+
292
+ @inproceedings{defossez2021hybrid,
293
+ title={Hybrid Spectrogram and Waveform Source Separation},
294
+ author={D{\'e}fossez, Alexandre},
295
+ booktitle={Proceedings of the ISMIR 2021 Workshop on Music Source Separation},
296
+ year={2021}
297
+ }
298
+ ```
299
+
300
+ ## License
301
+
302
+ Demucs is released under the MIT license as found in the [LICENSE](LICENSE) file.
303
+
304
+ [hybrid_paper]: https://arxiv.org/abs/2111.03600
305
+ [waveunet]: https://github.com/f90/Wave-U-Net
306
+ [musdb]: https://sigsep.github.io/datasets/musdb.html
307
+ [openunmix]: https://github.com/sigsep/open-unmix-pytorch
308
+ [mmdenselstm]: https://arxiv.org/abs/1805.02410
309
+ [demucs_v2]: https://github.com/facebookresearch/demucs/tree/v2
310
+ [demucs_v3]: https://github.com/facebookresearch/demucs/tree/v3
311
+ [spleeter]: https://github.com/deezer/spleeter
312
+ [soundcloud]: https://soundcloud.com/honualx/sets/source-separation-in-the-waveform-domain
313
+ [d3net]: https://arxiv.org/abs/2010.01733
314
+ [mdx]: https://www.aicrowd.com/challenges/music-demixing-challenge-ismir-2021
315
+ [kuielab]: https://github.com/kuielab/mdx-net-submission
316
+ [decouple]: https://arxiv.org/abs/2109.05418
317
+ [mdx_submission]: https://github.com/adefossez/mdx21_demucs
318
+ [bandsplit]: https://arxiv.org/abs/2209.15174
319
+ [htdemucs]: https://arxiv.org/abs/2211.08553
conf/config.yaml ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - dset: musdb44
4
+ - svd: default
5
+ - variant: default
6
+ - override hydra/hydra_logging: colorlog
7
+ - override hydra/job_logging: colorlog
8
+
9
+ dummy:
10
+ dset:
11
+ musdb: /checkpoint/defossez/datasets/musdbhq
12
+ musdb_samplerate: 44100
13
+ use_musdb: true # set to false to not use musdb as training data.
14
+ wav: # path to custom wav dataset
15
+ wav2: # second custom wav dataset
16
+ segment: 11
17
+ shift: 1
18
+ train_valid: false
19
+ full_cv: true
20
+ samplerate: 44100
21
+ channels: 2
22
+ normalize: true
23
+ metadata: ./metadata
24
+ sources: ['drums', 'bass', 'other', 'vocals']
25
+ valid_samples: # valid dataset size
26
+ backend: null # if provided select torchaudio backend.
27
+
28
+ test:
29
+ save: False
30
+ best: True
31
+ workers: 2
32
+ every: 20
33
+ split: true
34
+ shifts: 1
35
+ overlap: 0.25
36
+ sdr: true
37
+ metric: 'loss' # metric used for best model selection on the valid set, can also be nsdr
38
+ nonhq: # path to non hq MusDB for evaluation
39
+
40
+ epochs: 360
41
+ batch_size: 64
42
+ max_batches: # limit the number of batches per epoch, useful for debugging
43
+ # or if your dataset is gigantic.
44
+ optim:
45
+ lr: 3e-4
46
+ momentum: 0.9
47
+ beta2: 0.999
48
+ loss: l1 # l1 or mse
49
+ optim: adam
50
+ weight_decay: 0
51
+ clip_grad: 0
52
+
53
+ seed: 42
54
+ debug: false
55
+ valid_apply: true
56
+ flag:
57
+ save_every:
58
+ weights: [1., 1., 1., 1.] # weights over each source for the training/valid loss.
59
+
60
+ augment:
61
+ shift_same: false
62
+ repitch:
63
+ proba: 0.2
64
+ max_tempo: 12
65
+ remix:
66
+ proba: 1
67
+ group_size: 4
68
+ scale:
69
+ proba: 1
70
+ min: 0.25
71
+ max: 1.25
72
+ flip: true
73
+
74
+ continue_from: # continue from other XP, give the XP Dora signature.
75
+ continue_pretrained: # signature of a pretrained XP, this cannot be a bag of models.
76
+ pretrained_repo: # repo for pretrained model (default is official AWS)
77
+ continue_best: true
78
+ continue_opt: false
79
+
80
+ misc:
81
+ num_workers: 10
82
+ num_prints: 4
83
+ show: false
84
+ verbose: false
85
+
86
+ # List of decay for EMA at batch or epoch level, e.g. 0.999.
87
+ # Batch level EMA are kept on GPU for speed.
88
+ ema:
89
+ epoch: []
90
+ batch: []
91
+
92
+ use_train_segment: true # to remove
93
+ model_segment: # override the segment parameter for the model, usually 4 times the training segment.
94
+ model: demucs # see demucs/train.py for the possibilities, and config for each model hereafter.
95
+ demucs: # see demucs/demucs.py for a detailed description
96
+ # Channels
97
+ channels: 64
98
+ growth: 2
99
+ # Main structure
100
+ depth: 6
101
+ rewrite: true
102
+ lstm_layers: 0
103
+ # Convolutions
104
+ kernel_size: 8
105
+ stride: 4
106
+ context: 1
107
+ # Activations
108
+ gelu: true
109
+ glu: true
110
+ # Normalization
111
+ norm_groups: 4
112
+ norm_starts: 4
113
+ # DConv residual branch
114
+ dconv_depth: 2
115
+ dconv_mode: 1 # 1 = branch in encoder, 2 = in decoder, 3 = in both.
116
+ dconv_comp: 4
117
+ dconv_attn: 4
118
+ dconv_lstm: 4
119
+ dconv_init: 1e-4
120
+ # Pre/post treatment
121
+ resample: true
122
+ normalize: false
123
+ # Weight init
124
+ rescale: 0.1
125
+
126
+ hdemucs: # see demucs/hdemucs.py for a detailed description
127
+ # Channels
128
+ channels: 48
129
+ channels_time:
130
+ growth: 2
131
+ # STFT
132
+ nfft: 4096
133
+ wiener_iters: 0
134
+ end_iters: 0
135
+ wiener_residual: false
136
+ cac: true
137
+ # Main structure
138
+ depth: 6
139
+ rewrite: true
140
+ hybrid: true
141
+ hybrid_old: false
142
+ # Frequency Branch
143
+ multi_freqs: []
144
+ multi_freqs_depth: 3
145
+ freq_emb: 0.2
146
+ emb_scale: 10
147
+ emb_smooth: true
148
+ # Convolutions
149
+ kernel_size: 8
150
+ stride: 4
151
+ time_stride: 2
152
+ context: 1
153
+ context_enc: 0
154
+ # normalization
155
+ norm_starts: 4
156
+ norm_groups: 4
157
+ # DConv residual branch
158
+ dconv_mode: 1
159
+ dconv_depth: 2
160
+ dconv_comp: 4
161
+ dconv_attn: 4
162
+ dconv_lstm: 4
163
+ dconv_init: 1e-3
164
+ # Weight init
165
+ rescale: 0.1
166
+
167
+ # Torchaudio implementation of HDemucs
168
+ torch_hdemucs:
169
+ # Channels
170
+ channels: 48
171
+ growth: 2
172
+ # STFT
173
+ nfft: 4096
174
+ # Main structure
175
+ depth: 6
176
+ freq_emb: 0.2
177
+ emb_scale: 10
178
+ emb_smooth: true
179
+ # Convolutions
180
+ kernel_size: 8
181
+ stride: 4
182
+ time_stride: 2
183
+ context: 1
184
+ context_enc: 0
185
+ # normalization
186
+ norm_starts: 4
187
+ norm_groups: 4
188
+ # DConv residual branch
189
+ dconv_depth: 2
190
+ dconv_comp: 4
191
+ dconv_attn: 4
192
+ dconv_lstm: 4
193
+ dconv_init: 1e-3
194
+
195
+ htdemucs: # see demucs/htdemucs.py for a detailed description
196
+ # Channels
197
+ channels: 48
198
+ channels_time:
199
+ growth: 2
200
+ # STFT
201
+ nfft: 4096
202
+ wiener_iters: 0
203
+ end_iters: 0
204
+ wiener_residual: false
205
+ cac: true
206
+ # Main structure
207
+ depth: 4
208
+ rewrite: true
209
+ # Frequency Branch
210
+ multi_freqs: []
211
+ multi_freqs_depth: 3
212
+ freq_emb: 0.2
213
+ emb_scale: 10
214
+ emb_smooth: true
215
+ # Convolutions
216
+ kernel_size: 8
217
+ stride: 4
218
+ time_stride: 2
219
+ context: 1
220
+ context_enc: 0
221
+ # normalization
222
+ norm_starts: 4
223
+ norm_groups: 4
224
+ # DConv residual branch
225
+ dconv_mode: 1
226
+ dconv_depth: 2
227
+ dconv_comp: 8
228
+ dconv_init: 1e-3
229
+ # Before the Transformer
230
+ bottom_channels: 0
231
+ # CrossTransformer
232
+ # ------ Common to all
233
+ # Regular parameters
234
+ t_layers: 5
235
+ t_hidden_scale: 4.0
236
+ t_heads: 8
237
+ t_dropout: 0.0
238
+ t_layer_scale: True
239
+ t_gelu: True
240
+ # ------------- Positional Embedding
241
+ t_emb: sin
242
+ t_max_positions: 10000 # for the scaled embedding
243
+ t_max_period: 10000.0
244
+ t_weight_pos_embed: 1.0
245
+ t_cape_mean_normalize: True
246
+ t_cape_augment: True
247
+ t_cape_glob_loc_scale: [5000.0, 1.0, 1.4]
248
+ t_sin_random_shift: 0
249
+ # ------------- norm before a transformer encoder
250
+ t_norm_in: True
251
+ t_norm_in_group: False
252
+ # ------------- norm inside the encoder
253
+ t_group_norm: False
254
+ t_norm_first: True
255
+ t_norm_out: True
256
+ # ------------- optim
257
+ t_weight_decay: 0.0
258
+ t_lr:
259
+ # ------------- sparsity
260
+ t_sparse_self_attn: False
261
+ t_sparse_cross_attn: False
262
+ t_mask_type: diag
263
+ t_mask_random_seed: 42
264
+ t_sparse_attn_window: 400
265
+ t_global_window: 100
266
+ t_sparsity: 0.95
267
+ t_auto_sparsity: False
268
+ # Cross Encoder First (False)
269
+ t_cross_first: False
270
+ # Weight init
271
+ rescale: 0.1
272
+
273
+ svd: # see svd.py for documentation
274
+ penalty: 0
275
+ min_size: 0.1
276
+ dim: 1
277
+ niters: 2
278
+ powm: false
279
+ proba: 1
280
+ conv_only: false
281
+ convtr: false
282
+ bs: 1
283
+
284
+ quant: # quantization hyper params
285
+ diffq: # diffq penalty, typically 1e-4 or 3e-4
286
+ qat: # use QAT with a fixed number of bits (not as good as diffq)
287
+ min_size: 0.2
288
+ group_size: 8
289
+
290
+ dora:
291
+ dir: outputs
292
+ exclude: ["misc.*", "slurm.*", 'test.reval', 'flag', 'dset.backend']
293
+
294
+ slurm:
295
+ time: 4320
296
+ constraint: volta32gb
297
+ setup: ['module load cudnn/v8.4.1.50-cuda.11.6 NCCL/2.11.4-6-cuda.11.6 cuda/11.6']
298
+
299
+ # Hydra config
300
+ hydra:
301
+ job_logging:
302
+ formatters:
303
+ colorlog:
304
+ datefmt: "%m-%d %H:%M:%S"
conf/dset/aetl.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # automix dataset with Musdb, extra training data and the test set of Musdb.
4
+ # This used even more remixes than auto_extra_test.
5
+ dset:
6
+ wav: /checkpoint/defossez/datasets/aetl
7
+ samplerate: 44100
8
+ channels: 2
9
+ epochs: 320
10
+ max_batches: 500
11
+
12
+ augment:
13
+ shift_same: true
14
+ scale:
15
+ proba: 0.
16
+ remix:
17
+ proba: 0
18
+ repitch:
19
+ proba: 0
conf/dset/auto_extra_test.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # automix dataset with Musdb, extra training data and the test set of Musdb.
4
+ dset:
5
+ wav: /checkpoint/defossez/datasets/automix_extra_test2
6
+ samplerate: 44100
7
+ channels: 2
8
+ epochs: 320
9
+ max_batches: 500
10
+
11
+ augment:
12
+ shift_same: true
13
+ scale:
14
+ proba: 0.
15
+ remix:
16
+ proba: 0
17
+ repitch:
18
+ proba: 0
conf/dset/auto_mus.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Automix dataset based on musdb train set.
4
+ dset:
5
+ wav: /checkpoint/defossez/datasets/automix_musdb
6
+ samplerate: 44100
7
+ channels: 2
8
+ epochs: 360
9
+ max_batches: 300
10
+ test:
11
+ every: 4
12
+
13
+ augment:
14
+ shift_same: true
15
+ scale:
16
+ proba: 0.5
17
+ remix:
18
+ proba: 0
19
+ repitch:
20
+ proba: 0
conf/dset/extra44.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Musdb + extra tracks
4
+ dset:
5
+ wav: /checkpoint/defossez/datasets/allstems_44/
6
+ samplerate: 44100
7
+ channels: 2
8
+ epochs: 320
conf/dset/extra_mmi_goodclean.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Musdb + extra tracks
4
+ dset:
5
+ wav: /checkpoint/defossez/datasets/allstems_44/
6
+ wav2: /checkpoint/defossez/datasets/mmi44_goodclean
7
+ samplerate: 44100
8
+ channels: 2
9
+ wav2_weight: null
10
+ wav2_valid: false
11
+ valid_samples: 100
12
+ epochs: 1200
conf/dset/extra_test.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Musdb + extra tracks + test set from musdb.
4
+ dset:
5
+ wav: /checkpoint/defossez/datasets/allstems_test_44/
6
+ samplerate: 44100
7
+ channels: 2
8
+ epochs: 320
9
+ max_batches: 700
10
+ test:
11
+ sdr: false
12
+ every: 500
conf/dset/musdb44.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ dset:
4
+ samplerate: 44100
5
+ channels: 2
conf/dset/sdx23_bleeding.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Musdb + extra tracks
4
+ dset:
5
+ wav: /shared/home/defossez/data/datasets/moisesdb23_bleeding_v1.0/
6
+ use_musdb: false
7
+ samplerate: 44100
8
+ channels: 2
9
+ backend: soundfile # must use soundfile as some mixture would clip with sox.
10
+ epochs: 320
conf/dset/sdx23_labelnoise.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Musdb + extra tracks
4
+ dset:
5
+ wav: /shared/home/defossez/data/datasets/moisesdb23_labelnoise_v1.0
6
+ use_musdb: false
7
+ samplerate: 44100
8
+ channels: 2
9
+ backend: soundfile # must use soundfile as some mixture would clip with sox.
10
+ epochs: 320
conf/svd/base.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ svd:
4
+ penalty: 0
5
+ min_size: 1
6
+ dim: 50
7
+ niters: 4
8
+ powm: false
9
+ proba: 1
10
+ conv_only: false
11
+ convtr: false # ideally this should be true, but some models were trained with this to false.
12
+
13
+ optim:
14
+ beta2: 0.9998
conf/svd/base2.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ svd:
4
+ penalty: 0
5
+ min_size: 1
6
+ dim: 100
7
+ niters: 4
8
+ powm: false
9
+ proba: 1
10
+ conv_only: false
11
+ convtr: true
12
+
13
+ optim:
14
+ beta2: 0.9998
conf/svd/default.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ # @package _global_
conf/variant/default.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ # @package _global_
conf/variant/example.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ model: hdemucs
4
+ hdemucs:
5
+ channels: 32
conf/variant/finetune.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ epochs: 4
4
+ batch_size: 16
5
+ optim:
6
+ lr: 0.0006
7
+ test:
8
+ every: 1
9
+ sdr: false
10
+ dset:
11
+ segment: 28
12
+ shift: 2
13
+
14
+ augment:
15
+ scale:
16
+ proba: 0
17
+ shift_same: true
18
+ remix:
19
+ proba: 0
demucs.png ADDED

Git LFS Details

  • SHA256: 7f8a53c1bbaa6c0268d358cd4cb9c2f1128907758aeb10a79789f7bbf61ded95
  • Pointer size: 131 Bytes
  • Size of remote file: 339 kB
demucs/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ __version__ = "4.1.0a2"
demucs/__main__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .separate import main
8
+
9
+ if __name__ == '__main__':
10
+ main()
demucs/api.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """API methods for demucs
8
+
9
+ Classes
10
+ -------
11
+ `demucs.api.Separator`: The base separator class
12
+
13
+ Functions
14
+ ---------
15
+ `demucs.api.save_audio`: Save an audio
16
+ `demucs.api.list_models`: Get models list
17
+
18
+ Examples
19
+ --------
20
+ See the end of this module (if __name__ == "__main__")
21
+ """
22
+
23
+ import subprocess
24
+
25
+ import torch as th
26
+ import torchaudio as ta
27
+
28
+ from dora.log import fatal
29
+ from pathlib import Path
30
+ from typing import Optional, Callable, Dict, Tuple, Union
31
+
32
+ from .apply import apply_model, _replace_dict
33
+ from .audio import AudioFile, convert_audio, save_audio
34
+ from .pretrained import get_model, _parse_remote_files, REMOTE_ROOT
35
+ from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo
36
+
37
+
38
+ class LoadAudioError(Exception):
39
+ pass
40
+
41
+
42
+ class LoadModelError(Exception):
43
+ pass
44
+
45
+
46
+ class _NotProvided:
47
+ pass
48
+
49
+
50
+ NotProvided = _NotProvided()
51
+
52
+
53
+ class Separator:
54
+ def __init__(
55
+ self,
56
+ model: str = "htdemucs",
57
+ repo: Optional[Path] = None,
58
+ device: str = "cuda" if th.cuda.is_available() else "cpu",
59
+ shifts: int = 1,
60
+ overlap: float = 0.25,
61
+ split: bool = True,
62
+ segment: Optional[int] = None,
63
+ jobs: int = 0,
64
+ progress: bool = False,
65
+ callback: Optional[Callable[[dict], None]] = None,
66
+ callback_arg: Optional[dict] = None,
67
+ ):
68
+ """
69
+ `class Separator`
70
+ =================
71
+
72
+ Parameters
73
+ ----------
74
+ model: Pretrained model name or signature. Default is htdemucs.
75
+ repo: Folder containing all pre-trained models for use.
76
+ segment: Length (in seconds) of each segment (only available if `split` is `True`). If \
77
+ not specified, will use the command line option.
78
+ shifts: If > 0, will shift in time `wav` by a random amount between 0 and 0.5 sec and \
79
+ apply the oppositve shift to the output. This is repeated `shifts` time and all \
80
+ predictions are averaged. This effectively makes the model time equivariant and \
81
+ improves SDR by up to 0.2 points. If not specified, will use the command line option.
82
+ split: If True, the input will be broken down into small chunks (length set by `segment`) \
83
+ and predictions will be performed individually on each and concatenated. Useful for \
84
+ model with large memory footprint like Tasnet. If not specified, will use the command \
85
+ line option.
86
+ overlap: The overlap between the splits. If not specified, will use the command line \
87
+ option.
88
+ device (torch.device, str, or None): If provided, device on which to execute the \
89
+ computation, otherwise `wav.device` is assumed. When `device` is different from \
90
+ `wav.device`, only local computations will be on `device`, while the entire tracks \
91
+ will be stored on `wav.device`. If not specified, will use the command line option.
92
+ jobs: Number of jobs. This can increase memory usage but will be much faster when \
93
+ multiple cores are available. If not specified, will use the command line option.
94
+ callback: A function will be called when the separation of a chunk starts or finished. \
95
+ The argument passed to the function will be a dict. For more information, please see \
96
+ the Callback section.
97
+ callback_arg: A dict containing private parameters to be passed to callback function. For \
98
+ more information, please see the Callback section.
99
+ progress: If true, show a progress bar.
100
+
101
+ Callback
102
+ --------
103
+ The function will be called with only one positional parameter whose type is `dict`. The
104
+ `callback_arg` will be combined with information of current separation progress. The
105
+ progress information will override the values in `callback_arg` if same key has been used.
106
+ To abort the separation, raise `KeyboardInterrupt`.
107
+
108
+ Progress information contains several keys (These keys will always exist):
109
+ - `model_idx_in_bag`: The index of the submodel in `BagOfModels`. Starts from 0.
110
+ - `shift_idx`: The index of shifts. Starts from 0.
111
+ - `segment_offset`: The offset of current segment. If the number is 441000, it doesn't
112
+ mean that it is at the 441000 second of the audio, but the "frame" of the tensor.
113
+ - `state`: Could be `"start"` or `"end"`.
114
+ - `audio_length`: Length of the audio (in "frame" of the tensor).
115
+ - `models`: Count of submodels in the model.
116
+ """
117
+ self._name = model
118
+ self._repo = repo
119
+ self._load_model()
120
+ self.update_parameter(device=device, shifts=shifts, overlap=overlap, split=split,
121
+ segment=segment, jobs=jobs, progress=progress, callback=callback,
122
+ callback_arg=callback_arg)
123
+
124
+ def update_parameter(
125
+ self,
126
+ device: Union[str, _NotProvided] = NotProvided,
127
+ shifts: Union[int, _NotProvided] = NotProvided,
128
+ overlap: Union[float, _NotProvided] = NotProvided,
129
+ split: Union[bool, _NotProvided] = NotProvided,
130
+ segment: Optional[Union[int, _NotProvided]] = NotProvided,
131
+ jobs: Union[int, _NotProvided] = NotProvided,
132
+ progress: Union[bool, _NotProvided] = NotProvided,
133
+ callback: Optional[
134
+ Union[Callable[[dict], None], _NotProvided]
135
+ ] = NotProvided,
136
+ callback_arg: Optional[Union[dict, _NotProvided]] = NotProvided,
137
+ ):
138
+ """
139
+ Update the parameters of separation.
140
+
141
+ Parameters
142
+ ----------
143
+ segment: Length (in seconds) of each segment (only available if `split` is `True`). If \
144
+ not specified, will use the command line option.
145
+ shifts: If > 0, will shift in time `wav` by a random amount between 0 and 0.5 sec and \
146
+ apply the oppositve shift to the output. This is repeated `shifts` time and all \
147
+ predictions are averaged. This effectively makes the model time equivariant and \
148
+ improves SDR by up to 0.2 points. If not specified, will use the command line option.
149
+ split: If True, the input will be broken down into small chunks (length set by `segment`) \
150
+ and predictions will be performed individually on each and concatenated. Useful for \
151
+ model with large memory footprint like Tasnet. If not specified, will use the command \
152
+ line option.
153
+ overlap: The overlap between the splits. If not specified, will use the command line \
154
+ option.
155
+ device (torch.device, str, or None): If provided, device on which to execute the \
156
+ computation, otherwise `wav.device` is assumed. When `device` is different from \
157
+ `wav.device`, only local computations will be on `device`, while the entire tracks \
158
+ will be stored on `wav.device`. If not specified, will use the command line option.
159
+ jobs: Number of jobs. This can increase memory usage but will be much faster when \
160
+ multiple cores are available. If not specified, will use the command line option.
161
+ callback: A function will be called when the separation of a chunk starts or finished. \
162
+ The argument passed to the function will be a dict. For more information, please see \
163
+ the Callback section.
164
+ callback_arg: A dict containing private parameters to be passed to callback function. For \
165
+ more information, please see the Callback section.
166
+ progress: If true, show a progress bar.
167
+
168
+ Callback
169
+ --------
170
+ The function will be called with only one positional parameter whose type is `dict`. The
171
+ `callback_arg` will be combined with information of current separation progress. The
172
+ progress information will override the values in `callback_arg` if same key has been used.
173
+ To abort the separation, raise `KeyboardInterrupt`.
174
+
175
+ Progress information contains several keys (These keys will always exist):
176
+ - `model_idx_in_bag`: The index of the submodel in `BagOfModels`. Starts from 0.
177
+ - `shift_idx`: The index of shifts. Starts from 0.
178
+ - `segment_offset`: The offset of current segment. If the number is 441000, it doesn't
179
+ mean that it is at the 441000 second of the audio, but the "frame" of the tensor.
180
+ - `state`: Could be `"start"` or `"end"`.
181
+ - `audio_length`: Length of the audio (in "frame" of the tensor).
182
+ - `models`: Count of submodels in the model.
183
+ """
184
+ if not isinstance(device, _NotProvided):
185
+ self._device = device
186
+ if not isinstance(shifts, _NotProvided):
187
+ self._shifts = shifts
188
+ if not isinstance(overlap, _NotProvided):
189
+ self._overlap = overlap
190
+ if not isinstance(split, _NotProvided):
191
+ self._split = split
192
+ if not isinstance(segment, _NotProvided):
193
+ self._segment = segment
194
+ if not isinstance(jobs, _NotProvided):
195
+ self._jobs = jobs
196
+ if not isinstance(progress, _NotProvided):
197
+ self._progress = progress
198
+ if not isinstance(callback, _NotProvided):
199
+ self._callback = callback
200
+ if not isinstance(callback_arg, _NotProvided):
201
+ self._callback_arg = callback_arg
202
+
203
+ def _load_model(self):
204
+ self._model = get_model(name=self._name, repo=self._repo)
205
+ if self._model is None:
206
+ raise LoadModelError("Failed to load model")
207
+ self._audio_channels = self._model.audio_channels
208
+ self._samplerate = self._model.samplerate
209
+
210
+ def _load_audio(self, track: Path):
211
+ errors = {}
212
+ wav = None
213
+
214
+ try:
215
+ wav = AudioFile(track).read(streams=0, samplerate=self._samplerate,
216
+ channels=self._audio_channels)
217
+ except FileNotFoundError:
218
+ errors["ffmpeg"] = "FFmpeg is not installed."
219
+ except subprocess.CalledProcessError:
220
+ errors["ffmpeg"] = "FFmpeg could not read the file."
221
+
222
+ if wav is None:
223
+ try:
224
+ wav, sr = ta.load(str(track))
225
+ except RuntimeError as err:
226
+ errors["torchaudio"] = err.args[0]
227
+ else:
228
+ wav = convert_audio(wav, sr, self._samplerate, self._audio_channels)
229
+
230
+ if wav is None:
231
+ raise LoadAudioError(
232
+ "\n".join(
233
+ "When trying to load using {}, got the following error: {}".format(
234
+ backend, error
235
+ )
236
+ for backend, error in errors.items()
237
+ )
238
+ )
239
+ return wav
240
+
241
+ def separate_tensor(
242
+ self, wav: th.Tensor, sr: Optional[int] = None
243
+ ) -> Tuple[th.Tensor, Dict[str, th.Tensor]]:
244
+ """
245
+ Separate a loaded tensor.
246
+
247
+ Parameters
248
+ ----------
249
+ wav: Waveform of the audio. Should have 2 dimensions, the first is each audio channel, \
250
+ while the second is the waveform of each channel. Type should be float32. \
251
+ e.g. `tuple(wav.shape) == (2, 884000)` means the audio has 2 channels.
252
+ sr: Sample rate of the original audio, the wave will be resampled if it doesn't match the \
253
+ model.
254
+
255
+ Returns
256
+ -------
257
+ A tuple, whose first element is the original wave and second element is a dict, whose keys
258
+ are the name of stems and values are separated waves. The original wave will have already
259
+ been resampled.
260
+
261
+ Notes
262
+ -----
263
+ Use this function with cautiousness. This function does not provide data verifying.
264
+ """
265
+ if sr is not None and sr != self.samplerate:
266
+ wav = convert_audio(wav, sr, self._samplerate, self._audio_channels)
267
+ ref = wav.mean(0)
268
+ wav -= ref.mean()
269
+ wav /= ref.std() + 1e-8
270
+ out = apply_model(
271
+ self._model,
272
+ wav[None],
273
+ segment=self._segment,
274
+ shifts=self._shifts,
275
+ split=self._split,
276
+ overlap=self._overlap,
277
+ device=self._device,
278
+ num_workers=self._jobs,
279
+ callback=self._callback,
280
+ callback_arg=_replace_dict(
281
+ self._callback_arg, ("audio_length", wav.shape[1])
282
+ ),
283
+ progress=self._progress,
284
+ )
285
+ if out is None:
286
+ raise KeyboardInterrupt
287
+ out *= ref.std() + 1e-8
288
+ out += ref.mean()
289
+ wav *= ref.std() + 1e-8
290
+ wav += ref.mean()
291
+ return (wav, dict(zip(self._model.sources, out[0])))
292
+
293
+ def separate_audio_file(self, file: Path):
294
+ """
295
+ Separate an audio file. The method will automatically read the file.
296
+
297
+ Parameters
298
+ ----------
299
+ wav: Path of the file to be separated.
300
+
301
+ Returns
302
+ -------
303
+ A tuple, whose first element is the original wave and second element is a dict, whose keys
304
+ are the name of stems and values are separated waves. The original wave will have already
305
+ been resampled.
306
+ """
307
+ return self.separate_tensor(self._load_audio(file), self.samplerate)
308
+
309
+ @property
310
+ def samplerate(self):
311
+ return self._samplerate
312
+
313
+ @property
314
+ def audio_channels(self):
315
+ return self._audio_channels
316
+
317
+ @property
318
+ def model(self):
319
+ return self._model
320
+
321
+
322
+ def list_models(repo: Optional[Path] = None) -> Dict[str, Dict[str, Union[str, Path]]]:
323
+ """
324
+ List the available models. Please remember that not all the returned models can be
325
+ successfully loaded.
326
+
327
+ Parameters
328
+ ----------
329
+ repo: The repo whose models are to be listed.
330
+
331
+ Returns
332
+ -------
333
+ A dict with two keys ("single" for single models and "bag" for bag of models). The values are
334
+ lists whose components are strs.
335
+ """
336
+ model_repo: ModelOnlyRepo
337
+ if repo is None:
338
+ models = _parse_remote_files(REMOTE_ROOT / 'files.txt')
339
+ model_repo = RemoteRepo(models)
340
+ bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo)
341
+ else:
342
+ if not repo.is_dir():
343
+ fatal(f"{repo} must exist and be a directory.")
344
+ model_repo = LocalRepo(repo)
345
+ bag_repo = BagOnlyRepo(repo, model_repo)
346
+ return {"single": model_repo.list_model(), "bag": bag_repo.list_model()}
347
+
348
+
349
+ if __name__ == "__main__":
350
+ # Test API functions
351
+ # two-stem not supported
352
+
353
+ from .separate import get_parser
354
+
355
+ args = get_parser().parse_args()
356
+ separator = Separator(
357
+ model=args.name,
358
+ repo=args.repo,
359
+ device=args.device,
360
+ shifts=args.shifts,
361
+ overlap=args.overlap,
362
+ split=args.split,
363
+ segment=args.segment,
364
+ jobs=args.jobs,
365
+ callback=print
366
+ )
367
+ out = args.out / args.name
368
+ out.mkdir(parents=True, exist_ok=True)
369
+ for file in args.tracks:
370
+ separated = separator.separate_audio_file(file)[1]
371
+ if args.mp3:
372
+ ext = "mp3"
373
+ elif args.flac:
374
+ ext = "flac"
375
+ else:
376
+ ext = "wav"
377
+ kwargs = {
378
+ "samplerate": separator.samplerate,
379
+ "bitrate": args.mp3_bitrate,
380
+ "clip": args.clip_mode,
381
+ "as_float": args.float32,
382
+ "bits_per_sample": 24 if args.int24 else 16,
383
+ }
384
+ for stem, source in separated.items():
385
+ stem = out / args.filename.format(
386
+ track=Path(file).name.rsplit(".", 1)[0],
387
+ trackext=Path(file).name.rsplit(".", 1)[-1],
388
+ stem=stem,
389
+ ext=ext,
390
+ )
391
+ stem.parent.mkdir(parents=True, exist_ok=True)
392
+ save_audio(source, str(stem), **kwargs)
demucs/apply.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Code to apply a model to a mix. It will handle chunking with overlaps and
8
+ inteprolation between chunks, as well as the "shift trick".
9
+ """
10
+ from concurrent.futures import ThreadPoolExecutor
11
+ import copy
12
+ import random
13
+ from threading import Lock
14
+ import typing as tp
15
+
16
+ import torch as th
17
+ from torch import nn
18
+ from torch.nn import functional as F
19
+ import tqdm
20
+
21
+ from .demucs import Demucs
22
+ from .hdemucs import HDemucs
23
+ from .htdemucs import HTDemucs
24
+ from .utils import center_trim, DummyPoolExecutor
25
+
26
+ Model = tp.Union[Demucs, HDemucs, HTDemucs]
27
+
28
+
29
+ class BagOfModels(nn.Module):
30
+ def __init__(self, models: tp.List[Model],
31
+ weights: tp.Optional[tp.List[tp.List[float]]] = None,
32
+ segment: tp.Optional[float] = None):
33
+ """
34
+ Represents a bag of models with specific weights.
35
+ You should call `apply_model` rather than calling directly the forward here for
36
+ optimal performance.
37
+
38
+ Args:
39
+ models (list[nn.Module]): list of Demucs/HDemucs models.
40
+ weights (list[list[float]]): list of weights. If None, assumed to
41
+ be all ones, otherwise it should be a list of N list (N number of models),
42
+ each containing S floats (S number of sources).
43
+ segment (None or float): overrides the `segment` attribute of each model
44
+ (this is performed inplace, be careful is you reuse the models passed).
45
+ """
46
+ super().__init__()
47
+ assert len(models) > 0
48
+ first = models[0]
49
+ for other in models:
50
+ assert other.sources == first.sources
51
+ assert other.samplerate == first.samplerate
52
+ assert other.audio_channels == first.audio_channels
53
+ if segment is not None:
54
+ if not isinstance(other, HTDemucs) and segment > other.segment:
55
+ other.segment = segment
56
+
57
+ self.audio_channels = first.audio_channels
58
+ self.samplerate = first.samplerate
59
+ self.sources = first.sources
60
+ self.models = nn.ModuleList(models)
61
+
62
+ if weights is None:
63
+ weights = [[1. for _ in first.sources] for _ in models]
64
+ else:
65
+ assert len(weights) == len(models)
66
+ for weight in weights:
67
+ assert len(weight) == len(first.sources)
68
+ self.weights = weights
69
+
70
+ @property
71
+ def max_allowed_segment(self) -> float:
72
+ max_allowed_segment = float('inf')
73
+ for model in self.models:
74
+ if isinstance(model, HTDemucs):
75
+ max_allowed_segment = min(max_allowed_segment, float(model.segment))
76
+ return max_allowed_segment
77
+
78
+ def forward(self, x):
79
+ raise NotImplementedError("Call `apply_model` on this.")
80
+
81
+
82
+ class TensorChunk:
83
+ def __init__(self, tensor, offset=0, length=None):
84
+ total_length = tensor.shape[-1]
85
+ assert offset >= 0
86
+ assert offset < total_length
87
+
88
+ if length is None:
89
+ length = total_length - offset
90
+ else:
91
+ length = min(total_length - offset, length)
92
+
93
+ if isinstance(tensor, TensorChunk):
94
+ self.tensor = tensor.tensor
95
+ self.offset = offset + tensor.offset
96
+ else:
97
+ self.tensor = tensor
98
+ self.offset = offset
99
+ self.length = length
100
+ self.device = tensor.device
101
+
102
+ @property
103
+ def shape(self):
104
+ shape = list(self.tensor.shape)
105
+ shape[-1] = self.length
106
+ return shape
107
+
108
+ def padded(self, target_length):
109
+ delta = target_length - self.length
110
+ total_length = self.tensor.shape[-1]
111
+ assert delta >= 0
112
+
113
+ start = self.offset - delta // 2
114
+ end = start + target_length
115
+
116
+ correct_start = max(0, start)
117
+ correct_end = min(total_length, end)
118
+
119
+ pad_left = correct_start - start
120
+ pad_right = end - correct_end
121
+
122
+ out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right))
123
+ assert out.shape[-1] == target_length
124
+ return out
125
+
126
+
127
+ def tensor_chunk(tensor_or_chunk):
128
+ if isinstance(tensor_or_chunk, TensorChunk):
129
+ return tensor_or_chunk
130
+ else:
131
+ assert isinstance(tensor_or_chunk, th.Tensor)
132
+ return TensorChunk(tensor_or_chunk)
133
+
134
+
135
+ def _replace_dict(_dict: tp.Optional[dict], *subs: tp.Tuple[tp.Hashable, tp.Any]) -> dict:
136
+ if _dict is None:
137
+ _dict = {}
138
+ else:
139
+ _dict = copy.copy(_dict)
140
+ for key, value in subs:
141
+ _dict[key] = value
142
+ return _dict
143
+
144
+
145
+ def apply_model(model: tp.Union[BagOfModels, Model],
146
+ mix: tp.Union[th.Tensor, TensorChunk],
147
+ shifts: int = 1, split: bool = True,
148
+ overlap: float = 0.25, transition_power: float = 1.,
149
+ progress: bool = False, device=None,
150
+ num_workers: int = 0, segment: tp.Optional[float] = None,
151
+ pool=None, lock=None,
152
+ callback: tp.Optional[tp.Callable[[dict], None]] = None,
153
+ callback_arg: tp.Optional[dict] = None) -> th.Tensor:
154
+ """
155
+ Apply model to a given mixture.
156
+
157
+ Args:
158
+ shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
159
+ and apply the oppositve shift to the output. This is repeated `shifts` time and
160
+ all predictions are averaged. This effectively makes the model time equivariant
161
+ and improves SDR by up to 0.2 points.
162
+ split (bool): if True, the input will be broken down in 8 seconds extracts
163
+ and predictions will be performed individually on each and concatenated.
164
+ Useful for model with large memory footprint like Tasnet.
165
+ progress (bool): if True, show a progress bar (requires split=True)
166
+ device (torch.device, str, or None): if provided, device on which to
167
+ execute the computation, otherwise `mix.device` is assumed.
168
+ When `device` is different from `mix.device`, only local computations will
169
+ be on `device`, while the entire tracks will be stored on `mix.device`.
170
+ num_workers (int): if non zero, device is 'cpu', how many threads to
171
+ use in parallel.
172
+ segment (float or None): override the model segment parameter.
173
+ """
174
+ if device is None:
175
+ device = mix.device
176
+ else:
177
+ device = th.device(device)
178
+ if pool is None:
179
+ if num_workers > 0 and device.type == 'cpu':
180
+ pool = ThreadPoolExecutor(num_workers)
181
+ else:
182
+ pool = DummyPoolExecutor()
183
+ if lock is None:
184
+ lock = Lock()
185
+ callback_arg = _replace_dict(
186
+ callback_arg, *{"model_idx_in_bag": 0, "shift_idx": 0, "segment_offset": 0}.items()
187
+ )
188
+ kwargs: tp.Dict[str, tp.Any] = {
189
+ 'shifts': shifts,
190
+ 'split': split,
191
+ 'overlap': overlap,
192
+ 'transition_power': transition_power,
193
+ 'progress': progress,
194
+ 'device': device,
195
+ 'pool': pool,
196
+ 'segment': segment,
197
+ 'lock': lock,
198
+ }
199
+ out: tp.Union[float, th.Tensor]
200
+ res: tp.Union[float, th.Tensor]
201
+ if isinstance(model, BagOfModels):
202
+ # Special treatment for bag of model.
203
+ # We explicitely apply multiple times `apply_model` so that the random shifts
204
+ # are different for each model.
205
+ estimates: tp.Union[float, th.Tensor] = 0.
206
+ totals = [0.] * len(model.sources)
207
+ callback_arg["models"] = len(model.models)
208
+ for sub_model, model_weights in zip(model.models, model.weights):
209
+ kwargs["callback"] = ((
210
+ lambda d, i=callback_arg["model_idx_in_bag"]: callback(
211
+ _replace_dict(d, ("model_idx_in_bag", i))) if callback else None)
212
+ )
213
+ original_model_device = next(iter(sub_model.parameters())).device
214
+ sub_model.to(device)
215
+
216
+ res = apply_model(sub_model, mix, **kwargs, callback_arg=callback_arg)
217
+ out = res
218
+ sub_model.to(original_model_device)
219
+ for k, inst_weight in enumerate(model_weights):
220
+ out[:, k, :, :] *= inst_weight
221
+ totals[k] += inst_weight
222
+ estimates += out
223
+ del out
224
+ callback_arg["model_idx_in_bag"] += 1
225
+
226
+ assert isinstance(estimates, th.Tensor)
227
+ for k in range(estimates.shape[1]):
228
+ estimates[:, k, :, :] /= totals[k]
229
+ return estimates
230
+
231
+ if "models" not in callback_arg:
232
+ callback_arg["models"] = 1
233
+ model.to(device)
234
+ model.eval()
235
+ assert transition_power >= 1, "transition_power < 1 leads to weird behavior."
236
+ batch, channels, length = mix.shape
237
+ if shifts:
238
+ kwargs['shifts'] = 0
239
+ max_shift = int(0.5 * model.samplerate)
240
+ mix = tensor_chunk(mix)
241
+ assert isinstance(mix, TensorChunk)
242
+ padded_mix = mix.padded(length + 2 * max_shift)
243
+ out = 0.
244
+ for shift_idx in range(shifts):
245
+ offset = random.randint(0, max_shift)
246
+ shifted = TensorChunk(padded_mix, offset, length + max_shift - offset)
247
+ kwargs["callback"] = (
248
+ (lambda d, i=shift_idx: callback(_replace_dict(d, ("shift_idx", i)))
249
+ if callback else None)
250
+ )
251
+ res = apply_model(model, shifted, **kwargs, callback_arg=callback_arg)
252
+ shifted_out = res
253
+ out += shifted_out[..., max_shift - offset:]
254
+ out /= shifts
255
+ assert isinstance(out, th.Tensor)
256
+ return out
257
+ elif split:
258
+ kwargs['split'] = False
259
+ out = th.zeros(batch, len(model.sources), channels, length, device=mix.device)
260
+ sum_weight = th.zeros(length, device=mix.device)
261
+ if segment is None:
262
+ segment = model.segment
263
+ assert segment is not None and segment > 0.
264
+ segment_length: int = int(model.samplerate * segment)
265
+ stride = int((1 - overlap) * segment_length)
266
+ offsets = range(0, length, stride)
267
+ scale = float(format(stride / model.samplerate, ".2f"))
268
+ # We start from a triangle shaped weight, with maximal weight in the middle
269
+ # of the segment. Then we normalize and take to the power `transition_power`.
270
+ # Large values of transition power will lead to sharper transitions.
271
+ weight = th.cat([th.arange(1, segment_length // 2 + 1, device=device),
272
+ th.arange(segment_length - segment_length // 2, 0, -1, device=device)])
273
+ assert len(weight) == segment_length
274
+ # If the overlap < 50%, this will translate to linear transition when
275
+ # transition_power is 1.
276
+ weight = (weight / weight.max())**transition_power
277
+ futures = []
278
+ for offset in offsets:
279
+ chunk = TensorChunk(mix, offset, segment_length)
280
+ future = pool.submit(apply_model, model, chunk, **kwargs, callback_arg=callback_arg,
281
+ callback=(lambda d, i=offset:
282
+ callback(_replace_dict(d, ("segment_offset", i)))
283
+ if callback else None))
284
+ futures.append((future, offset))
285
+ offset += segment_length
286
+ if progress:
287
+ futures = tqdm.tqdm(futures, unit_scale=scale, ncols=120, unit='seconds')
288
+ for future, offset in futures:
289
+ try:
290
+ chunk_out = future.result() # type: th.Tensor
291
+ except Exception:
292
+ pool.shutdown(wait=True, cancel_futures=True)
293
+ raise
294
+ chunk_length = chunk_out.shape[-1]
295
+ out[..., offset:offset + segment_length] += (
296
+ weight[:chunk_length] * chunk_out).to(mix.device)
297
+ sum_weight[offset:offset + segment_length] += weight[:chunk_length].to(mix.device)
298
+ assert sum_weight.min() > 0
299
+ out /= sum_weight
300
+ assert isinstance(out, th.Tensor)
301
+ return out
302
+ else:
303
+ valid_length: int
304
+ if isinstance(model, HTDemucs) and segment is not None:
305
+ valid_length = int(segment * model.samplerate)
306
+ elif hasattr(model, 'valid_length'):
307
+ valid_length = model.valid_length(length) # type: ignore
308
+ else:
309
+ valid_length = length
310
+ mix = tensor_chunk(mix)
311
+ assert isinstance(mix, TensorChunk)
312
+ padded_mix = mix.padded(valid_length).to(device)
313
+ with lock:
314
+ if callback is not None:
315
+ callback(_replace_dict(callback_arg, ("state", "start"))) # type: ignore
316
+ with th.no_grad():
317
+ out = model(padded_mix)
318
+ with lock:
319
+ if callback is not None:
320
+ callback(_replace_dict(callback_arg, ("state", "end"))) # type: ignore
321
+ assert isinstance(out, th.Tensor)
322
+ return center_trim(out, length)
demucs/audio.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import json
7
+ import subprocess as sp
8
+ from pathlib import Path
9
+
10
+ import lameenc
11
+ import julius
12
+ import numpy as np
13
+ import torch
14
+ import torchaudio as ta
15
+ import typing as tp
16
+
17
+ from .utils import temp_filenames
18
+
19
+
20
+ def _read_info(path):
21
+ stdout_data = sp.check_output([
22
+ 'ffprobe', "-loglevel", "panic",
23
+ str(path), '-print_format', 'json', '-show_format', '-show_streams'
24
+ ])
25
+ return json.loads(stdout_data.decode('utf-8'))
26
+
27
+
28
+ class AudioFile:
29
+ """
30
+ Allows to read audio from any format supported by ffmpeg, as well as resampling or
31
+ converting to mono on the fly. See :method:`read` for more details.
32
+ """
33
+ def __init__(self, path: Path):
34
+ self.path = Path(path)
35
+ self._info = None
36
+
37
+ def __repr__(self):
38
+ features = [("path", self.path)]
39
+ features.append(("samplerate", self.samplerate()))
40
+ features.append(("channels", self.channels()))
41
+ features.append(("streams", len(self)))
42
+ features_str = ", ".join(f"{name}={value}" for name, value in features)
43
+ return f"AudioFile({features_str})"
44
+
45
+ @property
46
+ def info(self):
47
+ if self._info is None:
48
+ self._info = _read_info(self.path)
49
+ return self._info
50
+
51
+ @property
52
+ def duration(self):
53
+ return float(self.info['format']['duration'])
54
+
55
+ @property
56
+ def _audio_streams(self):
57
+ return [
58
+ index for index, stream in enumerate(self.info["streams"])
59
+ if stream["codec_type"] == "audio"
60
+ ]
61
+
62
+ def __len__(self):
63
+ return len(self._audio_streams)
64
+
65
+ def channels(self, stream=0):
66
+ return int(self.info['streams'][self._audio_streams[stream]]['channels'])
67
+
68
+ def samplerate(self, stream=0):
69
+ return int(self.info['streams'][self._audio_streams[stream]]['sample_rate'])
70
+
71
+ def read(self,
72
+ seek_time=None,
73
+ duration=None,
74
+ streams=slice(None),
75
+ samplerate=None,
76
+ channels=None):
77
+ """
78
+ Slightly more efficient implementation than stempeg,
79
+ in particular, this will extract all stems at once
80
+ rather than having to loop over one file multiple times
81
+ for each stream.
82
+
83
+ Args:
84
+ seek_time (float): seek time in seconds or None if no seeking is needed.
85
+ duration (float): duration in seconds to extract or None to extract until the end.
86
+ streams (slice, int or list): streams to extract, can be a single int, a list or
87
+ a slice. If it is a slice or list, the output will be of size [S, C, T]
88
+ with S the number of streams, C the number of channels and T the number of samples.
89
+ If it is an int, the output will be [C, T].
90
+ samplerate (int): if provided, will resample on the fly. If None, no resampling will
91
+ be done. Original sampling rate can be obtained with :method:`samplerate`.
92
+ channels (int): if 1, will convert to mono. We do not rely on ffmpeg for that
93
+ as ffmpeg automatically scale by +3dB to conserve volume when playing on speakers.
94
+ See https://sound.stackexchange.com/a/42710.
95
+ Our definition of mono is simply the average of the two channels. Any other
96
+ value will be ignored.
97
+ """
98
+ streams = np.array(range(len(self)))[streams]
99
+ single = not isinstance(streams, np.ndarray)
100
+ if single:
101
+ streams = [streams]
102
+
103
+ if duration is None:
104
+ target_size = None
105
+ query_duration = None
106
+ else:
107
+ target_size = int((samplerate or self.samplerate()) * duration)
108
+ query_duration = float((target_size + 1) / (samplerate or self.samplerate()))
109
+
110
+ with temp_filenames(len(streams)) as filenames:
111
+ command = ['ffmpeg', '-y']
112
+ command += ['-loglevel', 'panic']
113
+ if seek_time:
114
+ command += ['-ss', str(seek_time)]
115
+ command += ['-i', str(self.path)]
116
+ for stream, filename in zip(streams, filenames):
117
+ command += ['-map', f'0:{self._audio_streams[stream]}']
118
+ if query_duration is not None:
119
+ command += ['-t', str(query_duration)]
120
+ command += ['-threads', '1']
121
+ command += ['-f', 'f32le']
122
+ if samplerate is not None:
123
+ command += ['-ar', str(samplerate)]
124
+ command += [filename]
125
+
126
+ sp.run(command, check=True)
127
+ wavs = []
128
+ for filename in filenames:
129
+ wav = np.fromfile(filename, dtype=np.float32)
130
+ wav = torch.from_numpy(wav)
131
+ wav = wav.view(-1, self.channels()).t()
132
+ if channels is not None:
133
+ wav = convert_audio_channels(wav, channels)
134
+ if target_size is not None:
135
+ wav = wav[..., :target_size]
136
+ wavs.append(wav)
137
+ wav = torch.stack(wavs, dim=0)
138
+ if single:
139
+ wav = wav[0]
140
+ return wav
141
+
142
+
143
+ def convert_audio_channels(wav, channels=2):
144
+ """Convert audio to the given number of channels."""
145
+ *shape, src_channels, length = wav.shape
146
+ if src_channels == channels:
147
+ pass
148
+ elif channels == 1:
149
+ # Case 1:
150
+ # The caller asked 1-channel audio, but the stream have multiple
151
+ # channels, downmix all channels.
152
+ wav = wav.mean(dim=-2, keepdim=True)
153
+ elif src_channels == 1:
154
+ # Case 2:
155
+ # The caller asked for multiple channels, but the input file have
156
+ # one single channel, replicate the audio over all channels.
157
+ wav = wav.expand(*shape, channels, length)
158
+ elif src_channels >= channels:
159
+ # Case 3:
160
+ # The caller asked for multiple channels, and the input file have
161
+ # more channels than requested. In that case return the first channels.
162
+ wav = wav[..., :channels, :]
163
+ else:
164
+ # Case 4: What is a reasonable choice here?
165
+ raise ValueError('The audio file has less channels than requested but is not mono.')
166
+ return wav
167
+
168
+
169
+ def convert_audio(wav, from_samplerate, to_samplerate, channels) -> torch.Tensor:
170
+ """Convert audio from a given samplerate to a target one and target number of channels."""
171
+ wav = convert_audio_channels(wav, channels)
172
+ return julius.resample_frac(wav, from_samplerate, to_samplerate)
173
+
174
+
175
+ def i16_pcm(wav):
176
+ """Convert audio to 16 bits integer PCM format."""
177
+ if wav.dtype.is_floating_point:
178
+ return (wav.clamp_(-1, 1) * (2**15 - 1)).short()
179
+ else:
180
+ return wav
181
+
182
+
183
+ def f32_pcm(wav):
184
+ """Convert audio to float 32 bits PCM format."""
185
+ if wav.dtype.is_floating_point:
186
+ return wav
187
+ else:
188
+ return wav.float() / (2**15 - 1)
189
+
190
+
191
+ def as_dtype_pcm(wav, dtype):
192
+ """Convert audio to either f32 pcm or i16 pcm depending on the given dtype."""
193
+ if wav.dtype.is_floating_point:
194
+ return f32_pcm(wav)
195
+ else:
196
+ return i16_pcm(wav)
197
+
198
+
199
+ def encode_mp3(wav, path, samplerate=44100, bitrate=320, quality=2, verbose=False):
200
+ """Save given audio as mp3. This should work on all OSes."""
201
+ C, T = wav.shape
202
+ wav = i16_pcm(wav)
203
+ encoder = lameenc.Encoder()
204
+ encoder.set_bit_rate(bitrate)
205
+ encoder.set_in_sample_rate(samplerate)
206
+ encoder.set_channels(C)
207
+ encoder.set_quality(quality) # 2-highest, 7-fastest
208
+ if not verbose:
209
+ encoder.silence()
210
+ wav = wav.data.cpu()
211
+ wav = wav.transpose(0, 1).numpy()
212
+ mp3_data = encoder.encode(wav.tobytes())
213
+ mp3_data += encoder.flush()
214
+ with open(path, "wb") as f:
215
+ f.write(mp3_data)
216
+
217
+
218
+ def prevent_clip(wav, mode='rescale'):
219
+ """
220
+ different strategies for avoiding raw clipping.
221
+ """
222
+ if mode is None or mode == 'none':
223
+ return wav
224
+ assert wav.dtype.is_floating_point, "too late for clipping"
225
+ if mode == 'rescale':
226
+ wav = wav / max(1.01 * wav.abs().max(), 1)
227
+ elif mode == 'clamp':
228
+ wav = wav.clamp(-0.99, 0.99)
229
+ elif mode == 'tanh':
230
+ wav = torch.tanh(wav)
231
+ else:
232
+ raise ValueError(f"Invalid mode {mode}")
233
+ return wav
234
+
235
+
236
+ def save_audio(wav: torch.Tensor,
237
+ path: tp.Union[str, Path],
238
+ samplerate: int,
239
+ bitrate: int = 320,
240
+ clip: tp.Literal["rescale", "clamp", "tanh", "none"] = 'rescale',
241
+ bits_per_sample: tp.Literal[16, 24, 32] = 16,
242
+ as_float: bool = False,
243
+ preset: tp.Literal[2, 3, 4, 5, 6, 7] = 2):
244
+ """Save audio file, automatically preventing clipping if necessary
245
+ based on the given `clip` strategy. If the path ends in `.mp3`, this
246
+ will save as mp3 with the given `bitrate`. Use `preset` to set mp3 quality:
247
+ 2 for highest quality, 7 for fastest speed
248
+ """
249
+ wav = prevent_clip(wav, mode=clip)
250
+ path = Path(path)
251
+ suffix = path.suffix.lower()
252
+ if suffix == ".mp3":
253
+ encode_mp3(wav, path, samplerate, bitrate, preset, verbose=True)
254
+ elif suffix == ".wav":
255
+ if as_float:
256
+ bits_per_sample = 32
257
+ encoding = 'PCM_F'
258
+ else:
259
+ encoding = 'PCM_S'
260
+ ta.save(str(path), wav, sample_rate=samplerate,
261
+ encoding=encoding, bits_per_sample=bits_per_sample)
262
+ elif suffix == ".flac":
263
+ ta.save(str(path), wav, sample_rate=samplerate, bits_per_sample=bits_per_sample)
264
+ else:
265
+ raise ValueError(f"Invalid suffix for path: {suffix}")
demucs/augment.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Data augmentations.
7
+ """
8
+
9
+ import random
10
+ import torch as th
11
+ from torch import nn
12
+
13
+
14
+ class Shift(nn.Module):
15
+ """
16
+ Randomly shift audio in time by up to `shift` samples.
17
+ """
18
+ def __init__(self, shift=8192, same=False):
19
+ super().__init__()
20
+ self.shift = shift
21
+ self.same = same
22
+
23
+ def forward(self, wav):
24
+ batch, sources, channels, time = wav.size()
25
+ length = time - self.shift
26
+ if self.shift > 0:
27
+ if not self.training:
28
+ wav = wav[..., :length]
29
+ else:
30
+ srcs = 1 if self.same else sources
31
+ offsets = th.randint(self.shift, [batch, srcs, 1, 1], device=wav.device)
32
+ offsets = offsets.expand(-1, sources, channels, -1)
33
+ indexes = th.arange(length, device=wav.device)
34
+ wav = wav.gather(3, indexes + offsets)
35
+ return wav
36
+
37
+
38
+ class FlipChannels(nn.Module):
39
+ """
40
+ Flip left-right channels.
41
+ """
42
+ def forward(self, wav):
43
+ batch, sources, channels, time = wav.size()
44
+ if self.training and wav.size(2) == 2:
45
+ left = th.randint(2, (batch, sources, 1, 1), device=wav.device)
46
+ left = left.expand(-1, -1, -1, time)
47
+ right = 1 - left
48
+ wav = th.cat([wav.gather(2, left), wav.gather(2, right)], dim=2)
49
+ return wav
50
+
51
+
52
+ class FlipSign(nn.Module):
53
+ """
54
+ Random sign flip.
55
+ """
56
+ def forward(self, wav):
57
+ batch, sources, channels, time = wav.size()
58
+ if self.training:
59
+ signs = th.randint(2, (batch, sources, 1, 1), device=wav.device, dtype=th.float32)
60
+ wav = wav * (2 * signs - 1)
61
+ return wav
62
+
63
+
64
+ class Remix(nn.Module):
65
+ """
66
+ Shuffle sources to make new mixes.
67
+ """
68
+ def __init__(self, proba=1, group_size=4):
69
+ """
70
+ Shuffle sources within one batch.
71
+ Each batch is divided into groups of size `group_size` and shuffling is done within
72
+ each group separatly. This allow to keep the same probability distribution no matter
73
+ the number of GPUs. Without this grouping, using more GPUs would lead to a higher
74
+ probability of keeping two sources from the same track together which can impact
75
+ performance.
76
+ """
77
+ super().__init__()
78
+ self.proba = proba
79
+ self.group_size = group_size
80
+
81
+ def forward(self, wav):
82
+ batch, streams, channels, time = wav.size()
83
+ device = wav.device
84
+
85
+ if self.training and random.random() < self.proba:
86
+ group_size = self.group_size or batch
87
+ if batch % group_size != 0:
88
+ raise ValueError(f"Batch size {batch} must be divisible by group size {group_size}")
89
+ groups = batch // group_size
90
+ wav = wav.view(groups, group_size, streams, channels, time)
91
+ permutations = th.argsort(th.rand(groups, group_size, streams, 1, 1, device=device),
92
+ dim=1)
93
+ wav = wav.gather(1, permutations.expand(-1, -1, -1, channels, time))
94
+ wav = wav.view(batch, streams, channels, time)
95
+ return wav
96
+
97
+
98
+ class Scale(nn.Module):
99
+ def __init__(self, proba=1., min=0.25, max=1.25):
100
+ super().__init__()
101
+ self.proba = proba
102
+ self.min = min
103
+ self.max = max
104
+
105
+ def forward(self, wav):
106
+ batch, streams, channels, time = wav.size()
107
+ device = wav.device
108
+ if self.training and random.random() < self.proba:
109
+ scales = th.empty(batch, streams, 1, 1, device=device).uniform_(self.min, self.max)
110
+ wav *= scales
111
+ return wav
demucs/demucs.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import typing as tp
9
+
10
+ import julius
11
+ import torch
12
+ from torch import nn
13
+ from torch.nn import functional as F
14
+
15
+ from .states import capture_init
16
+ from .utils import center_trim, unfold
17
+ from .transformer import LayerScale
18
+
19
+
20
+ class BLSTM(nn.Module):
21
+ """
22
+ BiLSTM with same hidden units as input dim.
23
+ If `max_steps` is not None, input will be splitting in overlapping
24
+ chunks and the LSTM applied separately on each chunk.
25
+ """
26
+ def __init__(self, dim, layers=1, max_steps=None, skip=False):
27
+ super().__init__()
28
+ assert max_steps is None or max_steps % 4 == 0
29
+ self.max_steps = max_steps
30
+ self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
31
+ self.linear = nn.Linear(2 * dim, dim)
32
+ self.skip = skip
33
+
34
+ def forward(self, x):
35
+ B, C, T = x.shape
36
+ y = x
37
+ framed = False
38
+ if self.max_steps is not None and T > self.max_steps:
39
+ width = self.max_steps
40
+ stride = width // 2
41
+ frames = unfold(x, width, stride)
42
+ nframes = frames.shape[2]
43
+ framed = True
44
+ x = frames.permute(0, 2, 1, 3).reshape(-1, C, width)
45
+
46
+ x = x.permute(2, 0, 1)
47
+
48
+ x = self.lstm(x)[0]
49
+ x = self.linear(x)
50
+ x = x.permute(1, 2, 0)
51
+ if framed:
52
+ out = []
53
+ frames = x.reshape(B, -1, C, width)
54
+ limit = stride // 2
55
+ for k in range(nframes):
56
+ if k == 0:
57
+ out.append(frames[:, k, :, :-limit])
58
+ elif k == nframes - 1:
59
+ out.append(frames[:, k, :, limit:])
60
+ else:
61
+ out.append(frames[:, k, :, limit:-limit])
62
+ out = torch.cat(out, -1)
63
+ out = out[..., :T]
64
+ x = out
65
+ if self.skip:
66
+ x = x + y
67
+ return x
68
+
69
+
70
+ def rescale_conv(conv, reference):
71
+ """Rescale initial weight scale. It is unclear why it helps but it certainly does.
72
+ """
73
+ std = conv.weight.std().detach()
74
+ scale = (std / reference)**0.5
75
+ conv.weight.data /= scale
76
+ if conv.bias is not None:
77
+ conv.bias.data /= scale
78
+
79
+
80
+ def rescale_module(module, reference):
81
+ for sub in module.modules():
82
+ if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)):
83
+ rescale_conv(sub, reference)
84
+
85
+
86
+ class DConv(nn.Module):
87
+ """
88
+ New residual branches in each encoder layer.
89
+ This alternates dilated convolutions, potentially with LSTMs and attention.
90
+ Also before entering each residual branch, dimension is projected on a smaller subspace,
91
+ e.g. of dim `channels // compress`.
92
+ """
93
+ def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4,
94
+ norm=True, attn=False, heads=4, ndecay=4, lstm=False, gelu=True,
95
+ kernel=3, dilate=True):
96
+ """
97
+ Args:
98
+ channels: input/output channels for residual branch.
99
+ compress: amount of channel compression inside the branch.
100
+ depth: number of layers in the residual branch. Each layer has its own
101
+ projection, and potentially LSTM and attention.
102
+ init: initial scale for LayerNorm.
103
+ norm: use GroupNorm.
104
+ attn: use LocalAttention.
105
+ heads: number of heads for the LocalAttention.
106
+ ndecay: number of decay controls in the LocalAttention.
107
+ lstm: use LSTM.
108
+ gelu: Use GELU activation.
109
+ kernel: kernel size for the (dilated) convolutions.
110
+ dilate: if true, use dilation, increasing with the depth.
111
+ """
112
+
113
+ super().__init__()
114
+ assert kernel % 2 == 1
115
+ self.channels = channels
116
+ self.compress = compress
117
+ self.depth = abs(depth)
118
+ dilate = depth > 0
119
+
120
+ norm_fn: tp.Callable[[int], nn.Module]
121
+ norm_fn = lambda d: nn.Identity() # noqa
122
+ if norm:
123
+ norm_fn = lambda d: nn.GroupNorm(1, d) # noqa
124
+
125
+ hidden = int(channels / compress)
126
+
127
+ act: tp.Type[nn.Module]
128
+ if gelu:
129
+ act = nn.GELU
130
+ else:
131
+ act = nn.ReLU
132
+
133
+ self.layers = nn.ModuleList([])
134
+ for d in range(self.depth):
135
+ dilation = 2 ** d if dilate else 1
136
+ padding = dilation * (kernel // 2)
137
+ mods = [
138
+ nn.Conv1d(channels, hidden, kernel, dilation=dilation, padding=padding),
139
+ norm_fn(hidden), act(),
140
+ nn.Conv1d(hidden, 2 * channels, 1),
141
+ norm_fn(2 * channels), nn.GLU(1),
142
+ LayerScale(channels, init),
143
+ ]
144
+ if attn:
145
+ mods.insert(3, LocalState(hidden, heads=heads, ndecay=ndecay))
146
+ if lstm:
147
+ mods.insert(3, BLSTM(hidden, layers=2, max_steps=200, skip=True))
148
+ layer = nn.Sequential(*mods)
149
+ self.layers.append(layer)
150
+
151
+ def forward(self, x):
152
+ for layer in self.layers:
153
+ x = x + layer(x)
154
+ return x
155
+
156
+
157
+ class LocalState(nn.Module):
158
+ """Local state allows to have attention based only on data (no positional embedding),
159
+ but while setting a constraint on the time window (e.g. decaying penalty term).
160
+
161
+ Also a failed experiments with trying to provide some frequency based attention.
162
+ """
163
+ def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4):
164
+ super().__init__()
165
+ assert channels % heads == 0, (channels, heads)
166
+ self.heads = heads
167
+ self.nfreqs = nfreqs
168
+ self.ndecay = ndecay
169
+ self.content = nn.Conv1d(channels, channels, 1)
170
+ self.query = nn.Conv1d(channels, channels, 1)
171
+ self.key = nn.Conv1d(channels, channels, 1)
172
+ if nfreqs:
173
+ self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1)
174
+ if ndecay:
175
+ self.query_decay = nn.Conv1d(channels, heads * ndecay, 1)
176
+ # Initialize decay close to zero (there is a sigmoid), for maximum initial window.
177
+ self.query_decay.weight.data *= 0.01
178
+ assert self.query_decay.bias is not None # stupid type checker
179
+ self.query_decay.bias.data[:] = -2
180
+ self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1)
181
+
182
+ def forward(self, x):
183
+ B, C, T = x.shape
184
+ heads = self.heads
185
+ indexes = torch.arange(T, device=x.device, dtype=x.dtype)
186
+ # left index are keys, right index are queries
187
+ delta = indexes[:, None] - indexes[None, :]
188
+
189
+ queries = self.query(x).view(B, heads, -1, T)
190
+ keys = self.key(x).view(B, heads, -1, T)
191
+ # t are keys, s are queries
192
+ dots = torch.einsum("bhct,bhcs->bhts", keys, queries)
193
+ dots /= keys.shape[2]**0.5
194
+ if self.nfreqs:
195
+ periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype)
196
+ freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1))
197
+ freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs ** 0.5
198
+ dots += torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q)
199
+ if self.ndecay:
200
+ decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype)
201
+ decay_q = self.query_decay(x).view(B, heads, -1, T)
202
+ decay_q = torch.sigmoid(decay_q) / 2
203
+ decay_kernel = - decays.view(-1, 1, 1) * delta.abs() / self.ndecay**0.5
204
+ dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q)
205
+
206
+ # Kill self reference.
207
+ dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100)
208
+ weights = torch.softmax(dots, dim=2)
209
+
210
+ content = self.content(x).view(B, heads, -1, T)
211
+ result = torch.einsum("bhts,bhct->bhcs", weights, content)
212
+ if self.nfreqs:
213
+ time_sig = torch.einsum("bhts,fts->bhfs", weights, freq_kernel)
214
+ result = torch.cat([result, time_sig], 2)
215
+ result = result.reshape(B, -1, T)
216
+ return x + self.proj(result)
217
+
218
+
219
+ class Demucs(nn.Module):
220
+ @capture_init
221
+ def __init__(self,
222
+ sources,
223
+ # Channels
224
+ audio_channels=2,
225
+ channels=64,
226
+ growth=2.,
227
+ # Main structure
228
+ depth=6,
229
+ rewrite=True,
230
+ lstm_layers=0,
231
+ # Convolutions
232
+ kernel_size=8,
233
+ stride=4,
234
+ context=1,
235
+ # Activations
236
+ gelu=True,
237
+ glu=True,
238
+ # Normalization
239
+ norm_starts=4,
240
+ norm_groups=4,
241
+ # DConv residual branch
242
+ dconv_mode=1,
243
+ dconv_depth=2,
244
+ dconv_comp=4,
245
+ dconv_attn=4,
246
+ dconv_lstm=4,
247
+ dconv_init=1e-4,
248
+ # Pre/post processing
249
+ normalize=True,
250
+ resample=True,
251
+ # Weight init
252
+ rescale=0.1,
253
+ # Metadata
254
+ samplerate=44100,
255
+ segment=4 * 10):
256
+ """
257
+ Args:
258
+ sources (list[str]): list of source names
259
+ audio_channels (int): stereo or mono
260
+ channels (int): first convolution channels
261
+ depth (int): number of encoder/decoder layers
262
+ growth (float): multiply (resp divide) number of channels by that
263
+ for each layer of the encoder (resp decoder)
264
+ depth (int): number of layers in the encoder and in the decoder.
265
+ rewrite (bool): add 1x1 convolution to each layer.
266
+ lstm_layers (int): number of lstm layers, 0 = no lstm. Deactivated
267
+ by default, as this is now replaced by the smaller and faster small LSTMs
268
+ in the DConv branches.
269
+ kernel_size (int): kernel size for convolutions
270
+ stride (int): stride for convolutions
271
+ context (int): kernel size of the convolution in the
272
+ decoder before the transposed convolution. If > 1,
273
+ will provide some context from neighboring time steps.
274
+ gelu: use GELU activation function.
275
+ glu (bool): use glu instead of ReLU for the 1x1 rewrite conv.
276
+ norm_starts: layer at which group norm starts being used.
277
+ decoder layers are numbered in reverse order.
278
+ norm_groups: number of groups for group norm.
279
+ dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
280
+ dconv_depth: depth of residual DConv branch.
281
+ dconv_comp: compression of DConv branch.
282
+ dconv_attn: adds attention layers in DConv branch starting at this layer.
283
+ dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
284
+ dconv_init: initial scale for the DConv branch LayerScale.
285
+ normalize (bool): normalizes the input audio on the fly, and scales back
286
+ the output by the same amount.
287
+ resample (bool): upsample x2 the input and downsample /2 the output.
288
+ rescale (float): rescale initial weights of convolutions
289
+ to get their standard deviation closer to `rescale`.
290
+ samplerate (int): stored as meta information for easing
291
+ future evaluations of the model.
292
+ segment (float): duration of the chunks of audio to ideally evaluate the model on.
293
+ This is used by `demucs.apply.apply_model`.
294
+ """
295
+
296
+ super().__init__()
297
+ self.audio_channels = audio_channels
298
+ self.sources = sources
299
+ self.kernel_size = kernel_size
300
+ self.context = context
301
+ self.stride = stride
302
+ self.depth = depth
303
+ self.resample = resample
304
+ self.channels = channels
305
+ self.normalize = normalize
306
+ self.samplerate = samplerate
307
+ self.segment = segment
308
+ self.encoder = nn.ModuleList()
309
+ self.decoder = nn.ModuleList()
310
+ self.skip_scales = nn.ModuleList()
311
+
312
+ if glu:
313
+ activation = nn.GLU(dim=1)
314
+ ch_scale = 2
315
+ else:
316
+ activation = nn.ReLU()
317
+ ch_scale = 1
318
+ if gelu:
319
+ act2 = nn.GELU
320
+ else:
321
+ act2 = nn.ReLU
322
+
323
+ in_channels = audio_channels
324
+ padding = 0
325
+ for index in range(depth):
326
+ norm_fn = lambda d: nn.Identity() # noqa
327
+ if index >= norm_starts:
328
+ norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
329
+
330
+ encode = []
331
+ encode += [
332
+ nn.Conv1d(in_channels, channels, kernel_size, stride),
333
+ norm_fn(channels),
334
+ act2(),
335
+ ]
336
+ attn = index >= dconv_attn
337
+ lstm = index >= dconv_lstm
338
+ if dconv_mode & 1:
339
+ encode += [DConv(channels, depth=dconv_depth, init=dconv_init,
340
+ compress=dconv_comp, attn=attn, lstm=lstm)]
341
+ if rewrite:
342
+ encode += [
343
+ nn.Conv1d(channels, ch_scale * channels, 1),
344
+ norm_fn(ch_scale * channels), activation]
345
+ self.encoder.append(nn.Sequential(*encode))
346
+
347
+ decode = []
348
+ if index > 0:
349
+ out_channels = in_channels
350
+ else:
351
+ out_channels = len(self.sources) * audio_channels
352
+ if rewrite:
353
+ decode += [
354
+ nn.Conv1d(channels, ch_scale * channels, 2 * context + 1, padding=context),
355
+ norm_fn(ch_scale * channels), activation]
356
+ if dconv_mode & 2:
357
+ decode += [DConv(channels, depth=dconv_depth, init=dconv_init,
358
+ compress=dconv_comp, attn=attn, lstm=lstm)]
359
+ decode += [nn.ConvTranspose1d(channels, out_channels,
360
+ kernel_size, stride, padding=padding)]
361
+ if index > 0:
362
+ decode += [norm_fn(out_channels), act2()]
363
+ self.decoder.insert(0, nn.Sequential(*decode))
364
+ in_channels = channels
365
+ channels = int(growth * channels)
366
+
367
+ channels = in_channels
368
+ if lstm_layers:
369
+ self.lstm = BLSTM(channels, lstm_layers)
370
+ else:
371
+ self.lstm = None
372
+
373
+ if rescale:
374
+ rescale_module(self, reference=rescale)
375
+
376
+ def valid_length(self, length):
377
+ """
378
+ Return the nearest valid length to use with the model so that
379
+ there is no time steps left over in a convolution, e.g. for all
380
+ layers, size of the input - kernel_size % stride = 0.
381
+
382
+ Note that input are automatically padded if necessary to ensure that the output
383
+ has the same length as the input.
384
+ """
385
+ if self.resample:
386
+ length *= 2
387
+
388
+ for _ in range(self.depth):
389
+ length = math.ceil((length - self.kernel_size) / self.stride) + 1
390
+ length = max(1, length)
391
+
392
+ for idx in range(self.depth):
393
+ length = (length - 1) * self.stride + self.kernel_size
394
+
395
+ if self.resample:
396
+ length = math.ceil(length / 2)
397
+ return int(length)
398
+
399
+ def forward(self, mix):
400
+ x = mix
401
+ length = x.shape[-1]
402
+
403
+ if self.normalize:
404
+ mono = mix.mean(dim=1, keepdim=True)
405
+ mean = mono.mean(dim=-1, keepdim=True)
406
+ std = mono.std(dim=-1, keepdim=True)
407
+ x = (x - mean) / (1e-5 + std)
408
+ else:
409
+ mean = 0
410
+ std = 1
411
+
412
+ delta = self.valid_length(length) - length
413
+ x = F.pad(x, (delta // 2, delta - delta // 2))
414
+
415
+ if self.resample:
416
+ x = julius.resample_frac(x, 1, 2)
417
+
418
+ saved = []
419
+ for encode in self.encoder:
420
+ x = encode(x)
421
+ saved.append(x)
422
+
423
+ if self.lstm:
424
+ x = self.lstm(x)
425
+
426
+ for decode in self.decoder:
427
+ skip = saved.pop(-1)
428
+ skip = center_trim(skip, x)
429
+ x = decode(x + skip)
430
+
431
+ if self.resample:
432
+ x = julius.resample_frac(x, 2, 1)
433
+ x = x * std + mean
434
+ x = center_trim(x, length)
435
+ x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
436
+ return x
437
+
438
+ def load_state_dict(self, state, strict=True):
439
+ # fix a mismatch with previous generation Demucs models.
440
+ for idx in range(self.depth):
441
+ for a in ['encoder', 'decoder']:
442
+ for b in ['bias', 'weight']:
443
+ new = f'{a}.{idx}.3.{b}'
444
+ old = f'{a}.{idx}.2.{b}'
445
+ if old in state and new not in state:
446
+ state[new] = state.pop(old)
447
+ super().load_state_dict(state, strict=strict)
demucs/distrib.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Distributed training utilities.
7
+ """
8
+ import logging
9
+ import pickle
10
+
11
+ import numpy as np
12
+ import torch
13
+ from torch.utils.data.distributed import DistributedSampler
14
+ from torch.utils.data import DataLoader, Subset
15
+ from torch.nn.parallel.distributed import DistributedDataParallel
16
+
17
+ from dora import distrib as dora_distrib
18
+
19
+ logger = logging.getLogger(__name__)
20
+ rank = 0
21
+ world_size = 1
22
+
23
+
24
+ def init():
25
+ global rank, world_size
26
+ if not torch.distributed.is_initialized():
27
+ dora_distrib.init()
28
+ rank = dora_distrib.rank()
29
+ world_size = dora_distrib.world_size()
30
+
31
+
32
+ def average(metrics, count=1.):
33
+ if isinstance(metrics, dict):
34
+ keys, values = zip(*sorted(metrics.items()))
35
+ values = average(values, count)
36
+ return dict(zip(keys, values))
37
+ if world_size == 1:
38
+ return metrics
39
+ tensor = torch.tensor(list(metrics) + [1], device='cuda', dtype=torch.float32)
40
+ tensor *= count
41
+ torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)
42
+ return (tensor[:-1] / tensor[-1]).cpu().numpy().tolist()
43
+
44
+
45
+ def wrap(model):
46
+ if world_size == 1:
47
+ return model
48
+ else:
49
+ return DistributedDataParallel(
50
+ model,
51
+ # find_unused_parameters=True,
52
+ device_ids=[torch.cuda.current_device()],
53
+ output_device=torch.cuda.current_device())
54
+
55
+
56
+ def barrier():
57
+ if world_size > 1:
58
+ torch.distributed.barrier()
59
+
60
+
61
+ def share(obj=None, src=0):
62
+ if world_size == 1:
63
+ return obj
64
+ size = torch.empty(1, device='cuda', dtype=torch.long)
65
+ if rank == src:
66
+ dump = pickle.dumps(obj)
67
+ size[0] = len(dump)
68
+ torch.distributed.broadcast(size, src=src)
69
+ # size variable is now set to the length of pickled obj in all processes
70
+
71
+ if rank == src:
72
+ buffer = torch.from_numpy(np.frombuffer(dump, dtype=np.uint8).copy()).cuda()
73
+ else:
74
+ buffer = torch.empty(size[0].item(), device='cuda', dtype=torch.uint8)
75
+ torch.distributed.broadcast(buffer, src=src)
76
+ # buffer variable is now set to pickled obj in all processes
77
+
78
+ if rank != src:
79
+ obj = pickle.loads(buffer.cpu().numpy().tobytes())
80
+ logger.debug(f"Shared object of size {len(buffer)}")
81
+ return obj
82
+
83
+
84
+ def loader(dataset, *args, shuffle=False, klass=DataLoader, **kwargs):
85
+ """
86
+ Create a dataloader properly in case of distributed training.
87
+ If a gradient is going to be computed you must set `shuffle=True`.
88
+ """
89
+ if world_size == 1:
90
+ return klass(dataset, *args, shuffle=shuffle, **kwargs)
91
+
92
+ if shuffle:
93
+ # train means we will compute backward, we use DistributedSampler
94
+ sampler = DistributedSampler(dataset)
95
+ # We ignore shuffle, DistributedSampler already shuffles
96
+ return klass(dataset, *args, **kwargs, sampler=sampler)
97
+ else:
98
+ # We make a manual shard, as DistributedSampler otherwise replicate some examples
99
+ dataset = Subset(dataset, list(range(rank, len(dataset), world_size)))
100
+ return klass(dataset, *args, shuffle=shuffle, **kwargs)
demucs/ema.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Inspired from https://github.com/rwightman/pytorch-image-models
8
+ from contextlib import contextmanager
9
+
10
+ import torch
11
+
12
+ from .states import swap_state
13
+
14
+
15
+ class ModelEMA:
16
+ """
17
+ Perform EMA on a model. You can switch to the EMA weights temporarily
18
+ with the `swap` method.
19
+
20
+ ema = ModelEMA(model)
21
+ with ema.swap():
22
+ # compute valid metrics with averaged model.
23
+ """
24
+ def __init__(self, model, decay=0.9999, unbias=True, device='cpu'):
25
+ self.decay = decay
26
+ self.model = model
27
+ self.state = {}
28
+ self.count = 0
29
+ self.device = device
30
+ self.unbias = unbias
31
+
32
+ self._init()
33
+
34
+ def _init(self):
35
+ for key, val in self.model.state_dict().items():
36
+ if val.dtype != torch.float32:
37
+ continue
38
+ device = self.device or val.device
39
+ if key not in self.state:
40
+ self.state[key] = val.detach().to(device, copy=True)
41
+
42
+ def update(self):
43
+ if self.unbias:
44
+ self.count = self.count * self.decay + 1
45
+ w = 1 / self.count
46
+ else:
47
+ w = 1 - self.decay
48
+ for key, val in self.model.state_dict().items():
49
+ if val.dtype != torch.float32:
50
+ continue
51
+ device = self.device or val.device
52
+ self.state[key].mul_(1 - w)
53
+ self.state[key].add_(val.detach().to(device), alpha=w)
54
+
55
+ @contextmanager
56
+ def swap(self):
57
+ with swap_state(self.model, self.state):
58
+ yield
59
+
60
+ def state_dict(self):
61
+ return {'state': self.state, 'count': self.count}
62
+
63
+ def load_state_dict(self, state):
64
+ self.count = state['count']
65
+ for k, v in state['state'].items():
66
+ self.state[k].copy_(v)
demucs/evaluate.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Test time evaluation, either using the original SDR from [Vincent et al. 2006]
8
+ or the newest SDR definition from the MDX 2021 competition (this one will
9
+ be reported as `nsdr` for `new sdr`).
10
+ """
11
+
12
+ from concurrent import futures
13
+ import logging
14
+
15
+ from dora.log import LogProgress
16
+ import numpy as np
17
+ import musdb
18
+ import museval
19
+ import torch as th
20
+
21
+ from .apply import apply_model
22
+ from .audio import convert_audio, save_audio
23
+ from . import distrib
24
+ from .utils import DummyPoolExecutor
25
+
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ def new_sdr(references, estimates):
31
+ """
32
+ Compute the SDR according to the MDX challenge definition.
33
+ Adapted from AIcrowd/music-demixing-challenge-starter-kit (MIT license)
34
+ """
35
+ assert references.dim() == 4
36
+ assert estimates.dim() == 4
37
+ delta = 1e-7 # avoid numerical errors
38
+ num = th.sum(th.square(references), dim=(2, 3))
39
+ den = th.sum(th.square(references - estimates), dim=(2, 3))
40
+ num += delta
41
+ den += delta
42
+ scores = 10 * th.log10(num / den)
43
+ return scores
44
+
45
+
46
+ def eval_track(references, estimates, win, hop, compute_sdr=True):
47
+ references = references.transpose(1, 2).double()
48
+ estimates = estimates.transpose(1, 2).double()
49
+
50
+ new_scores = new_sdr(references.cpu()[None], estimates.cpu()[None])[0]
51
+
52
+ if not compute_sdr:
53
+ return None, new_scores
54
+ else:
55
+ references = references.numpy()
56
+ estimates = estimates.numpy()
57
+ scores = museval.metrics.bss_eval(
58
+ references, estimates,
59
+ compute_permutation=False,
60
+ window=win,
61
+ hop=hop,
62
+ framewise_filters=False,
63
+ bsseval_sources_version=False)[:-1]
64
+ return scores, new_scores
65
+
66
+
67
+ def evaluate(solver, compute_sdr=False):
68
+ """
69
+ Evaluate model using museval.
70
+ compute_sdr=False means using only the MDX definition of the SDR, which
71
+ is much faster to evaluate.
72
+ """
73
+
74
+ args = solver.args
75
+
76
+ output_dir = solver.folder / "results"
77
+ output_dir.mkdir(exist_ok=True, parents=True)
78
+ json_folder = solver.folder / "results/test"
79
+ json_folder.mkdir(exist_ok=True, parents=True)
80
+
81
+ # we load tracks from the original musdb set
82
+ if args.test.nonhq is None:
83
+ test_set = musdb.DB(args.dset.musdb, subsets=["test"], is_wav=True)
84
+ else:
85
+ test_set = musdb.DB(args.test.nonhq, subsets=["test"], is_wav=False)
86
+ src_rate = args.dset.musdb_samplerate
87
+
88
+ eval_device = 'cpu'
89
+
90
+ model = solver.model
91
+ win = int(1. * model.samplerate)
92
+ hop = int(1. * model.samplerate)
93
+
94
+ indexes = range(distrib.rank, len(test_set), distrib.world_size)
95
+ indexes = LogProgress(logger, indexes, updates=args.misc.num_prints,
96
+ name='Eval')
97
+ pendings = []
98
+
99
+ pool = futures.ProcessPoolExecutor if args.test.workers else DummyPoolExecutor
100
+ with pool(args.test.workers) as pool:
101
+ for index in indexes:
102
+ track = test_set.tracks[index]
103
+
104
+ mix = th.from_numpy(track.audio).t().float()
105
+ if mix.dim() == 1:
106
+ mix = mix[None]
107
+ mix = mix.to(solver.device)
108
+ ref = mix.mean(dim=0) # mono mixture
109
+ mix = (mix - ref.mean()) / ref.std()
110
+ mix = convert_audio(mix, src_rate, model.samplerate, model.audio_channels)
111
+ estimates = apply_model(model, mix[None],
112
+ shifts=args.test.shifts, split=args.test.split,
113
+ overlap=args.test.overlap)[0]
114
+ estimates = estimates * ref.std() + ref.mean()
115
+ estimates = estimates.to(eval_device)
116
+
117
+ references = th.stack(
118
+ [th.from_numpy(track.targets[name].audio).t() for name in model.sources])
119
+ if references.dim() == 2:
120
+ references = references[:, None]
121
+ references = references.to(eval_device)
122
+ references = convert_audio(references, src_rate,
123
+ model.samplerate, model.audio_channels)
124
+ if args.test.save:
125
+ folder = solver.folder / "wav" / track.name
126
+ folder.mkdir(exist_ok=True, parents=True)
127
+ for name, estimate in zip(model.sources, estimates):
128
+ save_audio(estimate.cpu(), folder / (name + ".mp3"), model.samplerate)
129
+
130
+ pendings.append((track.name, pool.submit(
131
+ eval_track, references, estimates, win=win, hop=hop, compute_sdr=compute_sdr)))
132
+
133
+ pendings = LogProgress(logger, pendings, updates=args.misc.num_prints,
134
+ name='Eval (BSS)')
135
+ tracks = {}
136
+ for track_name, pending in pendings:
137
+ pending = pending.result()
138
+ scores, nsdrs = pending
139
+ tracks[track_name] = {}
140
+ for idx, target in enumerate(model.sources):
141
+ tracks[track_name][target] = {'nsdr': [float(nsdrs[idx])]}
142
+ if scores is not None:
143
+ (sdr, isr, sir, sar) = scores
144
+ for idx, target in enumerate(model.sources):
145
+ values = {
146
+ "SDR": sdr[idx].tolist(),
147
+ "SIR": sir[idx].tolist(),
148
+ "ISR": isr[idx].tolist(),
149
+ "SAR": sar[idx].tolist()
150
+ }
151
+ tracks[track_name][target].update(values)
152
+
153
+ all_tracks = {}
154
+ for src in range(distrib.world_size):
155
+ all_tracks.update(distrib.share(tracks, src))
156
+
157
+ result = {}
158
+ metric_names = next(iter(all_tracks.values()))[model.sources[0]]
159
+ for metric_name in metric_names:
160
+ avg = 0
161
+ avg_of_medians = 0
162
+ for source in model.sources:
163
+ medians = [
164
+ np.nanmedian(all_tracks[track][source][metric_name])
165
+ for track in all_tracks.keys()]
166
+ mean = np.mean(medians)
167
+ median = np.median(medians)
168
+ result[metric_name.lower() + "_" + source] = mean
169
+ result[metric_name.lower() + "_med" + "_" + source] = median
170
+ avg += mean / len(model.sources)
171
+ avg_of_medians += median / len(model.sources)
172
+ result[metric_name.lower()] = avg
173
+ result[metric_name.lower() + "_med"] = avg_of_medians
174
+ return result
demucs/grids/__init__.py ADDED
File without changes
demucs/grids/_explorers.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ from dora import Explorer
7
+ import treetable as tt
8
+
9
+
10
+ class MyExplorer(Explorer):
11
+ test_metrics = ['nsdr', 'sdr_med']
12
+
13
+ def get_grid_metrics(self):
14
+ """Return the metrics that should be displayed in the tracking table.
15
+ """
16
+ return [
17
+ tt.group("train", [
18
+ tt.leaf("epoch"),
19
+ tt.leaf("reco", ".3f"),
20
+ ], align=">"),
21
+ tt.group("valid", [
22
+ tt.leaf("penalty", ".1f"),
23
+ tt.leaf("ms", ".1f"),
24
+ tt.leaf("reco", ".2%"),
25
+ tt.leaf("breco", ".2%"),
26
+ tt.leaf("b_nsdr", ".2f"),
27
+ # tt.leaf("b_nsdr_drums", ".2f"),
28
+ # tt.leaf("b_nsdr_bass", ".2f"),
29
+ # tt.leaf("b_nsdr_other", ".2f"),
30
+ # tt.leaf("b_nsdr_vocals", ".2f"),
31
+ ], align=">"),
32
+ tt.group("test", [
33
+ tt.leaf(name, ".2f")
34
+ for name in self.test_metrics
35
+ ], align=">")
36
+ ]
37
+
38
+ def process_history(self, history):
39
+ train = {
40
+ 'epoch': len(history),
41
+ }
42
+ valid = {}
43
+ test = {}
44
+ best_v_main = float('inf')
45
+ breco = float('inf')
46
+ for metrics in history:
47
+ train.update(metrics['train'])
48
+ valid.update(metrics['valid'])
49
+ if 'main' in metrics['valid']:
50
+ best_v_main = min(best_v_main, metrics['valid']['main']['loss'])
51
+ valid['bmain'] = best_v_main
52
+ valid['breco'] = min(breco, metrics['valid']['reco'])
53
+ breco = valid['breco']
54
+ if (metrics['valid']['loss'] == metrics['valid']['best'] or
55
+ metrics['valid'].get('nsdr') == metrics['valid']['best']):
56
+ for k, v in metrics['valid'].items():
57
+ if k.startswith('reco_'):
58
+ valid['b_' + k[len('reco_'):]] = v
59
+ if k.startswith('nsdr'):
60
+ valid[f'b_{k}'] = v
61
+ if 'test' in metrics:
62
+ test.update(metrics['test'])
63
+ metrics = history[-1]
64
+ return {"train": train, "valid": valid, "test": test}
demucs/grids/mdx.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Main training for the Track A MDX models.
8
+ """
9
+
10
+ from ._explorers import MyExplorer
11
+ from ..train import main
12
+
13
+
14
+ TRACK_A = ['0d19c1c6', '7ecf8ec1', 'c511e2ab', '7d865c68']
15
+
16
+
17
+ @MyExplorer
18
+ def explorer(launcher):
19
+ launcher.slurm_(
20
+ gpus=8,
21
+ time=3 * 24 * 60,
22
+ partition='learnlab')
23
+
24
+ # Reproduce results from MDX competition Track A
25
+ # This trains the first round of models. Once this is trained,
26
+ # you will need to schedule `mdx_refine`.
27
+ for sig in TRACK_A:
28
+ xp = main.get_xp_from_sig(sig)
29
+ parent = xp.cfg.continue_from
30
+ xp = main.get_xp_from_sig(parent)
31
+ launcher(xp.argv)
32
+ launcher(xp.argv, {'quant.diffq': 1e-4})
33
+ launcher(xp.argv, {'quant.diffq': 3e-4})
demucs/grids/mdx_extra.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Main training for the Track A MDX models.
8
+ """
9
+
10
+ from ._explorers import MyExplorer
11
+ from ..train import main
12
+
13
+ TRACK_B = ['e51eebcc', 'a1d90b5c', '5d2d6c55', 'cfa93e08']
14
+
15
+
16
+ @MyExplorer
17
+ def explorer(launcher):
18
+ launcher.slurm_(
19
+ gpus=8,
20
+ time=3 * 24 * 60,
21
+ partition='learnlab')
22
+
23
+ # Reproduce results from MDX competition Track A
24
+ # This trains the first round of models. Once this is trained,
25
+ # you will need to schedule `mdx_refine`.
26
+ for sig in TRACK_B:
27
+ while sig is not None:
28
+ xp = main.get_xp_from_sig(sig)
29
+ sig = xp.cfg.continue_from
30
+
31
+ for dset in ['extra44', 'extra_test']:
32
+ sub = launcher.bind(xp.argv, dset=dset)
33
+ sub()
34
+ if dset == 'extra_test':
35
+ sub({'quant.diffq': 1e-4})
36
+ sub({'quant.diffq': 3e-4})
demucs/grids/mdx_refine.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Main training for the Track A MDX models.
8
+ """
9
+
10
+ from ._explorers import MyExplorer
11
+ from .mdx import TRACK_A
12
+ from ..train import main
13
+
14
+
15
+ @MyExplorer
16
+ def explorer(launcher):
17
+ launcher.slurm_(
18
+ gpus=8,
19
+ time=3 * 24 * 60,
20
+ partition='learnlab')
21
+
22
+ # Reproduce results from MDX competition Track A
23
+ # WARNING: all the experiments in the `mdx` grid must have completed.
24
+ for sig in TRACK_A:
25
+ xp = main.get_xp_from_sig(sig)
26
+ launcher(xp.argv)
27
+ for diffq in [1e-4, 3e-4]:
28
+ xp_src = main.get_xp_from_sig(xp.cfg.continue_from)
29
+ q_argv = [f'quant.diffq={diffq}']
30
+ actual_src = main.get_xp(xp_src.argv + q_argv)
31
+ actual_src.link.load()
32
+ assert len(actual_src.link.history) == actual_src.cfg.epochs
33
+ argv = xp.argv + q_argv + [f'continue_from="{actual_src.sig}"']
34
+ launcher(argv)
demucs/grids/mmi.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from ._explorers import MyExplorer
8
+ from dora import Launcher
9
+
10
+
11
+ @MyExplorer
12
+ def explorer(launcher: Launcher):
13
+ launcher.slurm_(gpus=8, time=3 * 24 * 60, partition="devlab,learnlab,learnfair") # 3 days
14
+
15
+ sub = launcher.bind_(
16
+ {
17
+ "dset": "extra_mmi_goodclean",
18
+ "test.shifts": 0,
19
+ "model": "htdemucs",
20
+ "htdemucs.dconv_mode": 3,
21
+ "htdemucs.depth": 4,
22
+ "htdemucs.t_dropout": 0.02,
23
+ "htdemucs.t_layers": 5,
24
+ "max_batches": 800,
25
+ "ema.epoch": [0.9, 0.95],
26
+ "ema.batch": [0.9995, 0.9999],
27
+ "dset.segment": 10,
28
+ "batch_size": 32,
29
+ }
30
+ )
31
+ sub({"model": "hdemucs"})
32
+ sub({"model": "hdemucs", "dset": "extra44"})
33
+ sub({"model": "hdemucs", "dset": "musdb44"})
34
+
35
+ sparse = {
36
+ 'batch_size': 3 * 8,
37
+ 'augment.remix.group_size': 3,
38
+ 'htdemucs.t_auto_sparsity': True,
39
+ 'htdemucs.t_sparse_self_attn': True,
40
+ 'htdemucs.t_sparse_cross_attn': True,
41
+ 'htdemucs.t_sparsity': 0.9,
42
+ "htdemucs.t_layers": 7
43
+ }
44
+
45
+ with launcher.job_array():
46
+ for transf_layers in [5, 7]:
47
+ for bottom_channels in [0, 512]:
48
+ sub = launcher.bind({
49
+ "htdemucs.t_layers": transf_layers,
50
+ "htdemucs.bottom_channels": bottom_channels,
51
+ })
52
+ if bottom_channels == 0 and transf_layers == 5:
53
+ sub({"augment.remix.proba": 0.0})
54
+ sub({
55
+ "augment.repitch.proba": 0.0,
56
+ # when doing repitching, we trim the outut to align on the
57
+ # highest change of BPM. When removing repitching,
58
+ # we simulate it here to ensure the training context is the same.
59
+ # Another second is lost for all experiments due to the random
60
+ # shift augmentation.
61
+ "dset.segment": 10 * 0.88})
62
+ elif bottom_channels == 512 and transf_layers == 5:
63
+ sub(dset="musdb44")
64
+ sub(dset="extra44")
65
+ # Sparse kernel XP, currently not released as kernels are still experimental.
66
+ sub(sparse, {'dset.segment': 15, "htdemucs.t_layers": 7})
67
+
68
+ for duration in [5, 10, 15]:
69
+ sub({"dset.segment": duration})
demucs/grids/mmi_ft.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from ._explorers import MyExplorer
8
+ from dora import Launcher
9
+ from demucs import train
10
+
11
+
12
+ def get_sub(launcher, sig):
13
+ xp = train.main.get_xp_from_sig(sig)
14
+ sub = launcher.bind(xp.argv)
15
+ sub()
16
+ sub.bind_({
17
+ 'continue_from': sig,
18
+ 'continue_best': True})
19
+ return sub
20
+
21
+
22
+ @MyExplorer
23
+ def explorer(launcher: Launcher):
24
+ launcher.slurm_(gpus=4, time=3 * 24 * 60, partition="devlab,learnlab,learnfair") # 3 days
25
+ ft = {
26
+ 'optim.lr': 1e-4,
27
+ 'augment.remix.proba': 0,
28
+ 'augment.scale.proba': 0,
29
+ 'augment.shift_same': True,
30
+ 'htdemucs.t_weight_decay': 0.05,
31
+ 'batch_size': 8,
32
+ 'optim.clip_grad': 5,
33
+ 'optim.optim': 'adamw',
34
+ 'epochs': 50,
35
+ 'dset.wav2_valid': True,
36
+ 'ema.epoch': [], # let's make valid a bit faster
37
+ }
38
+ with launcher.job_array():
39
+ for sig in ['2899e11a']:
40
+ sub = get_sub(launcher, sig)
41
+ sub.bind_(ft)
42
+ for segment in [15, 18]:
43
+ for source in range(4):
44
+ w = [0] * 4
45
+ w[source] = 1
46
+ sub({'weights': w, 'dset.segment': segment})
47
+
48
+ for sig in ['955717e8']:
49
+ sub = get_sub(launcher, sig)
50
+ sub.bind_(ft)
51
+ for segment in [10, 15]:
52
+ for source in range(4):
53
+ w = [0] * 4
54
+ w[source] = 1
55
+ sub({'weights': w, 'dset.segment': segment})
demucs/grids/repro.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Easier training for reproducibility
8
+ """
9
+
10
+ from ._explorers import MyExplorer
11
+
12
+
13
+ @MyExplorer
14
+ def explorer(launcher):
15
+ launcher.slurm_(
16
+ gpus=8,
17
+ time=3 * 24 * 60,
18
+ partition='devlab,learnlab')
19
+
20
+ launcher.bind_({'ema.epoch': [0.9, 0.95]})
21
+ launcher.bind_({'ema.batch': [0.9995, 0.9999]})
22
+ launcher.bind_({'epochs': 600})
23
+
24
+ base = {'model': 'demucs', 'demucs.dconv_mode': 0, 'demucs.gelu': False,
25
+ 'demucs.lstm_layers': 2}
26
+ newt = {'model': 'demucs', 'demucs.normalize': True}
27
+ hdem = {'model': 'hdemucs'}
28
+ svd = {'svd.penalty': 1e-5, 'svd': 'base2'}
29
+
30
+ with launcher.job_array():
31
+ for model in [base, newt, hdem]:
32
+ sub = launcher.bind(model)
33
+ if model is base:
34
+ # Training the v2 Demucs on MusDB HQ
35
+ sub(epochs=360)
36
+ continue
37
+
38
+ # those two will be used in the repro_mdx_a bag of models.
39
+ sub(svd)
40
+ sub(svd, seed=43)
41
+ if model == newt:
42
+ # Ablation study
43
+ sub()
44
+ abl = sub.bind(svd)
45
+ abl({'ema.epoch': [], 'ema.batch': []})
46
+ abl({'demucs.dconv_lstm': 10})
47
+ abl({'demucs.dconv_attn': 10})
48
+ abl({'demucs.dconv_attn': 10, 'demucs.dconv_lstm': 10, 'demucs.lstm_layers': 2})
49
+ abl({'demucs.dconv_mode': 0})
50
+ abl({'demucs.gelu': False})
demucs/grids/repro_ft.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Fine tuning experiments
8
+ """
9
+
10
+ from ._explorers import MyExplorer
11
+ from ..train import main
12
+
13
+
14
+ @MyExplorer
15
+ def explorer(launcher):
16
+ launcher.slurm_(
17
+ gpus=8,
18
+ time=300,
19
+ partition='devlab,learnlab')
20
+
21
+ # Mus
22
+ launcher.slurm_(constraint='volta32gb')
23
+
24
+ grid = "repro"
25
+ folder = main.dora.dir / "grids" / grid
26
+
27
+ for sig in folder.iterdir():
28
+ if not sig.is_symlink():
29
+ continue
30
+ xp = main.get_xp_from_sig(sig)
31
+ xp.link.load()
32
+ if len(xp.link.history) != xp.cfg.epochs:
33
+ continue
34
+ sub = launcher.bind(xp.argv, [f'continue_from="{xp.sig}"'])
35
+ sub.bind_({'ema.epoch': [0.9, 0.95], 'ema.batch': [0.9995, 0.9999]})
36
+ sub.bind_({'test.every': 1, 'test.sdr': True, 'epochs': 4})
37
+ sub.bind_({'dset.segment': 28, 'dset.shift': 2})
38
+ sub.bind_({'batch_size': 32})
39
+ auto = {'dset': 'auto_mus'}
40
+ auto.update({'augment.remix.proba': 0, 'augment.scale.proba': 0,
41
+ 'augment.shift_same': True})
42
+ sub.bind_(auto)
43
+ sub.bind_({'batch_size': 16})
44
+ sub.bind_({'optim.lr': 1e-4})
45
+ sub.bind_({'model_segment': 44})
46
+ sub()
demucs/grids/sdx23.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from ._explorers import MyExplorer
8
+ from dora import Launcher
9
+
10
+
11
+ @MyExplorer
12
+ def explorer(launcher: Launcher):
13
+ launcher.slurm_(gpus=8, time=3 * 24 * 60, partition="speechgpt,learnfair",
14
+ mem_per_gpu=None, constraint='')
15
+ launcher.bind_({"dset.use_musdb": False})
16
+
17
+ with launcher.job_array():
18
+ launcher(dset='sdx23_bleeding')
19
+ launcher(dset='sdx23_labelnoise')
demucs/hdemucs.py ADDED
@@ -0,0 +1,794 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ This code contains the spectrogram and Hybrid version of Demucs.
8
+ """
9
+ from copy import deepcopy
10
+ import math
11
+ import typing as tp
12
+
13
+ from openunmix.filtering import wiener
14
+ import torch
15
+ from torch import nn
16
+ from torch.nn import functional as F
17
+
18
+ from .demucs import DConv, rescale_module
19
+ from .states import capture_init
20
+ from .spec import spectro, ispectro
21
+
22
+
23
+ def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
24
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
25
+ If this is the case, we insert extra 0 padding to the right before the reflection happen."""
26
+ x0 = x
27
+ length = x.shape[-1]
28
+ padding_left, padding_right = paddings
29
+ if mode == 'reflect':
30
+ max_pad = max(padding_left, padding_right)
31
+ if length <= max_pad:
32
+ extra_pad = max_pad - length + 1
33
+ extra_pad_right = min(padding_right, extra_pad)
34
+ extra_pad_left = extra_pad - extra_pad_right
35
+ paddings = (padding_left - extra_pad_left, padding_right - extra_pad_right)
36
+ x = F.pad(x, (extra_pad_left, extra_pad_right))
37
+ out = F.pad(x, paddings, mode, value)
38
+ assert out.shape[-1] == length + padding_left + padding_right
39
+ assert (out[..., padding_left: padding_left + length] == x0).all()
40
+ return out
41
+
42
+
43
+ class ScaledEmbedding(nn.Module):
44
+ """
45
+ Boost learning rate for embeddings (with `scale`).
46
+ Also, can make embeddings continuous with `smooth`.
47
+ """
48
+ def __init__(self, num_embeddings: int, embedding_dim: int,
49
+ scale: float = 10., smooth=False):
50
+ super().__init__()
51
+ self.embedding = nn.Embedding(num_embeddings, embedding_dim)
52
+ if smooth:
53
+ weight = torch.cumsum(self.embedding.weight.data, dim=0)
54
+ # when summing gaussian, overscale raises as sqrt(n), so we nornalize by that.
55
+ weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None]
56
+ self.embedding.weight.data[:] = weight
57
+ self.embedding.weight.data /= scale
58
+ self.scale = scale
59
+
60
+ @property
61
+ def weight(self):
62
+ return self.embedding.weight * self.scale
63
+
64
+ def forward(self, x):
65
+ out = self.embedding(x) * self.scale
66
+ return out
67
+
68
+
69
+ class HEncLayer(nn.Module):
70
+ def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False,
71
+ freq=True, dconv=True, norm=True, context=0, dconv_kw={}, pad=True,
72
+ rewrite=True):
73
+ """Encoder layer. This used both by the time and the frequency branch.
74
+
75
+ Args:
76
+ chin: number of input channels.
77
+ chout: number of output channels.
78
+ norm_groups: number of groups for group norm.
79
+ empty: used to make a layer with just the first conv. this is used
80
+ before merging the time and freq. branches.
81
+ freq: this is acting on frequencies.
82
+ dconv: insert DConv residual branches.
83
+ norm: use GroupNorm.
84
+ context: context size for the 1x1 conv.
85
+ dconv_kw: list of kwargs for the DConv class.
86
+ pad: pad the input. Padding is done so that the output size is
87
+ always the input size / stride.
88
+ rewrite: add 1x1 conv at the end of the layer.
89
+ """
90
+ super().__init__()
91
+ norm_fn = lambda d: nn.Identity() # noqa
92
+ if norm:
93
+ norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
94
+ if pad:
95
+ pad = kernel_size // 4
96
+ else:
97
+ pad = 0
98
+ klass = nn.Conv1d
99
+ self.freq = freq
100
+ self.kernel_size = kernel_size
101
+ self.stride = stride
102
+ self.empty = empty
103
+ self.norm = norm
104
+ self.pad = pad
105
+ if freq:
106
+ kernel_size = [kernel_size, 1]
107
+ stride = [stride, 1]
108
+ pad = [pad, 0]
109
+ klass = nn.Conv2d
110
+ self.conv = klass(chin, chout, kernel_size, stride, pad)
111
+ if self.empty:
112
+ return
113
+ self.norm1 = norm_fn(chout)
114
+ self.rewrite = None
115
+ if rewrite:
116
+ self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
117
+ self.norm2 = norm_fn(2 * chout)
118
+
119
+ self.dconv = None
120
+ if dconv:
121
+ self.dconv = DConv(chout, **dconv_kw)
122
+
123
+ def forward(self, x, inject=None):
124
+ """
125
+ `inject` is used to inject the result from the time branch into the frequency branch,
126
+ when both have the same stride.
127
+ """
128
+ if not self.freq and x.dim() == 4:
129
+ B, C, Fr, T = x.shape
130
+ x = x.view(B, -1, T)
131
+
132
+ if not self.freq:
133
+ le = x.shape[-1]
134
+ if not le % self.stride == 0:
135
+ x = F.pad(x, (0, self.stride - (le % self.stride)))
136
+ y = self.conv(x)
137
+ if self.empty:
138
+ return y
139
+ if inject is not None:
140
+ assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape)
141
+ if inject.dim() == 3 and y.dim() == 4:
142
+ inject = inject[:, :, None]
143
+ y = y + inject
144
+ y = F.gelu(self.norm1(y))
145
+ if self.dconv:
146
+ if self.freq:
147
+ B, C, Fr, T = y.shape
148
+ y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
149
+ y = self.dconv(y)
150
+ if self.freq:
151
+ y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
152
+ if self.rewrite:
153
+ z = self.norm2(self.rewrite(y))
154
+ z = F.glu(z, dim=1)
155
+ else:
156
+ z = y
157
+ return z
158
+
159
+
160
+ class MultiWrap(nn.Module):
161
+ """
162
+ Takes one layer and replicate it N times. each replica will act
163
+ on a frequency band. All is done so that if the N replica have the same weights,
164
+ then this is exactly equivalent to applying the original module on all frequencies.
165
+
166
+ This is a bit over-engineered to avoid edge artifacts when splitting
167
+ the frequency bands, but it is possible the naive implementation would work as well...
168
+ """
169
+ def __init__(self, layer, split_ratios):
170
+ """
171
+ Args:
172
+ layer: module to clone, must be either HEncLayer or HDecLayer.
173
+ split_ratios: list of float indicating which ratio to keep for each band.
174
+ """
175
+ super().__init__()
176
+ self.split_ratios = split_ratios
177
+ self.layers = nn.ModuleList()
178
+ self.conv = isinstance(layer, HEncLayer)
179
+ assert not layer.norm
180
+ assert layer.freq
181
+ assert layer.pad
182
+ if not self.conv:
183
+ assert not layer.context_freq
184
+ for k in range(len(split_ratios) + 1):
185
+ lay = deepcopy(layer)
186
+ if self.conv:
187
+ lay.conv.padding = (0, 0)
188
+ else:
189
+ lay.pad = False
190
+ for m in lay.modules():
191
+ if hasattr(m, 'reset_parameters'):
192
+ m.reset_parameters()
193
+ self.layers.append(lay)
194
+
195
+ def forward(self, x, skip=None, length=None):
196
+ B, C, Fr, T = x.shape
197
+
198
+ ratios = list(self.split_ratios) + [1]
199
+ start = 0
200
+ outs = []
201
+ for ratio, layer in zip(ratios, self.layers):
202
+ if self.conv:
203
+ pad = layer.kernel_size // 4
204
+ if ratio == 1:
205
+ limit = Fr
206
+ frames = -1
207
+ else:
208
+ limit = int(round(Fr * ratio))
209
+ le = limit - start
210
+ if start == 0:
211
+ le += pad
212
+ frames = round((le - layer.kernel_size) / layer.stride + 1)
213
+ limit = start + (frames - 1) * layer.stride + layer.kernel_size
214
+ if start == 0:
215
+ limit -= pad
216
+ assert limit - start > 0, (limit, start)
217
+ assert limit <= Fr, (limit, Fr)
218
+ y = x[:, :, start:limit, :]
219
+ if start == 0:
220
+ y = F.pad(y, (0, 0, pad, 0))
221
+ if ratio == 1:
222
+ y = F.pad(y, (0, 0, 0, pad))
223
+ outs.append(layer(y))
224
+ start = limit - layer.kernel_size + layer.stride
225
+ else:
226
+ if ratio == 1:
227
+ limit = Fr
228
+ else:
229
+ limit = int(round(Fr * ratio))
230
+ last = layer.last
231
+ layer.last = True
232
+
233
+ y = x[:, :, start:limit]
234
+ s = skip[:, :, start:limit]
235
+ out, _ = layer(y, s, None)
236
+ if outs:
237
+ outs[-1][:, :, -layer.stride:] += (
238
+ out[:, :, :layer.stride] - layer.conv_tr.bias.view(1, -1, 1, 1))
239
+ out = out[:, :, layer.stride:]
240
+ if ratio == 1:
241
+ out = out[:, :, :-layer.stride // 2, :]
242
+ if start == 0:
243
+ out = out[:, :, layer.stride // 2:, :]
244
+ outs.append(out)
245
+ layer.last = last
246
+ start = limit
247
+ out = torch.cat(outs, dim=2)
248
+ if not self.conv and not last:
249
+ out = F.gelu(out)
250
+ if self.conv:
251
+ return out
252
+ else:
253
+ return out, None
254
+
255
+
256
+ class HDecLayer(nn.Module):
257
+ def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False,
258
+ freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True,
259
+ context_freq=True, rewrite=True):
260
+ """
261
+ Same as HEncLayer but for decoder. See `HEncLayer` for documentation.
262
+ """
263
+ super().__init__()
264
+ norm_fn = lambda d: nn.Identity() # noqa
265
+ if norm:
266
+ norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
267
+ if pad:
268
+ pad = kernel_size // 4
269
+ else:
270
+ pad = 0
271
+ self.pad = pad
272
+ self.last = last
273
+ self.freq = freq
274
+ self.chin = chin
275
+ self.empty = empty
276
+ self.stride = stride
277
+ self.kernel_size = kernel_size
278
+ self.norm = norm
279
+ self.context_freq = context_freq
280
+ klass = nn.Conv1d
281
+ klass_tr = nn.ConvTranspose1d
282
+ if freq:
283
+ kernel_size = [kernel_size, 1]
284
+ stride = [stride, 1]
285
+ klass = nn.Conv2d
286
+ klass_tr = nn.ConvTranspose2d
287
+ self.conv_tr = klass_tr(chin, chout, kernel_size, stride)
288
+ self.norm2 = norm_fn(chout)
289
+ if self.empty:
290
+ return
291
+ self.rewrite = None
292
+ if rewrite:
293
+ if context_freq:
294
+ self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)
295
+ else:
296
+ self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1,
297
+ [0, context])
298
+ self.norm1 = norm_fn(2 * chin)
299
+
300
+ self.dconv = None
301
+ if dconv:
302
+ self.dconv = DConv(chin, **dconv_kw)
303
+
304
+ def forward(self, x, skip, length):
305
+ if self.freq and x.dim() == 3:
306
+ B, C, T = x.shape
307
+ x = x.view(B, self.chin, -1, T)
308
+
309
+ if not self.empty:
310
+ x = x + skip
311
+
312
+ if self.rewrite:
313
+ y = F.glu(self.norm1(self.rewrite(x)), dim=1)
314
+ else:
315
+ y = x
316
+ if self.dconv:
317
+ if self.freq:
318
+ B, C, Fr, T = y.shape
319
+ y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
320
+ y = self.dconv(y)
321
+ if self.freq:
322
+ y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
323
+ else:
324
+ y = x
325
+ assert skip is None
326
+ z = self.norm2(self.conv_tr(y))
327
+ if self.freq:
328
+ if self.pad:
329
+ z = z[..., self.pad:-self.pad, :]
330
+ else:
331
+ z = z[..., self.pad:self.pad + length]
332
+ assert z.shape[-1] == length, (z.shape[-1], length)
333
+ if not self.last:
334
+ z = F.gelu(z)
335
+ return z, y
336
+
337
+
338
+ class HDemucs(nn.Module):
339
+ """
340
+ Spectrogram and hybrid Demucs model.
341
+ The spectrogram model has the same structure as Demucs, except the first few layers are over the
342
+ frequency axis, until there is only 1 frequency, and then it moves to time convolutions.
343
+ Frequency layers can still access information across time steps thanks to the DConv residual.
344
+
345
+ Hybrid model have a parallel time branch. At some layer, the time branch has the same stride
346
+ as the frequency branch and then the two are combined. The opposite happens in the decoder.
347
+
348
+ Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]),
349
+ or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on
350
+ Open Unmix implementation [Stoter et al. 2019].
351
+
352
+ The loss is always on the temporal domain, by backpropagating through the above
353
+ output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks
354
+ a bit Wiener filtering, as doing more iteration at test time will change the spectrogram
355
+ contribution, without changing the one from the waveform, which will lead to worse performance.
356
+ I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve.
357
+ CaC on the other hand provides similar performance for hybrid, and works naturally with
358
+ hybrid models.
359
+
360
+ This model also uses frequency embeddings are used to improve efficiency on convolutions
361
+ over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf).
362
+
363
+ Unlike classic Demucs, there is no resampling here, and normalization is always applied.
364
+ """
365
+ @capture_init
366
+ def __init__(self,
367
+ sources,
368
+ # Channels
369
+ audio_channels=2,
370
+ channels=48,
371
+ channels_time=None,
372
+ growth=2,
373
+ # STFT
374
+ nfft=4096,
375
+ wiener_iters=0,
376
+ end_iters=0,
377
+ wiener_residual=False,
378
+ cac=True,
379
+ # Main structure
380
+ depth=6,
381
+ rewrite=True,
382
+ hybrid=True,
383
+ hybrid_old=False,
384
+ # Frequency branch
385
+ multi_freqs=None,
386
+ multi_freqs_depth=2,
387
+ freq_emb=0.2,
388
+ emb_scale=10,
389
+ emb_smooth=True,
390
+ # Convolutions
391
+ kernel_size=8,
392
+ time_stride=2,
393
+ stride=4,
394
+ context=1,
395
+ context_enc=0,
396
+ # Normalization
397
+ norm_starts=4,
398
+ norm_groups=4,
399
+ # DConv residual branch
400
+ dconv_mode=1,
401
+ dconv_depth=2,
402
+ dconv_comp=4,
403
+ dconv_attn=4,
404
+ dconv_lstm=4,
405
+ dconv_init=1e-4,
406
+ # Weight init
407
+ rescale=0.1,
408
+ # Metadata
409
+ samplerate=44100,
410
+ segment=4 * 10):
411
+ """
412
+ Args:
413
+ sources (list[str]): list of source names.
414
+ audio_channels (int): input/output audio channels.
415
+ channels (int): initial number of hidden channels.
416
+ channels_time: if not None, use a different `channels` value for the time branch.
417
+ growth: increase the number of hidden channels by this factor at each layer.
418
+ nfft: number of fft bins. Note that changing this require careful computation of
419
+ various shape parameters and will not work out of the box for hybrid models.
420
+ wiener_iters: when using Wiener filtering, number of iterations at test time.
421
+ end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
422
+ wiener_residual: add residual source before wiener filtering.
423
+ cac: uses complex as channels, i.e. complex numbers are 2 channels each
424
+ in input and output. no further processing is done before ISTFT.
425
+ depth (int): number of layers in the encoder and in the decoder.
426
+ rewrite (bool): add 1x1 convolution to each layer.
427
+ hybrid (bool): make a hybrid time/frequency domain, otherwise frequency only.
428
+ hybrid_old: some models trained for MDX had a padding bug. This replicates
429
+ this bug to avoid retraining them.
430
+ multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`.
431
+ multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost
432
+ layers will be wrapped.
433
+ freq_emb: add frequency embedding after the first frequency layer if > 0,
434
+ the actual value controls the weight of the embedding.
435
+ emb_scale: equivalent to scaling the embedding learning rate
436
+ emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
437
+ kernel_size: kernel_size for encoder and decoder layers.
438
+ stride: stride for encoder and decoder layers.
439
+ time_stride: stride for the final time layer, after the merge.
440
+ context: context for 1x1 conv in the decoder.
441
+ context_enc: context for 1x1 conv in the encoder.
442
+ norm_starts: layer at which group norm starts being used.
443
+ decoder layers are numbered in reverse order.
444
+ norm_groups: number of groups for group norm.
445
+ dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
446
+ dconv_depth: depth of residual DConv branch.
447
+ dconv_comp: compression of DConv branch.
448
+ dconv_attn: adds attention layers in DConv branch starting at this layer.
449
+ dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
450
+ dconv_init: initial scale for the DConv branch LayerScale.
451
+ rescale: weight recaling trick
452
+
453
+ """
454
+ super().__init__()
455
+ self.cac = cac
456
+ self.wiener_residual = wiener_residual
457
+ self.audio_channels = audio_channels
458
+ self.sources = sources
459
+ self.kernel_size = kernel_size
460
+ self.context = context
461
+ self.stride = stride
462
+ self.depth = depth
463
+ self.channels = channels
464
+ self.samplerate = samplerate
465
+ self.segment = segment
466
+
467
+ self.nfft = nfft
468
+ self.hop_length = nfft // 4
469
+ self.wiener_iters = wiener_iters
470
+ self.end_iters = end_iters
471
+ self.freq_emb = None
472
+ self.hybrid = hybrid
473
+ self.hybrid_old = hybrid_old
474
+ if hybrid_old:
475
+ assert hybrid, "hybrid_old must come with hybrid=True"
476
+ if hybrid:
477
+ assert wiener_iters == end_iters
478
+
479
+ self.encoder = nn.ModuleList()
480
+ self.decoder = nn.ModuleList()
481
+
482
+ if hybrid:
483
+ self.tencoder = nn.ModuleList()
484
+ self.tdecoder = nn.ModuleList()
485
+
486
+ chin = audio_channels
487
+ chin_z = chin # number of channels for the freq branch
488
+ if self.cac:
489
+ chin_z *= 2
490
+ chout = channels_time or channels
491
+ chout_z = channels
492
+ freqs = nfft // 2
493
+
494
+ for index in range(depth):
495
+ lstm = index >= dconv_lstm
496
+ attn = index >= dconv_attn
497
+ norm = index >= norm_starts
498
+ freq = freqs > 1
499
+ stri = stride
500
+ ker = kernel_size
501
+ if not freq:
502
+ assert freqs == 1
503
+ ker = time_stride * 2
504
+ stri = time_stride
505
+
506
+ pad = True
507
+ last_freq = False
508
+ if freq and freqs <= kernel_size:
509
+ ker = freqs
510
+ pad = False
511
+ last_freq = True
512
+
513
+ kw = {
514
+ 'kernel_size': ker,
515
+ 'stride': stri,
516
+ 'freq': freq,
517
+ 'pad': pad,
518
+ 'norm': norm,
519
+ 'rewrite': rewrite,
520
+ 'norm_groups': norm_groups,
521
+ 'dconv_kw': {
522
+ 'lstm': lstm,
523
+ 'attn': attn,
524
+ 'depth': dconv_depth,
525
+ 'compress': dconv_comp,
526
+ 'init': dconv_init,
527
+ 'gelu': True,
528
+ }
529
+ }
530
+ kwt = dict(kw)
531
+ kwt['freq'] = 0
532
+ kwt['kernel_size'] = kernel_size
533
+ kwt['stride'] = stride
534
+ kwt['pad'] = True
535
+ kw_dec = dict(kw)
536
+ multi = False
537
+ if multi_freqs and index < multi_freqs_depth:
538
+ multi = True
539
+ kw_dec['context_freq'] = False
540
+
541
+ if last_freq:
542
+ chout_z = max(chout, chout_z)
543
+ chout = chout_z
544
+
545
+ enc = HEncLayer(chin_z, chout_z,
546
+ dconv=dconv_mode & 1, context=context_enc, **kw)
547
+ if hybrid and freq:
548
+ tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc,
549
+ empty=last_freq, **kwt)
550
+ self.tencoder.append(tenc)
551
+
552
+ if multi:
553
+ enc = MultiWrap(enc, multi_freqs)
554
+ self.encoder.append(enc)
555
+ if index == 0:
556
+ chin = self.audio_channels * len(self.sources)
557
+ chin_z = chin
558
+ if self.cac:
559
+ chin_z *= 2
560
+ dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2,
561
+ last=index == 0, context=context, **kw_dec)
562
+ if multi:
563
+ dec = MultiWrap(dec, multi_freqs)
564
+ if hybrid and freq:
565
+ tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq,
566
+ last=index == 0, context=context, **kwt)
567
+ self.tdecoder.insert(0, tdec)
568
+ self.decoder.insert(0, dec)
569
+
570
+ chin = chout
571
+ chin_z = chout_z
572
+ chout = int(growth * chout)
573
+ chout_z = int(growth * chout_z)
574
+ if freq:
575
+ if freqs <= kernel_size:
576
+ freqs = 1
577
+ else:
578
+ freqs //= stride
579
+ if index == 0 and freq_emb:
580
+ self.freq_emb = ScaledEmbedding(
581
+ freqs, chin_z, smooth=emb_smooth, scale=emb_scale)
582
+ self.freq_emb_scale = freq_emb
583
+
584
+ if rescale:
585
+ rescale_module(self, reference=rescale)
586
+
587
+ def _spec(self, x):
588
+ hl = self.hop_length
589
+ nfft = self.nfft
590
+ x0 = x # noqa
591
+
592
+ if self.hybrid:
593
+ # We re-pad the signal in order to keep the property
594
+ # that the size of the output is exactly the size of the input
595
+ # divided by the stride (here hop_length), when divisible.
596
+ # This is achieved by padding by 1/4th of the kernel size (here nfft).
597
+ # which is not supported by torch.stft.
598
+ # Having all convolution operations follow this convention allow to easily
599
+ # align the time and frequency branches later on.
600
+ assert hl == nfft // 4
601
+ le = int(math.ceil(x.shape[-1] / hl))
602
+ pad = hl // 2 * 3
603
+ if not self.hybrid_old:
604
+ x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode='reflect')
605
+ else:
606
+ x = pad1d(x, (pad, pad + le * hl - x.shape[-1]))
607
+
608
+ z = spectro(x, nfft, hl)[..., :-1, :]
609
+ if self.hybrid:
610
+ assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
611
+ z = z[..., 2:2+le]
612
+ return z
613
+
614
+ def _ispec(self, z, length=None, scale=0):
615
+ hl = self.hop_length // (4 ** scale)
616
+ z = F.pad(z, (0, 0, 0, 1))
617
+ if self.hybrid:
618
+ z = F.pad(z, (2, 2))
619
+ pad = hl // 2 * 3
620
+ if not self.hybrid_old:
621
+ le = hl * int(math.ceil(length / hl)) + 2 * pad
622
+ else:
623
+ le = hl * int(math.ceil(length / hl))
624
+ x = ispectro(z, hl, length=le)
625
+ if not self.hybrid_old:
626
+ x = x[..., pad:pad + length]
627
+ else:
628
+ x = x[..., :length]
629
+ else:
630
+ x = ispectro(z, hl, length)
631
+ return x
632
+
633
+ def _magnitude(self, z):
634
+ # return the magnitude of the spectrogram, except when cac is True,
635
+ # in which case we just move the complex dimension to the channel one.
636
+ if self.cac:
637
+ B, C, Fr, T = z.shape
638
+ m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
639
+ m = m.reshape(B, C * 2, Fr, T)
640
+ else:
641
+ m = z.abs()
642
+ return m
643
+
644
+ def _mask(self, z, m):
645
+ # Apply masking given the mixture spectrogram `z` and the estimated mask `m`.
646
+ # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored.
647
+ niters = self.wiener_iters
648
+ if self.cac:
649
+ B, S, C, Fr, T = m.shape
650
+ out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
651
+ out = torch.view_as_complex(out.contiguous())
652
+ return out
653
+ if self.training:
654
+ niters = self.end_iters
655
+ if niters < 0:
656
+ z = z[:, None]
657
+ return z / (1e-8 + z.abs()) * m
658
+ else:
659
+ return self._wiener(m, z, niters)
660
+
661
+ def _wiener(self, mag_out, mix_stft, niters):
662
+ # apply wiener filtering from OpenUnmix.
663
+ init = mix_stft.dtype
664
+ wiener_win_len = 300
665
+ residual = self.wiener_residual
666
+
667
+ B, S, C, Fq, T = mag_out.shape
668
+ mag_out = mag_out.permute(0, 4, 3, 2, 1)
669
+ mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
670
+
671
+ outs = []
672
+ for sample in range(B):
673
+ pos = 0
674
+ out = []
675
+ for pos in range(0, T, wiener_win_len):
676
+ frame = slice(pos, pos + wiener_win_len)
677
+ z_out = wiener(
678
+ mag_out[sample, frame], mix_stft[sample, frame], niters,
679
+ residual=residual)
680
+ out.append(z_out.transpose(-1, -2))
681
+ outs.append(torch.cat(out, dim=0))
682
+ out = torch.view_as_complex(torch.stack(outs, 0))
683
+ out = out.permute(0, 4, 3, 2, 1).contiguous()
684
+ if residual:
685
+ out = out[:, :-1]
686
+ assert list(out.shape) == [B, S, C, Fq, T]
687
+ return out.to(init)
688
+
689
+ def forward(self, mix):
690
+ x = mix
691
+ length = x.shape[-1]
692
+
693
+ z = self._spec(mix)
694
+ mag = self._magnitude(z).to(mix.device)
695
+ x = mag
696
+
697
+ B, C, Fq, T = x.shape
698
+
699
+ # unlike previous Demucs, we always normalize because it is easier.
700
+ mean = x.mean(dim=(1, 2, 3), keepdim=True)
701
+ std = x.std(dim=(1, 2, 3), keepdim=True)
702
+ x = (x - mean) / (1e-5 + std)
703
+ # x will be the freq. branch input.
704
+
705
+ if self.hybrid:
706
+ # Prepare the time branch input.
707
+ xt = mix
708
+ meant = xt.mean(dim=(1, 2), keepdim=True)
709
+ stdt = xt.std(dim=(1, 2), keepdim=True)
710
+ xt = (xt - meant) / (1e-5 + stdt)
711
+
712
+ # okay, this is a giant mess I know...
713
+ saved = [] # skip connections, freq.
714
+ saved_t = [] # skip connections, time.
715
+ lengths = [] # saved lengths to properly remove padding, freq branch.
716
+ lengths_t = [] # saved lengths for time branch.
717
+ for idx, encode in enumerate(self.encoder):
718
+ lengths.append(x.shape[-1])
719
+ inject = None
720
+ if self.hybrid and idx < len(self.tencoder):
721
+ # we have not yet merged branches.
722
+ lengths_t.append(xt.shape[-1])
723
+ tenc = self.tencoder[idx]
724
+ xt = tenc(xt)
725
+ if not tenc.empty:
726
+ # save for skip connection
727
+ saved_t.append(xt)
728
+ else:
729
+ # tenc contains just the first conv., so that now time and freq.
730
+ # branches have the same shape and can be merged.
731
+ inject = xt
732
+ x = encode(x, inject)
733
+ if idx == 0 and self.freq_emb is not None:
734
+ # add frequency embedding to allow for non equivariant convolutions
735
+ # over the frequency axis.
736
+ frs = torch.arange(x.shape[-2], device=x.device)
737
+ emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
738
+ x = x + self.freq_emb_scale * emb
739
+
740
+ saved.append(x)
741
+
742
+ x = torch.zeros_like(x)
743
+ if self.hybrid:
744
+ xt = torch.zeros_like(x)
745
+ # initialize everything to zero (signal will go through u-net skips).
746
+
747
+ for idx, decode in enumerate(self.decoder):
748
+ skip = saved.pop(-1)
749
+ x, pre = decode(x, skip, lengths.pop(-1))
750
+ # `pre` contains the output just before final transposed convolution,
751
+ # which is used when the freq. and time branch separate.
752
+
753
+ if self.hybrid:
754
+ offset = self.depth - len(self.tdecoder)
755
+ if self.hybrid and idx >= offset:
756
+ tdec = self.tdecoder[idx - offset]
757
+ length_t = lengths_t.pop(-1)
758
+ if tdec.empty:
759
+ assert pre.shape[2] == 1, pre.shape
760
+ pre = pre[:, :, 0]
761
+ xt, _ = tdec(pre, None, length_t)
762
+ else:
763
+ skip = saved_t.pop(-1)
764
+ xt, _ = tdec(xt, skip, length_t)
765
+
766
+ # Let's make sure we used all stored skip connections.
767
+ assert len(saved) == 0
768
+ assert len(lengths_t) == 0
769
+ assert len(saved_t) == 0
770
+
771
+ S = len(self.sources)
772
+ x = x.view(B, S, -1, Fq, T)
773
+ x = x * std[:, None] + mean[:, None]
774
+
775
+ # to cpu as mps doesnt support complex numbers
776
+ # demucs issue #435 ##432
777
+ # NOTE: in this case z already is on cpu
778
+ # TODO: remove this when mps supports complex numbers
779
+ x_is_mps = x.device.type == "mps"
780
+ if x_is_mps:
781
+ x = x.cpu()
782
+
783
+ zout = self._mask(z, x)
784
+ x = self._ispec(zout, length)
785
+
786
+ # back to mps device
787
+ if x_is_mps:
788
+ x = x.to('mps')
789
+
790
+ if self.hybrid:
791
+ xt = xt.view(B, S, -1, length)
792
+ xt = xt * stdt[:, None] + meant[:, None]
793
+ x = xt + x
794
+ return x
demucs/htdemucs.py ADDED
@@ -0,0 +1,660 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # First author is Simon Rouard.
7
+ """
8
+ This code contains the spectrogram and Hybrid version of Demucs.
9
+ """
10
+ import math
11
+
12
+ from openunmix.filtering import wiener
13
+ import torch
14
+ from torch import nn
15
+ from torch.nn import functional as F
16
+ from fractions import Fraction
17
+ from einops import rearrange
18
+
19
+ from .transformer import CrossTransformerEncoder
20
+
21
+ from .demucs import rescale_module
22
+ from .states import capture_init
23
+ from .spec import spectro, ispectro
24
+ from .hdemucs import pad1d, ScaledEmbedding, HEncLayer, MultiWrap, HDecLayer
25
+
26
+
27
+ class HTDemucs(nn.Module):
28
+ """
29
+ Spectrogram and hybrid Demucs model.
30
+ The spectrogram model has the same structure as Demucs, except the first few layers are over the
31
+ frequency axis, until there is only 1 frequency, and then it moves to time convolutions.
32
+ Frequency layers can still access information across time steps thanks to the DConv residual.
33
+
34
+ Hybrid model have a parallel time branch. At some layer, the time branch has the same stride
35
+ as the frequency branch and then the two are combined. The opposite happens in the decoder.
36
+
37
+ Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]),
38
+ or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on
39
+ Open Unmix implementation [Stoter et al. 2019].
40
+
41
+ The loss is always on the temporal domain, by backpropagating through the above
42
+ output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks
43
+ a bit Wiener filtering, as doing more iteration at test time will change the spectrogram
44
+ contribution, without changing the one from the waveform, which will lead to worse performance.
45
+ I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve.
46
+ CaC on the other hand provides similar performance for hybrid, and works naturally with
47
+ hybrid models.
48
+
49
+ This model also uses frequency embeddings are used to improve efficiency on convolutions
50
+ over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf).
51
+
52
+ Unlike classic Demucs, there is no resampling here, and normalization is always applied.
53
+ """
54
+
55
+ @capture_init
56
+ def __init__(
57
+ self,
58
+ sources,
59
+ # Channels
60
+ audio_channels=2,
61
+ channels=48,
62
+ channels_time=None,
63
+ growth=2,
64
+ # STFT
65
+ nfft=4096,
66
+ wiener_iters=0,
67
+ end_iters=0,
68
+ wiener_residual=False,
69
+ cac=True,
70
+ # Main structure
71
+ depth=4,
72
+ rewrite=True,
73
+ # Frequency branch
74
+ multi_freqs=None,
75
+ multi_freqs_depth=3,
76
+ freq_emb=0.2,
77
+ emb_scale=10,
78
+ emb_smooth=True,
79
+ # Convolutions
80
+ kernel_size=8,
81
+ time_stride=2,
82
+ stride=4,
83
+ context=1,
84
+ context_enc=0,
85
+ # Normalization
86
+ norm_starts=4,
87
+ norm_groups=4,
88
+ # DConv residual branch
89
+ dconv_mode=1,
90
+ dconv_depth=2,
91
+ dconv_comp=8,
92
+ dconv_init=1e-3,
93
+ # Before the Transformer
94
+ bottom_channels=0,
95
+ # Transformer
96
+ t_layers=5,
97
+ t_emb="sin",
98
+ t_hidden_scale=4.0,
99
+ t_heads=8,
100
+ t_dropout=0.0,
101
+ t_max_positions=10000,
102
+ t_norm_in=True,
103
+ t_norm_in_group=False,
104
+ t_group_norm=False,
105
+ t_norm_first=True,
106
+ t_norm_out=True,
107
+ t_max_period=10000.0,
108
+ t_weight_decay=0.0,
109
+ t_lr=None,
110
+ t_layer_scale=True,
111
+ t_gelu=True,
112
+ t_weight_pos_embed=1.0,
113
+ t_sin_random_shift=0,
114
+ t_cape_mean_normalize=True,
115
+ t_cape_augment=True,
116
+ t_cape_glob_loc_scale=[5000.0, 1.0, 1.4],
117
+ t_sparse_self_attn=False,
118
+ t_sparse_cross_attn=False,
119
+ t_mask_type="diag",
120
+ t_mask_random_seed=42,
121
+ t_sparse_attn_window=500,
122
+ t_global_window=100,
123
+ t_sparsity=0.95,
124
+ t_auto_sparsity=False,
125
+ # ------ Particuliar parameters
126
+ t_cross_first=False,
127
+ # Weight init
128
+ rescale=0.1,
129
+ # Metadata
130
+ samplerate=44100,
131
+ segment=10,
132
+ use_train_segment=True,
133
+ ):
134
+ """
135
+ Args:
136
+ sources (list[str]): list of source names.
137
+ audio_channels (int): input/output audio channels.
138
+ channels (int): initial number of hidden channels.
139
+ channels_time: if not None, use a different `channels` value for the time branch.
140
+ growth: increase the number of hidden channels by this factor at each layer.
141
+ nfft: number of fft bins. Note that changing this require careful computation of
142
+ various shape parameters and will not work out of the box for hybrid models.
143
+ wiener_iters: when using Wiener filtering, number of iterations at test time.
144
+ end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
145
+ wiener_residual: add residual source before wiener filtering.
146
+ cac: uses complex as channels, i.e. complex numbers are 2 channels each
147
+ in input and output. no further processing is done before ISTFT.
148
+ depth (int): number of layers in the encoder and in the decoder.
149
+ rewrite (bool): add 1x1 convolution to each layer.
150
+ multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`.
151
+ multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost
152
+ layers will be wrapped.
153
+ freq_emb: add frequency embedding after the first frequency layer if > 0,
154
+ the actual value controls the weight of the embedding.
155
+ emb_scale: equivalent to scaling the embedding learning rate
156
+ emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
157
+ kernel_size: kernel_size for encoder and decoder layers.
158
+ stride: stride for encoder and decoder layers.
159
+ time_stride: stride for the final time layer, after the merge.
160
+ context: context for 1x1 conv in the decoder.
161
+ context_enc: context for 1x1 conv in the encoder.
162
+ norm_starts: layer at which group norm starts being used.
163
+ decoder layers are numbered in reverse order.
164
+ norm_groups: number of groups for group norm.
165
+ dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
166
+ dconv_depth: depth of residual DConv branch.
167
+ dconv_comp: compression of DConv branch.
168
+ dconv_attn: adds attention layers in DConv branch starting at this layer.
169
+ dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
170
+ dconv_init: initial scale for the DConv branch LayerScale.
171
+ bottom_channels: if >0 it adds a linear layer (1x1 Conv) before and after the
172
+ transformer in order to change the number of channels
173
+ t_layers: number of layers in each branch (waveform and spec) of the transformer
174
+ t_emb: "sin", "cape" or "scaled"
175
+ t_hidden_scale: the hidden scale of the Feedforward parts of the transformer
176
+ for instance if C = 384 (the number of channels in the transformer) and
177
+ t_hidden_scale = 4.0 then the intermediate layer of the FFN has dimension
178
+ 384 * 4 = 1536
179
+ t_heads: number of heads for the transformer
180
+ t_dropout: dropout in the transformer
181
+ t_max_positions: max_positions for the "scaled" positional embedding, only
182
+ useful if t_emb="scaled"
183
+ t_norm_in: (bool) norm before addinf positional embedding and getting into the
184
+ transformer layers
185
+ t_norm_in_group: (bool) if True while t_norm_in=True, the norm is on all the
186
+ timesteps (GroupNorm with group=1)
187
+ t_group_norm: (bool) if True, the norms of the Encoder Layers are on all the
188
+ timesteps (GroupNorm with group=1)
189
+ t_norm_first: (bool) if True the norm is before the attention and before the FFN
190
+ t_norm_out: (bool) if True, there is a GroupNorm (group=1) at the end of each layer
191
+ t_max_period: (float) denominator in the sinusoidal embedding expression
192
+ t_weight_decay: (float) weight decay for the transformer
193
+ t_lr: (float) specific learning rate for the transformer
194
+ t_layer_scale: (bool) Layer Scale for the transformer
195
+ t_gelu: (bool) activations of the transformer are GeLU if True, ReLU else
196
+ t_weight_pos_embed: (float) weighting of the positional embedding
197
+ t_cape_mean_normalize: (bool) if t_emb="cape", normalisation of positional embeddings
198
+ see: https://arxiv.org/abs/2106.03143
199
+ t_cape_augment: (bool) if t_emb="cape", must be True during training and False
200
+ during the inference, see: https://arxiv.org/abs/2106.03143
201
+ t_cape_glob_loc_scale: (list of 3 floats) if t_emb="cape", CAPE parameters
202
+ see: https://arxiv.org/abs/2106.03143
203
+ t_sparse_self_attn: (bool) if True, the self attentions are sparse
204
+ t_sparse_cross_attn: (bool) if True, the cross-attentions are sparse (don't use it
205
+ unless you designed really specific masks)
206
+ t_mask_type: (str) can be "diag", "jmask", "random", "global" or any combination
207
+ with '_' between: i.e. "diag_jmask_random" (note that this is permutation
208
+ invariant i.e. "diag_jmask_random" is equivalent to "jmask_random_diag")
209
+ t_mask_random_seed: (int) if "random" is in t_mask_type, controls the seed
210
+ that generated the random part of the mask
211
+ t_sparse_attn_window: (int) if "diag" is in t_mask_type, for a query (i), and
212
+ a key (j), the mask is True id |i-j|<=t_sparse_attn_window
213
+ t_global_window: (int) if "global" is in t_mask_type, mask[:t_global_window, :]
214
+ and mask[:, :t_global_window] will be True
215
+ t_sparsity: (float) if "random" is in t_mask_type, t_sparsity is the sparsity
216
+ level of the random part of the mask.
217
+ t_cross_first: (bool) if True cross attention is the first layer of the
218
+ transformer (False seems to be better)
219
+ rescale: weight rescaling trick
220
+ use_train_segment: (bool) if True, the actual size that is used during the
221
+ training is used during inference.
222
+ """
223
+ super().__init__()
224
+ self.cac = cac
225
+ self.wiener_residual = wiener_residual
226
+ self.audio_channels = audio_channels
227
+ self.sources = sources
228
+ self.kernel_size = kernel_size
229
+ self.context = context
230
+ self.stride = stride
231
+ self.depth = depth
232
+ self.bottom_channels = bottom_channels
233
+ self.channels = channels
234
+ self.samplerate = samplerate
235
+ self.segment = segment
236
+ self.use_train_segment = use_train_segment
237
+ self.nfft = nfft
238
+ self.hop_length = nfft // 4
239
+ self.wiener_iters = wiener_iters
240
+ self.end_iters = end_iters
241
+ self.freq_emb = None
242
+ assert wiener_iters == end_iters
243
+
244
+ self.encoder = nn.ModuleList()
245
+ self.decoder = nn.ModuleList()
246
+
247
+ self.tencoder = nn.ModuleList()
248
+ self.tdecoder = nn.ModuleList()
249
+
250
+ chin = audio_channels
251
+ chin_z = chin # number of channels for the freq branch
252
+ if self.cac:
253
+ chin_z *= 2
254
+ chout = channels_time or channels
255
+ chout_z = channels
256
+ freqs = nfft // 2
257
+
258
+ for index in range(depth):
259
+ norm = index >= norm_starts
260
+ freq = freqs > 1
261
+ stri = stride
262
+ ker = kernel_size
263
+ if not freq:
264
+ assert freqs == 1
265
+ ker = time_stride * 2
266
+ stri = time_stride
267
+
268
+ pad = True
269
+ last_freq = False
270
+ if freq and freqs <= kernel_size:
271
+ ker = freqs
272
+ pad = False
273
+ last_freq = True
274
+
275
+ kw = {
276
+ "kernel_size": ker,
277
+ "stride": stri,
278
+ "freq": freq,
279
+ "pad": pad,
280
+ "norm": norm,
281
+ "rewrite": rewrite,
282
+ "norm_groups": norm_groups,
283
+ "dconv_kw": {
284
+ "depth": dconv_depth,
285
+ "compress": dconv_comp,
286
+ "init": dconv_init,
287
+ "gelu": True,
288
+ },
289
+ }
290
+ kwt = dict(kw)
291
+ kwt["freq"] = 0
292
+ kwt["kernel_size"] = kernel_size
293
+ kwt["stride"] = stride
294
+ kwt["pad"] = True
295
+ kw_dec = dict(kw)
296
+ multi = False
297
+ if multi_freqs and index < multi_freqs_depth:
298
+ multi = True
299
+ kw_dec["context_freq"] = False
300
+
301
+ if last_freq:
302
+ chout_z = max(chout, chout_z)
303
+ chout = chout_z
304
+
305
+ enc = HEncLayer(
306
+ chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw
307
+ )
308
+ if freq:
309
+ tenc = HEncLayer(
310
+ chin,
311
+ chout,
312
+ dconv=dconv_mode & 1,
313
+ context=context_enc,
314
+ empty=last_freq,
315
+ **kwt
316
+ )
317
+ self.tencoder.append(tenc)
318
+
319
+ if multi:
320
+ enc = MultiWrap(enc, multi_freqs)
321
+ self.encoder.append(enc)
322
+ if index == 0:
323
+ chin = self.audio_channels * len(self.sources)
324
+ chin_z = chin
325
+ if self.cac:
326
+ chin_z *= 2
327
+ dec = HDecLayer(
328
+ chout_z,
329
+ chin_z,
330
+ dconv=dconv_mode & 2,
331
+ last=index == 0,
332
+ context=context,
333
+ **kw_dec
334
+ )
335
+ if multi:
336
+ dec = MultiWrap(dec, multi_freqs)
337
+ if freq:
338
+ tdec = HDecLayer(
339
+ chout,
340
+ chin,
341
+ dconv=dconv_mode & 2,
342
+ empty=last_freq,
343
+ last=index == 0,
344
+ context=context,
345
+ **kwt
346
+ )
347
+ self.tdecoder.insert(0, tdec)
348
+ self.decoder.insert(0, dec)
349
+
350
+ chin = chout
351
+ chin_z = chout_z
352
+ chout = int(growth * chout)
353
+ chout_z = int(growth * chout_z)
354
+ if freq:
355
+ if freqs <= kernel_size:
356
+ freqs = 1
357
+ else:
358
+ freqs //= stride
359
+ if index == 0 and freq_emb:
360
+ self.freq_emb = ScaledEmbedding(
361
+ freqs, chin_z, smooth=emb_smooth, scale=emb_scale
362
+ )
363
+ self.freq_emb_scale = freq_emb
364
+
365
+ if rescale:
366
+ rescale_module(self, reference=rescale)
367
+
368
+ transformer_channels = channels * growth ** (depth - 1)
369
+ if bottom_channels:
370
+ self.channel_upsampler = nn.Conv1d(transformer_channels, bottom_channels, 1)
371
+ self.channel_downsampler = nn.Conv1d(
372
+ bottom_channels, transformer_channels, 1
373
+ )
374
+ self.channel_upsampler_t = nn.Conv1d(
375
+ transformer_channels, bottom_channels, 1
376
+ )
377
+ self.channel_downsampler_t = nn.Conv1d(
378
+ bottom_channels, transformer_channels, 1
379
+ )
380
+
381
+ transformer_channels = bottom_channels
382
+
383
+ if t_layers > 0:
384
+ self.crosstransformer = CrossTransformerEncoder(
385
+ dim=transformer_channels,
386
+ emb=t_emb,
387
+ hidden_scale=t_hidden_scale,
388
+ num_heads=t_heads,
389
+ num_layers=t_layers,
390
+ cross_first=t_cross_first,
391
+ dropout=t_dropout,
392
+ max_positions=t_max_positions,
393
+ norm_in=t_norm_in,
394
+ norm_in_group=t_norm_in_group,
395
+ group_norm=t_group_norm,
396
+ norm_first=t_norm_first,
397
+ norm_out=t_norm_out,
398
+ max_period=t_max_period,
399
+ weight_decay=t_weight_decay,
400
+ lr=t_lr,
401
+ layer_scale=t_layer_scale,
402
+ gelu=t_gelu,
403
+ sin_random_shift=t_sin_random_shift,
404
+ weight_pos_embed=t_weight_pos_embed,
405
+ cape_mean_normalize=t_cape_mean_normalize,
406
+ cape_augment=t_cape_augment,
407
+ cape_glob_loc_scale=t_cape_glob_loc_scale,
408
+ sparse_self_attn=t_sparse_self_attn,
409
+ sparse_cross_attn=t_sparse_cross_attn,
410
+ mask_type=t_mask_type,
411
+ mask_random_seed=t_mask_random_seed,
412
+ sparse_attn_window=t_sparse_attn_window,
413
+ global_window=t_global_window,
414
+ sparsity=t_sparsity,
415
+ auto_sparsity=t_auto_sparsity,
416
+ )
417
+ else:
418
+ self.crosstransformer = None
419
+
420
+ def _spec(self, x):
421
+ hl = self.hop_length
422
+ nfft = self.nfft
423
+ x0 = x # noqa
424
+
425
+ # We re-pad the signal in order to keep the property
426
+ # that the size of the output is exactly the size of the input
427
+ # divided by the stride (here hop_length), when divisible.
428
+ # This is achieved by padding by 1/4th of the kernel size (here nfft).
429
+ # which is not supported by torch.stft.
430
+ # Having all convolution operations follow this convention allow to easily
431
+ # align the time and frequency branches later on.
432
+ assert hl == nfft // 4
433
+ le = int(math.ceil(x.shape[-1] / hl))
434
+ pad = hl // 2 * 3
435
+ x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect")
436
+
437
+ z = spectro(x, nfft, hl)[..., :-1, :]
438
+ assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
439
+ z = z[..., 2: 2 + le]
440
+ return z
441
+
442
+ def _ispec(self, z, length=None, scale=0):
443
+ hl = self.hop_length // (4**scale)
444
+ z = F.pad(z, (0, 0, 0, 1))
445
+ z = F.pad(z, (2, 2))
446
+ pad = hl // 2 * 3
447
+ le = hl * int(math.ceil(length / hl)) + 2 * pad
448
+ x = ispectro(z, hl, length=le)
449
+ x = x[..., pad: pad + length]
450
+ return x
451
+
452
+ def _magnitude(self, z):
453
+ # return the magnitude of the spectrogram, except when cac is True,
454
+ # in which case we just move the complex dimension to the channel one.
455
+ if self.cac:
456
+ B, C, Fr, T = z.shape
457
+ m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
458
+ m = m.reshape(B, C * 2, Fr, T)
459
+ else:
460
+ m = z.abs()
461
+ return m
462
+
463
+ def _mask(self, z, m):
464
+ # Apply masking given the mixture spectrogram `z` and the estimated mask `m`.
465
+ # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored.
466
+ niters = self.wiener_iters
467
+ if self.cac:
468
+ B, S, C, Fr, T = m.shape
469
+ out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
470
+ out = torch.view_as_complex(out.contiguous())
471
+ return out
472
+ if self.training:
473
+ niters = self.end_iters
474
+ if niters < 0:
475
+ z = z[:, None]
476
+ return z / (1e-8 + z.abs()) * m
477
+ else:
478
+ return self._wiener(m, z, niters)
479
+
480
+ def _wiener(self, mag_out, mix_stft, niters):
481
+ # apply wiener filtering from OpenUnmix.
482
+ init = mix_stft.dtype
483
+ wiener_win_len = 300
484
+ residual = self.wiener_residual
485
+
486
+ B, S, C, Fq, T = mag_out.shape
487
+ mag_out = mag_out.permute(0, 4, 3, 2, 1)
488
+ mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
489
+
490
+ outs = []
491
+ for sample in range(B):
492
+ pos = 0
493
+ out = []
494
+ for pos in range(0, T, wiener_win_len):
495
+ frame = slice(pos, pos + wiener_win_len)
496
+ z_out = wiener(
497
+ mag_out[sample, frame],
498
+ mix_stft[sample, frame],
499
+ niters,
500
+ residual=residual,
501
+ )
502
+ out.append(z_out.transpose(-1, -2))
503
+ outs.append(torch.cat(out, dim=0))
504
+ out = torch.view_as_complex(torch.stack(outs, 0))
505
+ out = out.permute(0, 4, 3, 2, 1).contiguous()
506
+ if residual:
507
+ out = out[:, :-1]
508
+ assert list(out.shape) == [B, S, C, Fq, T]
509
+ return out.to(init)
510
+
511
+ def valid_length(self, length: int):
512
+ """
513
+ Return a length that is appropriate for evaluation.
514
+ In our case, always return the training length, unless
515
+ it is smaller than the given length, in which case this
516
+ raises an error.
517
+ """
518
+ if not self.use_train_segment:
519
+ return length
520
+ training_length = int(self.segment * self.samplerate)
521
+ if training_length < length:
522
+ raise ValueError(
523
+ f"Given length {length} is longer than "
524
+ f"training length {training_length}")
525
+ return training_length
526
+
527
+ def forward(self, mix):
528
+ length = mix.shape[-1]
529
+ length_pre_pad = None
530
+ if self.use_train_segment:
531
+ if self.training:
532
+ self.segment = Fraction(mix.shape[-1], self.samplerate)
533
+ else:
534
+ training_length = int(self.segment * self.samplerate)
535
+ if mix.shape[-1] < training_length:
536
+ length_pre_pad = mix.shape[-1]
537
+ mix = F.pad(mix, (0, training_length - length_pre_pad))
538
+ z = self._spec(mix)
539
+ mag = self._magnitude(z).to(mix.device)
540
+ x = mag
541
+
542
+ B, C, Fq, T = x.shape
543
+
544
+ # unlike previous Demucs, we always normalize because it is easier.
545
+ mean = x.mean(dim=(1, 2, 3), keepdim=True)
546
+ std = x.std(dim=(1, 2, 3), keepdim=True)
547
+ x = (x - mean) / (1e-5 + std)
548
+ # x will be the freq. branch input.
549
+
550
+ # Prepare the time branch input.
551
+ xt = mix
552
+ meant = xt.mean(dim=(1, 2), keepdim=True)
553
+ stdt = xt.std(dim=(1, 2), keepdim=True)
554
+ xt = (xt - meant) / (1e-5 + stdt)
555
+
556
+ # okay, this is a giant mess I know...
557
+ saved = [] # skip connections, freq.
558
+ saved_t = [] # skip connections, time.
559
+ lengths = [] # saved lengths to properly remove padding, freq branch.
560
+ lengths_t = [] # saved lengths for time branch.
561
+ for idx, encode in enumerate(self.encoder):
562
+ lengths.append(x.shape[-1])
563
+ inject = None
564
+ if idx < len(self.tencoder):
565
+ # we have not yet merged branches.
566
+ lengths_t.append(xt.shape[-1])
567
+ tenc = self.tencoder[idx]
568
+ xt = tenc(xt)
569
+ if not tenc.empty:
570
+ # save for skip connection
571
+ saved_t.append(xt)
572
+ else:
573
+ # tenc contains just the first conv., so that now time and freq.
574
+ # branches have the same shape and can be merged.
575
+ inject = xt
576
+ x = encode(x, inject)
577
+ if idx == 0 and self.freq_emb is not None:
578
+ # add frequency embedding to allow for non equivariant convolutions
579
+ # over the frequency axis.
580
+ frs = torch.arange(x.shape[-2], device=x.device)
581
+ emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
582
+ x = x + self.freq_emb_scale * emb
583
+
584
+ saved.append(x)
585
+ if self.crosstransformer:
586
+ if self.bottom_channels:
587
+ b, c, f, t = x.shape
588
+ x = rearrange(x, "b c f t-> b c (f t)")
589
+ x = self.channel_upsampler(x)
590
+ x = rearrange(x, "b c (f t)-> b c f t", f=f)
591
+ xt = self.channel_upsampler_t(xt)
592
+
593
+ x, xt = self.crosstransformer(x, xt)
594
+
595
+ if self.bottom_channels:
596
+ x = rearrange(x, "b c f t-> b c (f t)")
597
+ x = self.channel_downsampler(x)
598
+ x = rearrange(x, "b c (f t)-> b c f t", f=f)
599
+ xt = self.channel_downsampler_t(xt)
600
+
601
+ for idx, decode in enumerate(self.decoder):
602
+ skip = saved.pop(-1)
603
+ x, pre = decode(x, skip, lengths.pop(-1))
604
+ # `pre` contains the output just before final transposed convolution,
605
+ # which is used when the freq. and time branch separate.
606
+
607
+ offset = self.depth - len(self.tdecoder)
608
+ if idx >= offset:
609
+ tdec = self.tdecoder[idx - offset]
610
+ length_t = lengths_t.pop(-1)
611
+ if tdec.empty:
612
+ assert pre.shape[2] == 1, pre.shape
613
+ pre = pre[:, :, 0]
614
+ xt, _ = tdec(pre, None, length_t)
615
+ else:
616
+ skip = saved_t.pop(-1)
617
+ xt, _ = tdec(xt, skip, length_t)
618
+
619
+ # Let's make sure we used all stored skip connections.
620
+ assert len(saved) == 0
621
+ assert len(lengths_t) == 0
622
+ assert len(saved_t) == 0
623
+
624
+ S = len(self.sources)
625
+ x = x.view(B, S, -1, Fq, T)
626
+ x = x * std[:, None] + mean[:, None]
627
+
628
+ # to cpu as mps doesnt support complex numbers
629
+ # demucs issue #435 ##432
630
+ # NOTE: in this case z already is on cpu
631
+ # TODO: remove this when mps supports complex numbers
632
+ x_is_mps = x.device.type == "mps"
633
+ if x_is_mps:
634
+ x = x.cpu()
635
+
636
+ zout = self._mask(z, x)
637
+ if self.use_train_segment:
638
+ if self.training:
639
+ x = self._ispec(zout, length)
640
+ else:
641
+ x = self._ispec(zout, training_length)
642
+ else:
643
+ x = self._ispec(zout, length)
644
+
645
+ # back to mps device
646
+ if x_is_mps:
647
+ x = x.to("mps")
648
+
649
+ if self.use_train_segment:
650
+ if self.training:
651
+ xt = xt.view(B, S, -1, length)
652
+ else:
653
+ xt = xt.view(B, S, -1, training_length)
654
+ else:
655
+ xt = xt.view(B, S, -1, length)
656
+ xt = xt * stdt[:, None] + meant[:, None]
657
+ x = xt + x
658
+ if length_pre_pad:
659
+ x = x[..., :length_pre_pad]
660
+ return x
demucs/pretrained.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Loading pretrained models.
7
+ """
8
+
9
+ import logging
10
+ from pathlib import Path
11
+ import typing as tp
12
+
13
+ from dora.log import fatal, bold
14
+
15
+ from .hdemucs import HDemucs
16
+ from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo, AnyModelRepo, ModelLoadingError # noqa
17
+ from .states import _check_diffq
18
+
19
+ logger = logging.getLogger(__name__)
20
+ ROOT_URL = "https://dl.fbaipublicfiles.com/demucs/"
21
+ REMOTE_ROOT = Path(__file__).parent / 'remote'
22
+
23
+ SOURCES = ["drums", "bass", "other", "vocals"]
24
+ DEFAULT_MODEL = 'htdemucs'
25
+
26
+
27
+ def demucs_unittest():
28
+ model = HDemucs(channels=4, sources=SOURCES)
29
+ return model
30
+
31
+
32
+ def add_model_flags(parser):
33
+ group = parser.add_mutually_exclusive_group(required=False)
34
+ group.add_argument("-s", "--sig", help="Locally trained XP signature.")
35
+ group.add_argument("-n", "--name", default="htdemucs",
36
+ help="Pretrained model name or signature. Default is htdemucs.")
37
+ parser.add_argument("--repo", type=Path,
38
+ help="Folder containing all pre-trained models for use with -n.")
39
+
40
+
41
+ def _parse_remote_files(remote_file_list) -> tp.Dict[str, str]:
42
+ root: str = ''
43
+ models: tp.Dict[str, str] = {}
44
+ for line in remote_file_list.read_text().split('\n'):
45
+ line = line.strip()
46
+ if line.startswith('#'):
47
+ continue
48
+ elif len(line) == 0:
49
+ continue
50
+ elif line.startswith('root:'):
51
+ root = line.split(':', 1)[1].strip()
52
+ else:
53
+ sig = line.split('-', 1)[0]
54
+ assert sig not in models
55
+ models[sig] = ROOT_URL + root + line
56
+ return models
57
+
58
+
59
+ def get_model(name: str,
60
+ repo: tp.Optional[Path] = None):
61
+ """`name` must be a bag of models name or a pretrained signature
62
+ from the remote AWS model repo or the specified local repo if `repo` is not None.
63
+ """
64
+ if name == 'demucs_unittest':
65
+ return demucs_unittest()
66
+ model_repo: ModelOnlyRepo
67
+ if repo is None:
68
+ models = _parse_remote_files(REMOTE_ROOT / 'files.txt')
69
+ model_repo = RemoteRepo(models)
70
+ bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo)
71
+ else:
72
+ if not repo.is_dir():
73
+ fatal(f"{repo} must exist and be a directory.")
74
+ model_repo = LocalRepo(repo)
75
+ bag_repo = BagOnlyRepo(repo, model_repo)
76
+ any_repo = AnyModelRepo(model_repo, bag_repo)
77
+ try:
78
+ model = any_repo.get_model(name)
79
+ except ImportError as exc:
80
+ if 'diffq' in exc.args[0]:
81
+ _check_diffq()
82
+ raise
83
+
84
+ model.eval()
85
+ return model
86
+
87
+
88
+ def get_model_from_args(args):
89
+ """
90
+ Load local model package or pre-trained model.
91
+ """
92
+ if args.name is None:
93
+ args.name = DEFAULT_MODEL
94
+ print(bold("Important: the default model was recently changed to `htdemucs`"),
95
+ "the latest Hybrid Transformer Demucs model. In some cases, this model can "
96
+ "actually perform worse than previous models. To get back the old default model "
97
+ "use `-n mdx_extra_q`.")
98
+ return get_model(name=args.name, repo=args.repo)
demucs/py.typed ADDED
File without changes
demucs/remote/files.txt ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MDX Models
2
+ root: mdx_final/
3
+ 0d19c1c6-0f06f20e.th
4
+ 5d2d6c55-db83574e.th
5
+ 7d865c68-3d5dd56b.th
6
+ 7ecf8ec1-70f50cc9.th
7
+ a1d90b5c-ae9d2452.th
8
+ c511e2ab-fe698775.th
9
+ cfa93e08-61801ae1.th
10
+ e51eebcc-c1b80bdd.th
11
+ 6b9c2ca1-3fd82607.th
12
+ b72baf4e-8778635e.th
13
+ 42e558d4-196e0e1b.th
14
+ 305bc58f-18378783.th
15
+ 14fc6a69-a89dd0ee.th
16
+ 464b36d7-e5a9386e.th
17
+ 7fd6ef75-a905dd85.th
18
+ 83fc094f-4a16d450.th
19
+ 1ef250f1-592467ce.th
20
+ 902315c2-b39ce9c9.th
21
+ 9a6b4851-03af0aa6.th
22
+ fa0cb7f9-100d8bf4.th
23
+ # Hybrid Transformer models
24
+ root: hybrid_transformer/
25
+ 955717e8-8726e21a.th
26
+ f7e0c4bc-ba3fe64a.th
27
+ d12395a8-e57c48e6.th
28
+ 92cfc3b6-ef3bcb9c.th
29
+ 04573f0d-f3cf25b2.th
30
+ 75fc33f5-1941ce65.th
31
+ # Experimental 6 sources model
32
+ 5c90dfd2-34c22ccb.th