Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +1 -0
- Levo_Song_Generation/SongGeneration-Runtime/ckpt/encode-s12k.pt +3 -0
- Levo_Song_Generation/SongGeneration-Runtime/ckpt/model_1rvq/model_2_fixed.safetensors +3 -0
- Levo_Song_Generation/SongGeneration-Runtime/ckpt/model_septoken/model_2.safetensors +3 -0
- Levo_Song_Generation/SongGeneration-Runtime/ckpt/models--lengyue233--content-vec-best/blobs/d8dd400e054ddf4e6be75dab5a2549db748cc99e756a097c496c099f65a4854e +3 -0
- Levo_Song_Generation/SongGeneration-Runtime/ckpt/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68/pytorch_model.bin +3 -0
- Levo_Song_Generation/SongGeneration-Runtime/ckpt/vae/autoencoder_music_1320k.ckpt +3 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/Qwen2-7B/LICENSE +202 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/Qwen2-7B/README.md +97 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/Qwen2-7B/config.json +27 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/Qwen2-7B/generation_config.json +7 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/Qwen2-7B/merges.txt +0 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/Qwen2-7B/tokenizer.json +0 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/Qwen2-7B/tokenizer_config.json +40 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/Qwen2-7B/vocab.json +0 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/demucs/__init__.py +0 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/demucs/ckpt/htdemucs.pth +3 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/demucs/ckpt/htdemucs.yaml +1 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/demucs/models/__init__.py +0 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/demucs/models/apply.py +315 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/demucs/models/audio.py +291 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/demucs/models/demucs.py +452 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/demucs/models/htdemucs.py +955 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/demucs/models/pretrained.py +34 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/demucs/models/spec.py +51 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/demucs/models/states.py +102 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/demucs/models/transformer.py +765 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/demucs/models/utils.py +125 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/demucs/run.py +109 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/hub/version.txt +1 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/config/model_1920.json +122 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/config/model_config.json +122 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/docs/autoencoders.md +357 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/docs/conditioning.md +158 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/docs/datasets.md +75 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/docs/diffusion.md +153 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/docs/pretransforms.md +43 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/scripts/ds_zero_to_pl_ckpt.py +14 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/stable_audio_tools/data/__init__.py +0 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/stable_audio_tools/data/dataset.py +654 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/stable_audio_tools/data/utils.py +96 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/stable_audio_tools/inference/__init__.py +0 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/stable_audio_tools/inference/generation.py +274 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/stable_audio_tools/inference/sampling.py +232 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/stable_audio_tools/inference/utils.py +35 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/stable_audio_tools/interface/__init__.py +0 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/stable_audio_tools/interface/gradio.py +700 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/stable_audio_tools/models/autoencoders.py +794 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/stable_audio_tools/models/bottleneck.py +355 -0
- Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/stable_audio_tools/models/conditioners.py +561 -0
.gitattributes
CHANGED
|
@@ -96,3 +96,4 @@ torchmcubes-0.1.0-cp310-cp310-win_amd64.whl filter=lfs diff=lfs merge=lfs -text
|
|
| 96 |
torchmcubes-0.1.0-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 97 |
groundingdino-0.1.0-cp310-cp310-win_amd64.whl filter=lfs diff=lfs merge=lfs -text
|
| 98 |
groundingdino-0.1.0-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 96 |
torchmcubes-0.1.0-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 97 |
groundingdino-0.1.0-cp310-cp310-win_amd64.whl filter=lfs diff=lfs merge=lfs -text
|
| 98 |
groundingdino-0.1.0-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 99 |
+
Levo_Song_Generation/SongGeneration-Runtime/ckpt/models--lengyue233--content-vec-best/blobs/d8dd400e054ddf4e6be75dab5a2549db748cc99e756a097c496c099f65a4854e filter=lfs diff=lfs merge=lfs -text
|
Levo_Song_Generation/SongGeneration-Runtime/ckpt/encode-s12k.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e250df56b035f74c1f66f15133f4c78f664d70fa0b09aa9a752b7871bb58c02f
|
| 3 |
+
size 3957949089
|
Levo_Song_Generation/SongGeneration-Runtime/ckpt/model_1rvq/model_2_fixed.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:339a16956b859a82defc02bfd32c3744d11ff942065f6ec9306dfd4400d62110
|
| 3 |
+
size 4704507596
|
Levo_Song_Generation/SongGeneration-Runtime/ckpt/model_septoken/model_2.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:430b7c1c245722fbe3893cd621b3d4a90076404596e9fb1ce987a4a0f2a4fc6f
|
| 3 |
+
size 4808167708
|
Levo_Song_Generation/SongGeneration-Runtime/ckpt/models--lengyue233--content-vec-best/blobs/d8dd400e054ddf4e6be75dab5a2549db748cc99e756a097c496c099f65a4854e
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d8dd400e054ddf4e6be75dab5a2549db748cc99e756a097c496c099f65a4854e
|
| 3 |
+
size 378342945
|
Levo_Song_Generation/SongGeneration-Runtime/ckpt/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68/pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d8dd400e054ddf4e6be75dab5a2549db748cc99e756a097c496c099f65a4854e
|
| 3 |
+
size 378342945
|
Levo_Song_Generation/SongGeneration-Runtime/ckpt/vae/autoencoder_music_1320k.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:10ccb6c83613781ad32e998a90597ba7eb9292911a224598da1fd53728eb4cd3
|
| 3 |
+
size 674920616
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/Qwen2-7B/LICENSE
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
Apache License
|
| 3 |
+
Version 2.0, January 2004
|
| 4 |
+
http://www.apache.org/licenses/
|
| 5 |
+
|
| 6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 7 |
+
|
| 8 |
+
1. Definitions.
|
| 9 |
+
|
| 10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 12 |
+
|
| 13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 14 |
+
the copyright owner that is granting the License.
|
| 15 |
+
|
| 16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 17 |
+
other entities that control, are controlled by, or are under common
|
| 18 |
+
control with that entity. For the purposes of this definition,
|
| 19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 20 |
+
direction or management of such entity, whether by contract or
|
| 21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 23 |
+
|
| 24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 25 |
+
exercising permissions granted by this License.
|
| 26 |
+
|
| 27 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 28 |
+
including but not limited to software source code, documentation
|
| 29 |
+
source, and configuration files.
|
| 30 |
+
|
| 31 |
+
"Object" form shall mean any form resulting from mechanical
|
| 32 |
+
transformation or translation of a Source form, including but
|
| 33 |
+
not limited to compiled object code, generated documentation,
|
| 34 |
+
and conversions to other media types.
|
| 35 |
+
|
| 36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 37 |
+
Object form, made available under the License, as indicated by a
|
| 38 |
+
copyright notice that is included in or attached to the work
|
| 39 |
+
(an example is provided in the Appendix below).
|
| 40 |
+
|
| 41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 42 |
+
form, that is based on (or derived from) the Work and for which the
|
| 43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 45 |
+
of this License, Derivative Works shall not include works that remain
|
| 46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 47 |
+
the Work and Derivative Works thereof.
|
| 48 |
+
|
| 49 |
+
"Contribution" shall mean any work of authorship, including
|
| 50 |
+
the original version of the Work and any modifications or additions
|
| 51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 55 |
+
means any form of electronic, verbal, or written communication sent
|
| 56 |
+
to the Licensor or its representatives, including but not limited to
|
| 57 |
+
communication on electronic mailing lists, source code control systems,
|
| 58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 60 |
+
excluding communication that is conspicuously marked or otherwise
|
| 61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 62 |
+
|
| 63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 65 |
+
subsequently incorporated within the Work.
|
| 66 |
+
|
| 67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 72 |
+
Work and such Derivative Works in Source or Object form.
|
| 73 |
+
|
| 74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 77 |
+
(except as stated in this section) patent license to make, have made,
|
| 78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 79 |
+
where such license applies only to those patent claims licensable
|
| 80 |
+
by such Contributor that are necessarily infringed by their
|
| 81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 83 |
+
institute patent litigation against any entity (including a
|
| 84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 85 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 86 |
+
or contributory patent infringement, then any patent licenses
|
| 87 |
+
granted to You under this License for that Work shall terminate
|
| 88 |
+
as of the date such litigation is filed.
|
| 89 |
+
|
| 90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 91 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 92 |
+
modifications, and in Source or Object form, provided that You
|
| 93 |
+
meet the following conditions:
|
| 94 |
+
|
| 95 |
+
(a) You must give any other recipients of the Work or
|
| 96 |
+
Derivative Works a copy of this License; and
|
| 97 |
+
|
| 98 |
+
(b) You must cause any modified files to carry prominent notices
|
| 99 |
+
stating that You changed the files; and
|
| 100 |
+
|
| 101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 102 |
+
that You distribute, all copyright, patent, trademark, and
|
| 103 |
+
attribution notices from the Source form of the Work,
|
| 104 |
+
excluding those notices that do not pertain to any part of
|
| 105 |
+
the Derivative Works; and
|
| 106 |
+
|
| 107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 108 |
+
distribution, then any Derivative Works that You distribute must
|
| 109 |
+
include a readable copy of the attribution notices contained
|
| 110 |
+
within such NOTICE file, excluding those notices that do not
|
| 111 |
+
pertain to any part of the Derivative Works, in at least one
|
| 112 |
+
of the following places: within a NOTICE text file distributed
|
| 113 |
+
as part of the Derivative Works; within the Source form or
|
| 114 |
+
documentation, if provided along with the Derivative Works; or,
|
| 115 |
+
within a display generated by the Derivative Works, if and
|
| 116 |
+
wherever such third-party notices normally appear. The contents
|
| 117 |
+
of the NOTICE file are for informational purposes only and
|
| 118 |
+
do not modify the License. You may add Your own attribution
|
| 119 |
+
notices within Derivative Works that You distribute, alongside
|
| 120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 121 |
+
that such additional attribution notices cannot be construed
|
| 122 |
+
as modifying the License.
|
| 123 |
+
|
| 124 |
+
You may add Your own copyright statement to Your modifications and
|
| 125 |
+
may provide additional or different license terms and conditions
|
| 126 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 127 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 128 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 129 |
+
the conditions stated in this License.
|
| 130 |
+
|
| 131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 133 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 134 |
+
this License, without any additional terms or conditions.
|
| 135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 136 |
+
the terms of any separate license agreement you may have executed
|
| 137 |
+
with Licensor regarding such Contributions.
|
| 138 |
+
|
| 139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 141 |
+
except as required for reasonable and customary use in describing the
|
| 142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 143 |
+
|
| 144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 145 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 148 |
+
implied, including, without limitation, any warranties or conditions
|
| 149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 151 |
+
appropriateness of using or redistributing the Work and assume any
|
| 152 |
+
risks associated with Your exercise of permissions under this License.
|
| 153 |
+
|
| 154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 155 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 156 |
+
unless required by applicable law (such as deliberate and grossly
|
| 157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 158 |
+
liable to You for damages, including any direct, indirect, special,
|
| 159 |
+
incidental, or consequential damages of any character arising as a
|
| 160 |
+
result of this License or out of the use or inability to use the
|
| 161 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 162 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 163 |
+
other commercial damages or losses), even if such Contributor
|
| 164 |
+
has been advised of the possibility of such damages.
|
| 165 |
+
|
| 166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 169 |
+
or other liability obligations and/or rights consistent with this
|
| 170 |
+
License. However, in accepting such obligations, You may act only
|
| 171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 172 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 173 |
+
defend, and hold each Contributor harmless for any liability
|
| 174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 175 |
+
of your accepting any such warranty or additional liability.
|
| 176 |
+
|
| 177 |
+
END OF TERMS AND CONDITIONS
|
| 178 |
+
|
| 179 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 180 |
+
|
| 181 |
+
To apply the Apache License to your work, attach the following
|
| 182 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 183 |
+
replaced with your own identifying information. (Don't include
|
| 184 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 185 |
+
comment syntax for the file format. We also recommend that a
|
| 186 |
+
file or class name and description of purpose be included on the
|
| 187 |
+
same "printed page" as the copyright notice for easier
|
| 188 |
+
identification within third-party archives.
|
| 189 |
+
|
| 190 |
+
Copyright 2024 Alibaba Cloud
|
| 191 |
+
|
| 192 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 193 |
+
you may not use this file except in compliance with the License.
|
| 194 |
+
You may obtain a copy of the License at
|
| 195 |
+
|
| 196 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 197 |
+
|
| 198 |
+
Unless required by applicable law or agreed to in writing, software
|
| 199 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 200 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 201 |
+
See the License for the specific language governing permissions and
|
| 202 |
+
limitations under the License.
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/Qwen2-7B/README.md
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language:
|
| 3 |
+
- en
|
| 4 |
+
pipeline_tag: text-generation
|
| 5 |
+
tags:
|
| 6 |
+
- pretrained
|
| 7 |
+
license: apache-2.0
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# Qwen2-7B
|
| 11 |
+
|
| 12 |
+
## Introduction
|
| 13 |
+
|
| 14 |
+
Qwen2 is the new series of Qwen large language models. For Qwen2, we release a number of base language models and instruction-tuned language models ranging from 0.5 to 72 billion parameters, including a Mixture-of-Experts model. This repo contains the 7B Qwen2 base language model.
|
| 15 |
+
|
| 16 |
+
Compared with the state-of-the-art opensource language models, including the previous released Qwen1.5, Qwen2 has generally surpassed most opensource models and demonstrated competitiveness against proprietary models across a series of benchmarks targeting for language understanding, language generation, multilingual capability, coding, mathematics, reasoning, etc.
|
| 17 |
+
|
| 18 |
+
For more details, please refer to our [blog](https://qwenlm.github.io/blog/qwen2/), [GitHub](https://github.com/QwenLM/Qwen2), and [Documentation](https://qwen.readthedocs.io/en/latest/).
|
| 19 |
+
<br>
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
## Model Details
|
| 23 |
+
Qwen2 is a language model series including decoder language models of different model sizes. For each size, we release the base language model and the aligned chat model. It is based on the Transformer architecture with SwiGLU activation, attention QKV bias, group query attention, etc. Additionally, we have an improved tokenizer adaptive to multiple natural languages and codes.
|
| 24 |
+
|
| 25 |
+
## Requirements
|
| 26 |
+
The code of Qwen2 has been in the latest Hugging face transformers and we advise you to install `transformers>=4.37.0`, or you might encounter the following error:
|
| 27 |
+
```
|
| 28 |
+
KeyError: 'qwen2'
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
## Usage
|
| 33 |
+
|
| 34 |
+
We do not advise you to use base language models for text generation. Instead, you can apply post-training, e.g., SFT, RLHF, continued pretraining, etc., on this model.
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
### Performance
|
| 38 |
+
|
| 39 |
+
The evaluation of base models mainly focuses on the model performance of natural language understanding, general question answering, coding, mathematics, scientific knowledge, reasoning, multilingual capability, etc.
|
| 40 |
+
|
| 41 |
+
The datasets for evaluation include:
|
| 42 |
+
|
| 43 |
+
**English Tasks**: MMLU (5-shot), MMLU-Pro (5-shot), GPQA (5shot), Theorem QA (5-shot), BBH (3-shot), HellaSwag (10-shot), Winogrande (5-shot), TruthfulQA (0-shot), ARC-C (25-shot)
|
| 44 |
+
|
| 45 |
+
**Coding Tasks**: EvalPlus (0-shot) (HumanEval, MBPP, HumanEval+, MBPP+), MultiPL-E (0-shot) (Python, C++, JAVA, PHP, TypeScript, C#, Bash, JavaScript)
|
| 46 |
+
|
| 47 |
+
**Math Tasks**: GSM8K (4-shot), MATH (4-shot)
|
| 48 |
+
|
| 49 |
+
**Chinese Tasks**: C-Eval(5-shot), CMMLU (5-shot)
|
| 50 |
+
|
| 51 |
+
**Multilingual Tasks**: Multi-Exam (M3Exam 5-shot, IndoMMLU 3-shot, ruMMLU 5-shot, mMMLU 5-shot), Multi-Understanding (BELEBELE 5-shot, XCOPA 5-shot, XWinograd 5-shot, XStoryCloze 0-shot, PAWS-X 5-shot), Multi-Mathematics (MGSM 8-shot), Multi-Translation (Flores-101 5-shot)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
#### Qwen2-7B performance
|
| 56 |
+
| Datasets | Mistral-7B | Gemma-7B | Llama-3-8B | Qwen1.5-7B | Qwen2-7B |
|
| 57 |
+
| :--------| :---------: | :------------: | :------------: | :------------: | :------------: |
|
| 58 |
+
|# Params | 7.2B | 8.5B | 8.0B | 7.7B | 7.6B |
|
| 59 |
+
|# Non-emb Params | 7.0B | 7.8B | 7.0B | 6.5B | 6.5B |
|
| 60 |
+
| ***English*** | | | | | |
|
| 61 |
+
|MMLU | 64.2 | 64.6 | 66.6 | 61.0 | **70.3** |
|
| 62 |
+
|MMLU-Pro | 30.9 | 33.7 | 35.4 | 29.9 | **40.0** |
|
| 63 |
+
|GPQA | 24.7 | 25.7 | 25.8 | 26.7 | **31.8** |
|
| 64 |
+
|Theorem QA | 19.2 | 21.5 | 22.1 | 14.2 | **31.1** |
|
| 65 |
+
|BBH | 56.1 | 55.1 | 57.7 | 40.2 | **62.6** |
|
| 66 |
+
|HellaSwag | **83.2** | 82.2 | 82.1 | 78.5 | 80.7 |
|
| 67 |
+
|Winogrande | 78.4 | **79.0** | 77.4 | 71.3 | 77.0 |
|
| 68 |
+
|ARC-C | 60.0 | **61.1** | 59.3 | 54.2 | 60.6 |
|
| 69 |
+
|TruthfulQA | 42.2 | 44.8 | 44.0 | 51.1 | **54.2** |
|
| 70 |
+
| ***Coding*** | | | | | |
|
| 71 |
+
|HumanEval | 29.3 | 37.2 | 33.5 | 36.0 | **51.2** |
|
| 72 |
+
|MBPP | 51.1 | 50.6 | 53.9 | 51.6 | **65.9** |
|
| 73 |
+
|EvalPlus | 36.4 | 39.6 | 40.3 | 40.0 | **54.2** |
|
| 74 |
+
|MultiPL-E | 29.4 | 29.7 | 22.6 | 28.1 | **46.3** |
|
| 75 |
+
| ***Mathematics*** | | | | | |
|
| 76 |
+
|GSM8K | 52.2 | 46.4 | 56.0 | 62.5 | **79.9** |
|
| 77 |
+
|MATH | 13.1 | 24.3 | 20.5 | 20.3 | **44.2** |
|
| 78 |
+
| ***Chinese*** | | | | | |
|
| 79 |
+
|C-Eval | 47.4 | 43.6 | 49.5 | 74.1 | **83.2** |
|
| 80 |
+
|CMMLU | - | - | 50.8 | 73.1 | **83.9** |
|
| 81 |
+
| ***Multilingual*** | | | | | |
|
| 82 |
+
|Multi-Exam | 47.1 | 42.7 | 52.3 | 47.7 | **59.2** |
|
| 83 |
+
|Multi-Understanding | 63.3 | 58.3 | 68.6 | 67.6 | **72.0** |
|
| 84 |
+
|Multi-Mathematics | 26.3 | 39.1 | 36.3 | 37.3 | **57.5** |
|
| 85 |
+
|Multi-Translation | 23.3 | 31.2 | **31.9** | 28.4 | 31.5 |
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
## Citation
|
| 89 |
+
|
| 90 |
+
If you find our work helpful, feel free to give us a cite.
|
| 91 |
+
|
| 92 |
+
```
|
| 93 |
+
@article{qwen2,
|
| 94 |
+
title={Qwen2 Technical Report},
|
| 95 |
+
year={2024}
|
| 96 |
+
}
|
| 97 |
+
```
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/Qwen2-7B/config.json
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"Qwen2ForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"attention_dropout": 0.0,
|
| 6 |
+
"bos_token_id": 151643,
|
| 7 |
+
"eos_token_id": 151643,
|
| 8 |
+
"hidden_act": "silu",
|
| 9 |
+
"hidden_size": 3584,
|
| 10 |
+
"initializer_range": 0.02,
|
| 11 |
+
"intermediate_size": 18944,
|
| 12 |
+
"max_position_embeddings": 131072,
|
| 13 |
+
"max_window_layers": 28,
|
| 14 |
+
"model_type": "qwen2",
|
| 15 |
+
"num_attention_heads": 28,
|
| 16 |
+
"num_hidden_layers": 28,
|
| 17 |
+
"num_key_value_heads": 4,
|
| 18 |
+
"rms_norm_eps": 1e-06,
|
| 19 |
+
"rope_theta": 1000000.0,
|
| 20 |
+
"sliding_window": 131072,
|
| 21 |
+
"tie_word_embeddings": false,
|
| 22 |
+
"torch_dtype": "bfloat16",
|
| 23 |
+
"transformers_version": "4.37.2",
|
| 24 |
+
"use_cache": true,
|
| 25 |
+
"use_sliding_window": false,
|
| 26 |
+
"vocab_size": 152064
|
| 27 |
+
}
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/Qwen2-7B/generation_config.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token_id": 151643,
|
| 3 |
+
"do_sample": false,
|
| 4 |
+
"eos_token_id": 151643,
|
| 5 |
+
"max_new_tokens": 2048,
|
| 6 |
+
"transformers_version": "4.37.0"
|
| 7 |
+
}
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/Qwen2-7B/merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/Qwen2-7B/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/Qwen2-7B/tokenizer_config.json
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_prefix_space": false,
|
| 3 |
+
"added_tokens_decoder": {
|
| 4 |
+
"151643": {
|
| 5 |
+
"content": "<|endoftext|>",
|
| 6 |
+
"lstrip": false,
|
| 7 |
+
"normalized": false,
|
| 8 |
+
"rstrip": false,
|
| 9 |
+
"single_word": false,
|
| 10 |
+
"special": true
|
| 11 |
+
},
|
| 12 |
+
"151644": {
|
| 13 |
+
"content": "<|im_start|>",
|
| 14 |
+
"lstrip": false,
|
| 15 |
+
"normalized": false,
|
| 16 |
+
"rstrip": false,
|
| 17 |
+
"single_word": false,
|
| 18 |
+
"special": true
|
| 19 |
+
},
|
| 20 |
+
"151645": {
|
| 21 |
+
"content": "<|im_end|>",
|
| 22 |
+
"lstrip": false,
|
| 23 |
+
"normalized": false,
|
| 24 |
+
"rstrip": false,
|
| 25 |
+
"single_word": false,
|
| 26 |
+
"special": true
|
| 27 |
+
}
|
| 28 |
+
},
|
| 29 |
+
"additional_special_tokens": ["<|im_start|>", "<|im_end|>"],
|
| 30 |
+
"bos_token": null,
|
| 31 |
+
"chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
| 32 |
+
"clean_up_tokenization_spaces": false,
|
| 33 |
+
"eos_token": "<|endoftext|>",
|
| 34 |
+
"errors": "replace",
|
| 35 |
+
"model_max_length": 32768,
|
| 36 |
+
"pad_token": "<|endoftext|>",
|
| 37 |
+
"split_special_tokens": false,
|
| 38 |
+
"tokenizer_class": "Qwen2Tokenizer",
|
| 39 |
+
"unk_token": null
|
| 40 |
+
}
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/Qwen2-7B/vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/demucs/__init__.py
ADDED
|
File without changes
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/demucs/ckpt/htdemucs.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e4378974c3df2fbcf872d2aeb32218e4de376799494579655775a375d09931c2
|
| 3 |
+
size 168138881
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/demucs/ckpt/htdemucs.yaml
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
models: ['htdemucs']
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/demucs/models/__init__.py
ADDED
|
File without changes
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/demucs/models/apply.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
@File : apply.py
|
| 5 |
+
@Time : 2023/8/8 下午4:22
|
| 6 |
+
@Author : waytan
|
| 7 |
+
@Contact : waytan@tencent.com
|
| 8 |
+
@License : (C)Copyright 2023, Tencent
|
| 9 |
+
@Desc : Apply
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 13 |
+
import torch
|
| 14 |
+
import os
|
| 15 |
+
import random
|
| 16 |
+
import typing as tp
|
| 17 |
+
|
| 18 |
+
import torch as th
|
| 19 |
+
from torch import nn
|
| 20 |
+
from torch.nn import functional as F
|
| 21 |
+
import tqdm
|
| 22 |
+
|
| 23 |
+
from .htdemucs import HTDemucs
|
| 24 |
+
from .audio import load_track, save_audio
|
| 25 |
+
from .utils import center_trim, DummyPoolExecutor
|
| 26 |
+
|
| 27 |
+
Model = tp.Union[HTDemucs]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class BagOfModels(nn.Module):
|
| 31 |
+
def __init__(self, models: tp.List[Model],
|
| 32 |
+
weights: tp.Optional[tp.List[tp.List[float]]] = None,
|
| 33 |
+
segment: tp.Optional[float] = None):
|
| 34 |
+
"""
|
| 35 |
+
Represents a bag of models with specific weights.
|
| 36 |
+
You should call `apply_model` rather than calling directly the forward here for
|
| 37 |
+
optimal performance.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
models (list[nn.Module]): list of Demucs/HDemucs models.
|
| 41 |
+
weights (list[list[float]]): list of weights. If None, assumed to
|
| 42 |
+
be all ones, otherwise it should be a list of N list (N number of models),
|
| 43 |
+
each containing S floats (S number of sources).
|
| 44 |
+
segment (None or float): overrides the `segment` attribute of each model
|
| 45 |
+
(this is performed inplace, be careful is you reuse the models passed).
|
| 46 |
+
"""
|
| 47 |
+
super().__init__()
|
| 48 |
+
assert len(models) > 0
|
| 49 |
+
first = models[0]
|
| 50 |
+
for other in models:
|
| 51 |
+
assert other.sources == first.sources
|
| 52 |
+
assert other.samplerate == first.samplerate
|
| 53 |
+
assert other.audio_channels == first.audio_channels
|
| 54 |
+
if segment is not None:
|
| 55 |
+
other.segment = segment
|
| 56 |
+
|
| 57 |
+
self.audio_channels = first.audio_channels
|
| 58 |
+
self.samplerate = first.samplerate
|
| 59 |
+
self.sources = first.sources
|
| 60 |
+
self.models = nn.ModuleList(models)
|
| 61 |
+
|
| 62 |
+
if weights is None:
|
| 63 |
+
weights = [[1. for _ in first.sources] for _ in models]
|
| 64 |
+
else:
|
| 65 |
+
assert len(weights) == len(models)
|
| 66 |
+
for weight in weights:
|
| 67 |
+
assert len(weight) == len(first.sources)
|
| 68 |
+
self.weights = weights
|
| 69 |
+
|
| 70 |
+
@property
|
| 71 |
+
def max_allowed_segment(self) -> float:
|
| 72 |
+
max_allowed_segment = float('inf')
|
| 73 |
+
for model in self.models:
|
| 74 |
+
if isinstance(model, HTDemucs):
|
| 75 |
+
max_allowed_segment = min(max_allowed_segment, float(model.segment))
|
| 76 |
+
return max_allowed_segment
|
| 77 |
+
|
| 78 |
+
def forward(self, x):
|
| 79 |
+
raise NotImplementedError("Call `apply_model` on this.")
|
| 80 |
+
|
| 81 |
+
def separate(self, source_file, output_dir, stem=None, device=None):
|
| 82 |
+
wav, _ = load_track(source_file, self.audio_channels, self.samplerate)
|
| 83 |
+
ref = wav.mean(0)
|
| 84 |
+
wav -= ref.mean()
|
| 85 |
+
wav /= ref.std()
|
| 86 |
+
sources = apply_model(self, wav[None], device=device, shifts=1, split=True, overlap=0.25,
|
| 87 |
+
progress=True, num_workers=0, segment=None)[0]
|
| 88 |
+
sources *= ref.std()
|
| 89 |
+
sources += ref.mean()
|
| 90 |
+
|
| 91 |
+
output_paths = []
|
| 92 |
+
name, ext = os.path.splitext(os.path.split(source_file)[-1])
|
| 93 |
+
if ext != ".flac":
|
| 94 |
+
ext = ".flac"
|
| 95 |
+
kwargs = {
|
| 96 |
+
'samplerate': self.samplerate,
|
| 97 |
+
'bitrate': 320,
|
| 98 |
+
'clip': "rescale",
|
| 99 |
+
'as_float': False,
|
| 100 |
+
'bits_per_sample': 16,
|
| 101 |
+
}
|
| 102 |
+
if stem is None:
|
| 103 |
+
for source, stem in zip(sources, self.sources):
|
| 104 |
+
output_stem_path = os.path.join(output_dir, f"{name}_{stem}{ext}")
|
| 105 |
+
save_audio(source, output_stem_path, **kwargs)
|
| 106 |
+
output_paths.append(output_stem_path)
|
| 107 |
+
else:
|
| 108 |
+
sources = list(sources)
|
| 109 |
+
output_stem_path = os.path.join(output_dir, f"{name}_{stem}{ext}")
|
| 110 |
+
save_audio(sources.pop(self.sources.index(stem)), output_stem_path, **kwargs)
|
| 111 |
+
other_stem = torch.zeros_like(sources[0])
|
| 112 |
+
for i in sources:
|
| 113 |
+
other_stem += i
|
| 114 |
+
output_no_stem_path = os.path.join(output_dir, f"{name}_no_{stem}{ext}")
|
| 115 |
+
save_audio(other_stem, output_no_stem_path, **kwargs)
|
| 116 |
+
output_paths = [output_stem_path, output_no_stem_path]
|
| 117 |
+
|
| 118 |
+
return output_paths
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class TensorChunk:
|
| 122 |
+
def __init__(self, tensor, offset=0, length=None):
|
| 123 |
+
total_length = tensor.shape[-1]
|
| 124 |
+
assert offset >= 0
|
| 125 |
+
assert offset < total_length
|
| 126 |
+
|
| 127 |
+
if length is None:
|
| 128 |
+
length = total_length - offset
|
| 129 |
+
else:
|
| 130 |
+
length = min(total_length - offset, length)
|
| 131 |
+
|
| 132 |
+
if isinstance(tensor, TensorChunk):
|
| 133 |
+
self.tensor = tensor.tensor
|
| 134 |
+
self.offset = offset + tensor.offset
|
| 135 |
+
else:
|
| 136 |
+
self.tensor = tensor
|
| 137 |
+
self.offset = offset
|
| 138 |
+
self.length = length
|
| 139 |
+
self.device = tensor.device
|
| 140 |
+
|
| 141 |
+
@property
|
| 142 |
+
def shape(self):
|
| 143 |
+
shape = list(self.tensor.shape)
|
| 144 |
+
shape[-1] = self.length
|
| 145 |
+
return shape
|
| 146 |
+
|
| 147 |
+
def padded(self, target_length):
|
| 148 |
+
delta = target_length - self.length
|
| 149 |
+
total_length = self.tensor.shape[-1]
|
| 150 |
+
assert delta >= 0
|
| 151 |
+
|
| 152 |
+
start = self.offset - delta // 2
|
| 153 |
+
end = start + target_length
|
| 154 |
+
|
| 155 |
+
correct_start = max(0, start)
|
| 156 |
+
correct_end = min(total_length, end)
|
| 157 |
+
|
| 158 |
+
pad_left = correct_start - start
|
| 159 |
+
pad_right = end - correct_end
|
| 160 |
+
|
| 161 |
+
out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right))
|
| 162 |
+
assert out.shape[-1] == target_length
|
| 163 |
+
return out
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def tensor_chunk(tensor_or_chunk):
|
| 167 |
+
if isinstance(tensor_or_chunk, TensorChunk):
|
| 168 |
+
return tensor_or_chunk
|
| 169 |
+
else:
|
| 170 |
+
assert isinstance(tensor_or_chunk, th.Tensor)
|
| 171 |
+
return TensorChunk(tensor_or_chunk)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def apply_model(model: tp.Union[BagOfModels, Model],
|
| 175 |
+
mix: tp.Union[th.Tensor, TensorChunk],
|
| 176 |
+
shifts: int = 1, split: bool = True,
|
| 177 |
+
overlap: float = 0.25, transition_power: float = 1.,
|
| 178 |
+
progress: bool = False, device=None,
|
| 179 |
+
num_workers: int = 0, segment: tp.Optional[float] = None,
|
| 180 |
+
pool=None) -> th.Tensor:
|
| 181 |
+
"""
|
| 182 |
+
Apply model to a given mixture.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
|
| 186 |
+
and apply the oppositve shift to the output. This is repeated `shifts` time and
|
| 187 |
+
all predictions are averaged. This effectively makes the model time equivariant
|
| 188 |
+
and improves SDR by up to 0.2 points.
|
| 189 |
+
split (bool): if True, the input will be broken down in 8 seconds extracts
|
| 190 |
+
and predictions will be performed individually on each and concatenated.
|
| 191 |
+
Useful for model with large memory footprint like Tasnet.
|
| 192 |
+
progress (bool): if True, show a progress bar (requires split=True)
|
| 193 |
+
device (torch.device, str, or None): if provided, device on which to
|
| 194 |
+
execute the computation, otherwise `mix.device` is assumed.
|
| 195 |
+
When `device` is different from `mix.device`, only local computations will
|
| 196 |
+
be on `device`, while the entire tracks will be stored on `mix.device`.
|
| 197 |
+
num_workers (int): if non zero, device is 'cpu', how many threads to
|
| 198 |
+
use in parallel.
|
| 199 |
+
segment (float or None): override the model segment parameter.
|
| 200 |
+
"""
|
| 201 |
+
if device is None:
|
| 202 |
+
device = mix.device
|
| 203 |
+
else:
|
| 204 |
+
device = th.device(device)
|
| 205 |
+
if pool is None:
|
| 206 |
+
if num_workers > 0 and device.type == 'cpu':
|
| 207 |
+
pool = ThreadPoolExecutor(num_workers)
|
| 208 |
+
else:
|
| 209 |
+
pool = DummyPoolExecutor()
|
| 210 |
+
kwargs: tp.Dict[str, tp.Any] = {
|
| 211 |
+
'shifts': shifts,
|
| 212 |
+
'split': split,
|
| 213 |
+
'overlap': overlap,
|
| 214 |
+
'transition_power': transition_power,
|
| 215 |
+
'progress': progress,
|
| 216 |
+
'device': device,
|
| 217 |
+
'pool': pool,
|
| 218 |
+
'segment': segment,
|
| 219 |
+
}
|
| 220 |
+
out: tp.Union[float, th.Tensor]
|
| 221 |
+
if isinstance(model, BagOfModels):
|
| 222 |
+
# Special treatment for bag of model.
|
| 223 |
+
# We explicitely apply multiple times `apply_model` so that the random shifts
|
| 224 |
+
# are different for each model.
|
| 225 |
+
estimates: tp.Union[float, th.Tensor] = 0.
|
| 226 |
+
totals = [0.] * len(model.sources)
|
| 227 |
+
for sub_model, model_weights in zip(model.models, model.weights):
|
| 228 |
+
original_model_device = next(iter(sub_model.parameters())).device
|
| 229 |
+
sub_model.to(device)
|
| 230 |
+
|
| 231 |
+
out = apply_model(sub_model, mix, **kwargs)
|
| 232 |
+
sub_model.to(original_model_device)
|
| 233 |
+
for k, inst_weight in enumerate(model_weights):
|
| 234 |
+
out[:, k, :, :] *= inst_weight
|
| 235 |
+
totals[k] += inst_weight
|
| 236 |
+
estimates += out
|
| 237 |
+
del out
|
| 238 |
+
|
| 239 |
+
assert isinstance(estimates, th.Tensor)
|
| 240 |
+
for k in range(estimates.shape[1]):
|
| 241 |
+
estimates[:, k, :, :] /= totals[k]
|
| 242 |
+
return estimates
|
| 243 |
+
|
| 244 |
+
model.to(device)
|
| 245 |
+
model.eval()
|
| 246 |
+
assert transition_power >= 1, "transition_power < 1 leads to weird behavior."
|
| 247 |
+
batch, channels, length = mix.shape
|
| 248 |
+
if shifts:
|
| 249 |
+
kwargs['shifts'] = 0
|
| 250 |
+
max_shift = int(0.5 * model.samplerate)
|
| 251 |
+
mix = tensor_chunk(mix)
|
| 252 |
+
assert isinstance(mix, TensorChunk)
|
| 253 |
+
padded_mix = mix.padded(length + 2 * max_shift)
|
| 254 |
+
out = 0.
|
| 255 |
+
for _ in range(shifts):
|
| 256 |
+
offset = random.randint(0, max_shift)
|
| 257 |
+
shifted = TensorChunk(padded_mix, offset, length + max_shift - offset)
|
| 258 |
+
shifted_out = apply_model(model, shifted, **kwargs)
|
| 259 |
+
out += shifted_out[..., max_shift - offset:]
|
| 260 |
+
out /= shifts
|
| 261 |
+
assert isinstance(out, th.Tensor)
|
| 262 |
+
return out
|
| 263 |
+
elif split:
|
| 264 |
+
kwargs['split'] = False
|
| 265 |
+
out = th.zeros(batch, len(model.sources), channels, length, device=mix.device)
|
| 266 |
+
sum_weight = th.zeros(length, device=mix.device)
|
| 267 |
+
if segment is None:
|
| 268 |
+
segment = model.segment
|
| 269 |
+
assert segment is not None and segment > 0.
|
| 270 |
+
segment_length: int = int(model.samplerate * segment)
|
| 271 |
+
stride = int((1 - overlap) * segment_length)
|
| 272 |
+
offsets = range(0, length, stride)
|
| 273 |
+
scale = float(format(stride / model.samplerate, ".2f"))
|
| 274 |
+
# We start from a triangle shaped weight, with maximal weight in the middle
|
| 275 |
+
# of the segment. Then we normalize and take to the power `transition_power`.
|
| 276 |
+
# Large values of transition power will lead to sharper transitions.
|
| 277 |
+
weight = th.cat([th.arange(1, segment_length // 2 + 1, device=device),
|
| 278 |
+
th.arange(segment_length - segment_length // 2, 0, -1, device=device)])
|
| 279 |
+
assert len(weight) == segment_length
|
| 280 |
+
# If the overlap < 50%, this will translate to linear transition when
|
| 281 |
+
# transition_power is 1.
|
| 282 |
+
weight = (weight / weight.max())**transition_power
|
| 283 |
+
futures = []
|
| 284 |
+
for offset in offsets:
|
| 285 |
+
chunk = TensorChunk(mix, offset, segment_length)
|
| 286 |
+
future = pool.submit(apply_model, model, chunk, **kwargs)
|
| 287 |
+
futures.append((future, offset))
|
| 288 |
+
offset += segment_length
|
| 289 |
+
if progress:
|
| 290 |
+
futures = tqdm.tqdm(futures, unit_scale=scale, ncols=120, unit='seconds')
|
| 291 |
+
for future, offset in futures:
|
| 292 |
+
chunk_out = future.result()
|
| 293 |
+
chunk_length = chunk_out.shape[-1]
|
| 294 |
+
out[..., offset:offset + segment_length] += (
|
| 295 |
+
weight[:chunk_length] * chunk_out).to(mix.device)
|
| 296 |
+
sum_weight[offset:offset + segment_length] += weight[:chunk_length].to(mix.device)
|
| 297 |
+
assert sum_weight.min() > 0
|
| 298 |
+
out /= sum_weight
|
| 299 |
+
assert isinstance(out, th.Tensor)
|
| 300 |
+
return out
|
| 301 |
+
else:
|
| 302 |
+
valid_length: int
|
| 303 |
+
if isinstance(model, HTDemucs) and segment is not None:
|
| 304 |
+
valid_length = int(segment * model.samplerate)
|
| 305 |
+
elif hasattr(model, 'valid_length'):
|
| 306 |
+
valid_length = model.valid_length(length) # type: ignore
|
| 307 |
+
else:
|
| 308 |
+
valid_length = length
|
| 309 |
+
mix = tensor_chunk(mix)
|
| 310 |
+
assert isinstance(mix, TensorChunk)
|
| 311 |
+
padded_mix = mix.padded(valid_length).to(device)
|
| 312 |
+
with th.no_grad():
|
| 313 |
+
out = model(padded_mix)
|
| 314 |
+
assert isinstance(out, th.Tensor)
|
| 315 |
+
return center_trim(out, length)
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/demucs/models/audio.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
@File : audio.py
|
| 5 |
+
@Time : 2023/8/8 下午7:18
|
| 6 |
+
@Author : waytan
|
| 7 |
+
@Contact : waytan@tencent.com
|
| 8 |
+
@License : (C)Copyright 2023, Tencent
|
| 9 |
+
@Desc : Audio
|
| 10 |
+
"""
|
| 11 |
+
import json
|
| 12 |
+
import subprocess as sp
|
| 13 |
+
import typing as tp
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
import lameenc
|
| 17 |
+
import julius
|
| 18 |
+
import torch
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torchaudio as ta
|
| 21 |
+
|
| 22 |
+
from .utils import temp_filenames
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _read_info(path):
|
| 26 |
+
stdout_data = sp.check_output([
|
| 27 |
+
'ffprobe', "-loglevel", "panic",
|
| 28 |
+
str(path), '-print_format', 'json', '-show_format', '-show_streams'
|
| 29 |
+
])
|
| 30 |
+
return json.loads(stdout_data.decode('utf-8'))
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class AudioFile:
|
| 34 |
+
"""
|
| 35 |
+
Allows to read audio from any format supported by ffmpeg, as well as resampling or
|
| 36 |
+
converting to mono on the fly. See :method:`read` for more details.
|
| 37 |
+
"""
|
| 38 |
+
def __init__(self, path: Path):
|
| 39 |
+
self.path = Path(path)
|
| 40 |
+
self._info = None
|
| 41 |
+
|
| 42 |
+
def __repr__(self):
|
| 43 |
+
features = [("path", self.path)]
|
| 44 |
+
features.append(("samplerate", self.samplerate()))
|
| 45 |
+
features.append(("channels", self.channels()))
|
| 46 |
+
features.append(("streams", len(self)))
|
| 47 |
+
features_str = ", ".join(f"{name}={value}" for name, value in features)
|
| 48 |
+
return f"AudioFile({features_str})"
|
| 49 |
+
|
| 50 |
+
@property
|
| 51 |
+
def info(self):
|
| 52 |
+
if self._info is None:
|
| 53 |
+
self._info = _read_info(self.path)
|
| 54 |
+
return self._info
|
| 55 |
+
|
| 56 |
+
@property
|
| 57 |
+
def duration(self):
|
| 58 |
+
return float(self.info['format']['duration'])
|
| 59 |
+
|
| 60 |
+
@property
|
| 61 |
+
def _audio_streams(self):
|
| 62 |
+
return [
|
| 63 |
+
index for index, stream in enumerate(self.info["streams"])
|
| 64 |
+
if stream["codec_type"] == "audio"
|
| 65 |
+
]
|
| 66 |
+
|
| 67 |
+
def __len__(self):
|
| 68 |
+
return len(self._audio_streams)
|
| 69 |
+
|
| 70 |
+
def channels(self, stream=0):
|
| 71 |
+
return int(self.info['streams'][self._audio_streams[stream]]['channels'])
|
| 72 |
+
|
| 73 |
+
def samplerate(self, stream=0):
|
| 74 |
+
return int(self.info['streams'][self._audio_streams[stream]]['sample_rate'])
|
| 75 |
+
|
| 76 |
+
def read(self,
|
| 77 |
+
seek_time=None,
|
| 78 |
+
duration=None,
|
| 79 |
+
streams=slice(None),
|
| 80 |
+
samplerate=None,
|
| 81 |
+
channels=None):
|
| 82 |
+
"""
|
| 83 |
+
Slightly more efficient implementation than stempeg,
|
| 84 |
+
in particular, this will extract all stems at once
|
| 85 |
+
rather than having to loop over one file multiple times
|
| 86 |
+
for each stream.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
seek_time (float): seek time in seconds or None if no seeking is needed.
|
| 90 |
+
duration (float): duration in seconds to extract or None to extract until the end.
|
| 91 |
+
streams (slice, int or list): streams to extract, can be a single int, a list or
|
| 92 |
+
a slice. If it is a slice or list, the output will be of size [S, C, T]
|
| 93 |
+
with S the number of streams, C the number of channels and T the number of samples.
|
| 94 |
+
If it is an int, the output will be [C, T].
|
| 95 |
+
samplerate (int): if provided, will resample on the fly. If None, no resampling will
|
| 96 |
+
be done. Original sampling rate can be obtained with :method:`samplerate`.
|
| 97 |
+
channels (int): if 1, will convert to mono. We do not rely on ffmpeg for that
|
| 98 |
+
as ffmpeg automatically scale by +3dB to conserve volume when playing on speakers.
|
| 99 |
+
See https://sound.stackexchange.com/a/42710.
|
| 100 |
+
Our definition of mono is simply the average of the two channels. Any other
|
| 101 |
+
value will be ignored.
|
| 102 |
+
"""
|
| 103 |
+
streams = np.array(range(len(self)))[streams]
|
| 104 |
+
single = not isinstance(streams, np.ndarray)
|
| 105 |
+
if single:
|
| 106 |
+
streams = [streams]
|
| 107 |
+
|
| 108 |
+
if duration is None:
|
| 109 |
+
target_size = None
|
| 110 |
+
query_duration = None
|
| 111 |
+
else:
|
| 112 |
+
target_size = int((samplerate or self.samplerate()) * duration)
|
| 113 |
+
query_duration = float((target_size + 1) / (samplerate or self.samplerate()))
|
| 114 |
+
|
| 115 |
+
with temp_filenames(len(streams)) as filenames:
|
| 116 |
+
command = ['ffmpeg', '-y']
|
| 117 |
+
command += ['-loglevel', 'panic']
|
| 118 |
+
if seek_time:
|
| 119 |
+
command += ['-ss', str(seek_time)]
|
| 120 |
+
command += ['-i', str(self.path)]
|
| 121 |
+
for stream, filename in zip(streams, filenames):
|
| 122 |
+
command += ['-map', f'0:{self._audio_streams[stream]}']
|
| 123 |
+
if query_duration is not None:
|
| 124 |
+
command += ['-t', str(query_duration)]
|
| 125 |
+
command += ['-threads', '1']
|
| 126 |
+
command += ['-f', 'f32le']
|
| 127 |
+
if samplerate is not None:
|
| 128 |
+
command += ['-ar', str(samplerate)]
|
| 129 |
+
command += [filename]
|
| 130 |
+
|
| 131 |
+
sp.run(command, check=True)
|
| 132 |
+
wavs = []
|
| 133 |
+
for filename in filenames:
|
| 134 |
+
wav = np.fromfile(filename, dtype=np.float32)
|
| 135 |
+
wav = torch.from_numpy(wav)
|
| 136 |
+
wav = wav.view(-1, self.channels()).t()
|
| 137 |
+
if channels is not None:
|
| 138 |
+
wav = convert_audio_channels(wav, channels)
|
| 139 |
+
if target_size is not None:
|
| 140 |
+
wav = wav[..., :target_size]
|
| 141 |
+
wavs.append(wav)
|
| 142 |
+
wav = torch.stack(wavs, dim=0)
|
| 143 |
+
if single:
|
| 144 |
+
wav = wav[0]
|
| 145 |
+
return wav
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def convert_audio_channels(wav, channels=2):
|
| 149 |
+
"""Convert audio to the given number of channels."""
|
| 150 |
+
*shape, src_channels, length = wav.shape
|
| 151 |
+
if src_channels == channels:
|
| 152 |
+
pass
|
| 153 |
+
elif channels == 1:
|
| 154 |
+
# Case 1:
|
| 155 |
+
# The caller asked 1-channel audio, but the stream have multiple
|
| 156 |
+
# channels, downmix all channels.
|
| 157 |
+
wav = wav.mean(dim=-2, keepdim=True)
|
| 158 |
+
elif src_channels == 1:
|
| 159 |
+
# Case 2:
|
| 160 |
+
# The caller asked for multiple channels, but the input file have
|
| 161 |
+
# one single channel, replicate the audio over all channels.
|
| 162 |
+
wav = wav.expand(*shape, channels, length)
|
| 163 |
+
elif src_channels >= channels:
|
| 164 |
+
# Case 3:
|
| 165 |
+
# The caller asked for multiple channels, and the input file have
|
| 166 |
+
# more channels than requested. In that case return the first channels.
|
| 167 |
+
wav = wav[..., :channels, :]
|
| 168 |
+
else:
|
| 169 |
+
# Case 4: What is a reasonable choice here?
|
| 170 |
+
raise ValueError('The audio file has less channels than requested but is not mono.')
|
| 171 |
+
return wav
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def convert_audio(wav, from_samplerate, to_samplerate, channels):
|
| 175 |
+
"""Convert audio from a given samplerate to a target one and target number of channels."""
|
| 176 |
+
wav = convert_audio_channels(wav, channels)
|
| 177 |
+
return julius.resample_frac(wav, from_samplerate, to_samplerate)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def i16_pcm(wav):
|
| 181 |
+
"""Convert audio to 16 bits integer PCM format."""
|
| 182 |
+
if wav.dtype.is_floating_point:
|
| 183 |
+
return (wav.clamp_(-1, 1) * (2**15 - 1)).short()
|
| 184 |
+
else:
|
| 185 |
+
return wav
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def f32_pcm(wav):
|
| 189 |
+
"""Convert audio to float 32 bits PCM format."""
|
| 190 |
+
if wav.dtype.is_floating_point:
|
| 191 |
+
return wav
|
| 192 |
+
else:
|
| 193 |
+
return wav.float() / (2**15 - 1)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def as_dtype_pcm(wav):
|
| 197 |
+
"""Convert audio to either f32 pcm or i16 pcm depending on the given dtype."""
|
| 198 |
+
if wav.dtype.is_floating_point:
|
| 199 |
+
return f32_pcm(wav)
|
| 200 |
+
else:
|
| 201 |
+
return i16_pcm(wav)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def encode_mp3(wav, path, samplerate=44100, bitrate=320, verbose=False):
|
| 205 |
+
"""Save given audio as mp3. This should work on all OSes."""
|
| 206 |
+
c, _ = wav.shape
|
| 207 |
+
wav = i16_pcm(wav)
|
| 208 |
+
encoder = lameenc.Encoder()
|
| 209 |
+
encoder.set_bit_rate(bitrate)
|
| 210 |
+
encoder.set_in_sample_rate(samplerate)
|
| 211 |
+
encoder.set_channels(c)
|
| 212 |
+
encoder.set_quality(2) # 2-highest, 7-fastest
|
| 213 |
+
if not verbose:
|
| 214 |
+
encoder.silence()
|
| 215 |
+
wav = wav.data.cpu()
|
| 216 |
+
wav = wav.transpose(0, 1).numpy()
|
| 217 |
+
mp3_data = encoder.encode(wav.tobytes())
|
| 218 |
+
mp3_data += encoder.flush()
|
| 219 |
+
with open(path, "wb") as f:
|
| 220 |
+
f.write(mp3_data)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def prevent_clip(wav, mode='rescale'):
|
| 224 |
+
"""
|
| 225 |
+
different strategies for avoiding raw clipping.
|
| 226 |
+
"""
|
| 227 |
+
if mode is None or mode == 'none':
|
| 228 |
+
return wav
|
| 229 |
+
assert wav.dtype.is_floating_point, "too late for clipping"
|
| 230 |
+
if mode == 'rescale':
|
| 231 |
+
wav = wav / max(1.01 * wav.abs().max(), 1)
|
| 232 |
+
elif mode == 'clamp':
|
| 233 |
+
wav = wav.clamp(-0.99, 0.99)
|
| 234 |
+
elif mode == 'tanh':
|
| 235 |
+
wav = torch.tanh(wav)
|
| 236 |
+
else:
|
| 237 |
+
raise ValueError(f"Invalid mode {mode}")
|
| 238 |
+
return wav
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def save_audio(wav: torch.Tensor,
|
| 242 |
+
path: tp.Union[str, Path],
|
| 243 |
+
samplerate: int,
|
| 244 |
+
bitrate: int = 320,
|
| 245 |
+
clip: tp.Union[str] = 'rescale',
|
| 246 |
+
bits_per_sample: tp.Union[int] = 16,
|
| 247 |
+
as_float: bool = False):
|
| 248 |
+
"""Save audio file, automatically preventing clipping if necessary
|
| 249 |
+
based on the given `clip` strategy. If the path ends in `.mp3`, this
|
| 250 |
+
will save as mp3 with the given `bitrate`.
|
| 251 |
+
"""
|
| 252 |
+
wav = prevent_clip(wav, mode=clip)
|
| 253 |
+
path = Path(path)
|
| 254 |
+
suffix = path.suffix.lower()
|
| 255 |
+
if suffix == ".mp3":
|
| 256 |
+
encode_mp3(wav, path, samplerate, bitrate, verbose=True)
|
| 257 |
+
elif suffix == ".wav":
|
| 258 |
+
if as_float:
|
| 259 |
+
bits_per_sample = 32
|
| 260 |
+
encoding = 'PCM_F'
|
| 261 |
+
else:
|
| 262 |
+
encoding = 'PCM_S'
|
| 263 |
+
ta.save(str(path), wav, sample_rate=samplerate,
|
| 264 |
+
encoding=encoding, bits_per_sample=bits_per_sample)
|
| 265 |
+
elif suffix == ".flac":
|
| 266 |
+
ta.save(str(path), wav, sample_rate=samplerate, bits_per_sample=bits_per_sample)
|
| 267 |
+
else:
|
| 268 |
+
raise ValueError(f"Invalid suffix for path: {suffix}")
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def load_track(track, audio_channels, samplerate):
|
| 272 |
+
errors = {}
|
| 273 |
+
wav = None
|
| 274 |
+
|
| 275 |
+
try:
|
| 276 |
+
wav = AudioFile(track).read(
|
| 277 |
+
streams=0,
|
| 278 |
+
samplerate=samplerate,
|
| 279 |
+
channels=audio_channels)
|
| 280 |
+
except sp.CalledProcessError:
|
| 281 |
+
errors['ffmpeg'] = 'FFmpeg could not read the file.'
|
| 282 |
+
|
| 283 |
+
if wav is None:
|
| 284 |
+
try:
|
| 285 |
+
wav, sr = ta.load(str(track))
|
| 286 |
+
except RuntimeError as err:
|
| 287 |
+
errors['torchaudio'] = err.args[0]
|
| 288 |
+
else:
|
| 289 |
+
wav = convert_audio(wav, sr, samplerate, audio_channels)
|
| 290 |
+
|
| 291 |
+
return wav, errors
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/demucs/models/demucs.py
ADDED
|
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
@File : demucs.py
|
| 5 |
+
@Time : 2023/8/8 下午4:36
|
| 6 |
+
@Author : waytan
|
| 7 |
+
@Contact : waytan@tencent.com
|
| 8 |
+
@License : (C)Copyright 2023, Tencent
|
| 9 |
+
@Desc : Demucs
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import math
|
| 13 |
+
import typing as tp
|
| 14 |
+
|
| 15 |
+
import julius
|
| 16 |
+
import torch
|
| 17 |
+
from torch import nn
|
| 18 |
+
from torch.nn import functional as F
|
| 19 |
+
|
| 20 |
+
from .states import capture_init
|
| 21 |
+
from .utils import center_trim, unfold
|
| 22 |
+
from .transformer import LayerScale
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class BLSTM(nn.Module):
|
| 26 |
+
"""
|
| 27 |
+
BiLSTM with same hidden units as input dim.
|
| 28 |
+
If `max_steps` is not None, input will be splitting in overlapping
|
| 29 |
+
chunks and the LSTM applied separately on each chunk.
|
| 30 |
+
"""
|
| 31 |
+
def __init__(self, dim, layers=1, max_steps=None, skip=False):
|
| 32 |
+
super().__init__()
|
| 33 |
+
assert max_steps is None or max_steps % 4 == 0
|
| 34 |
+
self.max_steps = max_steps
|
| 35 |
+
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
|
| 36 |
+
self.linear = nn.Linear(2 * dim, dim)
|
| 37 |
+
self.skip = skip
|
| 38 |
+
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
b, c, t = x.shape
|
| 41 |
+
y = x
|
| 42 |
+
framed = False
|
| 43 |
+
if self.max_steps is not None and t > self.max_steps:
|
| 44 |
+
width = self.max_steps
|
| 45 |
+
stride = width // 2
|
| 46 |
+
frames = unfold(x, width, stride)
|
| 47 |
+
nframes = frames.shape[2]
|
| 48 |
+
framed = True
|
| 49 |
+
x = frames.permute(0, 2, 1, 3).reshape(-1, c, width)
|
| 50 |
+
|
| 51 |
+
x = x.permute(2, 0, 1)
|
| 52 |
+
|
| 53 |
+
x = self.lstm(x)[0]
|
| 54 |
+
x = self.linear(x)
|
| 55 |
+
x = x.permute(1, 2, 0)
|
| 56 |
+
if framed:
|
| 57 |
+
out = []
|
| 58 |
+
frames = x.reshape(b, -1, c, width)
|
| 59 |
+
limit = stride // 2
|
| 60 |
+
for k in range(nframes):
|
| 61 |
+
if k == 0:
|
| 62 |
+
out.append(frames[:, k, :, :-limit])
|
| 63 |
+
elif k == nframes - 1:
|
| 64 |
+
out.append(frames[:, k, :, limit:])
|
| 65 |
+
else:
|
| 66 |
+
out.append(frames[:, k, :, limit:-limit])
|
| 67 |
+
out = torch.cat(out, -1)
|
| 68 |
+
out = out[..., :t]
|
| 69 |
+
x = out
|
| 70 |
+
if self.skip:
|
| 71 |
+
x = x + y
|
| 72 |
+
return x
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def rescale_conv(conv, reference):
|
| 76 |
+
"""Rescale initial weight scale. It is unclear why it helps but it certainly does.
|
| 77 |
+
"""
|
| 78 |
+
std = conv.weight.std().detach()
|
| 79 |
+
scale = (std / reference)**0.5
|
| 80 |
+
conv.weight.data /= scale
|
| 81 |
+
if conv.bias is not None:
|
| 82 |
+
conv.bias.data /= scale
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def rescale_module(module, reference):
|
| 86 |
+
for sub in module.modules():
|
| 87 |
+
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)):
|
| 88 |
+
rescale_conv(sub, reference)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class DConv(nn.Module):
|
| 92 |
+
"""
|
| 93 |
+
New residual branches in each encoder layer.
|
| 94 |
+
This alternates dilated convolutions, potentially with LSTMs and attention.
|
| 95 |
+
Also before entering each residual branch, dimension is projected on a smaller subspace,
|
| 96 |
+
e.g. of dim `channels // compress`.
|
| 97 |
+
"""
|
| 98 |
+
def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4,
|
| 99 |
+
norm=True, attn=False, heads=4, ndecay=4, lstm=False, gelu=True,
|
| 100 |
+
kernel=3, dilate=True):
|
| 101 |
+
"""
|
| 102 |
+
Args:
|
| 103 |
+
channels: input/output channels for residual branch.
|
| 104 |
+
compress: amount of channel compression inside the branch.
|
| 105 |
+
depth: number of layers in the residual branch. Each layer has its own
|
| 106 |
+
projection, and potentially LSTM and attention.
|
| 107 |
+
init: initial scale for LayerNorm.
|
| 108 |
+
norm: use GroupNorm.
|
| 109 |
+
attn: use LocalAttention.
|
| 110 |
+
heads: number of heads for the LocalAttention.
|
| 111 |
+
ndecay: number of decay controls in the LocalAttention.
|
| 112 |
+
lstm: use LSTM.
|
| 113 |
+
gelu: Use GELU activation.
|
| 114 |
+
kernel: kernel size for the (dilated) convolutions.
|
| 115 |
+
dilate: if true, use dilation, increasing with the depth.
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
super().__init__()
|
| 119 |
+
assert kernel % 2 == 1
|
| 120 |
+
self.channels = channels
|
| 121 |
+
self.compress = compress
|
| 122 |
+
self.depth = abs(depth)
|
| 123 |
+
dilate = depth > 0
|
| 124 |
+
|
| 125 |
+
norm_fn: tp.Callable[[int], nn.Module]
|
| 126 |
+
norm_fn = lambda d: nn.Identity() # noqa
|
| 127 |
+
if norm:
|
| 128 |
+
norm_fn = lambda d: nn.GroupNorm(1, d) # noqa
|
| 129 |
+
|
| 130 |
+
hidden = int(channels / compress)
|
| 131 |
+
|
| 132 |
+
act: tp.Type[nn.Module]
|
| 133 |
+
if gelu:
|
| 134 |
+
act = nn.GELU
|
| 135 |
+
else:
|
| 136 |
+
act = nn.ReLU
|
| 137 |
+
|
| 138 |
+
self.layers = nn.ModuleList([])
|
| 139 |
+
for d in range(self.depth):
|
| 140 |
+
dilation = 2 ** d if dilate else 1
|
| 141 |
+
padding = dilation * (kernel // 2)
|
| 142 |
+
mods = [
|
| 143 |
+
nn.Conv1d(channels, hidden, kernel, dilation=dilation, padding=padding),
|
| 144 |
+
norm_fn(hidden), act(),
|
| 145 |
+
nn.Conv1d(hidden, 2 * channels, 1),
|
| 146 |
+
norm_fn(2 * channels), nn.GLU(1),
|
| 147 |
+
LayerScale(channels, init),
|
| 148 |
+
]
|
| 149 |
+
if attn:
|
| 150 |
+
mods.insert(3, LocalState(hidden, heads=heads, ndecay=ndecay))
|
| 151 |
+
if lstm:
|
| 152 |
+
mods.insert(3, BLSTM(hidden, layers=2, max_steps=200, skip=True))
|
| 153 |
+
layer = nn.Sequential(*mods)
|
| 154 |
+
self.layers.append(layer)
|
| 155 |
+
|
| 156 |
+
def forward(self, x):
|
| 157 |
+
for layer in self.layers:
|
| 158 |
+
x = x + layer(x)
|
| 159 |
+
return x
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class LocalState(nn.Module):
|
| 163 |
+
"""Local state allows to have attention based only on data (no positional embedding),
|
| 164 |
+
but while setting a constraint on the time window (e.g. decaying penalty term).
|
| 165 |
+
|
| 166 |
+
Also a failed experiments with trying to provide some frequency based attention.
|
| 167 |
+
"""
|
| 168 |
+
def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4):
|
| 169 |
+
super().__init__()
|
| 170 |
+
assert channels % heads == 0, (channels, heads)
|
| 171 |
+
self.heads = heads
|
| 172 |
+
self.nfreqs = nfreqs
|
| 173 |
+
self.ndecay = ndecay
|
| 174 |
+
self.content = nn.Conv1d(channels, channels, 1)
|
| 175 |
+
self.query = nn.Conv1d(channels, channels, 1)
|
| 176 |
+
self.key = nn.Conv1d(channels, channels, 1)
|
| 177 |
+
if nfreqs:
|
| 178 |
+
self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1)
|
| 179 |
+
if ndecay:
|
| 180 |
+
self.query_decay = nn.Conv1d(channels, heads * ndecay, 1)
|
| 181 |
+
# Initialize decay close to zero (there is a sigmoid), for maximum initial window.
|
| 182 |
+
self.query_decay.weight.data *= 0.01
|
| 183 |
+
assert self.query_decay.bias is not None # stupid type checker
|
| 184 |
+
self.query_decay.bias.data[:] = -2
|
| 185 |
+
self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1)
|
| 186 |
+
|
| 187 |
+
def forward(self, x):
|
| 188 |
+
b, _, t = x.shape
|
| 189 |
+
heads = self.heads
|
| 190 |
+
indexes = torch.arange(t, device=x.device, dtype=x.dtype)
|
| 191 |
+
# left index are keys, right index are queries
|
| 192 |
+
delta = indexes[:, None] - indexes[None, :]
|
| 193 |
+
|
| 194 |
+
queries = self.query(x).view(b, heads, -1, t)
|
| 195 |
+
keys = self.key(x).view(b, heads, -1, t)
|
| 196 |
+
# t are keys, s are queries
|
| 197 |
+
dots = torch.einsum("bhct,bhcs->bhts", keys, queries)
|
| 198 |
+
dots /= keys.shape[2]**0.5
|
| 199 |
+
if self.nfreqs:
|
| 200 |
+
periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype)
|
| 201 |
+
freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1))
|
| 202 |
+
freq_q = self.query_freqs(x).view(b, heads, -1, t) / self.nfreqs ** 0.5
|
| 203 |
+
dots += torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q)
|
| 204 |
+
if self.ndecay:
|
| 205 |
+
decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype)
|
| 206 |
+
decay_q = self.query_decay(x).view(b, heads, -1, t)
|
| 207 |
+
decay_q = torch.sigmoid(decay_q) / 2
|
| 208 |
+
decay_kernel = - decays.view(-1, 1, 1) * delta.abs() / self.ndecay**0.5
|
| 209 |
+
dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q)
|
| 210 |
+
|
| 211 |
+
# Kill self reference.
|
| 212 |
+
dots.masked_fill_(torch.eye(t, device=dots.device, dtype=torch.bool), -100)
|
| 213 |
+
weights = torch.softmax(dots, dim=2)
|
| 214 |
+
|
| 215 |
+
content = self.content(x).view(b, heads, -1, t)
|
| 216 |
+
result = torch.einsum("bhts,bhct->bhcs", weights, content)
|
| 217 |
+
if self.nfreqs:
|
| 218 |
+
time_sig = torch.einsum("bhts,fts->bhfs", weights, freq_kernel)
|
| 219 |
+
result = torch.cat([result, time_sig], 2)
|
| 220 |
+
result = result.reshape(b, -1, t)
|
| 221 |
+
return x + self.proj(result)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class Demucs(nn.Module):
|
| 225 |
+
@capture_init
|
| 226 |
+
def __init__(self,
|
| 227 |
+
sources,
|
| 228 |
+
# Channels
|
| 229 |
+
audio_channels=2,
|
| 230 |
+
channels=64,
|
| 231 |
+
growth=2.,
|
| 232 |
+
# Main structure
|
| 233 |
+
depth=6,
|
| 234 |
+
rewrite=True,
|
| 235 |
+
lstm_layers=0,
|
| 236 |
+
# Convolutions
|
| 237 |
+
kernel_size=8,
|
| 238 |
+
stride=4,
|
| 239 |
+
context=1,
|
| 240 |
+
# Activations
|
| 241 |
+
gelu=True,
|
| 242 |
+
glu=True,
|
| 243 |
+
# Normalization
|
| 244 |
+
norm_starts=4,
|
| 245 |
+
norm_groups=4,
|
| 246 |
+
# DConv residual branch
|
| 247 |
+
dconv_mode=1,
|
| 248 |
+
dconv_depth=2,
|
| 249 |
+
dconv_comp=4,
|
| 250 |
+
dconv_attn=4,
|
| 251 |
+
dconv_lstm=4,
|
| 252 |
+
dconv_init=1e-4,
|
| 253 |
+
# Pre/post processing
|
| 254 |
+
normalize=True,
|
| 255 |
+
resample=True,
|
| 256 |
+
# Weight init
|
| 257 |
+
rescale=0.1,
|
| 258 |
+
# Metadata
|
| 259 |
+
samplerate=44100,
|
| 260 |
+
segment=4 * 10):
|
| 261 |
+
"""
|
| 262 |
+
Args:
|
| 263 |
+
sources (list[str]): list of source names
|
| 264 |
+
audio_channels (int): stereo or mono
|
| 265 |
+
channels (int): first convolution channels
|
| 266 |
+
depth (int): number of encoder/decoder layers
|
| 267 |
+
growth (float): multiply (resp divide) number of channels by that
|
| 268 |
+
for each layer of the encoder (resp decoder)
|
| 269 |
+
depth (int): number of layers in the encoder and in the decoder.
|
| 270 |
+
rewrite (bool): add 1x1 convolution to each layer.
|
| 271 |
+
lstm_layers (int): number of lstm layers, 0 = no lstm. Deactivated
|
| 272 |
+
by default, as this is now replaced by the smaller and faster small LSTMs
|
| 273 |
+
in the DConv branches.
|
| 274 |
+
kernel_size (int): kernel size for convolutions
|
| 275 |
+
stride (int): stride for convolutions
|
| 276 |
+
context (int): kernel size of the convolution in the
|
| 277 |
+
decoder before the transposed convolution. If > 1,
|
| 278 |
+
will provide some context from neighboring time steps.
|
| 279 |
+
gelu: use GELU activation function.
|
| 280 |
+
glu (bool): use glu instead of ReLU for the 1x1 rewrite conv.
|
| 281 |
+
norm_starts: layer at which group norm starts being used.
|
| 282 |
+
decoder layers are numbered in reverse order.
|
| 283 |
+
norm_groups: number of groups for group norm.
|
| 284 |
+
dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
|
| 285 |
+
dconv_depth: depth of residual DConv branch.
|
| 286 |
+
dconv_comp: compression of DConv branch.
|
| 287 |
+
dconv_attn: adds attention layers in DConv branch starting at this layer.
|
| 288 |
+
dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
|
| 289 |
+
dconv_init: initial scale for the DConv branch LayerScale.
|
| 290 |
+
normalize (bool): normalizes the input audio on the fly, and scales back
|
| 291 |
+
the output by the same amount.
|
| 292 |
+
resample (bool): upsample x2 the input and downsample /2 the output.
|
| 293 |
+
rescale (float): rescale initial weights of convolutions
|
| 294 |
+
to get their standard deviation closer to `rescale`.
|
| 295 |
+
samplerate (int): stored as meta information for easing
|
| 296 |
+
future evaluations of the model.
|
| 297 |
+
segment (float): duration of the chunks of audio to ideally evaluate the model on.
|
| 298 |
+
This is used by `demucs.apply.apply_model`.
|
| 299 |
+
"""
|
| 300 |
+
|
| 301 |
+
super().__init__()
|
| 302 |
+
self.audio_channels = audio_channels
|
| 303 |
+
self.sources = sources
|
| 304 |
+
self.kernel_size = kernel_size
|
| 305 |
+
self.context = context
|
| 306 |
+
self.stride = stride
|
| 307 |
+
self.depth = depth
|
| 308 |
+
self.resample = resample
|
| 309 |
+
self.channels = channels
|
| 310 |
+
self.normalize = normalize
|
| 311 |
+
self.samplerate = samplerate
|
| 312 |
+
self.segment = segment
|
| 313 |
+
self.encoder = nn.ModuleList()
|
| 314 |
+
self.decoder = nn.ModuleList()
|
| 315 |
+
self.skip_scales = nn.ModuleList()
|
| 316 |
+
|
| 317 |
+
if glu:
|
| 318 |
+
activation = nn.GLU(dim=1)
|
| 319 |
+
ch_scale = 2
|
| 320 |
+
else:
|
| 321 |
+
activation = nn.ReLU()
|
| 322 |
+
ch_scale = 1
|
| 323 |
+
if gelu:
|
| 324 |
+
act2 = nn.GELU
|
| 325 |
+
else:
|
| 326 |
+
act2 = nn.ReLU
|
| 327 |
+
|
| 328 |
+
in_channels = audio_channels
|
| 329 |
+
padding = 0
|
| 330 |
+
for index in range(depth):
|
| 331 |
+
norm_fn = lambda d: nn.Identity() # noqa
|
| 332 |
+
if index >= norm_starts:
|
| 333 |
+
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
|
| 334 |
+
|
| 335 |
+
encode = []
|
| 336 |
+
encode += [
|
| 337 |
+
nn.Conv1d(in_channels, channels, kernel_size, stride),
|
| 338 |
+
norm_fn(channels),
|
| 339 |
+
act2(),
|
| 340 |
+
]
|
| 341 |
+
attn = index >= dconv_attn
|
| 342 |
+
lstm = index >= dconv_lstm
|
| 343 |
+
if dconv_mode & 1:
|
| 344 |
+
encode += [DConv(channels, depth=dconv_depth, init=dconv_init,
|
| 345 |
+
compress=dconv_comp, attn=attn, lstm=lstm)]
|
| 346 |
+
if rewrite:
|
| 347 |
+
encode += [
|
| 348 |
+
nn.Conv1d(channels, ch_scale * channels, 1),
|
| 349 |
+
norm_fn(ch_scale * channels), activation]
|
| 350 |
+
self.encoder.append(nn.Sequential(*encode))
|
| 351 |
+
|
| 352 |
+
decode = []
|
| 353 |
+
if index > 0:
|
| 354 |
+
out_channels = in_channels
|
| 355 |
+
else:
|
| 356 |
+
out_channels = len(self.sources) * audio_channels
|
| 357 |
+
if rewrite:
|
| 358 |
+
decode += [
|
| 359 |
+
nn.Conv1d(channels, ch_scale * channels, 2 * context + 1, padding=context),
|
| 360 |
+
norm_fn(ch_scale * channels), activation]
|
| 361 |
+
if dconv_mode & 2:
|
| 362 |
+
decode += [DConv(channels, depth=dconv_depth, init=dconv_init,
|
| 363 |
+
compress=dconv_comp, attn=attn, lstm=lstm)]
|
| 364 |
+
decode += [nn.ConvTranspose1d(channels, out_channels,
|
| 365 |
+
kernel_size, stride, padding=padding)]
|
| 366 |
+
if index > 0:
|
| 367 |
+
decode += [norm_fn(out_channels), act2()]
|
| 368 |
+
self.decoder.insert(0, nn.Sequential(*decode))
|
| 369 |
+
in_channels = channels
|
| 370 |
+
channels = int(growth * channels)
|
| 371 |
+
|
| 372 |
+
channels = in_channels
|
| 373 |
+
if lstm_layers:
|
| 374 |
+
self.lstm = BLSTM(channels, lstm_layers)
|
| 375 |
+
else:
|
| 376 |
+
self.lstm = None
|
| 377 |
+
|
| 378 |
+
if rescale:
|
| 379 |
+
rescale_module(self, reference=rescale)
|
| 380 |
+
|
| 381 |
+
def valid_length(self, length):
|
| 382 |
+
"""
|
| 383 |
+
Return the nearest valid length to use with the model so that
|
| 384 |
+
there is no time steps left over in a convolution, e.g. for all
|
| 385 |
+
layers, size of the input - kernel_size % stride = 0.
|
| 386 |
+
|
| 387 |
+
Note that input are automatically padded if necessary to ensure that the output
|
| 388 |
+
has the same length as the input.
|
| 389 |
+
"""
|
| 390 |
+
if self.resample:
|
| 391 |
+
length *= 2
|
| 392 |
+
|
| 393 |
+
for _ in range(self.depth):
|
| 394 |
+
length = math.ceil((length - self.kernel_size) / self.stride) + 1
|
| 395 |
+
length = max(1, length)
|
| 396 |
+
|
| 397 |
+
for _ in range(self.depth):
|
| 398 |
+
length = (length - 1) * self.stride + self.kernel_size
|
| 399 |
+
|
| 400 |
+
if self.resample:
|
| 401 |
+
length = math.ceil(length / 2)
|
| 402 |
+
return int(length)
|
| 403 |
+
|
| 404 |
+
def forward(self, mix):
|
| 405 |
+
x = mix
|
| 406 |
+
length = x.shape[-1]
|
| 407 |
+
|
| 408 |
+
if self.normalize:
|
| 409 |
+
mono = mix.mean(dim=1, keepdim=True)
|
| 410 |
+
mean = mono.mean(dim=-1, keepdim=True)
|
| 411 |
+
std = mono.std(dim=-1, keepdim=True)
|
| 412 |
+
x = (x - mean) / (1e-5 + std)
|
| 413 |
+
else:
|
| 414 |
+
mean = 0
|
| 415 |
+
std = 1
|
| 416 |
+
|
| 417 |
+
delta = self.valid_length(length) - length
|
| 418 |
+
x = F.pad(x, (delta // 2, delta - delta // 2))
|
| 419 |
+
|
| 420 |
+
if self.resample:
|
| 421 |
+
x = julius.resample_frac(x, 1, 2)
|
| 422 |
+
|
| 423 |
+
saved = []
|
| 424 |
+
for encode in self.encoder:
|
| 425 |
+
x = encode(x)
|
| 426 |
+
saved.append(x)
|
| 427 |
+
|
| 428 |
+
if self.lstm:
|
| 429 |
+
x = self.lstm(x)
|
| 430 |
+
|
| 431 |
+
for decode in self.decoder:
|
| 432 |
+
skip = saved.pop(-1)
|
| 433 |
+
skip = center_trim(skip, x)
|
| 434 |
+
x = decode(x + skip)
|
| 435 |
+
|
| 436 |
+
if self.resample:
|
| 437 |
+
x = julius.resample_frac(x, 2, 1)
|
| 438 |
+
x = x * std + mean
|
| 439 |
+
x = center_trim(x, length)
|
| 440 |
+
x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
|
| 441 |
+
return x
|
| 442 |
+
|
| 443 |
+
def load_state_dict(self, state, strict=True):
|
| 444 |
+
# fix a mismatch with previous generation Demucs models.
|
| 445 |
+
for idx in range(self.depth):
|
| 446 |
+
for a in ['encoder', 'decoder']:
|
| 447 |
+
for b in ['bias', 'weight']:
|
| 448 |
+
new = f'{a}.{idx}.3.{b}'
|
| 449 |
+
old = f'{a}.{idx}.2.{b}'
|
| 450 |
+
if old in state and new not in state:
|
| 451 |
+
state[new] = state.pop(old)
|
| 452 |
+
super().load_state_dict(state, strict=strict)
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/demucs/models/htdemucs.py
ADDED
|
@@ -0,0 +1,955 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
@File : htdemucs.py
|
| 5 |
+
@Time : 2023/8/8 下午4:27
|
| 6 |
+
@Author : waytan
|
| 7 |
+
@Contact : waytan@tencent.com
|
| 8 |
+
@License : (C)Copyright 2023, Tencent
|
| 9 |
+
@Desc : The spectrogram and Hybrid version of Demucs
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import math
|
| 13 |
+
import typing as tp
|
| 14 |
+
from copy import deepcopy
|
| 15 |
+
from fractions import Fraction
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
from torch import nn
|
| 19 |
+
from torch.nn import functional as F
|
| 20 |
+
from einops import rearrange
|
| 21 |
+
from openunmix.filtering import wiener
|
| 22 |
+
|
| 23 |
+
from .transformer import CrossTransformerEncoder
|
| 24 |
+
from .demucs import DConv, rescale_module
|
| 25 |
+
from .states import capture_init
|
| 26 |
+
from .spec import spectro, ispectro
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
|
| 30 |
+
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
| 31 |
+
If this is the case, we insert extra 0 padding to the right before the reflection happen."""
|
| 32 |
+
x0 = x
|
| 33 |
+
length = x.shape[-1]
|
| 34 |
+
padding_left, padding_right = paddings
|
| 35 |
+
if mode == 'reflect':
|
| 36 |
+
max_pad = max(padding_left, padding_right)
|
| 37 |
+
if length <= max_pad:
|
| 38 |
+
extra_pad = max_pad - length + 1
|
| 39 |
+
extra_pad_right = min(padding_right, extra_pad)
|
| 40 |
+
extra_pad_left = extra_pad - extra_pad_right
|
| 41 |
+
paddings = (padding_left - extra_pad_left, padding_right - extra_pad_right)
|
| 42 |
+
x = F.pad(x, (extra_pad_left, extra_pad_right))
|
| 43 |
+
out = F.pad(x, paddings, mode, value)
|
| 44 |
+
assert out.shape[-1] == length + padding_left + padding_right
|
| 45 |
+
assert (out[..., padding_left: padding_left + length] == x0).all()
|
| 46 |
+
return out
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class ScaledEmbedding(nn.Module):
|
| 50 |
+
"""
|
| 51 |
+
Boost learning rate for embeddings (with `scale`).
|
| 52 |
+
Also, can make embeddings continuous with `smooth`.
|
| 53 |
+
"""
|
| 54 |
+
def __init__(self, num_embeddings: int, embedding_dim: int,
|
| 55 |
+
scale: float = 10., smooth=False):
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
|
| 58 |
+
if smooth:
|
| 59 |
+
weight = torch.cumsum(self.embedding.weight.data, dim=0)
|
| 60 |
+
# when summing gaussian, overscale raises as sqrt(n), so we nornalize by that.
|
| 61 |
+
weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None]
|
| 62 |
+
self.embedding.weight.data[:] = weight
|
| 63 |
+
self.embedding.weight.data /= scale
|
| 64 |
+
self.scale = scale
|
| 65 |
+
|
| 66 |
+
@property
|
| 67 |
+
def weight(self):
|
| 68 |
+
return self.embedding.weight * self.scale
|
| 69 |
+
|
| 70 |
+
def forward(self, x):
|
| 71 |
+
out = self.embedding(x) * self.scale
|
| 72 |
+
return out
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class HEncLayer(nn.Module):
|
| 76 |
+
def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False,
|
| 77 |
+
freq=True, dconv=True, norm=True, context=0, dconv_kw=None, pad=True,
|
| 78 |
+
rewrite=True):
|
| 79 |
+
"""Encoder layer. This used both by the time and the frequency branch.
|
| 80 |
+
"""
|
| 81 |
+
super().__init__()
|
| 82 |
+
norm_fn = lambda d: nn.Identity() # noqa
|
| 83 |
+
if norm:
|
| 84 |
+
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
|
| 85 |
+
if pad:
|
| 86 |
+
pad = kernel_size // 4
|
| 87 |
+
else:
|
| 88 |
+
pad = 0
|
| 89 |
+
klass = nn.Conv1d
|
| 90 |
+
self.freq = freq
|
| 91 |
+
self.kernel_size = kernel_size
|
| 92 |
+
self.stride = stride
|
| 93 |
+
self.empty = empty
|
| 94 |
+
self.norm = norm
|
| 95 |
+
self.pad = pad
|
| 96 |
+
if freq:
|
| 97 |
+
kernel_size = [kernel_size, 1]
|
| 98 |
+
stride = [stride, 1]
|
| 99 |
+
pad = [pad, 0]
|
| 100 |
+
klass = nn.Conv2d
|
| 101 |
+
self.conv = klass(chin, chout, kernel_size, stride, pad)
|
| 102 |
+
if self.empty:
|
| 103 |
+
return
|
| 104 |
+
self.norm1 = norm_fn(chout)
|
| 105 |
+
self.rewrite = None
|
| 106 |
+
if rewrite:
|
| 107 |
+
self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
|
| 108 |
+
self.norm2 = norm_fn(2 * chout)
|
| 109 |
+
|
| 110 |
+
self.dconv = None
|
| 111 |
+
if dconv:
|
| 112 |
+
self.dconv = DConv(chout, **dconv_kw)
|
| 113 |
+
|
| 114 |
+
def forward(self, x, inject=None):
|
| 115 |
+
"""
|
| 116 |
+
`inject` is used to inject the result from the time branch into the frequency branch,
|
| 117 |
+
when both have the same stride.
|
| 118 |
+
"""
|
| 119 |
+
if not self.freq and x.dim() == 4:
|
| 120 |
+
b, c, fr, t = x.shape
|
| 121 |
+
x = x.view(b, -1, t)
|
| 122 |
+
|
| 123 |
+
if not self.freq:
|
| 124 |
+
le = x.shape[-1]
|
| 125 |
+
if not le % self.stride == 0:
|
| 126 |
+
x = F.pad(x, (0, self.stride - (le % self.stride)))
|
| 127 |
+
y = self.conv(x)
|
| 128 |
+
if self.empty:
|
| 129 |
+
return y
|
| 130 |
+
if inject is not None:
|
| 131 |
+
assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape)
|
| 132 |
+
if inject.dim() == 3 and y.dim() == 4:
|
| 133 |
+
inject = inject[:, :, None]
|
| 134 |
+
y = y + inject
|
| 135 |
+
y = F.gelu(self.norm1(y))
|
| 136 |
+
if self.dconv:
|
| 137 |
+
if self.freq:
|
| 138 |
+
b, c, fr, t = y.shape
|
| 139 |
+
y = y.permute(0, 2, 1, 3).reshape(-1, c, t)
|
| 140 |
+
y = self.dconv(y)
|
| 141 |
+
if self.freq:
|
| 142 |
+
y = y.view(b, fr, c, t).permute(0, 2, 1, 3)
|
| 143 |
+
if self.rewrite:
|
| 144 |
+
z = self.norm2(self.rewrite(y))
|
| 145 |
+
z = F.glu(z, dim=1)
|
| 146 |
+
else:
|
| 147 |
+
z = y
|
| 148 |
+
return z
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class MultiWrap(nn.Module):
|
| 152 |
+
"""
|
| 153 |
+
Takes one layer and replicate it N times. each replica will act
|
| 154 |
+
on a frequency band. All is done so that if the N replica have the same weights,
|
| 155 |
+
then this is exactly equivalent to applying the original module on all frequencies.
|
| 156 |
+
"""
|
| 157 |
+
def __init__(self, layer, split_ratios):
|
| 158 |
+
super().__init__()
|
| 159 |
+
self.split_ratios = split_ratios
|
| 160 |
+
self.layers = nn.ModuleList()
|
| 161 |
+
self.conv = isinstance(layer, HEncLayer)
|
| 162 |
+
assert not layer.norm
|
| 163 |
+
assert layer.freq
|
| 164 |
+
assert layer.pad
|
| 165 |
+
if not self.conv:
|
| 166 |
+
assert not layer.context_freq
|
| 167 |
+
for _ in range(len(split_ratios) + 1):
|
| 168 |
+
lay = deepcopy(layer)
|
| 169 |
+
if self.conv:
|
| 170 |
+
lay.conv.padding = (0, 0)
|
| 171 |
+
else:
|
| 172 |
+
lay.pad = False
|
| 173 |
+
for m in lay.modules():
|
| 174 |
+
if hasattr(m, 'reset_parameters'):
|
| 175 |
+
m.reset_parameters()
|
| 176 |
+
self.layers.append(lay)
|
| 177 |
+
|
| 178 |
+
def forward(self, x, skip=None, length=None):
|
| 179 |
+
_, _, fr, _ = x.shape
|
| 180 |
+
|
| 181 |
+
ratios = list(self.split_ratios) + [1]
|
| 182 |
+
start = 0
|
| 183 |
+
outs = []
|
| 184 |
+
for ratio, layer in zip(ratios, self.layers):
|
| 185 |
+
if self.conv:
|
| 186 |
+
pad = layer.kernel_size // 4
|
| 187 |
+
if ratio == 1:
|
| 188 |
+
limit = fr
|
| 189 |
+
frames = -1
|
| 190 |
+
else:
|
| 191 |
+
limit = int(round(fr * ratio))
|
| 192 |
+
le = limit - start
|
| 193 |
+
if start == 0:
|
| 194 |
+
le += pad
|
| 195 |
+
frames = round((le - layer.kernel_size) / layer.stride + 1)
|
| 196 |
+
limit = start + (frames - 1) * layer.stride + layer.kernel_size
|
| 197 |
+
if start == 0:
|
| 198 |
+
limit -= pad
|
| 199 |
+
assert limit - start > 0, (limit, start)
|
| 200 |
+
assert limit <= fr, (limit, fr)
|
| 201 |
+
y = x[:, :, start:limit, :]
|
| 202 |
+
if start == 0:
|
| 203 |
+
y = F.pad(y, (0, 0, pad, 0))
|
| 204 |
+
if ratio == 1:
|
| 205 |
+
y = F.pad(y, (0, 0, 0, pad))
|
| 206 |
+
outs.append(layer(y))
|
| 207 |
+
start = limit - layer.kernel_size + layer.stride
|
| 208 |
+
else:
|
| 209 |
+
if ratio == 1:
|
| 210 |
+
limit = fr
|
| 211 |
+
else:
|
| 212 |
+
limit = int(round(fr * ratio))
|
| 213 |
+
last = layer.last
|
| 214 |
+
layer.last = True
|
| 215 |
+
|
| 216 |
+
y = x[:, :, start:limit]
|
| 217 |
+
s = skip[:, :, start:limit]
|
| 218 |
+
out, _ = layer(y, s, None)
|
| 219 |
+
if outs:
|
| 220 |
+
outs[-1][:, :, -layer.stride:] += (
|
| 221 |
+
out[:, :, :layer.stride] - layer.conv_tr.bias.view(1, -1, 1, 1))
|
| 222 |
+
out = out[:, :, layer.stride:]
|
| 223 |
+
if ratio == 1:
|
| 224 |
+
out = out[:, :, :-layer.stride // 2, :]
|
| 225 |
+
if start == 0:
|
| 226 |
+
out = out[:, :, layer.stride // 2:, :]
|
| 227 |
+
outs.append(out)
|
| 228 |
+
layer.last = last
|
| 229 |
+
start = limit
|
| 230 |
+
out = torch.cat(outs, dim=2)
|
| 231 |
+
if not self.conv and not last:
|
| 232 |
+
out = F.gelu(out)
|
| 233 |
+
if self.conv:
|
| 234 |
+
return out
|
| 235 |
+
else:
|
| 236 |
+
return out, None
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class HDecLayer(nn.Module):
|
| 240 |
+
def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False,
|
| 241 |
+
freq=True, dconv=True, norm=True, context=1, dconv_kw=None, pad=True,
|
| 242 |
+
context_freq=True, rewrite=True):
|
| 243 |
+
"""
|
| 244 |
+
Same as HEncLayer but for decoder. See `HEncLayer` for documentation.
|
| 245 |
+
"""
|
| 246 |
+
super().__init__()
|
| 247 |
+
norm_fn = lambda d: nn.Identity() # noqa
|
| 248 |
+
if norm:
|
| 249 |
+
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
|
| 250 |
+
if pad:
|
| 251 |
+
pad = kernel_size // 4
|
| 252 |
+
else:
|
| 253 |
+
pad = 0
|
| 254 |
+
self.pad = pad
|
| 255 |
+
self.last = last
|
| 256 |
+
self.freq = freq
|
| 257 |
+
self.chin = chin
|
| 258 |
+
self.empty = empty
|
| 259 |
+
self.stride = stride
|
| 260 |
+
self.kernel_size = kernel_size
|
| 261 |
+
self.norm = norm
|
| 262 |
+
self.context_freq = context_freq
|
| 263 |
+
klass = nn.Conv1d
|
| 264 |
+
klass_tr = nn.ConvTranspose1d
|
| 265 |
+
if freq:
|
| 266 |
+
kernel_size = [kernel_size, 1]
|
| 267 |
+
stride = [stride, 1]
|
| 268 |
+
klass = nn.Conv2d
|
| 269 |
+
klass_tr = nn.ConvTranspose2d
|
| 270 |
+
self.conv_tr = klass_tr(chin, chout, kernel_size, stride)
|
| 271 |
+
self.norm2 = norm_fn(chout)
|
| 272 |
+
if self.empty:
|
| 273 |
+
return
|
| 274 |
+
self.rewrite = None
|
| 275 |
+
if rewrite:
|
| 276 |
+
if context_freq:
|
| 277 |
+
self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)
|
| 278 |
+
else:
|
| 279 |
+
self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1,
|
| 280 |
+
[0, context])
|
| 281 |
+
self.norm1 = norm_fn(2 * chin)
|
| 282 |
+
|
| 283 |
+
self.dconv = None
|
| 284 |
+
if dconv:
|
| 285 |
+
self.dconv = DConv(chin, **dconv_kw)
|
| 286 |
+
|
| 287 |
+
def forward(self, x, skip, length):
|
| 288 |
+
if self.freq and x.dim() == 3:
|
| 289 |
+
b, c, t = x.shape
|
| 290 |
+
x = x.view(b, self.chin, -1, t)
|
| 291 |
+
|
| 292 |
+
if not self.empty:
|
| 293 |
+
x = x + skip
|
| 294 |
+
|
| 295 |
+
if self.rewrite:
|
| 296 |
+
y = F.glu(self.norm1(self.rewrite(x)), dim=1)
|
| 297 |
+
else:
|
| 298 |
+
y = x
|
| 299 |
+
if self.dconv:
|
| 300 |
+
if self.freq:
|
| 301 |
+
b, c, fr, t = y.shape
|
| 302 |
+
y = y.permute(0, 2, 1, 3).reshape(-1, c, t)
|
| 303 |
+
y = self.dconv(y)
|
| 304 |
+
if self.freq:
|
| 305 |
+
y = y.view(b, fr, c, t).permute(0, 2, 1, 3)
|
| 306 |
+
else:
|
| 307 |
+
y = x
|
| 308 |
+
assert skip is None
|
| 309 |
+
z = self.norm2(self.conv_tr(y))
|
| 310 |
+
if self.freq:
|
| 311 |
+
if self.pad:
|
| 312 |
+
z = z[..., self.pad:-self.pad, :]
|
| 313 |
+
else:
|
| 314 |
+
z = z[..., self.pad:self.pad + length]
|
| 315 |
+
assert z.shape[-1] == length, (z.shape[-1], length)
|
| 316 |
+
if not self.last:
|
| 317 |
+
z = F.gelu(z)
|
| 318 |
+
return z, y
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
class HTDemucs(nn.Module):
|
| 322 |
+
"""
|
| 323 |
+
Spectrogram and hybrid Demucs model.
|
| 324 |
+
The spectrogram model has the same structure as Demucs, except the first few layers are over the
|
| 325 |
+
frequency axis, until there is only 1 frequency, and then it moves to time convolutions.
|
| 326 |
+
Frequency layers can still access information across time steps thanks to the DConv residual.
|
| 327 |
+
|
| 328 |
+
Hybrid model have a parallel time branch. At some layer, the time branch has the same stride
|
| 329 |
+
as the frequency branch and then the two are combined. The opposite happens in the decoder.
|
| 330 |
+
|
| 331 |
+
Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]),
|
| 332 |
+
or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on
|
| 333 |
+
Open Unmix implementation [Stoter et al. 2019].
|
| 334 |
+
|
| 335 |
+
The loss is always on the temporal domain, by backpropagating through the above
|
| 336 |
+
output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks
|
| 337 |
+
a bit Wiener filtering, as doing more iteration at test time will change the spectrogram
|
| 338 |
+
contribution, without changing the one from the waveform, which will lead to worse performance.
|
| 339 |
+
I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve.
|
| 340 |
+
CaC on the other hand provides similar performance for hybrid, and works naturally with
|
| 341 |
+
hybrid models.
|
| 342 |
+
|
| 343 |
+
This model also uses frequency embeddings are used to improve efficiency on convolutions
|
| 344 |
+
over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf).
|
| 345 |
+
|
| 346 |
+
Unlike classic Demucs, there is no resampling here, and normalization is always applied.
|
| 347 |
+
"""
|
| 348 |
+
|
| 349 |
+
@capture_init
|
| 350 |
+
def __init__(
|
| 351 |
+
self,
|
| 352 |
+
sources,
|
| 353 |
+
# Channels
|
| 354 |
+
audio_channels=2,
|
| 355 |
+
channels=48,
|
| 356 |
+
channels_time=None,
|
| 357 |
+
growth=2,
|
| 358 |
+
# STFT
|
| 359 |
+
nfft=4096,
|
| 360 |
+
wiener_iters=0,
|
| 361 |
+
end_iters=0,
|
| 362 |
+
wiener_residual=False,
|
| 363 |
+
cac=True,
|
| 364 |
+
# Main structure
|
| 365 |
+
depth=4,
|
| 366 |
+
rewrite=True,
|
| 367 |
+
# Frequency branch
|
| 368 |
+
multi_freqs=None,
|
| 369 |
+
multi_freqs_depth=3,
|
| 370 |
+
freq_emb=0.2,
|
| 371 |
+
emb_scale=10,
|
| 372 |
+
emb_smooth=True,
|
| 373 |
+
# Convolutions
|
| 374 |
+
kernel_size=8,
|
| 375 |
+
time_stride=2,
|
| 376 |
+
stride=4,
|
| 377 |
+
context=1,
|
| 378 |
+
context_enc=0,
|
| 379 |
+
# Normalization
|
| 380 |
+
norm_starts=4,
|
| 381 |
+
norm_groups=4,
|
| 382 |
+
# DConv residual branch
|
| 383 |
+
dconv_mode=1,
|
| 384 |
+
dconv_depth=2,
|
| 385 |
+
dconv_comp=8,
|
| 386 |
+
dconv_init=1e-3,
|
| 387 |
+
# Before the Transformer
|
| 388 |
+
bottom_channels=0,
|
| 389 |
+
# Transformer
|
| 390 |
+
t_layers=5,
|
| 391 |
+
t_emb="sin",
|
| 392 |
+
t_hidden_scale=4.0,
|
| 393 |
+
t_heads=8,
|
| 394 |
+
t_dropout=0.0,
|
| 395 |
+
t_max_positions=10000,
|
| 396 |
+
t_norm_in=True,
|
| 397 |
+
t_norm_in_group=False,
|
| 398 |
+
t_group_norm=False,
|
| 399 |
+
t_norm_first=True,
|
| 400 |
+
t_norm_out=True,
|
| 401 |
+
t_max_period=10000.0,
|
| 402 |
+
t_weight_decay=0.0,
|
| 403 |
+
t_lr=None,
|
| 404 |
+
t_layer_scale=True,
|
| 405 |
+
t_gelu=True,
|
| 406 |
+
t_weight_pos_embed=1.0,
|
| 407 |
+
t_sin_random_shift=0,
|
| 408 |
+
t_cape_mean_normalize=True,
|
| 409 |
+
t_cape_augment=True,
|
| 410 |
+
t_cape_glob_loc_scale=None,
|
| 411 |
+
t_sparse_self_attn=False,
|
| 412 |
+
t_sparse_cross_attn=False,
|
| 413 |
+
t_mask_type="diag",
|
| 414 |
+
t_mask_random_seed=42,
|
| 415 |
+
t_sparse_attn_window=500,
|
| 416 |
+
t_global_window=100,
|
| 417 |
+
t_sparsity=0.95,
|
| 418 |
+
t_auto_sparsity=False,
|
| 419 |
+
# ------ Particuliar parameters
|
| 420 |
+
t_cross_first=False,
|
| 421 |
+
# Weight init
|
| 422 |
+
rescale=0.1,
|
| 423 |
+
# Metadata
|
| 424 |
+
samplerate=44100,
|
| 425 |
+
segment=10,
|
| 426 |
+
use_train_segment=True,
|
| 427 |
+
):
|
| 428 |
+
"""
|
| 429 |
+
Args:
|
| 430 |
+
sources (list[str]): list of source names.
|
| 431 |
+
audio_channels (int): input/output audio channels.
|
| 432 |
+
channels (int): initial number of hidden channels.
|
| 433 |
+
channels_time: if not None, use a different `channels` value for the time branch.
|
| 434 |
+
growth: increase the number of hidden channels by this factor at each layer.
|
| 435 |
+
nfft: number of fft bins. Note that changing this require careful computation of
|
| 436 |
+
various shape parameters and will not work out of the box for hybrid models.
|
| 437 |
+
wiener_iters: when using Wiener filtering, number of iterations at test time.
|
| 438 |
+
end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
|
| 439 |
+
wiener_residual: add residual source before wiener filtering.
|
| 440 |
+
cac: uses complex as channels, i.e. complex numbers are 2 channels each
|
| 441 |
+
in input and output. no further processing is done before ISTFT.
|
| 442 |
+
depth (int): number of layers in the encoder and in the decoder.
|
| 443 |
+
rewrite (bool): add 1x1 convolution to each layer.
|
| 444 |
+
multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`.
|
| 445 |
+
multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost
|
| 446 |
+
layers will be wrapped.
|
| 447 |
+
freq_emb: add frequency embedding after the first frequency layer if > 0,
|
| 448 |
+
the actual value controls the weight of the embedding.
|
| 449 |
+
emb_scale: equivalent to scaling the embedding learning rate
|
| 450 |
+
emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
|
| 451 |
+
kernel_size: kernel_size for encoder and decoder layers.
|
| 452 |
+
stride: stride for encoder and decoder layers.
|
| 453 |
+
time_stride: stride for the final time layer, after the merge.
|
| 454 |
+
context: context for 1x1 conv in the decoder.
|
| 455 |
+
context_enc: context for 1x1 conv in the encoder.
|
| 456 |
+
norm_starts: layer at which group norm starts being used.
|
| 457 |
+
decoder layers are numbered in reverse order.
|
| 458 |
+
norm_groups: number of groups for group norm.
|
| 459 |
+
dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
|
| 460 |
+
dconv_depth: depth of residual DConv branch.
|
| 461 |
+
dconv_comp: compression of DConv branch.
|
| 462 |
+
dconv_attn: adds attention layers in DConv branch starting at this layer.
|
| 463 |
+
dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
|
| 464 |
+
dconv_init: initial scale for the DConv branch LayerScale.
|
| 465 |
+
bottom_channels: if >0 it adds a linear layer (1x1 Conv) before and after the
|
| 466 |
+
transformer in order to change the number of channels
|
| 467 |
+
t_layers: number of layers in each branch (waveform and spec) of the transformer
|
| 468 |
+
t_emb: "sin", "cape" or "scaled"
|
| 469 |
+
t_hidden_scale: the hidden scale of the Feedforward parts of the transformer
|
| 470 |
+
for instance if C = 384 (the number of channels in the transformer) and
|
| 471 |
+
t_hidden_scale = 4.0 then the intermediate layer of the FFN has dimension
|
| 472 |
+
384 * 4 = 1536
|
| 473 |
+
t_heads: number of heads for the transformer
|
| 474 |
+
t_dropout: dropout in the transformer
|
| 475 |
+
t_max_positions: max_positions for the "scaled" positional embedding, only
|
| 476 |
+
useful if t_emb="scaled"
|
| 477 |
+
t_norm_in: (bool) norm before addinf positional embedding and getting into the
|
| 478 |
+
transformer layers
|
| 479 |
+
t_norm_in_group: (bool) if True while t_norm_in=True, the norm is on all the
|
| 480 |
+
timesteps (GroupNorm with group=1)
|
| 481 |
+
t_group_norm: (bool) if True, the norms of the Encoder Layers are on all the
|
| 482 |
+
timesteps (GroupNorm with group=1)
|
| 483 |
+
t_norm_first: (bool) if True the norm is before the attention and before the FFN
|
| 484 |
+
t_norm_out: (bool) if True, there is a GroupNorm (group=1) at the end of each layer
|
| 485 |
+
t_max_period: (float) denominator in the sinusoidal embedding expression
|
| 486 |
+
t_weight_decay: (float) weight decay for the transformer
|
| 487 |
+
t_lr: (float) specific learning rate for the transformer
|
| 488 |
+
t_layer_scale: (bool) Layer Scale for the transformer
|
| 489 |
+
t_gelu: (bool) activations of the transformer are GeLU if True, ReLU else
|
| 490 |
+
t_weight_pos_embed: (float) weighting of the positional embedding
|
| 491 |
+
t_cape_mean_normalize: (bool) if t_emb="cape", normalisation of positional embeddings
|
| 492 |
+
see: https://arxiv.org/abs/2106.03143
|
| 493 |
+
t_cape_augment: (bool) if t_emb="cape", must be True during training and False
|
| 494 |
+
during the inference, see: https://arxiv.org/abs/2106.03143
|
| 495 |
+
t_cape_glob_loc_scale: (list of 3 floats) if t_emb="cape", CAPE parameters
|
| 496 |
+
see: https://arxiv.org/abs/2106.03143
|
| 497 |
+
t_sparse_self_attn: (bool) if True, the self attentions are sparse
|
| 498 |
+
t_sparse_cross_attn: (bool) if True, the cross-attentions are sparse (don't use it
|
| 499 |
+
unless you designed really specific masks)
|
| 500 |
+
t_mask_type: (str) can be "diag", "jmask", "random", "global" or any combination
|
| 501 |
+
with '_' between: i.e. "diag_jmask_random" (note that this is permutation
|
| 502 |
+
invariant i.e. "diag_jmask_random" is equivalent to "jmask_random_diag")
|
| 503 |
+
t_mask_random_seed: (int) if "random" is in t_mask_type, controls the seed
|
| 504 |
+
that generated the random part of the mask
|
| 505 |
+
t_sparse_attn_window: (int) if "diag" is in t_mask_type, for a query (i), and
|
| 506 |
+
a key (j), the mask is True id |i-j|<=t_sparse_attn_window
|
| 507 |
+
t_global_window: (int) if "global" is in t_mask_type, mask[:t_global_window, :]
|
| 508 |
+
and mask[:, :t_global_window] will be True
|
| 509 |
+
t_sparsity: (float) if "random" is in t_mask_type, t_sparsity is the sparsity
|
| 510 |
+
level of the random part of the mask.
|
| 511 |
+
t_cross_first: (bool) if True cross attention is the first layer of the
|
| 512 |
+
transformer (False seems to be better)
|
| 513 |
+
rescale: weight rescaling trick
|
| 514 |
+
use_train_segment: (bool) if True, the actual size that is used during the
|
| 515 |
+
training is used during inference.
|
| 516 |
+
"""
|
| 517 |
+
super().__init__()
|
| 518 |
+
self.cac = cac
|
| 519 |
+
self.wiener_residual = wiener_residual
|
| 520 |
+
self.audio_channels = audio_channels
|
| 521 |
+
self.sources = sources
|
| 522 |
+
self.kernel_size = kernel_size
|
| 523 |
+
self.context = context
|
| 524 |
+
self.stride = stride
|
| 525 |
+
self.depth = depth
|
| 526 |
+
self.bottom_channels = bottom_channels
|
| 527 |
+
self.channels = channels
|
| 528 |
+
self.samplerate = samplerate
|
| 529 |
+
self.segment = segment
|
| 530 |
+
self.use_train_segment = use_train_segment
|
| 531 |
+
self.nfft = nfft
|
| 532 |
+
self.hop_length = nfft // 4
|
| 533 |
+
self.wiener_iters = wiener_iters
|
| 534 |
+
self.end_iters = end_iters
|
| 535 |
+
self.freq_emb = None
|
| 536 |
+
assert wiener_iters == end_iters
|
| 537 |
+
|
| 538 |
+
self.encoder = nn.ModuleList()
|
| 539 |
+
self.decoder = nn.ModuleList()
|
| 540 |
+
|
| 541 |
+
self.tencoder = nn.ModuleList()
|
| 542 |
+
self.tdecoder = nn.ModuleList()
|
| 543 |
+
|
| 544 |
+
chin = audio_channels
|
| 545 |
+
chin_z = chin # number of channels for the freq branch
|
| 546 |
+
if self.cac:
|
| 547 |
+
chin_z *= 2
|
| 548 |
+
chout = channels_time or channels
|
| 549 |
+
chout_z = channels
|
| 550 |
+
freqs = nfft // 2
|
| 551 |
+
|
| 552 |
+
for index in range(depth):
|
| 553 |
+
norm = index >= norm_starts
|
| 554 |
+
freq = freqs > 1
|
| 555 |
+
stri = stride
|
| 556 |
+
ker = kernel_size
|
| 557 |
+
if not freq:
|
| 558 |
+
assert freqs == 1
|
| 559 |
+
ker = time_stride * 2
|
| 560 |
+
stri = time_stride
|
| 561 |
+
|
| 562 |
+
pad = True
|
| 563 |
+
last_freq = False
|
| 564 |
+
if freq and freqs <= kernel_size:
|
| 565 |
+
ker = freqs
|
| 566 |
+
pad = False
|
| 567 |
+
last_freq = True
|
| 568 |
+
|
| 569 |
+
kw = {
|
| 570 |
+
"kernel_size": ker,
|
| 571 |
+
"stride": stri,
|
| 572 |
+
"freq": freq,
|
| 573 |
+
"pad": pad,
|
| 574 |
+
"norm": norm,
|
| 575 |
+
"rewrite": rewrite,
|
| 576 |
+
"norm_groups": norm_groups,
|
| 577 |
+
"dconv_kw": {
|
| 578 |
+
"depth": dconv_depth,
|
| 579 |
+
"compress": dconv_comp,
|
| 580 |
+
"init": dconv_init,
|
| 581 |
+
"gelu": True,
|
| 582 |
+
},
|
| 583 |
+
}
|
| 584 |
+
kwt = dict(kw)
|
| 585 |
+
kwt["freq"] = 0
|
| 586 |
+
kwt["kernel_size"] = kernel_size
|
| 587 |
+
kwt["stride"] = stride
|
| 588 |
+
kwt["pad"] = True
|
| 589 |
+
kw_dec = dict(kw)
|
| 590 |
+
multi = False
|
| 591 |
+
if multi_freqs and index < multi_freqs_depth:
|
| 592 |
+
multi = True
|
| 593 |
+
kw_dec["context_freq"] = False
|
| 594 |
+
|
| 595 |
+
if last_freq:
|
| 596 |
+
chout_z = max(chout, chout_z)
|
| 597 |
+
chout = chout_z
|
| 598 |
+
|
| 599 |
+
enc = HEncLayer(
|
| 600 |
+
chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw
|
| 601 |
+
)
|
| 602 |
+
if freq:
|
| 603 |
+
tenc = HEncLayer(
|
| 604 |
+
chin,
|
| 605 |
+
chout,
|
| 606 |
+
dconv=dconv_mode & 1,
|
| 607 |
+
context=context_enc,
|
| 608 |
+
empty=last_freq,
|
| 609 |
+
**kwt
|
| 610 |
+
)
|
| 611 |
+
self.tencoder.append(tenc)
|
| 612 |
+
|
| 613 |
+
if multi:
|
| 614 |
+
enc = MultiWrap(enc, multi_freqs)
|
| 615 |
+
self.encoder.append(enc)
|
| 616 |
+
if index == 0:
|
| 617 |
+
chin = self.audio_channels * len(self.sources)
|
| 618 |
+
chin_z = chin
|
| 619 |
+
if self.cac:
|
| 620 |
+
chin_z *= 2
|
| 621 |
+
dec = HDecLayer(
|
| 622 |
+
chout_z,
|
| 623 |
+
chin_z,
|
| 624 |
+
dconv=dconv_mode & 2,
|
| 625 |
+
last=index == 0,
|
| 626 |
+
context=context,
|
| 627 |
+
**kw_dec
|
| 628 |
+
)
|
| 629 |
+
if multi:
|
| 630 |
+
dec = MultiWrap(dec, multi_freqs)
|
| 631 |
+
if freq:
|
| 632 |
+
tdec = HDecLayer(
|
| 633 |
+
chout,
|
| 634 |
+
chin,
|
| 635 |
+
dconv=dconv_mode & 2,
|
| 636 |
+
empty=last_freq,
|
| 637 |
+
last=index == 0,
|
| 638 |
+
context=context,
|
| 639 |
+
**kwt
|
| 640 |
+
)
|
| 641 |
+
self.tdecoder.insert(0, tdec)
|
| 642 |
+
self.decoder.insert(0, dec)
|
| 643 |
+
|
| 644 |
+
chin = chout
|
| 645 |
+
chin_z = chout_z
|
| 646 |
+
chout = int(growth * chout)
|
| 647 |
+
chout_z = int(growth * chout_z)
|
| 648 |
+
if freq:
|
| 649 |
+
if freqs <= kernel_size:
|
| 650 |
+
freqs = 1
|
| 651 |
+
else:
|
| 652 |
+
freqs //= stride
|
| 653 |
+
if index == 0 and freq_emb:
|
| 654 |
+
self.freq_emb = ScaledEmbedding(
|
| 655 |
+
freqs, chin_z, smooth=emb_smooth, scale=emb_scale
|
| 656 |
+
)
|
| 657 |
+
self.freq_emb_scale = freq_emb
|
| 658 |
+
|
| 659 |
+
if rescale:
|
| 660 |
+
rescale_module(self, reference=rescale)
|
| 661 |
+
|
| 662 |
+
transformer_channels = channels * growth ** (depth - 1)
|
| 663 |
+
if bottom_channels:
|
| 664 |
+
self.channel_upsampler = nn.Conv1d(transformer_channels, bottom_channels, 1)
|
| 665 |
+
self.channel_downsampler = nn.Conv1d(
|
| 666 |
+
bottom_channels, transformer_channels, 1
|
| 667 |
+
)
|
| 668 |
+
self.channel_upsampler_t = nn.Conv1d(
|
| 669 |
+
transformer_channels, bottom_channels, 1
|
| 670 |
+
)
|
| 671 |
+
self.channel_downsampler_t = nn.Conv1d(
|
| 672 |
+
bottom_channels, transformer_channels, 1
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
+
transformer_channels = bottom_channels
|
| 676 |
+
|
| 677 |
+
if t_layers > 0:
|
| 678 |
+
if t_cape_glob_loc_scale is None:
|
| 679 |
+
t_cape_glob_loc_scale = [5000.0, 1.0, 1.4]
|
| 680 |
+
self.crosstransformer = CrossTransformerEncoder(
|
| 681 |
+
dim=transformer_channels,
|
| 682 |
+
emb=t_emb,
|
| 683 |
+
hidden_scale=t_hidden_scale,
|
| 684 |
+
num_heads=t_heads,
|
| 685 |
+
num_layers=t_layers,
|
| 686 |
+
cross_first=t_cross_first,
|
| 687 |
+
dropout=t_dropout,
|
| 688 |
+
max_positions=t_max_positions,
|
| 689 |
+
norm_in=t_norm_in,
|
| 690 |
+
norm_in_group=t_norm_in_group,
|
| 691 |
+
group_norm=t_group_norm,
|
| 692 |
+
norm_first=t_norm_first,
|
| 693 |
+
norm_out=t_norm_out,
|
| 694 |
+
max_period=t_max_period,
|
| 695 |
+
weight_decay=t_weight_decay,
|
| 696 |
+
lr=t_lr,
|
| 697 |
+
layer_scale=t_layer_scale,
|
| 698 |
+
gelu=t_gelu,
|
| 699 |
+
sin_random_shift=t_sin_random_shift,
|
| 700 |
+
weight_pos_embed=t_weight_pos_embed,
|
| 701 |
+
cape_mean_normalize=t_cape_mean_normalize,
|
| 702 |
+
cape_augment=t_cape_augment,
|
| 703 |
+
cape_glob_loc_scale=t_cape_glob_loc_scale,
|
| 704 |
+
sparse_self_attn=t_sparse_self_attn,
|
| 705 |
+
sparse_cross_attn=t_sparse_cross_attn,
|
| 706 |
+
mask_type=t_mask_type,
|
| 707 |
+
mask_random_seed=t_mask_random_seed,
|
| 708 |
+
sparse_attn_window=t_sparse_attn_window,
|
| 709 |
+
global_window=t_global_window,
|
| 710 |
+
sparsity=t_sparsity,
|
| 711 |
+
auto_sparsity=t_auto_sparsity,
|
| 712 |
+
)
|
| 713 |
+
else:
|
| 714 |
+
self.crosstransformer = None
|
| 715 |
+
|
| 716 |
+
def _spec(self, x):
|
| 717 |
+
hl = self.hop_length
|
| 718 |
+
nfft = self.nfft
|
| 719 |
+
|
| 720 |
+
# We re-pad the signal in order to keep the property
|
| 721 |
+
# that the size of the output is exactly the size of the input
|
| 722 |
+
# divided by the stride (here hop_length), when divisible.
|
| 723 |
+
# This is achieved by padding by 1/4th of the kernel size (here nfft).
|
| 724 |
+
# which is not supported by torch.stft.
|
| 725 |
+
# Having all convolution operations follow this convention allow to easily
|
| 726 |
+
# align the time and frequency branches later on.
|
| 727 |
+
assert hl == nfft // 4
|
| 728 |
+
le = int(math.ceil(x.shape[-1] / hl))
|
| 729 |
+
pad = hl // 2 * 3
|
| 730 |
+
x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect")
|
| 731 |
+
|
| 732 |
+
z = spectro(x, nfft, hl)[..., :-1, :]
|
| 733 |
+
assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
|
| 734 |
+
z = z[..., 2: 2 + le]
|
| 735 |
+
return z
|
| 736 |
+
|
| 737 |
+
def _ispec(self, z, length=None, scale=0):
|
| 738 |
+
hl = self.hop_length // (4**scale)
|
| 739 |
+
z = F.pad(z, (0, 0, 0, 1))
|
| 740 |
+
z = F.pad(z, (2, 2))
|
| 741 |
+
pad = hl // 2 * 3
|
| 742 |
+
le = hl * int(math.ceil(length / hl)) + 2 * pad
|
| 743 |
+
x = ispectro(z, hl, length=le)
|
| 744 |
+
x = x[..., pad: pad + length]
|
| 745 |
+
return x
|
| 746 |
+
|
| 747 |
+
def _magnitude(self, z):
|
| 748 |
+
# return the magnitude of the spectrogram, except when cac is True,
|
| 749 |
+
# in which case we just move the complex dimension to the channel one.
|
| 750 |
+
if self.cac:
|
| 751 |
+
b, c, fr, t = z.shape
|
| 752 |
+
m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
|
| 753 |
+
m = m.reshape(b, c * 2, fr, t)
|
| 754 |
+
else:
|
| 755 |
+
m = z.abs()
|
| 756 |
+
return m
|
| 757 |
+
|
| 758 |
+
def _mask(self, z, m):
|
| 759 |
+
# Apply masking given the mixture spectrogram `z` and the estimated mask `m`.
|
| 760 |
+
# If `cac` is True, `m` is actually a full spectrogram and `z` is ignored.
|
| 761 |
+
niters = self.wiener_iters
|
| 762 |
+
if self.cac:
|
| 763 |
+
b, s, _, fr, t = m.shape
|
| 764 |
+
out = m.view(b, s, -1, 2, fr, t).permute(0, 1, 2, 4, 5, 3)
|
| 765 |
+
out = torch.view_as_complex(out.contiguous())
|
| 766 |
+
return out
|
| 767 |
+
if self.training:
|
| 768 |
+
niters = self.end_iters
|
| 769 |
+
if niters < 0:
|
| 770 |
+
z = z[:, None]
|
| 771 |
+
return z / (1e-8 + z.abs()) * m
|
| 772 |
+
else:
|
| 773 |
+
return self._wiener(m, z, niters)
|
| 774 |
+
|
| 775 |
+
def _wiener(self, mag_out, mix_stft, niters):
|
| 776 |
+
# apply wiener filtering from OpenUnmix.
|
| 777 |
+
init = mix_stft.dtype
|
| 778 |
+
wiener_win_len = 300
|
| 779 |
+
residual = self.wiener_residual
|
| 780 |
+
|
| 781 |
+
b, s, c, fq, t = mag_out.shape
|
| 782 |
+
mag_out = mag_out.permute(0, 4, 3, 2, 1)
|
| 783 |
+
mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
|
| 784 |
+
|
| 785 |
+
outs = []
|
| 786 |
+
for sample in range(b):
|
| 787 |
+
pos = 0
|
| 788 |
+
out = []
|
| 789 |
+
for pos in range(0, t, wiener_win_len):
|
| 790 |
+
frame = slice(pos, pos + wiener_win_len)
|
| 791 |
+
z_out = wiener(
|
| 792 |
+
mag_out[sample, frame],
|
| 793 |
+
mix_stft[sample, frame],
|
| 794 |
+
niters,
|
| 795 |
+
residual=residual,
|
| 796 |
+
)
|
| 797 |
+
out.append(z_out.transpose(-1, -2))
|
| 798 |
+
outs.append(torch.cat(out, dim=0))
|
| 799 |
+
out = torch.view_as_complex(torch.stack(outs, 0))
|
| 800 |
+
out = out.permute(0, 4, 3, 2, 1).contiguous()
|
| 801 |
+
if residual:
|
| 802 |
+
out = out[:, :-1]
|
| 803 |
+
assert list(out.shape) == [b, s, c, fq, t]
|
| 804 |
+
return out.to(init)
|
| 805 |
+
|
| 806 |
+
def valid_length(self, length: int):
|
| 807 |
+
"""
|
| 808 |
+
Return a length that is appropriate for evaluation.
|
| 809 |
+
In our case, always return the training length, unless
|
| 810 |
+
it is smaller than the given length, in which case this
|
| 811 |
+
raises an error.
|
| 812 |
+
"""
|
| 813 |
+
if not self.use_train_segment:
|
| 814 |
+
return length
|
| 815 |
+
training_length = int(self.segment * self.samplerate)
|
| 816 |
+
if training_length < length:
|
| 817 |
+
raise ValueError(
|
| 818 |
+
f"Given length {length} is longer than "
|
| 819 |
+
f"training length {training_length}")
|
| 820 |
+
return training_length
|
| 821 |
+
|
| 822 |
+
def forward(self, mix):
|
| 823 |
+
length = mix.shape[-1]
|
| 824 |
+
length_pre_pad = None
|
| 825 |
+
if self.use_train_segment:
|
| 826 |
+
if self.training:
|
| 827 |
+
self.segment = Fraction(mix.shape[-1], self.samplerate)
|
| 828 |
+
else:
|
| 829 |
+
training_length = int(self.segment * self.samplerate)
|
| 830 |
+
if mix.shape[-1] < training_length:
|
| 831 |
+
length_pre_pad = mix.shape[-1]
|
| 832 |
+
mix = F.pad(mix, (0, training_length - length_pre_pad))
|
| 833 |
+
z = self._spec(mix)
|
| 834 |
+
mag = self._magnitude(z).to(mix.device)
|
| 835 |
+
x = mag
|
| 836 |
+
|
| 837 |
+
b, _, fq, t = x.shape
|
| 838 |
+
|
| 839 |
+
# unlike previous Demucs, we always normalize because it is easier.
|
| 840 |
+
mean = x.mean(dim=(1, 2, 3), keepdim=True)
|
| 841 |
+
std = x.std(dim=(1, 2, 3), keepdim=True)
|
| 842 |
+
x = (x - mean) / (1e-5 + std)
|
| 843 |
+
# x will be the freq. branch input.
|
| 844 |
+
|
| 845 |
+
# Prepare the time branch input.
|
| 846 |
+
xt = mix
|
| 847 |
+
meant = xt.mean(dim=(1, 2), keepdim=True)
|
| 848 |
+
stdt = xt.std(dim=(1, 2), keepdim=True)
|
| 849 |
+
xt = (xt - meant) / (1e-5 + stdt)
|
| 850 |
+
|
| 851 |
+
# okay, this is a giant mess I know...
|
| 852 |
+
saved = [] # skip connections, freq.
|
| 853 |
+
saved_t = [] # skip connections, time.
|
| 854 |
+
lengths = [] # saved lengths to properly remove padding, freq branch.
|
| 855 |
+
lengths_t = [] # saved lengths for time branch.
|
| 856 |
+
for idx, encode in enumerate(self.encoder):
|
| 857 |
+
lengths.append(x.shape[-1])
|
| 858 |
+
inject = None
|
| 859 |
+
if idx < len(self.tencoder):
|
| 860 |
+
# we have not yet merged branches.
|
| 861 |
+
lengths_t.append(xt.shape[-1])
|
| 862 |
+
tenc = self.tencoder[idx]
|
| 863 |
+
xt = tenc(xt)
|
| 864 |
+
if not tenc.empty:
|
| 865 |
+
# save for skip connection
|
| 866 |
+
saved_t.append(xt)
|
| 867 |
+
else:
|
| 868 |
+
# tenc contains just the first conv., so that now time and freq.
|
| 869 |
+
# branches have the same shape and can be merged.
|
| 870 |
+
inject = xt
|
| 871 |
+
x = encode(x, inject)
|
| 872 |
+
if idx == 0 and self.freq_emb is not None:
|
| 873 |
+
# add frequency embedding to allow for non equivariant convolutions
|
| 874 |
+
# over the frequency axis.
|
| 875 |
+
frs = torch.arange(x.shape[-2], device=x.device)
|
| 876 |
+
emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
|
| 877 |
+
x = x + self.freq_emb_scale * emb
|
| 878 |
+
|
| 879 |
+
saved.append(x)
|
| 880 |
+
if self.crosstransformer:
|
| 881 |
+
if self.bottom_channels:
|
| 882 |
+
_, _, f, _ = x.shape
|
| 883 |
+
x = rearrange(x, "b c f t-> b c (f t)")
|
| 884 |
+
x = self.channel_upsampler(x)
|
| 885 |
+
x = rearrange(x, "b c (f t)-> b c f t", f=f)
|
| 886 |
+
xt = self.channel_upsampler_t(xt)
|
| 887 |
+
|
| 888 |
+
x, xt = self.crosstransformer(x, xt)
|
| 889 |
+
|
| 890 |
+
if self.bottom_channels:
|
| 891 |
+
x = rearrange(x, "b c f t-> b c (f t)")
|
| 892 |
+
x = self.channel_downsampler(x)
|
| 893 |
+
x = rearrange(x, "b c (f t)-> b c f t", f=f)
|
| 894 |
+
xt = self.channel_downsampler_t(xt)
|
| 895 |
+
|
| 896 |
+
for idx, decode in enumerate(self.decoder):
|
| 897 |
+
skip = saved.pop(-1)
|
| 898 |
+
x, pre = decode(x, skip, lengths.pop(-1))
|
| 899 |
+
# `pre` contains the output just before final transposed convolution,
|
| 900 |
+
# which is used when the freq. and time branch separate.
|
| 901 |
+
|
| 902 |
+
offset = self.depth - len(self.tdecoder)
|
| 903 |
+
if idx >= offset:
|
| 904 |
+
tdec = self.tdecoder[idx - offset]
|
| 905 |
+
length_t = lengths_t.pop(-1)
|
| 906 |
+
if tdec.empty:
|
| 907 |
+
assert pre.shape[2] == 1, pre.shape
|
| 908 |
+
pre = pre[:, :, 0]
|
| 909 |
+
xt, _ = tdec(pre, None, length_t)
|
| 910 |
+
else:
|
| 911 |
+
skip = saved_t.pop(-1)
|
| 912 |
+
xt, _ = tdec(xt, skip, length_t)
|
| 913 |
+
|
| 914 |
+
# Let's make sure we used all stored skip connections.
|
| 915 |
+
assert len(saved) == 0
|
| 916 |
+
assert len(lengths_t) == 0
|
| 917 |
+
assert len(saved_t) == 0
|
| 918 |
+
|
| 919 |
+
s = len(self.sources)
|
| 920 |
+
x = x.view(b, s, -1, fq, t)
|
| 921 |
+
x = x * std[:, None] + mean[:, None]
|
| 922 |
+
|
| 923 |
+
# to cpu as mps doesnt support complex numbers
|
| 924 |
+
# demucs issue #435 ##432
|
| 925 |
+
# NOTE: in this case z already is on cpu
|
| 926 |
+
# TODO: remove this when mps supports complex numbers
|
| 927 |
+
x_is_mps = x.device.type == "mps"
|
| 928 |
+
if x_is_mps:
|
| 929 |
+
x = x.cpu()
|
| 930 |
+
|
| 931 |
+
zout = self._mask(z, x)
|
| 932 |
+
if self.use_train_segment:
|
| 933 |
+
if self.training:
|
| 934 |
+
x = self._ispec(zout, length)
|
| 935 |
+
else:
|
| 936 |
+
x = self._ispec(zout, training_length)
|
| 937 |
+
else:
|
| 938 |
+
x = self._ispec(zout, length)
|
| 939 |
+
|
| 940 |
+
# back to mps device
|
| 941 |
+
if x_is_mps:
|
| 942 |
+
x = x.to("mps")
|
| 943 |
+
|
| 944 |
+
if self.use_train_segment:
|
| 945 |
+
if self.training:
|
| 946 |
+
xt = xt.view(b, s, -1, length)
|
| 947 |
+
else:
|
| 948 |
+
xt = xt.view(b, s, -1, training_length)
|
| 949 |
+
else:
|
| 950 |
+
xt = xt.view(b, s, -1, length)
|
| 951 |
+
xt = xt * stdt[:, None] + meant[:, None]
|
| 952 |
+
x = xt + x
|
| 953 |
+
if length_pre_pad:
|
| 954 |
+
x = x[..., :length_pre_pad]
|
| 955 |
+
return x
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/demucs/models/pretrained.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
@File : pretrained.py
|
| 5 |
+
@Time : 2023/8/8 下午7:22
|
| 6 |
+
@Author : waytan
|
| 7 |
+
@Contact : waytan@tencent.com
|
| 8 |
+
@License : (C)Copyright 2023, Tencent
|
| 9 |
+
@Desc : Loading pretrained models.
|
| 10 |
+
"""
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
import yaml
|
| 14 |
+
|
| 15 |
+
from .apply import BagOfModels
|
| 16 |
+
from .htdemucs import HTDemucs
|
| 17 |
+
from .states import load_state_dict
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def add_model_flags(parser):
|
| 21 |
+
group = parser.add_mutually_exclusive_group(required=False)
|
| 22 |
+
group.add_argument("-s", "--sig", help="Locally trained XP signature.")
|
| 23 |
+
group.add_argument("-n", "--name", default=None,
|
| 24 |
+
help="Pretrained model name or signature. Default is htdemucs.")
|
| 25 |
+
parser.add_argument("--repo", type=Path,
|
| 26 |
+
help="Folder containing all pre-trained models for use with -n.")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_model_from_yaml(yaml_file, model_file):
|
| 30 |
+
bag = yaml.safe_load(open(yaml_file))
|
| 31 |
+
model = load_state_dict(HTDemucs, model_file)
|
| 32 |
+
weights = bag.get('weights')
|
| 33 |
+
segment = bag.get('segment')
|
| 34 |
+
return BagOfModels([model], weights, segment)
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/demucs/models/spec.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
@File : spec.py
|
| 5 |
+
@Time : 2023/8/8 下午5:10
|
| 6 |
+
@Author : waytan
|
| 7 |
+
@Contact : waytan@tencent.com
|
| 8 |
+
@License : (C)Copyright 2023, Tencent
|
| 9 |
+
@Desc : Spec
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import torch as th
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def spectro(x, n_fft=512, hop_length=None, pad=0):
|
| 16 |
+
*other, length = x.shape
|
| 17 |
+
x = x.reshape(-1, length)
|
| 18 |
+
is_mps = x.device.type == 'mps'
|
| 19 |
+
if is_mps:
|
| 20 |
+
x = x.cpu()
|
| 21 |
+
z = th.stft(x,
|
| 22 |
+
n_fft * (1 + pad),
|
| 23 |
+
hop_length or n_fft // 4,
|
| 24 |
+
window=th.hann_window(n_fft).to(x),
|
| 25 |
+
win_length=n_fft,
|
| 26 |
+
normalized=True,
|
| 27 |
+
center=True,
|
| 28 |
+
return_complex=True,
|
| 29 |
+
pad_mode='reflect')
|
| 30 |
+
_, freqs, frame = z.shape
|
| 31 |
+
return z.view(*other, freqs, frame)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def ispectro(z, hop_length=None, length=None, pad=0):
|
| 35 |
+
*other, freqs, frames = z.shape
|
| 36 |
+
n_fft = 2 * freqs - 2
|
| 37 |
+
z = z.view(-1, freqs, frames)
|
| 38 |
+
win_length = n_fft // (1 + pad)
|
| 39 |
+
is_mps = z.device.type == 'mps'
|
| 40 |
+
if is_mps:
|
| 41 |
+
z = z.cpu()
|
| 42 |
+
x = th.istft(z,
|
| 43 |
+
n_fft,
|
| 44 |
+
hop_length,
|
| 45 |
+
window=th.hann_window(win_length).to(z.real),
|
| 46 |
+
win_length=win_length,
|
| 47 |
+
normalized=True,
|
| 48 |
+
length=length,
|
| 49 |
+
center=True)
|
| 50 |
+
_, length = x.shape
|
| 51 |
+
return x.view(*other, length)
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/demucs/models/states.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
@File : states.py
|
| 5 |
+
@Time : 2023/8/8 下午7:01
|
| 6 |
+
@Author : waytan
|
| 7 |
+
@Contact : waytan@tencent.com
|
| 8 |
+
@License : (C)Copyright 2023, Tencent
|
| 9 |
+
@Desc : Utilities to save and load models.
|
| 10 |
+
"""
|
| 11 |
+
import functools
|
| 12 |
+
import inspect
|
| 13 |
+
import warnings
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from fractions import Fraction
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def load_state_dict(net, pth_path):
|
| 21 |
+
kwargs = {'sources': ['drums', 'bass', 'other', 'vocal'], 'audio_channels': 2, 'samplerate': 44100,
|
| 22 |
+
'segment': Fraction(39, 5), 'channels': 48, 'channels_time': None, 'growth': 2, 'nfft': 4096,
|
| 23 |
+
'wiener_iters': 0, 'end_iters': 0, 'wiener_residual': False, 'cac': True, 'depth': 4, 'rewrite': True,
|
| 24 |
+
'multi_freqs': [], 'multi_freqs_depth': 3, 'freq_emb': 0.2, 'emb_scale': 10, 'emb_smooth': True,
|
| 25 |
+
'kernel_size': 8, 'stride': 4, 'time_stride': 2, 'context': 1, 'context_enc': 0, 'norm_starts': 4,
|
| 26 |
+
'norm_groups': 4, 'dconv_mode': 3, 'dconv_depth': 2, 'dconv_comp': 8, 'dconv_init': 0.001,
|
| 27 |
+
'bottom_channels': 512, 't_layers': 5, 't_hidden_scale': 4.0, 't_heads': 8, 't_dropout': 0.02,
|
| 28 |
+
't_layer_scale': True, 't_gelu': True, 't_emb': 'sin', 't_max_positions': 10000, 't_max_period': 10000.0,
|
| 29 |
+
't_weight_pos_embed': 1.0, 't_cape_mean_normalize': True, 't_cape_augment': True,
|
| 30 |
+
't_cape_glob_loc_scale': [5000.0, 1.0, 1.4], 't_sin_random_shift': 0, 't_norm_in': True,
|
| 31 |
+
't_norm_in_group': False, 't_group_norm': False, 't_norm_first': True, 't_norm_out': True,
|
| 32 |
+
't_weight_decay': 0.0, 't_lr': None, 't_sparse_self_attn': False, 't_sparse_cross_attn': False,
|
| 33 |
+
't_mask_type': 'diag', 't_mask_random_seed': 42, 't_sparse_attn_window': 400, 't_global_window': 100,
|
| 34 |
+
't_sparsity': 0.95, 't_auto_sparsity': False, 't_cross_first': False, 'rescale': 0.1}
|
| 35 |
+
model = net(**kwargs)
|
| 36 |
+
state_dict = torch.load(pth_path)
|
| 37 |
+
model.load_state_dict(state_dict)
|
| 38 |
+
return model
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def load_model(path_or_package, strict=False):
|
| 42 |
+
"""Load a model from the given serialized model, either given as a dict (already loaded)
|
| 43 |
+
or a path to a file on disk."""
|
| 44 |
+
if isinstance(path_or_package, dict):
|
| 45 |
+
package = path_or_package
|
| 46 |
+
elif isinstance(path_or_package, (str, Path)):
|
| 47 |
+
with warnings.catch_warnings():
|
| 48 |
+
warnings.simplefilter("ignore")
|
| 49 |
+
path = path_or_package
|
| 50 |
+
package = torch.load(path, 'cpu')
|
| 51 |
+
else:
|
| 52 |
+
raise ValueError(f"Invalid type for {path_or_package}.")
|
| 53 |
+
|
| 54 |
+
klass = package["klass"]
|
| 55 |
+
args = package["args"]
|
| 56 |
+
kwargs = package["kwargs"]
|
| 57 |
+
|
| 58 |
+
if strict:
|
| 59 |
+
model = klass(*args, **kwargs)
|
| 60 |
+
else:
|
| 61 |
+
sig = inspect.signature(klass)
|
| 62 |
+
for key in list(kwargs):
|
| 63 |
+
if key not in sig.parameters:
|
| 64 |
+
warnings.warn("Dropping inexistant parameter " + key)
|
| 65 |
+
del kwargs[key]
|
| 66 |
+
model = klass(*args, **kwargs)
|
| 67 |
+
|
| 68 |
+
state = package["state"]
|
| 69 |
+
|
| 70 |
+
set_state(model, state)
|
| 71 |
+
return model
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def get_state(model, quantizer, half=False):
|
| 75 |
+
"""Get the state from a model, potentially with quantization applied.
|
| 76 |
+
If `half` is True, model are stored as half precision, which shouldn't impact performance
|
| 77 |
+
but half the state size."""
|
| 78 |
+
if quantizer is None:
|
| 79 |
+
dtype = torch.half if half else None
|
| 80 |
+
state = {k: p.data.to(device='cpu', dtype=dtype) for k, p in model.state_dict().items()}
|
| 81 |
+
else:
|
| 82 |
+
state = quantizer.get_quantized_state()
|
| 83 |
+
state['__quantized'] = True
|
| 84 |
+
return state
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def set_state(model, state, quantizer=None):
|
| 88 |
+
"""Set the state on a given model."""
|
| 89 |
+
if state.get('__quantized'):
|
| 90 |
+
quantizer.restore_quantized_state(model, state['quantized'])
|
| 91 |
+
else:
|
| 92 |
+
model.load_state_dict(state)
|
| 93 |
+
return state
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def capture_init(init):
|
| 97 |
+
@functools.wraps(init)
|
| 98 |
+
def __init__(self, *args, **kwargs):
|
| 99 |
+
self._init_args_kwargs = (args, kwargs)
|
| 100 |
+
init(self, *args, **kwargs)
|
| 101 |
+
|
| 102 |
+
return __init__
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/demucs/models/transformer.py
ADDED
|
@@ -0,0 +1,765 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
@File : transformer.py
|
| 5 |
+
@Time : 2023/8/8 下午5:05
|
| 6 |
+
@Author : waytan
|
| 7 |
+
@Contact : waytan@tencent.com
|
| 8 |
+
@License : (C)Copyright 2023, Tencent
|
| 9 |
+
@Desc : Transformer
|
| 10 |
+
"""
|
| 11 |
+
import math
|
| 12 |
+
import random
|
| 13 |
+
import typing as tp
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
import numpy as np
|
| 19 |
+
from einops import rearrange
|
| 20 |
+
from torch.nn import TransformerEncoderLayer, MultiheadAttention, Linear, LayerNorm
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def create_sin_embedding(
|
| 24 |
+
length: int, dim: int, shift: int = 0, device="cpu", max_period=10000
|
| 25 |
+
):
|
| 26 |
+
# We aim for TBC format
|
| 27 |
+
assert dim % 2 == 0
|
| 28 |
+
pos = shift + torch.arange(length, device=device).view(-1, 1, 1)
|
| 29 |
+
half_dim = dim // 2
|
| 30 |
+
adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
|
| 31 |
+
phase = pos / (max_period ** (adim / (half_dim - 1)))
|
| 32 |
+
return torch.cat(
|
| 33 |
+
[
|
| 34 |
+
torch.cos(phase),
|
| 35 |
+
torch.sin(phase),
|
| 36 |
+
],
|
| 37 |
+
dim=-1,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def create_2d_sin_embedding(d_model, height, width, device="cpu", max_period=10000):
|
| 42 |
+
"""
|
| 43 |
+
:param d_model: dimension of the model
|
| 44 |
+
:param height: height of the positions
|
| 45 |
+
:param width: width of the positions
|
| 46 |
+
:return: d_model*height*width position matrix
|
| 47 |
+
"""
|
| 48 |
+
if d_model % 4 != 0:
|
| 49 |
+
raise ValueError(
|
| 50 |
+
"Cannot use sin/cos positional encoding with "
|
| 51 |
+
"odd dimension (got dim={:d})".format(d_model)
|
| 52 |
+
)
|
| 53 |
+
pe = torch.zeros(d_model, height, width)
|
| 54 |
+
# Each dimension use half of d_model
|
| 55 |
+
d_model = int(d_model / 2)
|
| 56 |
+
div_term = torch.exp(
|
| 57 |
+
torch.arange(0.0, d_model, 2) * -(math.log(max_period) / d_model)
|
| 58 |
+
)
|
| 59 |
+
pos_w = torch.arange(0.0, width).unsqueeze(1)
|
| 60 |
+
pos_h = torch.arange(0.0, height).unsqueeze(1)
|
| 61 |
+
pe[0:d_model:2, :, :] = (
|
| 62 |
+
torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
|
| 63 |
+
)
|
| 64 |
+
pe[1:d_model:2, :, :] = (
|
| 65 |
+
torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
|
| 66 |
+
)
|
| 67 |
+
pe[d_model::2, :, :] = (
|
| 68 |
+
torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
|
| 69 |
+
)
|
| 70 |
+
pe[d_model + 1:: 2, :, :] = (
|
| 71 |
+
torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
return pe[None, :].to(device)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def create_sin_embedding_cape(
|
| 78 |
+
length: int,
|
| 79 |
+
dim: int,
|
| 80 |
+
batch_size: int,
|
| 81 |
+
mean_normalize: bool,
|
| 82 |
+
augment: bool, # True during training
|
| 83 |
+
max_global_shift: float = 0.0, # delta max
|
| 84 |
+
max_local_shift: float = 0.0, # epsilon max
|
| 85 |
+
max_scale: float = 1.0,
|
| 86 |
+
device: str = "cpu",
|
| 87 |
+
max_period: float = 10000.0,
|
| 88 |
+
):
|
| 89 |
+
# We aim for TBC format
|
| 90 |
+
assert dim % 2 == 0
|
| 91 |
+
pos = 1.0 * torch.arange(length).view(-1, 1, 1) # (length, 1, 1)
|
| 92 |
+
pos = pos.repeat(1, batch_size, 1) # (length, batch_size, 1)
|
| 93 |
+
if mean_normalize:
|
| 94 |
+
pos -= torch.nanmean(pos, dim=0, keepdim=True)
|
| 95 |
+
|
| 96 |
+
if augment:
|
| 97 |
+
delta = np.random.uniform(
|
| 98 |
+
-max_global_shift, +max_global_shift, size=[1, batch_size, 1]
|
| 99 |
+
)
|
| 100 |
+
delta_local = np.random.uniform(
|
| 101 |
+
-max_local_shift, +max_local_shift, size=[length, batch_size, 1]
|
| 102 |
+
)
|
| 103 |
+
log_lambdas = np.random.uniform(
|
| 104 |
+
-np.log(max_scale), +np.log(max_scale), size=[1, batch_size, 1]
|
| 105 |
+
)
|
| 106 |
+
pos = (pos + delta + delta_local) * np.exp(log_lambdas)
|
| 107 |
+
|
| 108 |
+
pos = pos.to(device)
|
| 109 |
+
|
| 110 |
+
half_dim = dim // 2
|
| 111 |
+
adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
|
| 112 |
+
phase = pos / (max_period ** (adim / (half_dim - 1)))
|
| 113 |
+
return torch.cat(
|
| 114 |
+
[
|
| 115 |
+
torch.cos(phase),
|
| 116 |
+
torch.sin(phase),
|
| 117 |
+
],
|
| 118 |
+
dim=-1,
|
| 119 |
+
).float()
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def get_causal_mask(length):
|
| 123 |
+
pos = torch.arange(length)
|
| 124 |
+
return pos > pos[:, None]
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def get_elementary_mask(
|
| 128 |
+
t1,
|
| 129 |
+
t2,
|
| 130 |
+
mask_type,
|
| 131 |
+
sparse_attn_window,
|
| 132 |
+
global_window,
|
| 133 |
+
mask_random_seed,
|
| 134 |
+
sparsity,
|
| 135 |
+
device,
|
| 136 |
+
):
|
| 137 |
+
"""
|
| 138 |
+
When the input of the Decoder has length T1 and the output T2
|
| 139 |
+
The mask matrix has shape (T2, T1)
|
| 140 |
+
"""
|
| 141 |
+
assert mask_type in ["diag", "jmask", "random", "global"]
|
| 142 |
+
|
| 143 |
+
if mask_type == "global":
|
| 144 |
+
mask = torch.zeros(t2, t1, dtype=torch.bool)
|
| 145 |
+
mask[:, :global_window] = True
|
| 146 |
+
line_window = int(global_window * t2 / t1)
|
| 147 |
+
mask[:line_window, :] = True
|
| 148 |
+
|
| 149 |
+
if mask_type == "diag":
|
| 150 |
+
|
| 151 |
+
mask = torch.zeros(t2, t1, dtype=torch.bool)
|
| 152 |
+
rows = torch.arange(t2)[:, None]
|
| 153 |
+
cols = (
|
| 154 |
+
(t1 / t2 * rows + torch.arange(-sparse_attn_window, sparse_attn_window + 1))
|
| 155 |
+
.long()
|
| 156 |
+
.clamp(0, t1 - 1)
|
| 157 |
+
)
|
| 158 |
+
mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols))
|
| 159 |
+
|
| 160 |
+
elif mask_type == "jmask":
|
| 161 |
+
mask = torch.zeros(t2 + 2, t1 + 2, dtype=torch.bool)
|
| 162 |
+
rows = torch.arange(t2 + 2)[:, None]
|
| 163 |
+
t = torch.arange(0, int((2 * t1) ** 0.5 + 1))
|
| 164 |
+
t = (t * (t + 1) / 2).int()
|
| 165 |
+
t = torch.cat([-t.flip(0)[:-1], t])
|
| 166 |
+
cols = (t1 / t2 * rows + t).long().clamp(0, t1 + 1)
|
| 167 |
+
mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols))
|
| 168 |
+
mask = mask[1:-1, 1:-1]
|
| 169 |
+
|
| 170 |
+
elif mask_type == "random":
|
| 171 |
+
gene = torch.Generator(device=device)
|
| 172 |
+
gene.manual_seed(mask_random_seed)
|
| 173 |
+
mask = (
|
| 174 |
+
torch.rand(t1 * t2, generator=gene, device=device).reshape(t2, t1)
|
| 175 |
+
> sparsity
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
mask = mask.to(device)
|
| 179 |
+
return mask
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def get_mask(
|
| 183 |
+
t1,
|
| 184 |
+
t2,
|
| 185 |
+
mask_type,
|
| 186 |
+
sparse_attn_window,
|
| 187 |
+
global_window,
|
| 188 |
+
mask_random_seed,
|
| 189 |
+
sparsity,
|
| 190 |
+
device,
|
| 191 |
+
):
|
| 192 |
+
"""
|
| 193 |
+
Return a SparseCSRTensor mask that is a combination of elementary masks
|
| 194 |
+
mask_type can be a combination of multiple masks: for instance "diag_jmask_random"
|
| 195 |
+
"""
|
| 196 |
+
from xformers.sparse import SparseCSRTensor
|
| 197 |
+
# create a list
|
| 198 |
+
mask_types = mask_type.split("_")
|
| 199 |
+
|
| 200 |
+
all_masks = [
|
| 201 |
+
get_elementary_mask(
|
| 202 |
+
t1,
|
| 203 |
+
t2,
|
| 204 |
+
mask,
|
| 205 |
+
sparse_attn_window,
|
| 206 |
+
global_window,
|
| 207 |
+
mask_random_seed,
|
| 208 |
+
sparsity,
|
| 209 |
+
device,
|
| 210 |
+
)
|
| 211 |
+
for mask in mask_types
|
| 212 |
+
]
|
| 213 |
+
|
| 214 |
+
final_mask = torch.stack(all_masks).sum(axis=0) > 0
|
| 215 |
+
|
| 216 |
+
return SparseCSRTensor.from_dense(final_mask[None])
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class ScaledEmbedding(nn.Module):
|
| 220 |
+
def __init__(
|
| 221 |
+
self,
|
| 222 |
+
num_embeddings: int,
|
| 223 |
+
embedding_dim: int,
|
| 224 |
+
scale: float = 1.0,
|
| 225 |
+
boost: float = 3.0,
|
| 226 |
+
):
|
| 227 |
+
super().__init__()
|
| 228 |
+
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
|
| 229 |
+
self.embedding.weight.data *= scale / boost
|
| 230 |
+
self.boost = boost
|
| 231 |
+
|
| 232 |
+
@property
|
| 233 |
+
def weight(self):
|
| 234 |
+
return self.embedding.weight * self.boost
|
| 235 |
+
|
| 236 |
+
def forward(self, x):
|
| 237 |
+
return self.embedding(x) * self.boost
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
class LayerScale(nn.Module):
|
| 241 |
+
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
|
| 242 |
+
This rescales diagonaly residual outputs close to 0 initially, then learnt.
|
| 243 |
+
"""
|
| 244 |
+
|
| 245 |
+
def __init__(self, channels: int, init: float = 0, channel_last=False):
|
| 246 |
+
"""
|
| 247 |
+
channel_last = False corresponds to (B, C, T) tensors
|
| 248 |
+
channel_last = True corresponds to (T, B, C) tensors
|
| 249 |
+
"""
|
| 250 |
+
super().__init__()
|
| 251 |
+
self.channel_last = channel_last
|
| 252 |
+
self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
|
| 253 |
+
self.scale.data[:] = init
|
| 254 |
+
|
| 255 |
+
def forward(self, x):
|
| 256 |
+
if self.channel_last:
|
| 257 |
+
return self.scale * x
|
| 258 |
+
else:
|
| 259 |
+
return self.scale[:, None] * x
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
class MyGroupNorm(nn.GroupNorm):
|
| 263 |
+
def __init__(self, *args, **kwargs):
|
| 264 |
+
super().__init__(*args, **kwargs)
|
| 265 |
+
|
| 266 |
+
def forward(self, x):
|
| 267 |
+
"""
|
| 268 |
+
x: (B, T, C)
|
| 269 |
+
if num_groups=1: Normalisation on all T and C together for each B
|
| 270 |
+
"""
|
| 271 |
+
x = x.transpose(1, 2)
|
| 272 |
+
return super().forward(x).transpose(1, 2)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class MyTransformerEncoderLayer(TransformerEncoderLayer):
|
| 276 |
+
def __init__(
|
| 277 |
+
self,
|
| 278 |
+
d_model,
|
| 279 |
+
nhead,
|
| 280 |
+
dim_feedforward=2048,
|
| 281 |
+
dropout=0.1,
|
| 282 |
+
activation=F.relu,
|
| 283 |
+
group_norm=0,
|
| 284 |
+
norm_first=False,
|
| 285 |
+
norm_out=False,
|
| 286 |
+
layer_norm_eps=1e-5,
|
| 287 |
+
layer_scale=False,
|
| 288 |
+
init_values=1e-4,
|
| 289 |
+
device=None,
|
| 290 |
+
dtype=None,
|
| 291 |
+
sparse=False,
|
| 292 |
+
mask_type="diag",
|
| 293 |
+
mask_random_seed=42,
|
| 294 |
+
sparse_attn_window=500,
|
| 295 |
+
global_window=50,
|
| 296 |
+
auto_sparsity=False,
|
| 297 |
+
sparsity=0.95,
|
| 298 |
+
batch_first=False,
|
| 299 |
+
):
|
| 300 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 301 |
+
super().__init__(
|
| 302 |
+
d_model=d_model,
|
| 303 |
+
nhead=nhead,
|
| 304 |
+
dim_feedforward=dim_feedforward,
|
| 305 |
+
dropout=dropout,
|
| 306 |
+
activation=activation,
|
| 307 |
+
layer_norm_eps=layer_norm_eps,
|
| 308 |
+
batch_first=batch_first,
|
| 309 |
+
norm_first=norm_first,
|
| 310 |
+
device=device,
|
| 311 |
+
dtype=dtype,
|
| 312 |
+
)
|
| 313 |
+
self.sparse = sparse
|
| 314 |
+
self.auto_sparsity = auto_sparsity
|
| 315 |
+
if sparse:
|
| 316 |
+
if not auto_sparsity:
|
| 317 |
+
self.mask_type = mask_type
|
| 318 |
+
self.sparse_attn_window = sparse_attn_window
|
| 319 |
+
self.global_window = global_window
|
| 320 |
+
self.sparsity = sparsity
|
| 321 |
+
if group_norm:
|
| 322 |
+
self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
|
| 323 |
+
self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
|
| 324 |
+
|
| 325 |
+
self.norm_out = None
|
| 326 |
+
if self.norm_first & norm_out:
|
| 327 |
+
self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
|
| 328 |
+
self.gamma_1 = (
|
| 329 |
+
LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
|
| 330 |
+
)
|
| 331 |
+
self.gamma_2 = (
|
| 332 |
+
LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
if sparse:
|
| 336 |
+
self.self_attn = MultiheadAttention(
|
| 337 |
+
d_model, nhead, dropout=dropout, batch_first=batch_first,
|
| 338 |
+
auto_sparsity=sparsity if auto_sparsity else 0,
|
| 339 |
+
)
|
| 340 |
+
self.__setattr__("src_mask", torch.zeros(1, 1))
|
| 341 |
+
self.mask_random_seed = mask_random_seed
|
| 342 |
+
|
| 343 |
+
def forward(self, src, src_mask=None, src_key_padding_mask=None):
|
| 344 |
+
"""
|
| 345 |
+
if batch_first = False, src shape is (t, b, c)
|
| 346 |
+
the case where batch_first=True is not covered
|
| 347 |
+
"""
|
| 348 |
+
device = src.device
|
| 349 |
+
x = src
|
| 350 |
+
t, _, _ = x.shape
|
| 351 |
+
if self.sparse and not self.auto_sparsity:
|
| 352 |
+
assert src_mask is None
|
| 353 |
+
src_mask = self.src_mask
|
| 354 |
+
if src_mask.shape[-1] != t:
|
| 355 |
+
src_mask = get_mask(
|
| 356 |
+
t,
|
| 357 |
+
t,
|
| 358 |
+
self.mask_type,
|
| 359 |
+
self.sparse_attn_window,
|
| 360 |
+
self.global_window,
|
| 361 |
+
self.mask_random_seed,
|
| 362 |
+
self.sparsity,
|
| 363 |
+
device,
|
| 364 |
+
)
|
| 365 |
+
self.__setattr__("src_mask", src_mask)
|
| 366 |
+
|
| 367 |
+
if self.norm_first:
|
| 368 |
+
x = x + self.gamma_1(
|
| 369 |
+
self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
|
| 370 |
+
)
|
| 371 |
+
x = x + self.gamma_2(self._ff_block(self.norm2(x)))
|
| 372 |
+
|
| 373 |
+
if self.norm_out:
|
| 374 |
+
x = self.norm_out(x)
|
| 375 |
+
else:
|
| 376 |
+
x = self.norm1(
|
| 377 |
+
x + self.gamma_1(self._sa_block(x, src_mask, src_key_padding_mask))
|
| 378 |
+
)
|
| 379 |
+
x = self.norm2(x + self.gamma_2(self._ff_block(x)))
|
| 380 |
+
|
| 381 |
+
return x
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
class CrossTransformerEncoderLayer(nn.Module):
|
| 385 |
+
def __init__(
|
| 386 |
+
self,
|
| 387 |
+
d_model: int,
|
| 388 |
+
nhead: int,
|
| 389 |
+
dim_feedforward: int = 2048,
|
| 390 |
+
dropout: float = 0.1,
|
| 391 |
+
activation=F.relu,
|
| 392 |
+
layer_norm_eps: float = 1e-5,
|
| 393 |
+
layer_scale: bool = False,
|
| 394 |
+
init_values: float = 1e-4,
|
| 395 |
+
norm_first: bool = False,
|
| 396 |
+
group_norm: bool = False,
|
| 397 |
+
norm_out: bool = False,
|
| 398 |
+
sparse=False,
|
| 399 |
+
mask_type="diag",
|
| 400 |
+
mask_random_seed=42,
|
| 401 |
+
sparse_attn_window=500,
|
| 402 |
+
global_window=50,
|
| 403 |
+
sparsity=0.95,
|
| 404 |
+
auto_sparsity=None,
|
| 405 |
+
device=None,
|
| 406 |
+
dtype=None,
|
| 407 |
+
batch_first=False,
|
| 408 |
+
):
|
| 409 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 410 |
+
super().__init__()
|
| 411 |
+
|
| 412 |
+
self.sparse = sparse
|
| 413 |
+
self.auto_sparsity = auto_sparsity
|
| 414 |
+
if sparse:
|
| 415 |
+
if not auto_sparsity:
|
| 416 |
+
self.mask_type = mask_type
|
| 417 |
+
self.sparse_attn_window = sparse_attn_window
|
| 418 |
+
self.global_window = global_window
|
| 419 |
+
self.sparsity = sparsity
|
| 420 |
+
|
| 421 |
+
self.cross_attn: nn.Module
|
| 422 |
+
self.cross_attn = MultiheadAttention(
|
| 423 |
+
d_model, nhead, dropout=dropout, batch_first=batch_first)
|
| 424 |
+
# Implementation of Feedforward model
|
| 425 |
+
self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
|
| 426 |
+
self.dropout = nn.Dropout(dropout)
|
| 427 |
+
self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)
|
| 428 |
+
|
| 429 |
+
self.norm_first = norm_first
|
| 430 |
+
self.norm1: nn.Module
|
| 431 |
+
self.norm2: nn.Module
|
| 432 |
+
self.norm3: nn.Module
|
| 433 |
+
if group_norm:
|
| 434 |
+
self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
|
| 435 |
+
self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
|
| 436 |
+
self.norm3 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
|
| 437 |
+
else:
|
| 438 |
+
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
| 439 |
+
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
| 440 |
+
self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
| 441 |
+
|
| 442 |
+
self.norm_out = None
|
| 443 |
+
if self.norm_first & norm_out:
|
| 444 |
+
self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
|
| 445 |
+
|
| 446 |
+
self.gamma_1 = (
|
| 447 |
+
LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
|
| 448 |
+
)
|
| 449 |
+
self.gamma_2 = (
|
| 450 |
+
LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 454 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 455 |
+
|
| 456 |
+
# Legacy string support for activation function.
|
| 457 |
+
if isinstance(activation, str):
|
| 458 |
+
self.activation = self._get_activation_fn(activation)
|
| 459 |
+
else:
|
| 460 |
+
self.activation = activation
|
| 461 |
+
|
| 462 |
+
if sparse:
|
| 463 |
+
self.cross_attn = MultiheadAttention(
|
| 464 |
+
d_model, nhead, dropout=dropout, batch_first=batch_first,
|
| 465 |
+
auto_sparsity=sparsity if auto_sparsity else 0)
|
| 466 |
+
if not auto_sparsity:
|
| 467 |
+
self.__setattr__("mask", torch.zeros(1, 1))
|
| 468 |
+
self.mask_random_seed = mask_random_seed
|
| 469 |
+
|
| 470 |
+
def forward(self, q, k, mask=None):
|
| 471 |
+
"""
|
| 472 |
+
Args:
|
| 473 |
+
q: tensor of shape (T, B, C)
|
| 474 |
+
k: tensor of shape (S, B, C)
|
| 475 |
+
mask: tensor of shape (T, S)
|
| 476 |
+
|
| 477 |
+
"""
|
| 478 |
+
device = q.device
|
| 479 |
+
t, _, _ = q.shape
|
| 480 |
+
s, _, _ = k.shape
|
| 481 |
+
if self.sparse and not self.auto_sparsity:
|
| 482 |
+
assert mask is None
|
| 483 |
+
mask = self.mask
|
| 484 |
+
if mask.shape[-1] != s or mask.shape[-2] != t:
|
| 485 |
+
mask = get_mask(
|
| 486 |
+
s,
|
| 487 |
+
t,
|
| 488 |
+
self.mask_type,
|
| 489 |
+
self.sparse_attn_window,
|
| 490 |
+
self.global_window,
|
| 491 |
+
self.mask_random_seed,
|
| 492 |
+
self.sparsity,
|
| 493 |
+
device,
|
| 494 |
+
)
|
| 495 |
+
self.__setattr__("mask", mask)
|
| 496 |
+
|
| 497 |
+
if self.norm_first:
|
| 498 |
+
x = q + self.gamma_1(self._ca_block(self.norm1(q), self.norm2(k), mask))
|
| 499 |
+
x = x + self.gamma_2(self._ff_block(self.norm3(x)))
|
| 500 |
+
if self.norm_out:
|
| 501 |
+
x = self.norm_out(x)
|
| 502 |
+
else:
|
| 503 |
+
x = self.norm1(q + self.gamma_1(self._ca_block(q, k, mask)))
|
| 504 |
+
x = self.norm2(x + self.gamma_2(self._ff_block(x)))
|
| 505 |
+
|
| 506 |
+
return x
|
| 507 |
+
|
| 508 |
+
# self-attention block
|
| 509 |
+
def _ca_block(self, q, k, attn_mask=None):
|
| 510 |
+
x = self.cross_attn(q, k, k, attn_mask=attn_mask, need_weights=False)[0]
|
| 511 |
+
return self.dropout1(x)
|
| 512 |
+
|
| 513 |
+
# feed forward block
|
| 514 |
+
def _ff_block(self, x):
|
| 515 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
| 516 |
+
return self.dropout2(x)
|
| 517 |
+
|
| 518 |
+
def _get_activation_fn(self, activation):
|
| 519 |
+
if activation == "relu":
|
| 520 |
+
return F.relu
|
| 521 |
+
elif activation == "gelu":
|
| 522 |
+
return F.gelu
|
| 523 |
+
|
| 524 |
+
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
# ----------------- MULTI-BLOCKS MODELS: -----------------------
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
class CrossTransformerEncoder(nn.Module):
|
| 531 |
+
def __init__(
|
| 532 |
+
self,
|
| 533 |
+
dim: int,
|
| 534 |
+
emb: str = "sin",
|
| 535 |
+
hidden_scale: float = 4.0,
|
| 536 |
+
num_heads: int = 8,
|
| 537 |
+
num_layers: int = 6,
|
| 538 |
+
cross_first: bool = False,
|
| 539 |
+
dropout: float = 0.0,
|
| 540 |
+
max_positions: int = 1000,
|
| 541 |
+
norm_in: bool = True,
|
| 542 |
+
norm_in_group: bool = False,
|
| 543 |
+
group_norm: int = False,
|
| 544 |
+
norm_first: bool = False,
|
| 545 |
+
norm_out: bool = False,
|
| 546 |
+
max_period: float = 10000.0,
|
| 547 |
+
weight_decay: float = 0.0,
|
| 548 |
+
lr: tp.Optional[float] = None,
|
| 549 |
+
layer_scale: bool = False,
|
| 550 |
+
gelu: bool = True,
|
| 551 |
+
sin_random_shift: int = 0,
|
| 552 |
+
weight_pos_embed: float = 1.0,
|
| 553 |
+
cape_mean_normalize: bool = True,
|
| 554 |
+
cape_augment: bool = True,
|
| 555 |
+
cape_glob_loc_scale: list = None,
|
| 556 |
+
sparse_self_attn: bool = False,
|
| 557 |
+
sparse_cross_attn: bool = False,
|
| 558 |
+
mask_type: str = "diag",
|
| 559 |
+
mask_random_seed: int = 42,
|
| 560 |
+
sparse_attn_window: int = 500,
|
| 561 |
+
global_window: int = 50,
|
| 562 |
+
auto_sparsity: bool = False,
|
| 563 |
+
sparsity: float = 0.95,
|
| 564 |
+
):
|
| 565 |
+
super().__init__()
|
| 566 |
+
"""
|
| 567 |
+
"""
|
| 568 |
+
assert dim % num_heads == 0
|
| 569 |
+
|
| 570 |
+
hidden_dim = int(dim * hidden_scale)
|
| 571 |
+
|
| 572 |
+
self.num_layers = num_layers
|
| 573 |
+
# classic parity = 1 means that if idx%2 == 1 there is a
|
| 574 |
+
# classical encoder else there is a cross encoder
|
| 575 |
+
self.classic_parity = 1 if cross_first else 0
|
| 576 |
+
self.emb = emb
|
| 577 |
+
self.max_period = max_period
|
| 578 |
+
self.weight_decay = weight_decay
|
| 579 |
+
self.weight_pos_embed = weight_pos_embed
|
| 580 |
+
self.sin_random_shift = sin_random_shift
|
| 581 |
+
if emb == "cape":
|
| 582 |
+
self.cape_mean_normalize = cape_mean_normalize
|
| 583 |
+
self.cape_augment = cape_augment
|
| 584 |
+
if cape_glob_loc_scale is None:
|
| 585 |
+
self.cape_glob_loc_scale = [5000.0, 1.0, 1.4]
|
| 586 |
+
else:
|
| 587 |
+
self.cape_glob_loc_scale = cape_glob_loc_scale
|
| 588 |
+
if emb == "scaled":
|
| 589 |
+
self.position_embeddings = ScaledEmbedding(max_positions, dim, scale=0.2)
|
| 590 |
+
|
| 591 |
+
self.lr = lr
|
| 592 |
+
|
| 593 |
+
activation: tp.Any = F.gelu if gelu else F.relu
|
| 594 |
+
|
| 595 |
+
self.norm_in: nn.Module
|
| 596 |
+
self.norm_in_t: nn.Module
|
| 597 |
+
if norm_in:
|
| 598 |
+
self.norm_in = LayerNorm(dim)
|
| 599 |
+
self.norm_in_t = LayerNorm(dim)
|
| 600 |
+
elif norm_in_group:
|
| 601 |
+
self.norm_in = MyGroupNorm(int(norm_in_group), dim)
|
| 602 |
+
self.norm_in_t = MyGroupNorm(int(norm_in_group), dim)
|
| 603 |
+
else:
|
| 604 |
+
self.norm_in = nn.Identity()
|
| 605 |
+
self.norm_in_t = nn.Identity()
|
| 606 |
+
|
| 607 |
+
# spectrogram layers
|
| 608 |
+
self.layers = nn.ModuleList()
|
| 609 |
+
# temporal layers
|
| 610 |
+
self.layers_t = nn.ModuleList()
|
| 611 |
+
|
| 612 |
+
kwargs_common = {
|
| 613 |
+
"d_model": dim,
|
| 614 |
+
"nhead": num_heads,
|
| 615 |
+
"dim_feedforward": hidden_dim,
|
| 616 |
+
"dropout": dropout,
|
| 617 |
+
"activation": activation,
|
| 618 |
+
"group_norm": group_norm,
|
| 619 |
+
"norm_first": norm_first,
|
| 620 |
+
"norm_out": norm_out,
|
| 621 |
+
"layer_scale": layer_scale,
|
| 622 |
+
"mask_type": mask_type,
|
| 623 |
+
"mask_random_seed": mask_random_seed,
|
| 624 |
+
"sparse_attn_window": sparse_attn_window,
|
| 625 |
+
"global_window": global_window,
|
| 626 |
+
"sparsity": sparsity,
|
| 627 |
+
"auto_sparsity": auto_sparsity,
|
| 628 |
+
"batch_first": True,
|
| 629 |
+
}
|
| 630 |
+
|
| 631 |
+
kwargs_classic_encoder = dict(kwargs_common)
|
| 632 |
+
kwargs_classic_encoder.update({
|
| 633 |
+
"sparse": sparse_self_attn,
|
| 634 |
+
})
|
| 635 |
+
kwargs_cross_encoder = dict(kwargs_common)
|
| 636 |
+
kwargs_cross_encoder.update({
|
| 637 |
+
"sparse": sparse_cross_attn,
|
| 638 |
+
})
|
| 639 |
+
|
| 640 |
+
for idx in range(num_layers):
|
| 641 |
+
if idx % 2 == self.classic_parity:
|
| 642 |
+
|
| 643 |
+
self.layers.append(MyTransformerEncoderLayer(**kwargs_classic_encoder))
|
| 644 |
+
self.layers_t.append(
|
| 645 |
+
MyTransformerEncoderLayer(**kwargs_classic_encoder)
|
| 646 |
+
)
|
| 647 |
+
|
| 648 |
+
else:
|
| 649 |
+
self.layers.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder))
|
| 650 |
+
|
| 651 |
+
self.layers_t.append(
|
| 652 |
+
CrossTransformerEncoderLayer(**kwargs_cross_encoder)
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
def forward(self, x, xt):
|
| 656 |
+
_, c, fr, t1 = x.shape
|
| 657 |
+
pos_emb_2d = create_2d_sin_embedding(
|
| 658 |
+
c, fr, t1, x.device, self.max_period
|
| 659 |
+
) # (1, C, Fr, T1)
|
| 660 |
+
pos_emb_2d = rearrange(pos_emb_2d, "b c fr t1 -> b (t1 fr) c")
|
| 661 |
+
x = rearrange(x, "b c fr t1 -> b (t1 fr) c")
|
| 662 |
+
x = self.norm_in(x)
|
| 663 |
+
x = x + self.weight_pos_embed * pos_emb_2d
|
| 664 |
+
|
| 665 |
+
b, c, t2 = xt.shape
|
| 666 |
+
xt = rearrange(xt, "b c t2 -> b t2 c") # now T2, B, C
|
| 667 |
+
pos_emb = self._get_pos_embedding(t2, b, c, x.device)
|
| 668 |
+
pos_emb = rearrange(pos_emb, "t2 b c -> b t2 c")
|
| 669 |
+
xt = self.norm_in_t(xt)
|
| 670 |
+
xt = xt + self.weight_pos_embed * pos_emb
|
| 671 |
+
|
| 672 |
+
for idx in range(self.num_layers):
|
| 673 |
+
if idx % 2 == self.classic_parity:
|
| 674 |
+
x = self.layers[idx](x)
|
| 675 |
+
xt = self.layers_t[idx](xt)
|
| 676 |
+
else:
|
| 677 |
+
old_x = x
|
| 678 |
+
x = self.layers[idx](x, xt)
|
| 679 |
+
xt = self.layers_t[idx](xt, old_x)
|
| 680 |
+
|
| 681 |
+
x = rearrange(x, "b (t1 fr) c -> b c fr t1", t1=t1)
|
| 682 |
+
xt = rearrange(xt, "b t2 c -> b c t2")
|
| 683 |
+
return x, xt
|
| 684 |
+
|
| 685 |
+
def _get_pos_embedding(self, t, b, c, device):
|
| 686 |
+
if self.emb == "sin":
|
| 687 |
+
shift = random.randrange(self.sin_random_shift + 1)
|
| 688 |
+
pos_emb = create_sin_embedding(
|
| 689 |
+
t, c, shift=shift, device=device, max_period=self.max_period
|
| 690 |
+
)
|
| 691 |
+
elif self.emb == "cape":
|
| 692 |
+
if self.training:
|
| 693 |
+
pos_emb = create_sin_embedding_cape(
|
| 694 |
+
t,
|
| 695 |
+
c,
|
| 696 |
+
b,
|
| 697 |
+
device=device,
|
| 698 |
+
max_period=self.max_period,
|
| 699 |
+
mean_normalize=self.cape_mean_normalize,
|
| 700 |
+
augment=self.cape_augment,
|
| 701 |
+
max_global_shift=self.cape_glob_loc_scale[0],
|
| 702 |
+
max_local_shift=self.cape_glob_loc_scale[1],
|
| 703 |
+
max_scale=self.cape_glob_loc_scale[2],
|
| 704 |
+
)
|
| 705 |
+
else:
|
| 706 |
+
pos_emb = create_sin_embedding_cape(
|
| 707 |
+
t,
|
| 708 |
+
c,
|
| 709 |
+
b,
|
| 710 |
+
device=device,
|
| 711 |
+
max_period=self.max_period,
|
| 712 |
+
mean_normalize=self.cape_mean_normalize,
|
| 713 |
+
augment=False,
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
elif self.emb == "scaled":
|
| 717 |
+
pos = torch.arange(t, device=device)
|
| 718 |
+
pos_emb = self.position_embeddings(pos)[:, None]
|
| 719 |
+
|
| 720 |
+
return pos_emb
|
| 721 |
+
|
| 722 |
+
def make_optim_group(self):
|
| 723 |
+
group = {"params": list(self.parameters()), "weight_decay": self.weight_decay}
|
| 724 |
+
if self.lr is not None:
|
| 725 |
+
group["lr"] = self.lr
|
| 726 |
+
return group
|
| 727 |
+
|
| 728 |
+
|
| 729 |
+
def scaled_query_key_softmax(q, k, att_mask):
|
| 730 |
+
from xformers.ops import masked_matmul
|
| 731 |
+
q = q / (k.size(-1)) ** 0.5
|
| 732 |
+
att = masked_matmul(q, k.transpose(-2, -1), att_mask)
|
| 733 |
+
att = torch.nn.functional.softmax(att, -1)
|
| 734 |
+
return att
|
| 735 |
+
|
| 736 |
+
|
| 737 |
+
def scaled_dot_product_attention(q, k, v, att_mask, dropout):
|
| 738 |
+
att = scaled_query_key_softmax(q, k, att_mask=att_mask)
|
| 739 |
+
att = dropout(att)
|
| 740 |
+
y = att @ v
|
| 741 |
+
return y
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
def _compute_buckets(x, r):
|
| 745 |
+
qq = torch.einsum('btf,bfhi->bhti', x, r)
|
| 746 |
+
qq = torch.cat([qq, -qq], dim=-1)
|
| 747 |
+
buckets = qq.argmax(dim=-1)
|
| 748 |
+
|
| 749 |
+
return buckets.permute(0, 2, 1).byte().contiguous()
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
def dynamic_sparse_attention(query, key, value, sparsity, infer_sparsity=True, attn_bias=None):
|
| 753 |
+
# assert False, "The code for the custom sparse kernel is not ready for release yet."
|
| 754 |
+
from xformers.ops import find_locations, sparse_memory_efficient_attention
|
| 755 |
+
n_hashes = 32
|
| 756 |
+
proj_size = 4
|
| 757 |
+
query, key, value = [x.contiguous() for x in [query, key, value]]
|
| 758 |
+
with torch.no_grad():
|
| 759 |
+
r = torch.randn(1, query.shape[-1], n_hashes, proj_size // 2, device=query.device)
|
| 760 |
+
bucket_query = _compute_buckets(query, r)
|
| 761 |
+
bucket_key = _compute_buckets(key, r)
|
| 762 |
+
row_offsets, column_indices = find_locations(
|
| 763 |
+
bucket_query, bucket_key, sparsity, infer_sparsity)
|
| 764 |
+
return sparse_memory_efficient_attention(
|
| 765 |
+
query, key, value, row_offsets, column_indices, attn_bias)
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/demucs/models/utils.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
@File : utils.py
|
| 5 |
+
@Time : 2023/8/8 下午4:26
|
| 6 |
+
@Author : waytan
|
| 7 |
+
@Contact : waytan@tencent.com
|
| 8 |
+
@License : (C)Copyright 2023, Tencent
|
| 9 |
+
@Desc : utils
|
| 10 |
+
"""
|
| 11 |
+
from contextlib import contextmanager
|
| 12 |
+
import math
|
| 13 |
+
import os
|
| 14 |
+
import tempfile
|
| 15 |
+
import typing as tp
|
| 16 |
+
import json
|
| 17 |
+
import subprocess
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from torch.nn import functional as F
|
| 21 |
+
from torch.utils.data import Subset
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def unfold(a, kernel_size, stride):
|
| 25 |
+
"""Given input of size [*OT, T], output Tensor of size [*OT, F, K]
|
| 26 |
+
with K the kernel size, by extracting frames with the given stride.
|
| 27 |
+
|
| 28 |
+
This will pad the input so that `F = ceil(T / K)`.
|
| 29 |
+
|
| 30 |
+
see https://github.com/pytorch/pytorch/issues/60466
|
| 31 |
+
"""
|
| 32 |
+
*shape, length = a.shape
|
| 33 |
+
n_frames = math.ceil(length / stride)
|
| 34 |
+
tgt_length = (n_frames - 1) * stride + kernel_size
|
| 35 |
+
a = F.pad(a, (0, tgt_length - length))
|
| 36 |
+
strides = list(a.stride())
|
| 37 |
+
assert strides[-1] == 1, 'data should be contiguous'
|
| 38 |
+
strides = strides[:-1] + [stride, 1]
|
| 39 |
+
return a.as_strided([*shape, n_frames, kernel_size], strides)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def center_trim(tensor: torch.Tensor, reference: tp.Union[torch.Tensor, int]):
|
| 43 |
+
"""
|
| 44 |
+
Center trim `tensor` with respect to `reference`, along the last dimension.
|
| 45 |
+
`reference` can also be a number, representing the length to trim to.
|
| 46 |
+
If the size difference != 0 mod 2, the extra sample is removed on the right side.
|
| 47 |
+
"""
|
| 48 |
+
ref_size: int
|
| 49 |
+
if isinstance(reference, torch.Tensor):
|
| 50 |
+
ref_size = reference.size(-1)
|
| 51 |
+
else:
|
| 52 |
+
ref_size = reference
|
| 53 |
+
delta = tensor.size(-1) - ref_size
|
| 54 |
+
if delta < 0:
|
| 55 |
+
raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.")
|
| 56 |
+
if delta:
|
| 57 |
+
tensor = tensor[..., delta // 2:-(delta - delta // 2)]
|
| 58 |
+
return tensor
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def pull_metric(history: tp.List[dict], name: str):
|
| 62 |
+
out = []
|
| 63 |
+
for metrics in history:
|
| 64 |
+
metric = metrics
|
| 65 |
+
for part in name.split("."):
|
| 66 |
+
metric = metric[part]
|
| 67 |
+
out.append(metric)
|
| 68 |
+
return out
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def sizeof_fmt(num: float, suffix: str = 'B'):
|
| 72 |
+
"""
|
| 73 |
+
Given `num` bytes, return human readable size.
|
| 74 |
+
Taken from https://stackoverflow.com/a/1094933
|
| 75 |
+
"""
|
| 76 |
+
for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:
|
| 77 |
+
if abs(num) < 1024.0:
|
| 78 |
+
return "%3.1f%s%s" % (num, unit, suffix)
|
| 79 |
+
num /= 1024.0
|
| 80 |
+
return "%.1f%s%s" % (num, 'Yi', suffix)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@contextmanager
|
| 84 |
+
def temp_filenames(count: int, delete=True):
|
| 85 |
+
names = []
|
| 86 |
+
try:
|
| 87 |
+
for _ in range(count):
|
| 88 |
+
names.append(tempfile.NamedTemporaryFile(delete=False).name)
|
| 89 |
+
yield names
|
| 90 |
+
finally:
|
| 91 |
+
if delete:
|
| 92 |
+
for name in names:
|
| 93 |
+
os.unlink(name)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def random_subset(dataset, max_samples: int, seed: int = 42):
|
| 97 |
+
if max_samples >= len(dataset):
|
| 98 |
+
return dataset
|
| 99 |
+
|
| 100 |
+
generator = torch.Generator().manual_seed(seed)
|
| 101 |
+
perm = torch.randperm(len(dataset), generator=generator)
|
| 102 |
+
return Subset(dataset, perm[:max_samples].tolist())
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class DummyPoolExecutor:
|
| 106 |
+
class DummyResult:
|
| 107 |
+
def __init__(self, func, *args, **kwargs):
|
| 108 |
+
self.func = func
|
| 109 |
+
self.args = args
|
| 110 |
+
self.kwargs = kwargs
|
| 111 |
+
|
| 112 |
+
def result(self):
|
| 113 |
+
return self.func(*self.args, **self.kwargs)
|
| 114 |
+
|
| 115 |
+
def __init__(self, workers=0):
|
| 116 |
+
pass
|
| 117 |
+
|
| 118 |
+
def submit(self, func, *args, **kwargs):
|
| 119 |
+
return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
|
| 120 |
+
|
| 121 |
+
def __enter__(self):
|
| 122 |
+
return self
|
| 123 |
+
|
| 124 |
+
def __exit__(self, exc_type, exc_value, exc_tb):
|
| 125 |
+
return
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/demucs/run.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
@File : layers.py
|
| 5 |
+
@Time : 2024/4/22 下午2:40
|
| 6 |
+
@Author : waytan
|
| 7 |
+
@Contact : waytan@tencent.com
|
| 8 |
+
@License : (C)Copyright 2024, Tencent
|
| 9 |
+
"""
|
| 10 |
+
import os
|
| 11 |
+
import json
|
| 12 |
+
import time
|
| 13 |
+
import logging
|
| 14 |
+
import argparse
|
| 15 |
+
from datetime import datetime
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from models.apply import BagOfModels
|
| 21 |
+
from models.pretrained import get_model_from_yaml
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Separator:
|
| 25 |
+
def __init__(self, dm_model_path, dm_config_path, gpu_id=0) -> None:
|
| 26 |
+
if torch.cuda.is_available() and gpu_id < torch.cuda.device_count():
|
| 27 |
+
self.device = torch.device(f"cuda:{gpu_id}")
|
| 28 |
+
else:
|
| 29 |
+
self.device = torch.device("cpu")
|
| 30 |
+
self.demucs_model = self.init_demucs_model(dm_model_path, dm_config_path)
|
| 31 |
+
|
| 32 |
+
def init_demucs_model(self, model_path, config_path) -> BagOfModels:
|
| 33 |
+
model = get_model_from_yaml(config_path, model_path)
|
| 34 |
+
model.to(self.device)
|
| 35 |
+
model.eval()
|
| 36 |
+
return model
|
| 37 |
+
|
| 38 |
+
def run(self, audio_path, output_dir, ext=".flac"):
|
| 39 |
+
name, _ = os.path.splitext(os.path.split(audio_path)[-1])
|
| 40 |
+
output_paths = []
|
| 41 |
+
for stem in self.demucs_model.sources:
|
| 42 |
+
output_path = os.path.join(output_dir, f"{name}_{stem}{ext}")
|
| 43 |
+
if os.path.exists(output_path):
|
| 44 |
+
output_paths.append(output_path)
|
| 45 |
+
if len(output_paths) == 4:
|
| 46 |
+
drums_path, bass_path, other_path, vocal_path = output_paths
|
| 47 |
+
else:
|
| 48 |
+
drums_path, bass_path, other_path, vocal_path = self.demucs_model.separate(audio_path, output_dir, device=self.device)
|
| 49 |
+
data_dict = {
|
| 50 |
+
"vocal_path": vocal_path,
|
| 51 |
+
"bgm_path": [drums_path, bass_path, other_path]
|
| 52 |
+
}
|
| 53 |
+
return data_dict
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def json_io(input_json, output_json, model_dir, dst_dir, gpu_id=0):
|
| 57 |
+
current_datetime = datetime.now()
|
| 58 |
+
current_datetime_str = current_datetime.strftime('%Y-%m-%d-%H:%M')
|
| 59 |
+
logging.basicConfig(filename=os.path.join(dst_dir, f'logger-separate-{os.path.split(input_json)[1]}-{current_datetime_str}.log'), level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 60 |
+
|
| 61 |
+
sp = Separator(os.path.join(model_dir, "htdemucs.pth"), os.path.join(model_dir, "htdemucs.yaml"), gpu_id=gpu_id)
|
| 62 |
+
with open(input_json, "r") as fp:
|
| 63 |
+
lines = fp.readlines()
|
| 64 |
+
t1 = time.time()
|
| 65 |
+
success_num = 0
|
| 66 |
+
fail_num = 0
|
| 67 |
+
total_num = len(lines)
|
| 68 |
+
sep_items = []
|
| 69 |
+
for line in lines:
|
| 70 |
+
item = json.loads(line)
|
| 71 |
+
flac_file = item["path"]
|
| 72 |
+
try:
|
| 73 |
+
fix_data = sp.run(flac_file, dst_dir)
|
| 74 |
+
except Exception as e:
|
| 75 |
+
fail_num += 1
|
| 76 |
+
logging.error(f"process-{success_num + fail_num}/{total_num}|success-{success_num}|fail-{fail_num}|{item['idx']} process fail for {str(e)}")
|
| 77 |
+
continue
|
| 78 |
+
|
| 79 |
+
item["vocal_path"] = fix_data["vocal_path"]
|
| 80 |
+
item["bgm_path"] = fix_data["bgm_path"]
|
| 81 |
+
sep_items.append(item)
|
| 82 |
+
success_num += 1
|
| 83 |
+
logging.debug(f"process-{success_num + fail_num}/{total_num}|success-{success_num}|fail-{fail_num}|{item['idx']} process success")
|
| 84 |
+
|
| 85 |
+
with open(output_json, "w", encoding='utf-8') as fw:
|
| 86 |
+
for item in sep_items:
|
| 87 |
+
fw.write(json.dumps(item, ensure_ascii=False) + "\n")
|
| 88 |
+
|
| 89 |
+
t2 = time.time()
|
| 90 |
+
logging.debug(f"total cost {round(t2-t1, 3)}s")
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
if __name__ == "__main__":
|
| 94 |
+
parser = argparse.ArgumentParser(description='')
|
| 95 |
+
parser.add_argument("-m", dest="model_dir")
|
| 96 |
+
parser.add_argument("-d", dest="dst_dir")
|
| 97 |
+
parser.add_argument("-j", dest="input_json")
|
| 98 |
+
parser.add_argument("-o", dest="output_json")
|
| 99 |
+
parser.add_argument("-gid", dest="gpu_id", default=0, type=int)
|
| 100 |
+
args = parser.parse_args()
|
| 101 |
+
|
| 102 |
+
if not args.dst_dir:
|
| 103 |
+
dst_dir = os.path.join(os.getcwd(), "separate_result")
|
| 104 |
+
os.makedirs(dst_dir, exist_ok=True)
|
| 105 |
+
else:
|
| 106 |
+
dst_dir = os.path.join(args.dst_dir, "separate_result")
|
| 107 |
+
os.makedirs(dst_dir, exist_ok=True)
|
| 108 |
+
|
| 109 |
+
json_io(args.input_json, args.output_json, args.model_dir, dst_dir, gpu_id=args.gpu_id)
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/hub/version.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
1
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/config/model_1920.json
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "autoencoder",
|
| 3 |
+
"sample_size": 403200,
|
| 4 |
+
"sample_rate": 48000,
|
| 5 |
+
"audio_channels": 2,
|
| 6 |
+
"model": {
|
| 7 |
+
"encoder": {
|
| 8 |
+
"type": "oobleck",
|
| 9 |
+
"config": {
|
| 10 |
+
"in_channels": 2,
|
| 11 |
+
"channels": 128,
|
| 12 |
+
"c_mults": [1, 2, 4, 8, 16],
|
| 13 |
+
"strides": [2, 4, 4, 6, 10],
|
| 14 |
+
"latent_dim": 128,
|
| 15 |
+
"use_snake": true
|
| 16 |
+
}
|
| 17 |
+
},
|
| 18 |
+
"decoder": {
|
| 19 |
+
"type": "oobleck",
|
| 20 |
+
"config": {
|
| 21 |
+
"out_channels": 2,
|
| 22 |
+
"channels": 128,
|
| 23 |
+
"c_mults": [1, 2, 4, 8, 16],
|
| 24 |
+
"strides": [2, 4, 4, 6, 10],
|
| 25 |
+
"latent_dim": 64,
|
| 26 |
+
"use_snake": true,
|
| 27 |
+
"final_tanh": false
|
| 28 |
+
}
|
| 29 |
+
},
|
| 30 |
+
"bottleneck": {
|
| 31 |
+
"type": "vae"
|
| 32 |
+
},
|
| 33 |
+
"latent_dim": 64,
|
| 34 |
+
"downsampling_ratio": 1920,
|
| 35 |
+
"io_channels": 2
|
| 36 |
+
},
|
| 37 |
+
"training": {
|
| 38 |
+
"learning_rate": 1.5e-4,
|
| 39 |
+
"warmup_steps": 0,
|
| 40 |
+
"use_ema": true,
|
| 41 |
+
"optimizer_configs": {
|
| 42 |
+
"autoencoder": {
|
| 43 |
+
"optimizer": {
|
| 44 |
+
"type": "AdamW",
|
| 45 |
+
"config": {
|
| 46 |
+
"betas": [0.8, 0.99],
|
| 47 |
+
"lr": 1.5e-4,
|
| 48 |
+
"weight_decay": 1e-3
|
| 49 |
+
}
|
| 50 |
+
},
|
| 51 |
+
"scheduler": {
|
| 52 |
+
"type": "InverseLR",
|
| 53 |
+
"config": {
|
| 54 |
+
"inv_gamma": 200000,
|
| 55 |
+
"power": 0.5,
|
| 56 |
+
"warmup": 0.999
|
| 57 |
+
}
|
| 58 |
+
}
|
| 59 |
+
},
|
| 60 |
+
"discriminator": {
|
| 61 |
+
"optimizer": {
|
| 62 |
+
"type": "AdamW",
|
| 63 |
+
"config": {
|
| 64 |
+
"betas": [0.8, 0.99],
|
| 65 |
+
"lr": 3e-4,
|
| 66 |
+
"weight_decay": 1e-3
|
| 67 |
+
}
|
| 68 |
+
},
|
| 69 |
+
"scheduler": {
|
| 70 |
+
"type": "InverseLR",
|
| 71 |
+
"config": {
|
| 72 |
+
"inv_gamma": 200000,
|
| 73 |
+
"power": 0.5,
|
| 74 |
+
"warmup": 0.999
|
| 75 |
+
}
|
| 76 |
+
}
|
| 77 |
+
}
|
| 78 |
+
},
|
| 79 |
+
"loss_configs": {
|
| 80 |
+
"discriminator": {
|
| 81 |
+
"type": "encodec",
|
| 82 |
+
"config": {
|
| 83 |
+
"filters": 64,
|
| 84 |
+
"n_ffts": [2048, 1024, 512, 256, 128],
|
| 85 |
+
"hop_lengths": [512, 256, 128, 64, 32],
|
| 86 |
+
"win_lengths": [2048, 1024, 512, 256, 128]
|
| 87 |
+
},
|
| 88 |
+
"weights": {
|
| 89 |
+
"adversarial": 0.1,
|
| 90 |
+
"feature_matching": 5.0
|
| 91 |
+
}
|
| 92 |
+
},
|
| 93 |
+
"spectral": {
|
| 94 |
+
"type": "mrstft",
|
| 95 |
+
"config": {
|
| 96 |
+
"fft_sizes": [2048, 1024, 512, 256, 128, 64, 32],
|
| 97 |
+
"hop_sizes": [512, 256, 128, 64, 32, 16, 8],
|
| 98 |
+
"win_lengths": [2048, 1024, 512, 256, 128, 64, 32],
|
| 99 |
+
"perceptual_weighting": true
|
| 100 |
+
},
|
| 101 |
+
"weights": {
|
| 102 |
+
"mrstft": 1.0
|
| 103 |
+
}
|
| 104 |
+
},
|
| 105 |
+
"time": {
|
| 106 |
+
"type": "l1",
|
| 107 |
+
"weights": {
|
| 108 |
+
"l1": 0.0
|
| 109 |
+
}
|
| 110 |
+
},
|
| 111 |
+
"bottleneck": {
|
| 112 |
+
"type": "kl",
|
| 113 |
+
"weights": {
|
| 114 |
+
"kl": 1e-4
|
| 115 |
+
}
|
| 116 |
+
}
|
| 117 |
+
},
|
| 118 |
+
"demo": {
|
| 119 |
+
"demo_every": 2000
|
| 120 |
+
}
|
| 121 |
+
}
|
| 122 |
+
}
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/config/model_config.json
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "autoencoder",
|
| 3 |
+
"sample_size": 409600,
|
| 4 |
+
"sample_rate": 44100,
|
| 5 |
+
"audio_channels": 2,
|
| 6 |
+
"model": {
|
| 7 |
+
"encoder": {
|
| 8 |
+
"type": "oobleck",
|
| 9 |
+
"config": {
|
| 10 |
+
"in_channels": 2,
|
| 11 |
+
"channels": 128,
|
| 12 |
+
"c_mults": [1, 2, 4, 8, 16],
|
| 13 |
+
"strides": [2, 4, 4, 8, 8],
|
| 14 |
+
"latent_dim": 128,
|
| 15 |
+
"use_snake": true
|
| 16 |
+
}
|
| 17 |
+
},
|
| 18 |
+
"decoder": {
|
| 19 |
+
"type": "oobleck",
|
| 20 |
+
"config": {
|
| 21 |
+
"out_channels": 2,
|
| 22 |
+
"channels": 128,
|
| 23 |
+
"c_mults": [1, 2, 4, 8, 16],
|
| 24 |
+
"strides": [2, 4, 4, 8, 8],
|
| 25 |
+
"latent_dim": 64,
|
| 26 |
+
"use_snake": true,
|
| 27 |
+
"final_tanh": false
|
| 28 |
+
}
|
| 29 |
+
},
|
| 30 |
+
"bottleneck": {
|
| 31 |
+
"type": "vae"
|
| 32 |
+
},
|
| 33 |
+
"latent_dim": 64,
|
| 34 |
+
"downsampling_ratio": 2048,
|
| 35 |
+
"io_channels": 2
|
| 36 |
+
},
|
| 37 |
+
"training": {
|
| 38 |
+
"learning_rate": 1.5e-4,
|
| 39 |
+
"warmup_steps": 0,
|
| 40 |
+
"use_ema": true,
|
| 41 |
+
"optimizer_configs": {
|
| 42 |
+
"autoencoder": {
|
| 43 |
+
"optimizer": {
|
| 44 |
+
"type": "AdamW",
|
| 45 |
+
"config": {
|
| 46 |
+
"betas": [0.8, 0.99],
|
| 47 |
+
"lr": 1.5e-4,
|
| 48 |
+
"weight_decay": 1e-3
|
| 49 |
+
}
|
| 50 |
+
},
|
| 51 |
+
"scheduler": {
|
| 52 |
+
"type": "InverseLR",
|
| 53 |
+
"config": {
|
| 54 |
+
"inv_gamma": 200000,
|
| 55 |
+
"power": 0.5,
|
| 56 |
+
"warmup": 0.999
|
| 57 |
+
}
|
| 58 |
+
}
|
| 59 |
+
},
|
| 60 |
+
"discriminator": {
|
| 61 |
+
"optimizer": {
|
| 62 |
+
"type": "AdamW",
|
| 63 |
+
"config": {
|
| 64 |
+
"betas": [0.8, 0.99],
|
| 65 |
+
"lr": 3e-4,
|
| 66 |
+
"weight_decay": 1e-3
|
| 67 |
+
}
|
| 68 |
+
},
|
| 69 |
+
"scheduler": {
|
| 70 |
+
"type": "InverseLR",
|
| 71 |
+
"config": {
|
| 72 |
+
"inv_gamma": 200000,
|
| 73 |
+
"power": 0.5,
|
| 74 |
+
"warmup": 0.999
|
| 75 |
+
}
|
| 76 |
+
}
|
| 77 |
+
}
|
| 78 |
+
},
|
| 79 |
+
"loss_configs": {
|
| 80 |
+
"discriminator": {
|
| 81 |
+
"type": "encodec",
|
| 82 |
+
"config": {
|
| 83 |
+
"filters": 64,
|
| 84 |
+
"n_ffts": [2048, 1024, 512, 256, 128],
|
| 85 |
+
"hop_lengths": [512, 256, 128, 64, 32],
|
| 86 |
+
"win_lengths": [2048, 1024, 512, 256, 128]
|
| 87 |
+
},
|
| 88 |
+
"weights": {
|
| 89 |
+
"adversarial": 0.1,
|
| 90 |
+
"feature_matching": 5.0
|
| 91 |
+
}
|
| 92 |
+
},
|
| 93 |
+
"spectral": {
|
| 94 |
+
"type": "mrstft",
|
| 95 |
+
"config": {
|
| 96 |
+
"fft_sizes": [2048, 1024, 512, 256, 128, 64, 32],
|
| 97 |
+
"hop_sizes": [512, 256, 128, 64, 32, 16, 8],
|
| 98 |
+
"win_lengths": [2048, 1024, 512, 256, 128, 64, 32],
|
| 99 |
+
"perceptual_weighting": true
|
| 100 |
+
},
|
| 101 |
+
"weights": {
|
| 102 |
+
"mrstft": 1.0
|
| 103 |
+
}
|
| 104 |
+
},
|
| 105 |
+
"time": {
|
| 106 |
+
"type": "l1",
|
| 107 |
+
"weights": {
|
| 108 |
+
"l1": 0.0
|
| 109 |
+
}
|
| 110 |
+
},
|
| 111 |
+
"bottleneck": {
|
| 112 |
+
"type": "kl",
|
| 113 |
+
"weights": {
|
| 114 |
+
"kl": 1e-4
|
| 115 |
+
}
|
| 116 |
+
}
|
| 117 |
+
},
|
| 118 |
+
"demo": {
|
| 119 |
+
"demo_every": 2000
|
| 120 |
+
}
|
| 121 |
+
}
|
| 122 |
+
}
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/docs/autoencoders.md
ADDED
|
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Autoencoders
|
| 2 |
+
At a high level, autoencoders are models constructed of two parts: an *encoder*, and a *decoder*.
|
| 3 |
+
|
| 4 |
+
The *encoder* takes in an sequence (such as mono or stereo audio) and outputs a compressed representation of that sequence as a d-channel "latent sequence", usually heavily downsampled by a constant factor.
|
| 5 |
+
|
| 6 |
+
The *decoder* takes in a d-channel latent sequence and upsamples it back to the original input sequence length, reversing the compression of the encoder.
|
| 7 |
+
|
| 8 |
+
Autoencoders are trained with a combination of reconstruction and adversarial losses in order to create a compact and invertible representation of raw audio data that allows downstream models to work in a data-compressed "latent space", with various desirable and controllable properties such as reduced sequence length, noise resistance, and discretization.
|
| 9 |
+
|
| 10 |
+
The autoencoder architectures defined in `stable-audio-tools` are largely fully-convolutional, which allows autoencoders trained on small lengths to be applied to arbitrary-length sequences. For example, an autoencoder trained on 1-second samples could be used to encode 45-second inputs to a latent diffusion model.
|
| 11 |
+
|
| 12 |
+
# Model configs
|
| 13 |
+
The model config file for an autoencoder should set the `model_type` to `autoencoder`, and the `model` object should have the following properties:
|
| 14 |
+
|
| 15 |
+
- `encoder`
|
| 16 |
+
- Configuration for the autoencoder's encoder half
|
| 17 |
+
- `decoder`
|
| 18 |
+
- Configuration for the autoencoder's decoder half
|
| 19 |
+
- `latent_dim`
|
| 20 |
+
- Latent dimension of the autoencoder, used by inference scripts and downstream models
|
| 21 |
+
- `downsampling_ratio`
|
| 22 |
+
- Downsampling ratio between the input sequence and the latent sequence, used by inference scripts and downstream models
|
| 23 |
+
- `io_channels`
|
| 24 |
+
- Number of input and output channels for the autoencoder when they're the same, used by inference scripts and downstream models
|
| 25 |
+
- `bottleneck`
|
| 26 |
+
- Configuration for the autoencoder's bottleneck
|
| 27 |
+
- Optional
|
| 28 |
+
- `pretransform`
|
| 29 |
+
- A pretransform definition for the autoencoder, such as wavelet decomposition or another autoencoder
|
| 30 |
+
- See [pretransforms.md](pretransforms.md) for more information
|
| 31 |
+
- Optional
|
| 32 |
+
- `in_channels`
|
| 33 |
+
- Specifies the number of input channels for the autoencoder, when it's different from `io_channels`, such as in a mono-to-stereo model
|
| 34 |
+
- Optional
|
| 35 |
+
- `out_channels`
|
| 36 |
+
- Specifies the number of output channels for the autoencoder, when it's different from `io_channels`
|
| 37 |
+
- Optional
|
| 38 |
+
|
| 39 |
+
# Training configs
|
| 40 |
+
The `training` config in the autoencoder model config file should have the following properties:
|
| 41 |
+
- `learning_rate`
|
| 42 |
+
- The learning rate to use during training
|
| 43 |
+
- `use_ema`
|
| 44 |
+
- If true, a copy of the model weights is maintained during training and updated as an exponential moving average of the trained model's weights.
|
| 45 |
+
- Optional. Default: `false`
|
| 46 |
+
- `warmup_steps`
|
| 47 |
+
- The number of training steps before turning on adversarial losses
|
| 48 |
+
- Optional. Default: `0`
|
| 49 |
+
- `encoder_freeze_on_warmup`
|
| 50 |
+
- If true, freezes the encoder after the warmup steps have completed, so adversarial training only affects the decoder.
|
| 51 |
+
- Optional. Default: `false`
|
| 52 |
+
- `loss_configs`
|
| 53 |
+
- Configurations for the loss function calculation
|
| 54 |
+
- Optional
|
| 55 |
+
- `optimizer_configs`
|
| 56 |
+
- Configuration for optimizers and schedulers
|
| 57 |
+
- Optional
|
| 58 |
+
|
| 59 |
+
## Loss configs
|
| 60 |
+
There are few different types of losses that are used for autoencoder training, including spectral losses, time-domain losses, adversarial losses, and bottleneck-specific losses.
|
| 61 |
+
|
| 62 |
+
Hyperparameters fo these losses as well as loss weighting factors can be configured in the `loss_configs` property in the `training` config.
|
| 63 |
+
|
| 64 |
+
### Spectral losses
|
| 65 |
+
Multi-resolution STFT losses are the main reconstruction loss used for our audio autoencoders. We use the [auraloss](https://github.com/csteinmetz1/auraloss/tree/main/auraloss) library for our spectral loss functions.
|
| 66 |
+
|
| 67 |
+
For mono autoencoders (`io_channels` == 1), we use the [MultiResolutionSTFTLoss](https://github.com/csteinmetz1/auraloss/blob/1576b0cd6e927abc002b23cf3bfc455b660f663c/auraloss/freq.py#L329) module.
|
| 68 |
+
|
| 69 |
+
For stereo autoencoders (`io_channels` == 2), we use the [SumAndDifferenceSTFTLoss](https://github.com/csteinmetz1/auraloss/blob/1576b0cd6e927abc002b23cf3bfc455b660f663c/auraloss/freq.py#L533) module.
|
| 70 |
+
|
| 71 |
+
#### Example config
|
| 72 |
+
```json
|
| 73 |
+
"spectral": {
|
| 74 |
+
"type": "mrstft",
|
| 75 |
+
"config": {
|
| 76 |
+
"fft_sizes": [2048, 1024, 512, 256, 128, 64, 32],
|
| 77 |
+
"hop_sizes": [512, 256, 128, 64, 32, 16, 8],
|
| 78 |
+
"win_lengths": [2048, 1024, 512, 256, 128, 64, 32],
|
| 79 |
+
"perceptual_weighting": true
|
| 80 |
+
},
|
| 81 |
+
"weights": {
|
| 82 |
+
"mrstft": 1.0
|
| 83 |
+
}
|
| 84 |
+
}
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
### Time-domain loss
|
| 88 |
+
We compute the L1 distance between the original audio and the decoded audio to provide a time-domain loss.
|
| 89 |
+
|
| 90 |
+
#### Example config
|
| 91 |
+
```json
|
| 92 |
+
"time": {
|
| 93 |
+
"type": "l1",
|
| 94 |
+
"weights": {
|
| 95 |
+
"l1": 0.1
|
| 96 |
+
}
|
| 97 |
+
}
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
### Adversarial losses
|
| 101 |
+
Adversarial losses bring in an ensemble of discriminator models to discriminate between real and fake audio, providing a signal to the autoencoder on perceptual discrepancies to fix.
|
| 102 |
+
|
| 103 |
+
We largely rely on the [multi-scale STFT discriminator](https://github.com/facebookresearch/encodec/blob/0e2d0aed29362c8e8f52494baf3e6f99056b214f/encodec/msstftd.py#L99) from the EnCodec repo
|
| 104 |
+
|
| 105 |
+
#### Example config
|
| 106 |
+
```json
|
| 107 |
+
"discriminator": {
|
| 108 |
+
"type": "encodec",
|
| 109 |
+
"config": {
|
| 110 |
+
"filters": 32,
|
| 111 |
+
"n_ffts": [2048, 1024, 512, 256, 128],
|
| 112 |
+
"hop_lengths": [512, 256, 128, 64, 32],
|
| 113 |
+
"win_lengths": [2048, 1024, 512, 256, 128]
|
| 114 |
+
},
|
| 115 |
+
"weights": {
|
| 116 |
+
"adversarial": 0.1,
|
| 117 |
+
"feature_matching": 5.0
|
| 118 |
+
}
|
| 119 |
+
}
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
## Demo config
|
| 123 |
+
The only property to set for autoencoder training demos is the `demo_every` property, determining the number of steps between demos.
|
| 124 |
+
|
| 125 |
+
### Example config
|
| 126 |
+
```json
|
| 127 |
+
"demo": {
|
| 128 |
+
"demo_every": 2000
|
| 129 |
+
}
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
# Encoder and decoder types
|
| 133 |
+
Encoders and decoders are defined separately in the model configuration, so encoders and decoders from different model architectures and libraries can be used interchangeably.
|
| 134 |
+
|
| 135 |
+
## Oobleck
|
| 136 |
+
Oobleck is Harmonai's in-house autoencoder architecture, implementing features from a variety of other autoencoder architectures.
|
| 137 |
+
|
| 138 |
+
### Example config
|
| 139 |
+
```json
|
| 140 |
+
"encoder": {
|
| 141 |
+
"type": "oobleck",
|
| 142 |
+
"config": {
|
| 143 |
+
"in_channels": 2,
|
| 144 |
+
"channels": 128,
|
| 145 |
+
"c_mults": [1, 2, 4, 8],
|
| 146 |
+
"strides": [2, 4, 8, 8],
|
| 147 |
+
"latent_dim": 128,
|
| 148 |
+
"use_snake": true
|
| 149 |
+
}
|
| 150 |
+
},
|
| 151 |
+
"decoder": {
|
| 152 |
+
"type": "oobleck",
|
| 153 |
+
"config": {
|
| 154 |
+
"out_channels": 2,
|
| 155 |
+
"channels": 128,
|
| 156 |
+
"c_mults": [1, 2, 4, 8],
|
| 157 |
+
"strides": [2, 4, 8, 8],
|
| 158 |
+
"latent_dim": 64,
|
| 159 |
+
"use_snake": true,
|
| 160 |
+
"use_nearest_upsample": false
|
| 161 |
+
}
|
| 162 |
+
}
|
| 163 |
+
```
|
| 164 |
+
|
| 165 |
+
## DAC
|
| 166 |
+
This is the Encoder and Decoder definitions from the `descript-audio-codec` repo. It's a simple fully-convolutional autoencoder with channels doubling every level. The encoder and decoder configs are passed directly into the constructors for the DAC [Encoder](https://github.com/descriptinc/descript-audio-codec/blob/c7cfc5d2647e26471dc394f95846a0830e7bec34/dac/model/dac.py#L64) and [Decoder](https://github.com/descriptinc/descript-audio-codec/blob/c7cfc5d2647e26471dc394f95846a0830e7bec34/dac/model/dac.py#L115).
|
| 167 |
+
|
| 168 |
+
**Note: This does not include the DAC quantizer, and does not load pre-trained DAC models, this is just the encoder and decoder definitions.**
|
| 169 |
+
|
| 170 |
+
### Example config
|
| 171 |
+
```json
|
| 172 |
+
"encoder": {
|
| 173 |
+
"type": "dac",
|
| 174 |
+
"config": {
|
| 175 |
+
"in_channels": 2,
|
| 176 |
+
"latent_dim": 32,
|
| 177 |
+
"d_model": 128,
|
| 178 |
+
"strides": [2, 4, 4, 4]
|
| 179 |
+
}
|
| 180 |
+
},
|
| 181 |
+
"decoder": {
|
| 182 |
+
"type": "dac",
|
| 183 |
+
"config": {
|
| 184 |
+
"out_channels": 2,
|
| 185 |
+
"latent_dim": 32,
|
| 186 |
+
"channels": 1536,
|
| 187 |
+
"rates": [4, 4, 4, 2]
|
| 188 |
+
}
|
| 189 |
+
}
|
| 190 |
+
```
|
| 191 |
+
|
| 192 |
+
## SEANet
|
| 193 |
+
This is the SEANetEncoder and SEANetDecoder definitions from Meta's EnCodec repo. This is the same encoder and decoder architecture used in the EnCodec models used in MusicGen, without the quantizer.
|
| 194 |
+
|
| 195 |
+
The encoder and decoder configs are passed directly into the [SEANetEncoder](https://github.com/facebookresearch/encodec/blob/0e2d0aed29362c8e8f52494baf3e6f99056b214f/encodec/modules/seanet.py#L66C12-L66C12) and [SEANetDecoder](https://github.com/facebookresearch/encodec/blob/0e2d0aed29362c8e8f52494baf3e6f99056b214f/encodec/modules/seanet.py#L147) classes directly, though we reverse the input order of the strides (ratios) in the encoder to make it consistent with the order in the decoder.
|
| 196 |
+
|
| 197 |
+
### Example config
|
| 198 |
+
```json
|
| 199 |
+
"encoder": {
|
| 200 |
+
"type": "seanet",
|
| 201 |
+
"config": {
|
| 202 |
+
"channels": 2,
|
| 203 |
+
"dimension": 128,
|
| 204 |
+
"n_filters": 64,
|
| 205 |
+
"ratios": [4, 4, 8, 8],
|
| 206 |
+
"n_residual_layers": 1,
|
| 207 |
+
"dilation_base": 2,
|
| 208 |
+
"lstm": 2,
|
| 209 |
+
"norm": "weight_norm"
|
| 210 |
+
}
|
| 211 |
+
},
|
| 212 |
+
"decoder": {
|
| 213 |
+
"type": "seanet",
|
| 214 |
+
"config": {
|
| 215 |
+
"channels": 2,
|
| 216 |
+
"dimension": 64,
|
| 217 |
+
"n_filters": 64,
|
| 218 |
+
"ratios": [4, 4, 8, 8],
|
| 219 |
+
"n_residual_layers": 1,
|
| 220 |
+
"dilation_base": 2,
|
| 221 |
+
"lstm": 2,
|
| 222 |
+
"norm": "weight_norm"
|
| 223 |
+
}
|
| 224 |
+
},
|
| 225 |
+
```
|
| 226 |
+
|
| 227 |
+
# Bottlenecks
|
| 228 |
+
In our terminology, the "bottleneck" of an autoencoder is a module placed between the encoder and decoder to enforce particular constraints on the latent space the encoder creates.
|
| 229 |
+
|
| 230 |
+
Bottlenecks have a similar interface to the autoencoder with `encode()` and `decode()` functions defined. Some bottlenecks return extra information in addition to the output latent series, such as quantized token indices, or additional losses to be considered during training.
|
| 231 |
+
|
| 232 |
+
To define a bottleneck for the autoencoder, you can provide the `bottleneck` object in the autoencoder's model configuration, with the following
|
| 233 |
+
|
| 234 |
+
## VAE
|
| 235 |
+
|
| 236 |
+
The Variational Autoencoder (VAE) bottleneck splits the encoder's output in half along the channel dimension, treats the two halves as the "mean" and "scale" parameters for VAE sampling, and performs the latent sampling. At a basic level, the "scale" values determine the amount of noise to add to the "mean" latents, which creates a noise-resistant latent space where more of the latent space decodes to perceptually "valid" audio. This is particularly helpful for diffusion models where the outpus of the diffusion sampling process leave a bit of Gaussian error noise.
|
| 237 |
+
|
| 238 |
+
**Note: For the VAE bottleneck to work, the output dimension of the encoder must be twice the size of the input dimension for the decoder.**
|
| 239 |
+
|
| 240 |
+
### Example config
|
| 241 |
+
```json
|
| 242 |
+
"bottleneck": {
|
| 243 |
+
"type": "vae"
|
| 244 |
+
}
|
| 245 |
+
```
|
| 246 |
+
|
| 247 |
+
### Extra info
|
| 248 |
+
The VAE bottleneck also returns a `kl` value in the encoder info. This is the [KL divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) between encoded/sampled latent space and a Gaussian distribution. By including this value as a loss value to optimize, we push our latent distribution closer to a normal distribution, potentially trading off reconstruction quality.
|
| 249 |
+
|
| 250 |
+
### Example loss config
|
| 251 |
+
```json
|
| 252 |
+
"bottleneck": {
|
| 253 |
+
"type": "kl",
|
| 254 |
+
"weights": {
|
| 255 |
+
"kl": 1e-4
|
| 256 |
+
}
|
| 257 |
+
}
|
| 258 |
+
```
|
| 259 |
+
|
| 260 |
+
## Tanh
|
| 261 |
+
This bottleneck applies the tanh function to the latent series, "soft-clipping" the latent values to be between -1 and 1. This is a quick and dirty way to enforce a limit on the variance of the latent space, but training these models can be unstable as it's seemingly easy for the latent space to saturate the values to -1 or 1 and never recover.
|
| 262 |
+
|
| 263 |
+
### Example config
|
| 264 |
+
```json
|
| 265 |
+
"bottleneck": {
|
| 266 |
+
"type": "tanh"
|
| 267 |
+
}
|
| 268 |
+
```
|
| 269 |
+
|
| 270 |
+
## Wasserstein
|
| 271 |
+
The Wasserstein bottleneck implements the WAE-MMD regularization method from the [Wasserstein Auto-Encoders](https://arxiv.org/abs/1711.01558) paper, calculating the Maximum Mean Discrepancy (MMD) between the latent space and a Gaussian distribution. Including this value as a loss value to optimize leads to a more Gaussian latent space, but does not require stochastic sampling as with a VAE, so the encoder is deterministic.
|
| 272 |
+
|
| 273 |
+
The Wasserstein bottleneck also exposes the `noise_augment_dim` property, which concatenates `noise_augment_dim` channels of Gaussian noise to the latent series before passing into the decoder. This adds some stochasticity to the latents which can be helpful for adversarial training, while keeping the encoder outputs deterministic.
|
| 274 |
+
|
| 275 |
+
**Note: The MMD calculation is very VRAM-intensive for longer sequence lengths, so training a Wasserstein autoencoder is best done on autoencoders with a decent downsampling factor, or on short sequence lengths. For inference, the MMD calculation is disabled.**
|
| 276 |
+
|
| 277 |
+
### Example config
|
| 278 |
+
```json
|
| 279 |
+
"bottleneck": {
|
| 280 |
+
"type": "wasserstein"
|
| 281 |
+
}
|
| 282 |
+
```
|
| 283 |
+
|
| 284 |
+
### Extra info
|
| 285 |
+
This bottleneck adds the `mmd` value to the encoder info, representing the Maximum Mean Discrepancy.
|
| 286 |
+
|
| 287 |
+
### Example loss config
|
| 288 |
+
```json
|
| 289 |
+
"bottleneck": {
|
| 290 |
+
"type": "mmd",
|
| 291 |
+
"weights": {
|
| 292 |
+
"mmd": 100
|
| 293 |
+
}
|
| 294 |
+
}
|
| 295 |
+
```
|
| 296 |
+
|
| 297 |
+
## L2 normalization (Spherical autoencoder)
|
| 298 |
+
The L2 normalization bottleneck normalizes the latents across the channel-dimension, projecting the latents to a d-dimensional hypersphere. This acts as a form of latent space normalization.
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
### Example config
|
| 302 |
+
```json
|
| 303 |
+
"bottleneck": {
|
| 304 |
+
"type": "l2_norm"
|
| 305 |
+
}
|
| 306 |
+
```
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
## RVQ
|
| 310 |
+
Residual vector quantization (RVQ) is currently the leading method for learning discrete neural audio codecs (tokenizers for audio). In vector quantization, each item in the latent sequence is individually "snapped" to the nearest vector in a discrete "codebook" of learned vectors. The index of the vector in the codebook can then be used as a token index for things like autoregressive transformers. Residual vector quantization improves the precision of normal vector quantization by adding additional codebooks. For a deeper dive into RVQ, check out [this blog post by Dr. Scott Hawley](https://drscotthawley.github.io/blog/posts/2023-06-12-RVQ.html).
|
| 311 |
+
|
| 312 |
+
This RVQ bottleneck uses [lucidrains' implementation](https://github.com/lucidrains/vector-quantize-pytorch/tree/master) from the `vector-quantize-pytorch` repo, which provides a lot of different quantizer options. The bottleneck config is passed through to the `ResidualVQ` [constructor](https://github.com/lucidrains/vector-quantize-pytorch/blob/0c6cea24ce68510b607f2c9997e766d9d55c085b/vector_quantize_pytorch/residual_vq.py#L26).
|
| 313 |
+
|
| 314 |
+
**Note: This RVQ implementation uses manual replacement of codebook vectors to reduce codebook collapse. This does not work with multi-GPU training as the random replacement is not synchronized across devices.**
|
| 315 |
+
|
| 316 |
+
### Example config
|
| 317 |
+
```json
|
| 318 |
+
"bottleneck": {
|
| 319 |
+
"type": "rvq",
|
| 320 |
+
"config": {
|
| 321 |
+
"num_quantizers": 4,
|
| 322 |
+
"codebook_size": 2048,
|
| 323 |
+
"dim": 1024,
|
| 324 |
+
"decay": 0.99,
|
| 325 |
+
}
|
| 326 |
+
}
|
| 327 |
+
```
|
| 328 |
+
|
| 329 |
+
## DAC RVQ
|
| 330 |
+
This is the residual vector quantization implementation from the `descript-audio-codec` repo. It differs from the above implementation in that it does not use manual replacements to improve codebook usage, but instead uses learnable linear layers to project the latents down to a lower-dimensional space before performing the individual quantization operations. This means it's compatible with distributed training.
|
| 331 |
+
|
| 332 |
+
The bottleneck config is passed directly into the `ResidualVectorQuantize` [constructor](https://github.com/descriptinc/descript-audio-codec/blob/c7cfc5d2647e26471dc394f95846a0830e7bec34/dac/nn/quantize.py#L97).
|
| 333 |
+
|
| 334 |
+
The `quantize_on_decode` property is also exposed, which moves the quantization process to the decoder. This should not be used during training, but is helpful when training latent diffusion models that use the quantization process as a way to remove error after the diffusion sampling process.
|
| 335 |
+
|
| 336 |
+
### Example config
|
| 337 |
+
```json
|
| 338 |
+
"bottleneck": {
|
| 339 |
+
"type": "dac_rvq",
|
| 340 |
+
"config": {
|
| 341 |
+
"input_dim": 64,
|
| 342 |
+
"n_codebooks": 9,
|
| 343 |
+
"codebook_dim": 32,
|
| 344 |
+
"codebook_size": 1024,
|
| 345 |
+
"quantizer_dropout": 0.5
|
| 346 |
+
}
|
| 347 |
+
}
|
| 348 |
+
```
|
| 349 |
+
|
| 350 |
+
### Extra info
|
| 351 |
+
The DAC RVQ bottleneck also adds the following properties to the `info` object:
|
| 352 |
+
- `pre_quantizer`
|
| 353 |
+
- The pre-quantization latent series, useful in combination with `quantize_on_decode` for training latent diffusion models.
|
| 354 |
+
- `vq/commitment_loss`
|
| 355 |
+
- Commitment loss for the quantizer
|
| 356 |
+
- `vq/codebook_loss`
|
| 357 |
+
- Codebook loss for the quantizer
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/docs/conditioning.md
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Conditioning
|
| 2 |
+
Conditioning, in the context of `stable-audio-tools` is the use of additional signals in a model that are used to add an additional level of control over the model's behavior. For example, we can condition the outputs of a diffusion model on a text prompt, creating a text-to-audio model.
|
| 3 |
+
|
| 4 |
+
# Conditioning types
|
| 5 |
+
There are a few different kinds of conditioning depending on the conditioning signal being used.
|
| 6 |
+
|
| 7 |
+
## Cross attention
|
| 8 |
+
Cross attention is a type of conditioning that allows us to find correlations between two sequences of potentially different lengths. For example, cross attention allows us to find correlations between a sequence of features from a text encoder and a sequence of high-level audio features.
|
| 9 |
+
|
| 10 |
+
Signals used for cross-attention conditioning should be of the shape `[batch, sequence, channels]`.
|
| 11 |
+
|
| 12 |
+
## Global conditioning
|
| 13 |
+
Global conditioning is the use of a single n-dimensional tensor to provide conditioning information that pertains to the whole sequence being conditioned. For example, this could be the single embedding output of a CLAP model, or a learned class embedding.
|
| 14 |
+
|
| 15 |
+
Signals used for global conditioning should be of the shape `[batch, channels]`.
|
| 16 |
+
|
| 17 |
+
## Prepend conditioning
|
| 18 |
+
Prepend conditioning involves prepending the conditioning tokens to the data tokens in the model, allowing for the information to be interpreted through the model's self-attention mechanism.
|
| 19 |
+
|
| 20 |
+
This kind of conditioning is currently only supported by Transformer-based models such as diffusion transformers.
|
| 21 |
+
|
| 22 |
+
Signals used for prepend conditioning should be of the shape `[batch, sequence, channels]`.
|
| 23 |
+
|
| 24 |
+
## Input concatenation
|
| 25 |
+
Input concatenation applies a spatial conditioning signal to the model that correlates in the sequence dimension with the model's input, and is of the same length. The conditioning signal will be concatenated with the model's input data along the channel dimension. This can be used for things like inpainting information, melody conditioning, or for creating a diffusion autoencoder.
|
| 26 |
+
|
| 27 |
+
Signals used for input concatenation conditioning should be of the shape `[batch, channels, sequence]` and must be the same length as the model's input.
|
| 28 |
+
|
| 29 |
+
# Conditioners and conditioning configs
|
| 30 |
+
`stable-audio-tools` uses Conditioner modules to translate human-readable metadata such as text prompts or a number of seconds into tensors that the model can take as input.
|
| 31 |
+
|
| 32 |
+
Each conditioner has a corresponding `id` that it expects to find in the conditioning dictionary provided during training or inference. Each conditioner takes in the relevant conditioning data and returns a tuple containing the corresponding tensor and a mask.
|
| 33 |
+
|
| 34 |
+
The ConditionedDiffusionModelWrapper manages the translation between the user-provided metadata dictionary (e.g. `{"prompt": "a beautiful song", "seconds_start": 22, "seconds_total": 193}`) and the dictionary of different conditioning types that the model uses (e.g. `{"cross_attn_cond": ...}`).
|
| 35 |
+
|
| 36 |
+
To apply conditioning to a model, you must provide a `conditioning` configuration in the model's config. At the moment, we only support conditioning diffusion models though the `diffusion_cond` model type.
|
| 37 |
+
|
| 38 |
+
The `conditioning` configuration should contain a `configs` array, which allows you to define multiple conditioning signals.
|
| 39 |
+
|
| 40 |
+
Each item in `configs` array should define the `id` for the corresponding metadata, the type of conditioner to be used, and the config for that conditioner.
|
| 41 |
+
|
| 42 |
+
The `cond_dim` property is used to enforce the same dimension on all conditioning inputs, however that can be overridden with an explicit `output_dim` property on any of the individual configs.
|
| 43 |
+
|
| 44 |
+
## Example config
|
| 45 |
+
```json
|
| 46 |
+
"conditioning": {
|
| 47 |
+
"configs": [
|
| 48 |
+
{
|
| 49 |
+
"id": "prompt",
|
| 50 |
+
"type": "t5",
|
| 51 |
+
"config": {
|
| 52 |
+
"t5_model_name": "t5-base",
|
| 53 |
+
"max_length": 77,
|
| 54 |
+
"project_out": true
|
| 55 |
+
}
|
| 56 |
+
}
|
| 57 |
+
],
|
| 58 |
+
"cond_dim": 768
|
| 59 |
+
}
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
# Conditioners
|
| 63 |
+
|
| 64 |
+
## Text encoders
|
| 65 |
+
|
| 66 |
+
### `t5`
|
| 67 |
+
This uses a frozen [T5](https://huggingface.co/docs/transformers/model_doc/t5) text encoder from the `transformers` library to encode text prompts into a sequence of text features.
|
| 68 |
+
|
| 69 |
+
The `t5_model_name` property determines which T5 model is loaded from the `transformers` library.
|
| 70 |
+
|
| 71 |
+
The `max_length` property determines the maximum number of tokens that the text encoder will take in, as well as the sequence length of the output text features.
|
| 72 |
+
|
| 73 |
+
If you set `enable_grad` to `true`, the T5 model will be un-frozen and saved with the model checkpoint, allowing you to fine-tune the T5 model.
|
| 74 |
+
|
| 75 |
+
T5 encodings are only compatible with cross attention conditioning.
|
| 76 |
+
|
| 77 |
+
#### Example config
|
| 78 |
+
```json
|
| 79 |
+
{
|
| 80 |
+
"id": "prompt",
|
| 81 |
+
"type": "t5",
|
| 82 |
+
"config": {
|
| 83 |
+
"t5_model_name": "t5-base",
|
| 84 |
+
"max_length": 77,
|
| 85 |
+
"project_out": true
|
| 86 |
+
}
|
| 87 |
+
}
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
### `clap_text`
|
| 91 |
+
This loads the text encoder from a [CLAP](https://github.com/LAION-AI/CLAP) model, which can provide either a sequence of text features, or a single multimodal text/audio embedding.
|
| 92 |
+
|
| 93 |
+
The CLAP model must be provided with a local file path, set in the `clap_ckpt_path` property,along with the correct `audio_model_type` and `enable_fusion` properties for the provided model.
|
| 94 |
+
|
| 95 |
+
If the `use_text_features` property is set to `true`, the conditioner output will be a sequence of text features, instead of a single multimodal embedding. This allows for more fine-grained text information to be used by the model, at the cost of losing the ability to prompt with CLAP audio embeddings.
|
| 96 |
+
|
| 97 |
+
By default, if `use_text_features` is true, the last layer of the CLAP text encoder's features are returned. You can return the text features of earlier layers by specifying the index of the layer to return in the `feature_layer_ix` property. For example, you can return the text features of the next-to-last layer of the CLAP model by setting `feature_layer_ix` to `-2`.
|
| 98 |
+
|
| 99 |
+
If you set `enable_grad` to `true`, the CLAP model will be un-frozen and saved with the model checkpoint, allowing you to fine-tune the CLAP model.
|
| 100 |
+
|
| 101 |
+
CLAP text embeddings are compatible with global conditioning and cross attention conditioning. If `use_text_features` is set to `true`, the features are not compatible with global conditioning.
|
| 102 |
+
|
| 103 |
+
#### Example config
|
| 104 |
+
```json
|
| 105 |
+
{
|
| 106 |
+
"id": "prompt",
|
| 107 |
+
"type": "clap_text",
|
| 108 |
+
"config": {
|
| 109 |
+
"clap_ckpt_path": "/path/to/clap/model.ckpt",
|
| 110 |
+
"audio_model_type": "HTSAT-base",
|
| 111 |
+
"enable_fusion": true,
|
| 112 |
+
"use_text_features": true,
|
| 113 |
+
"feature_layer_ix": -2
|
| 114 |
+
}
|
| 115 |
+
}
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
## Number encoders
|
| 119 |
+
|
| 120 |
+
### `int`
|
| 121 |
+
The IntConditioner takes in a list of integers in a given range, and returns a discrete learned embedding for each of those integers.
|
| 122 |
+
|
| 123 |
+
The `min_val` and `max_val` properties set the range of the embedding values. Input integers are clamped to this range.
|
| 124 |
+
|
| 125 |
+
This can be used for things like discrete timing embeddings, or learned class embeddings.
|
| 126 |
+
|
| 127 |
+
Int embeddings are compatible with global conditioning and cross attention conditioning.
|
| 128 |
+
|
| 129 |
+
#### Example config
|
| 130 |
+
```json
|
| 131 |
+
{
|
| 132 |
+
"id": "seconds_start",
|
| 133 |
+
"type": "int",
|
| 134 |
+
"config": {
|
| 135 |
+
"min_val": 0,
|
| 136 |
+
"max_val": 512
|
| 137 |
+
}
|
| 138 |
+
}
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
### `number`
|
| 142 |
+
The NumberConditioner takes in a a list of floats in a given range, and returns a continuous Fourier embedding of the provided floats.
|
| 143 |
+
|
| 144 |
+
The `min_val` and `max_val` properties set the range of the float values. This is the range used to normalize the input float values.
|
| 145 |
+
|
| 146 |
+
Number embeddings are compatible with global conditioning and cross attention conditioning.
|
| 147 |
+
|
| 148 |
+
#### Example config
|
| 149 |
+
```json
|
| 150 |
+
{
|
| 151 |
+
"id": "seconds_total",
|
| 152 |
+
"type": "number",
|
| 153 |
+
"config": {
|
| 154 |
+
"min_val": 0,
|
| 155 |
+
"max_val": 512
|
| 156 |
+
}
|
| 157 |
+
}
|
| 158 |
+
```
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/docs/datasets.md
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Datasets
|
| 2 |
+
`stable-audio-tools` supports loading data from local file storage, as well as loading audio files and JSON files in the [WebDataset](https://github.com/webdataset/webdataset/tree/main/webdataset) format from Amazon S3 buckets.
|
| 3 |
+
|
| 4 |
+
# Dataset configs
|
| 5 |
+
To specify the dataset used for training, you must provide a dataset config JSON file to `train.py`.
|
| 6 |
+
|
| 7 |
+
The dataset config consists of a `dataset_type` property specifying the type of data loader to use, a `datasets` array to provide multiple data sources, and a `random_crop` property, which decides if the cropped audio from the training samples is from a random place in the audio file, or always from the beginning.
|
| 8 |
+
|
| 9 |
+
## Local audio files
|
| 10 |
+
To use a local directory of audio samples, set the `dataset_type` property in your dataset config to `"audio_dir"`, and provide a list of objects to the `datasets` property including the `path` property, which should be the path to your directory of audio samples.
|
| 11 |
+
|
| 12 |
+
This will load all of the compatible audio files from the provided directory and all subdirectories.
|
| 13 |
+
|
| 14 |
+
### Example config
|
| 15 |
+
```json
|
| 16 |
+
{
|
| 17 |
+
"dataset_type": "audio_dir",
|
| 18 |
+
"datasets": [
|
| 19 |
+
{
|
| 20 |
+
"id": "my_audio",
|
| 21 |
+
"path": "/path/to/audio/dataset/"
|
| 22 |
+
}
|
| 23 |
+
],
|
| 24 |
+
"random_crop": true
|
| 25 |
+
}
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
## S3 WebDataset
|
| 29 |
+
To load audio files and related metadata from .tar files in the WebDataset format hosted in Amazon S3 buckets, you can set the `dataset_type` property to `s3`, and provide the `datasets` parameter with a list of objects containing the AWS S3 path to the shared S3 bucket prefix of the WebDataset .tar files. The S3 bucket will be searched recursively given the path, and assumes any .tar files found contain audio files and corresponding JSON files where the related files differ only in file extension (e.g. "000001.flac", "000001.json", "00002.flac", "00002.json", etc.)
|
| 30 |
+
|
| 31 |
+
### Example config
|
| 32 |
+
```json
|
| 33 |
+
{
|
| 34 |
+
"dataset_type": "s3",
|
| 35 |
+
"datasets": [
|
| 36 |
+
{
|
| 37 |
+
"id": "s3-test",
|
| 38 |
+
"s3_path": "s3://my-bucket/datasets/webdataset/audio/"
|
| 39 |
+
}
|
| 40 |
+
],
|
| 41 |
+
"random_crop": true
|
| 42 |
+
}
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
# Custom metadata
|
| 46 |
+
To customize the metadata provided to the conditioners during model training, you can provide a separate custom metadata module to the dataset config. This metadata module should be a Python file that must contain a function called `get_custom_metadata` that takes in two parameters, `info`, and `audio`, and returns a dictionary.
|
| 47 |
+
|
| 48 |
+
For local training, the `info` parameter will contain a few pieces of information about the loaded audio file, such as the path, and information about how the audio was cropped from the original training sample. For WebDataset datasets, it will also contain the metadata from the related JSON files.
|
| 49 |
+
|
| 50 |
+
The `audio` parameter contains the audio sample that will be passed to the model at training time. This lets you analyze the audio for extra properties that you can then pass in as extra conditioning signals.
|
| 51 |
+
|
| 52 |
+
The dictionary returned from the `get_custom_metadata` function will have its properties added to the `metadata` object used at training time. For more information on how conditioning works, please see the [Conditioning documentation](./conditioning.md)
|
| 53 |
+
|
| 54 |
+
## Example config and custom metadata module
|
| 55 |
+
```json
|
| 56 |
+
{
|
| 57 |
+
"dataset_type": "audio_dir",
|
| 58 |
+
"datasets": [
|
| 59 |
+
{
|
| 60 |
+
"id": "my_audio",
|
| 61 |
+
"path": "/path/to/audio/dataset/",
|
| 62 |
+
"custom_metadata_module": "/path/to/custom_metadata.py",
|
| 63 |
+
}
|
| 64 |
+
],
|
| 65 |
+
"random_crop": true
|
| 66 |
+
}
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
`custom_metadata.py`:
|
| 70 |
+
```py
|
| 71 |
+
def get_custom_metadata(info, audio):
|
| 72 |
+
|
| 73 |
+
# Pass in the relative path of the audio file as the prompt
|
| 74 |
+
return {"prompt": info["relpath"]}
|
| 75 |
+
```
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/docs/diffusion.md
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Diffusion
|
| 2 |
+
|
| 3 |
+
Diffusion models learn to denoise data
|
| 4 |
+
|
| 5 |
+
# Model configs
|
| 6 |
+
The model config file for a diffusion model should set the `model_type` to `diffusion_cond` if the model uses conditioning, or `diffusion_uncond` if it does not, and the `model` object should have the following properties:
|
| 7 |
+
|
| 8 |
+
- `diffusion`
|
| 9 |
+
- The configuration for the diffusion model itself. See below for more information on the diffusion model config
|
| 10 |
+
- `pretransform`
|
| 11 |
+
- The configuration of the diffusion model's [pretransform](pretransforms.md), such as an autoencoder for latent diffusion.
|
| 12 |
+
- Optional
|
| 13 |
+
- `conditioning`
|
| 14 |
+
- The configuration of the various [conditioning](conditioning.md) modules for the diffusion model
|
| 15 |
+
- Only required for `diffusion_cond`
|
| 16 |
+
- `io_channels`
|
| 17 |
+
- The base number of input/output channels for the diffusion model
|
| 18 |
+
- Used by inference scripts to determine the shape of the noise to generate for the diffusion model
|
| 19 |
+
|
| 20 |
+
# Diffusion configs
|
| 21 |
+
- `type`
|
| 22 |
+
- The underlying model type for the transformer
|
| 23 |
+
- For conditioned diffusion models, be one of `dit` ([Diffusion Transformer](#diffusion-transformers-dit)), `DAU1d` ([Dance Diffusion U-Net](#dance-diffusion-u-net)), or `adp_cfg_1d` ([audio-diffusion-pytorch U-Net](#audio-diffusion-pytorch-u-net-adp))
|
| 24 |
+
- Unconditioned diffusion models can also use `adp_1d`
|
| 25 |
+
- `cross_attention_cond_ids`
|
| 26 |
+
- Conditioner ids for conditioning information to be used as cross-attention input
|
| 27 |
+
- If multiple ids are specified, the conditioning tensors will be concatenated along the sequence dimension
|
| 28 |
+
- `global_cond_ids`
|
| 29 |
+
- Conditioner ids for conditioning information to be used as global conditioning input
|
| 30 |
+
- If multiple ids are specified, the conditioning tensors will be concatenated along the channel dimension
|
| 31 |
+
- `prepend_cond_ids`
|
| 32 |
+
- Conditioner ids for conditioning information to be prepended to the model input
|
| 33 |
+
- If multiple ids are specified, the conditioning tensors will be concatenated along the sequence dimension
|
| 34 |
+
- Only works with diffusion transformer models
|
| 35 |
+
- `input_concat_ids`
|
| 36 |
+
- Conditioner ids for conditioning information to be concatenated to the model input
|
| 37 |
+
- If multiple ids are specified, the conditioning tensors will be concatenated along the channel dimension
|
| 38 |
+
- If the conditioning tensors are not the same length as the model input, they will be interpolated along the sequence dimension to be the same length.
|
| 39 |
+
- The interpolation algorithm is model-dependent, but usually uses nearest-neighbor resampling.
|
| 40 |
+
- `config`
|
| 41 |
+
- The configuration for the model backbone itself
|
| 42 |
+
- Model-dependent
|
| 43 |
+
|
| 44 |
+
# Training configs
|
| 45 |
+
The `training` config in the diffusion model config file should have the following properties:
|
| 46 |
+
|
| 47 |
+
- `learning_rate`
|
| 48 |
+
- The learning rate to use during training
|
| 49 |
+
- Defaults to constant learning rate, can be overridden with `optimizer_configs`
|
| 50 |
+
- `use_ema`
|
| 51 |
+
- If true, a copy of the model weights is maintained during training and updated as an exponential moving average of the trained model's weights.
|
| 52 |
+
- Optional. Default: `true`
|
| 53 |
+
- `log_loss_info`
|
| 54 |
+
- If true, additional diffusion loss info will be gathered across all GPUs and displayed during training
|
| 55 |
+
- Optional. Default: `false`
|
| 56 |
+
- `loss_configs`
|
| 57 |
+
- Configurations for the loss function calculation
|
| 58 |
+
- Optional
|
| 59 |
+
- `optimizer_configs`
|
| 60 |
+
- Configuration for optimizers and schedulers
|
| 61 |
+
- Optional, overrides `learning_rate`
|
| 62 |
+
- `demo`
|
| 63 |
+
- Configuration for the demos during training, including conditioning information
|
| 64 |
+
|
| 65 |
+
## Example config
|
| 66 |
+
```json
|
| 67 |
+
"training": {
|
| 68 |
+
"use_ema": true,
|
| 69 |
+
"log_loss_info": false,
|
| 70 |
+
"optimizer_configs": {
|
| 71 |
+
"diffusion": {
|
| 72 |
+
"optimizer": {
|
| 73 |
+
"type": "AdamW",
|
| 74 |
+
"config": {
|
| 75 |
+
"lr": 5e-5,
|
| 76 |
+
"betas": [0.9, 0.999],
|
| 77 |
+
"weight_decay": 1e-3
|
| 78 |
+
}
|
| 79 |
+
},
|
| 80 |
+
"scheduler": {
|
| 81 |
+
"type": "InverseLR",
|
| 82 |
+
"config": {
|
| 83 |
+
"inv_gamma": 1000000,
|
| 84 |
+
"power": 0.5,
|
| 85 |
+
"warmup": 0.99
|
| 86 |
+
}
|
| 87 |
+
}
|
| 88 |
+
}
|
| 89 |
+
},
|
| 90 |
+
"demo": { ... }
|
| 91 |
+
}
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
# Demo configs
|
| 95 |
+
The `demo` config in the diffusion model training config should have the following properties:
|
| 96 |
+
- `demo_every`
|
| 97 |
+
- How many training steps between demos
|
| 98 |
+
- `demo_steps`
|
| 99 |
+
- Number of diffusion timesteps to run for the demos
|
| 100 |
+
- `num_demos`
|
| 101 |
+
- This is the number of examples to generate in each demo
|
| 102 |
+
- `demo_cond`
|
| 103 |
+
- For conditioned diffusion models, this is the conditioning metadata to provide to each example, provided as a list
|
| 104 |
+
- NOTE: List must be the same length as `num_demos`
|
| 105 |
+
- `demo_cfg_scales`
|
| 106 |
+
- For conditioned diffusion models, this provides a list of classifier-free guidance (CFG) scales to render during the demos. This can be helpful to get an idea of how the model responds to different conditioning strengths as training continues.
|
| 107 |
+
|
| 108 |
+
## Example config
|
| 109 |
+
```json
|
| 110 |
+
"demo": {
|
| 111 |
+
"demo_every": 2000,
|
| 112 |
+
"demo_steps": 250,
|
| 113 |
+
"num_demos": 4,
|
| 114 |
+
"demo_cond": [
|
| 115 |
+
{"prompt": "A beautiful piano arpeggio", "seconds_start": 0, "seconds_total": 80},
|
| 116 |
+
{"prompt": "A tropical house track with upbeat melodies, a driving bassline, and cheery vibes", "seconds_start": 0, "seconds_total": 250},
|
| 117 |
+
{"prompt": "A cool 80s glam rock song with driving drums and distorted guitars", "seconds_start": 0, "seconds_total": 180},
|
| 118 |
+
{"prompt": "A grand orchestral arrangement", "seconds_start": 0, "seconds_total": 190}
|
| 119 |
+
],
|
| 120 |
+
"demo_cfg_scales": [3, 6, 9]
|
| 121 |
+
}
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
# Model types
|
| 125 |
+
|
| 126 |
+
A variety of different model types can be used as the underlying backbone for a diffusion model. At the moment, this includes variants of U-Net and Transformer models.
|
| 127 |
+
|
| 128 |
+
## Diffusion Transformers (DiT)
|
| 129 |
+
|
| 130 |
+
Transformers tend to consistently outperform U-Nets in terms of model quality, but are much more memory- and compute-intensive and work best on shorter sequences such as latent encodings of audio.
|
| 131 |
+
|
| 132 |
+
### Continuous Transformer
|
| 133 |
+
|
| 134 |
+
This is our custom implementation of a transformer model, based on the `x-transformers` implementation, but with efficiency improvements such as fused QKV layers, and Flash Attention 2 support.
|
| 135 |
+
|
| 136 |
+
### `x-transformers`
|
| 137 |
+
|
| 138 |
+
This model type uses the `ContinuousTransformerWrapper` class from the https://github.com/lucidrains/x-transformers repository as the diffusion transformer backbone.
|
| 139 |
+
|
| 140 |
+
`x-transformers` is a great baseline transformer implementation with lots of options for various experimental settings.
|
| 141 |
+
It's great for testing out experimental features without implementing them yourself, but the implementations might not be fully optimized, and breaking changes may be introduced without much warning.
|
| 142 |
+
|
| 143 |
+
## Diffusion U-Net
|
| 144 |
+
|
| 145 |
+
U-Nets use a hierarchical architecture to gradually downsample the input data before more heavy processing is performed, then upsample the data again, using skip connections to pass data across the downsampling "valley" (the "U" in the name) to the upsampling layer at the same resolution.
|
| 146 |
+
|
| 147 |
+
### audio-diffusion-pytorch U-Net (ADP)
|
| 148 |
+
|
| 149 |
+
This model type uses a modified implementation of the `UNetCFG1D` class from version 0.0.94 of the `https://github.com/archinetai/audio-diffusion-pytorch` repo, with added Flash Attention support.
|
| 150 |
+
|
| 151 |
+
### Dance Diffusion U-Net
|
| 152 |
+
|
| 153 |
+
This is a reimplementation of the U-Net used in [Dance Diffusion](https://github.com/Harmonai-org/sample-generator). It has minimal conditioning support, only really supporting global conditioning. Mostly used for unconditional diffusion models.
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/docs/pretransforms.md
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Pretransforms
|
| 2 |
+
Many models require some fixed transform to be applied to the input audio before the audio is passed in to the trainable layers of the model, as well as a corresponding inverse transform to be applied to the outputs of the model. We refer to these as "pretransforms".
|
| 3 |
+
|
| 4 |
+
At the moment, `stable-audio-tools` supports two pretransforms, frozen autoencoders for latent diffusion models and wavelet decompositions.
|
| 5 |
+
|
| 6 |
+
Pretransforms have a similar interface to autoencoders with "encode" and "decode" functions defined for each pretransform.
|
| 7 |
+
|
| 8 |
+
## Autoencoder pretransform
|
| 9 |
+
To define a model with an autoencoder pretransform, you can define the "pretransform" property in the model config, with the `type` property set to `autoencoder`. The `config` property should be an autoencoder model definition.
|
| 10 |
+
|
| 11 |
+
Example:
|
| 12 |
+
```json
|
| 13 |
+
"pretransform": {
|
| 14 |
+
"type": "autoencoder",
|
| 15 |
+
"config": {
|
| 16 |
+
"encoder": {
|
| 17 |
+
...
|
| 18 |
+
},
|
| 19 |
+
"decoder": {
|
| 20 |
+
...
|
| 21 |
+
}
|
| 22 |
+
...normal autoencoder configuration
|
| 23 |
+
}
|
| 24 |
+
}
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
### Latent rescaling
|
| 28 |
+
The original [Latent Diffusion paper](https://arxiv.org/abs/2112.10752) found that rescaling the latent series to unit variance before performing diffusion improved quality. To this end, we expose a `scale` property on autoencoder pretransforms that will take care of this rescaling. The scale should be set to the original standard deviation of the latents, which can be determined experimentally, or by looking at the `latent_std` value during training. The pretransform code will divide by this scale factor in the `encode` function and multiply by this scale in the `decode` function.
|
| 29 |
+
|
| 30 |
+
## Wavelet pretransform
|
| 31 |
+
`stable-audio-tools` also exposes wavelet decomposition as a pretransform. Wavelet decomposition is a quick way to trade off sequence length for channels in autoencoders, while maintaining a multi-band implicit bias.
|
| 32 |
+
|
| 33 |
+
Wavelet pretransforms take the following properties:
|
| 34 |
+
|
| 35 |
+
- `channels`
|
| 36 |
+
- The number of input and output audio channels for the wavelet transform
|
| 37 |
+
- `levels`
|
| 38 |
+
- The number of successive wavelet decompositions to perform. Each level doubles the channel count and halves the sequence length
|
| 39 |
+
- `wavelet`
|
| 40 |
+
- The specific wavelet from [PyWavelets](https://pywavelets.readthedocs.io/en/latest/ref/wavelets.html) to use, currently limited to `"bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"`
|
| 41 |
+
|
| 42 |
+
## Future work
|
| 43 |
+
We hope to add more filters and transforms to this list, including PQMF and STFT transforms.
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/scripts/ds_zero_to_pl_ckpt.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
# from lightning.pytorch.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict
|
| 3 |
+
|
| 4 |
+
if __name__ == "__main__":
|
| 5 |
+
|
| 6 |
+
parser = argparse.ArgumentParser()
|
| 7 |
+
parser.add_argument("--save_path", type=str, help="Path to the zero checkpoint")
|
| 8 |
+
parser.add_argument("--output_path", type=str, help="Path to the output checkpoint", default="lightning_model.pt")
|
| 9 |
+
args = parser.parse_args()
|
| 10 |
+
|
| 11 |
+
# lightning deepspeed has saved a directory instead of a file
|
| 12 |
+
save_path = args.save_path
|
| 13 |
+
output_path = args.output_path
|
| 14 |
+
convert_zero_checkpoint_to_fp32_state_dict(save_path, output_path)
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/stable_audio_tools/data/__init__.py
ADDED
|
File without changes
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/stable_audio_tools/data/dataset.py
ADDED
|
@@ -0,0 +1,654 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
import numpy as np
|
| 3 |
+
import io
|
| 4 |
+
import os
|
| 5 |
+
import posixpath
|
| 6 |
+
import random
|
| 7 |
+
import re
|
| 8 |
+
import subprocess
|
| 9 |
+
import time
|
| 10 |
+
import torch
|
| 11 |
+
import torchaudio
|
| 12 |
+
import webdataset as wds
|
| 13 |
+
|
| 14 |
+
from aeiou.core import is_silence
|
| 15 |
+
from os import path
|
| 16 |
+
from pedalboard.io import AudioFile
|
| 17 |
+
from torchaudio import transforms as T
|
| 18 |
+
from typing import Optional, Callable, List
|
| 19 |
+
|
| 20 |
+
from .utils import Stereo, Mono, PhaseFlipper, PadCrop_Normalized_T
|
| 21 |
+
|
| 22 |
+
AUDIO_KEYS = ("flac", "wav", "mp3", "m4a", "ogg", "opus")
|
| 23 |
+
|
| 24 |
+
# fast_scandir implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py
|
| 25 |
+
|
| 26 |
+
def fast_scandir(
|
| 27 |
+
dir:str, # top-level directory at which to begin scanning
|
| 28 |
+
ext:list, # list of allowed file extensions,
|
| 29 |
+
#max_size = 1 * 1000 * 1000 * 1000 # Only files < 1 GB
|
| 30 |
+
):
|
| 31 |
+
"very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243"
|
| 32 |
+
subfolders, files = [], []
|
| 33 |
+
ext = ['.'+x if x[0]!='.' else x for x in ext] # add starting period to extensions if needed
|
| 34 |
+
try: # hope to avoid 'permission denied' by this try
|
| 35 |
+
for f in os.scandir(dir):
|
| 36 |
+
try: # 'hope to avoid too many levels of symbolic links' error
|
| 37 |
+
if f.is_dir():
|
| 38 |
+
subfolders.append(f.path)
|
| 39 |
+
elif f.is_file():
|
| 40 |
+
file_ext = os.path.splitext(f.name)[1].lower()
|
| 41 |
+
is_hidden = os.path.basename(f.path).startswith(".")
|
| 42 |
+
|
| 43 |
+
if file_ext in ext and not is_hidden:
|
| 44 |
+
files.append(f.path)
|
| 45 |
+
except:
|
| 46 |
+
pass
|
| 47 |
+
except:
|
| 48 |
+
pass
|
| 49 |
+
|
| 50 |
+
for dir in list(subfolders):
|
| 51 |
+
sf, f = fast_scandir(dir, ext)
|
| 52 |
+
subfolders.extend(sf)
|
| 53 |
+
files.extend(f)
|
| 54 |
+
return subfolders, files
|
| 55 |
+
|
| 56 |
+
def keyword_scandir(
|
| 57 |
+
dir: str, # top-level directory at which to begin scanning
|
| 58 |
+
ext: list, # list of allowed file extensions
|
| 59 |
+
keywords: list, # list of keywords to search for in the file name
|
| 60 |
+
):
|
| 61 |
+
"very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243"
|
| 62 |
+
subfolders, files = [], []
|
| 63 |
+
# make keywords case insensitive
|
| 64 |
+
keywords = [keyword.lower() for keyword in keywords]
|
| 65 |
+
# add starting period to extensions if needed
|
| 66 |
+
ext = ['.'+x if x[0] != '.' else x for x in ext]
|
| 67 |
+
banned_words = ["paxheader", "__macosx"]
|
| 68 |
+
try: # hope to avoid 'permission denied' by this try
|
| 69 |
+
for f in os.scandir(dir):
|
| 70 |
+
try: # 'hope to avoid too many levels of symbolic links' error
|
| 71 |
+
if f.is_dir():
|
| 72 |
+
subfolders.append(f.path)
|
| 73 |
+
elif f.is_file():
|
| 74 |
+
is_hidden = f.name.split("/")[-1][0] == '.'
|
| 75 |
+
has_ext = os.path.splitext(f.name)[1].lower() in ext
|
| 76 |
+
name_lower = f.name.lower()
|
| 77 |
+
has_keyword = any(
|
| 78 |
+
[keyword in name_lower for keyword in keywords])
|
| 79 |
+
has_banned = any(
|
| 80 |
+
[banned_word in name_lower for banned_word in banned_words])
|
| 81 |
+
if has_ext and has_keyword and not has_banned and not is_hidden and not os.path.basename(f.path).startswith("._"):
|
| 82 |
+
files.append(f.path)
|
| 83 |
+
except:
|
| 84 |
+
pass
|
| 85 |
+
except:
|
| 86 |
+
pass
|
| 87 |
+
|
| 88 |
+
for dir in list(subfolders):
|
| 89 |
+
sf, f = keyword_scandir(dir, ext, keywords)
|
| 90 |
+
subfolders.extend(sf)
|
| 91 |
+
files.extend(f)
|
| 92 |
+
return subfolders, files
|
| 93 |
+
|
| 94 |
+
def get_audio_filenames(
|
| 95 |
+
paths: list, # directories in which to search
|
| 96 |
+
keywords=None,
|
| 97 |
+
exts=['.wav', '.mp3', '.flac', '.ogg', '.aif', '.opus']
|
| 98 |
+
):
|
| 99 |
+
"recursively get a list of audio filenames"
|
| 100 |
+
filenames = []
|
| 101 |
+
if type(paths) is str:
|
| 102 |
+
paths = [paths]
|
| 103 |
+
for path in paths: # get a list of relevant filenames
|
| 104 |
+
if keywords is not None:
|
| 105 |
+
subfolders, files = keyword_scandir(path, exts, keywords)
|
| 106 |
+
else:
|
| 107 |
+
subfolders, files = fast_scandir(path, exts)
|
| 108 |
+
filenames.extend(files)
|
| 109 |
+
return filenames
|
| 110 |
+
|
| 111 |
+
class LocalDatasetConfig:
|
| 112 |
+
def __init__(
|
| 113 |
+
self,
|
| 114 |
+
id: str,
|
| 115 |
+
path: str,
|
| 116 |
+
custom_metadata_fn: Optional[Callable[[str], str]] = None
|
| 117 |
+
):
|
| 118 |
+
self.id = id
|
| 119 |
+
self.path = path
|
| 120 |
+
self.custom_metadata_fn = custom_metadata_fn
|
| 121 |
+
|
| 122 |
+
class SampleDataset(torch.utils.data.Dataset):
|
| 123 |
+
def __init__(
|
| 124 |
+
self,
|
| 125 |
+
configs,
|
| 126 |
+
sample_size=65536,
|
| 127 |
+
sample_rate=48000,
|
| 128 |
+
keywords=None,
|
| 129 |
+
random_crop=True,
|
| 130 |
+
force_channels="stereo"
|
| 131 |
+
):
|
| 132 |
+
super().__init__()
|
| 133 |
+
self.filenames = []
|
| 134 |
+
|
| 135 |
+
self.augs = torch.nn.Sequential(
|
| 136 |
+
PhaseFlipper(),
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
self.root_paths = []
|
| 140 |
+
|
| 141 |
+
self.pad_crop = PadCrop_Normalized_T(sample_size, sample_rate, randomize=random_crop)
|
| 142 |
+
|
| 143 |
+
self.force_channels = force_channels
|
| 144 |
+
|
| 145 |
+
self.encoding = torch.nn.Sequential(
|
| 146 |
+
Stereo() if self.force_channels == "stereo" else torch.nn.Identity(),
|
| 147 |
+
Mono() if self.force_channels == "mono" else torch.nn.Identity(),
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
self.sr = sample_rate
|
| 151 |
+
|
| 152 |
+
self.custom_metadata_fns = {}
|
| 153 |
+
|
| 154 |
+
for config in configs:
|
| 155 |
+
self.root_paths.append(config.path)
|
| 156 |
+
self.filenames.extend(get_audio_filenames(config.path, keywords))
|
| 157 |
+
if config.custom_metadata_fn is not None:
|
| 158 |
+
self.custom_metadata_fns[config.path] = config.custom_metadata_fn
|
| 159 |
+
|
| 160 |
+
print(f'Found {len(self.filenames)} files')
|
| 161 |
+
|
| 162 |
+
def load_file(self, filename):
|
| 163 |
+
ext = filename.split(".")[-1]
|
| 164 |
+
|
| 165 |
+
if ext == "mp3":
|
| 166 |
+
with AudioFile(filename) as f:
|
| 167 |
+
audio = f.read(f.frames)
|
| 168 |
+
audio = torch.from_numpy(audio)
|
| 169 |
+
in_sr = f.samplerate
|
| 170 |
+
else:
|
| 171 |
+
audio, in_sr = torchaudio.load(filename, format=ext)
|
| 172 |
+
|
| 173 |
+
if in_sr != self.sr:
|
| 174 |
+
resample_tf = T.Resample(in_sr, self.sr)
|
| 175 |
+
audio = resample_tf(audio)
|
| 176 |
+
|
| 177 |
+
return audio
|
| 178 |
+
|
| 179 |
+
def __len__(self):
|
| 180 |
+
return len(self.filenames)
|
| 181 |
+
|
| 182 |
+
def __getitem__(self, idx):
|
| 183 |
+
audio_filename = self.filenames[idx]
|
| 184 |
+
try:
|
| 185 |
+
start_time = time.time()
|
| 186 |
+
audio = self.load_file(audio_filename)
|
| 187 |
+
|
| 188 |
+
audio, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio)
|
| 189 |
+
|
| 190 |
+
# Run augmentations on this sample (including random crop)
|
| 191 |
+
if self.augs is not None:
|
| 192 |
+
audio = self.augs(audio)
|
| 193 |
+
|
| 194 |
+
audio = audio.clamp(-1, 1)
|
| 195 |
+
|
| 196 |
+
# Encode the file to assist in prediction
|
| 197 |
+
if self.encoding is not None:
|
| 198 |
+
audio = self.encoding(audio)
|
| 199 |
+
|
| 200 |
+
info = {}
|
| 201 |
+
|
| 202 |
+
info["path"] = audio_filename
|
| 203 |
+
|
| 204 |
+
for root_path in self.root_paths:
|
| 205 |
+
if root_path in audio_filename:
|
| 206 |
+
info["relpath"] = path.relpath(audio_filename, root_path)
|
| 207 |
+
|
| 208 |
+
info["timestamps"] = (t_start, t_end)
|
| 209 |
+
info["seconds_start"] = seconds_start
|
| 210 |
+
info["seconds_total"] = seconds_total
|
| 211 |
+
info["padding_mask"] = padding_mask
|
| 212 |
+
|
| 213 |
+
end_time = time.time()
|
| 214 |
+
|
| 215 |
+
info["load_time"] = end_time - start_time
|
| 216 |
+
|
| 217 |
+
for custom_md_path in self.custom_metadata_fns.keys():
|
| 218 |
+
if custom_md_path in audio_filename:
|
| 219 |
+
custom_metadata_fn = self.custom_metadata_fns[custom_md_path]
|
| 220 |
+
custom_metadata = custom_metadata_fn(info, audio)
|
| 221 |
+
info.update(custom_metadata)
|
| 222 |
+
|
| 223 |
+
if "__reject__" in info and info["__reject__"]:
|
| 224 |
+
return self[random.randrange(len(self))]
|
| 225 |
+
|
| 226 |
+
return (audio, info)
|
| 227 |
+
except Exception as e:
|
| 228 |
+
print(f'Couldn\'t load file {audio_filename}: {e}')
|
| 229 |
+
return self[random.randrange(len(self))]
|
| 230 |
+
|
| 231 |
+
def group_by_keys(data, keys=wds.tariterators.base_plus_ext, lcase=True, suffixes=None, handler=None):
|
| 232 |
+
"""Return function over iterator that groups key, value pairs into samples.
|
| 233 |
+
:param keys: function that splits the key into key and extension (base_plus_ext)
|
| 234 |
+
:param lcase: convert suffixes to lower case (Default value = True)
|
| 235 |
+
"""
|
| 236 |
+
current_sample = None
|
| 237 |
+
for filesample in data:
|
| 238 |
+
assert isinstance(filesample, dict)
|
| 239 |
+
fname, value = filesample["fname"], filesample["data"]
|
| 240 |
+
prefix, suffix = keys(fname)
|
| 241 |
+
if wds.tariterators.trace:
|
| 242 |
+
print(
|
| 243 |
+
prefix,
|
| 244 |
+
suffix,
|
| 245 |
+
current_sample.keys() if isinstance(current_sample, dict) else None,
|
| 246 |
+
)
|
| 247 |
+
if prefix is None:
|
| 248 |
+
continue
|
| 249 |
+
if lcase:
|
| 250 |
+
suffix = suffix.lower()
|
| 251 |
+
if current_sample is None or prefix != current_sample["__key__"]:
|
| 252 |
+
if wds.tariterators.valid_sample(current_sample):
|
| 253 |
+
yield current_sample
|
| 254 |
+
current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
|
| 255 |
+
if suffix in current_sample:
|
| 256 |
+
print(f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}")
|
| 257 |
+
if suffixes is None or suffix in suffixes:
|
| 258 |
+
current_sample[suffix] = value
|
| 259 |
+
if wds.tariterators.valid_sample(current_sample):
|
| 260 |
+
yield current_sample
|
| 261 |
+
|
| 262 |
+
wds.tariterators.group_by_keys = group_by_keys
|
| 263 |
+
|
| 264 |
+
# S3 code and WDS preprocessing code based on implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py
|
| 265 |
+
|
| 266 |
+
def get_s3_contents(dataset_path, s3_url_prefix=None, filter='', recursive=True, debug=False, profile=None):
|
| 267 |
+
"""
|
| 268 |
+
Returns a list of full S3 paths to files in a given S3 bucket and directory path.
|
| 269 |
+
"""
|
| 270 |
+
# Ensure dataset_path ends with a trailing slash
|
| 271 |
+
if dataset_path != '' and not dataset_path.endswith('/'):
|
| 272 |
+
dataset_path += '/'
|
| 273 |
+
# Use posixpath to construct the S3 URL path
|
| 274 |
+
bucket_path = posixpath.join(s3_url_prefix or '', dataset_path)
|
| 275 |
+
# Construct the `aws s3 ls` command
|
| 276 |
+
cmd = ['aws', 's3', 'ls', bucket_path]
|
| 277 |
+
|
| 278 |
+
if profile is not None:
|
| 279 |
+
cmd.extend(['--profile', profile])
|
| 280 |
+
|
| 281 |
+
if recursive:
|
| 282 |
+
# Add the --recursive flag if requested
|
| 283 |
+
cmd.append('--recursive')
|
| 284 |
+
|
| 285 |
+
# Run the `aws s3 ls` command and capture the output
|
| 286 |
+
run_ls = subprocess.run(cmd, capture_output=True, check=True)
|
| 287 |
+
# Split the output into lines and strip whitespace from each line
|
| 288 |
+
contents = run_ls.stdout.decode('utf-8').split('\n')
|
| 289 |
+
contents = [x.strip() for x in contents if x]
|
| 290 |
+
# Remove the timestamp from lines that begin with a timestamp
|
| 291 |
+
contents = [re.sub(r'^\S+\s+\S+\s+\d+\s+', '', x)
|
| 292 |
+
if re.match(r'^\S+\s+\S+\s+\d+\s+', x) else x for x in contents]
|
| 293 |
+
# Construct a full S3 path for each file in the contents list
|
| 294 |
+
contents = [posixpath.join(s3_url_prefix or '', x)
|
| 295 |
+
for x in contents if not x.endswith('/')]
|
| 296 |
+
# Apply the filter, if specified
|
| 297 |
+
if filter:
|
| 298 |
+
contents = [x for x in contents if filter in x]
|
| 299 |
+
# Remove redundant directory names in the S3 URL
|
| 300 |
+
if recursive:
|
| 301 |
+
# Get the main directory name from the S3 URL
|
| 302 |
+
main_dir = "/".join(bucket_path.split('/')[3:])
|
| 303 |
+
# Remove the redundant directory names from each file path
|
| 304 |
+
contents = [x.replace(f'{main_dir}', '').replace(
|
| 305 |
+
'//', '/') for x in contents]
|
| 306 |
+
# Print debugging information, if requested
|
| 307 |
+
if debug:
|
| 308 |
+
print("contents = \n", contents)
|
| 309 |
+
# Return the list of S3 paths to files
|
| 310 |
+
return contents
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def get_all_s3_urls(
|
| 314 |
+
names=[], # list of all valid [LAION AudioDataset] dataset names
|
| 315 |
+
# list of subsets you want from those datasets, e.g. ['train','valid']
|
| 316 |
+
subsets=[''],
|
| 317 |
+
s3_url_prefix=None, # prefix for those dataset names
|
| 318 |
+
recursive=True, # recursively list all tar files in all subdirs
|
| 319 |
+
filter_str='tar', # only grab files with this substring
|
| 320 |
+
# print debugging info -- note: info displayed likely to change at dev's whims
|
| 321 |
+
debug=False,
|
| 322 |
+
profiles={}, # dictionary of profiles for each item in names, e.g. {'dataset1': 'profile1', 'dataset2': 'profile2'}
|
| 323 |
+
):
|
| 324 |
+
"get urls of shards (tar files) for multiple datasets in one s3 bucket"
|
| 325 |
+
urls = []
|
| 326 |
+
for name in names:
|
| 327 |
+
# If s3_url_prefix is not specified, assume the full S3 path is included in each element of the names list
|
| 328 |
+
if s3_url_prefix is None:
|
| 329 |
+
contents_str = name
|
| 330 |
+
else:
|
| 331 |
+
# Construct the S3 path using the s3_url_prefix and the current name value
|
| 332 |
+
contents_str = posixpath.join(s3_url_prefix, name)
|
| 333 |
+
if debug:
|
| 334 |
+
print(f"get_all_s3_urls: {contents_str}:")
|
| 335 |
+
for subset in subsets:
|
| 336 |
+
subset_str = posixpath.join(contents_str, subset)
|
| 337 |
+
if debug:
|
| 338 |
+
print(f"subset_str = {subset_str}")
|
| 339 |
+
# Get the list of tar files in the current subset directory
|
| 340 |
+
profile = profiles.get(name, None)
|
| 341 |
+
tar_list = get_s3_contents(
|
| 342 |
+
subset_str, s3_url_prefix=None, recursive=recursive, filter=filter_str, debug=debug, profile=profile)
|
| 343 |
+
for tar in tar_list:
|
| 344 |
+
# Escape spaces and parentheses in the tar filename for use in the shell command
|
| 345 |
+
tar = tar.replace(" ", "\ ").replace(
|
| 346 |
+
"(", "\(").replace(")", "\)")
|
| 347 |
+
# Construct the S3 path to the current tar file
|
| 348 |
+
s3_path = posixpath.join(name, subset, tar) + " -"
|
| 349 |
+
# Construct the AWS CLI command to download the current tar file
|
| 350 |
+
if s3_url_prefix is None:
|
| 351 |
+
request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {s3_path}"
|
| 352 |
+
else:
|
| 353 |
+
request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {posixpath.join(s3_url_prefix, s3_path)}"
|
| 354 |
+
if profiles.get(name):
|
| 355 |
+
request_str += f" --profile {profiles.get(name)}"
|
| 356 |
+
if debug:
|
| 357 |
+
print("request_str = ", request_str)
|
| 358 |
+
# Add the constructed URL to the list of URLs
|
| 359 |
+
urls.append(request_str)
|
| 360 |
+
return urls
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def log_and_continue(exn):
|
| 364 |
+
"""Call in an exception handler to ignore any exception, isssue a warning, and continue."""
|
| 365 |
+
print(f"Handling webdataset error ({repr(exn)}). Ignoring.")
|
| 366 |
+
return True
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def is_valid_sample(sample):
|
| 370 |
+
has_json = "json" in sample
|
| 371 |
+
has_audio = "audio" in sample
|
| 372 |
+
is_silent = is_silence(sample["audio"])
|
| 373 |
+
is_rejected = "__reject__" in sample["json"] and sample["json"]["__reject__"]
|
| 374 |
+
|
| 375 |
+
return has_json and has_audio and not is_silent and not is_rejected
|
| 376 |
+
|
| 377 |
+
class S3DatasetConfig:
|
| 378 |
+
def __init__(
|
| 379 |
+
self,
|
| 380 |
+
id: str,
|
| 381 |
+
s3_path: str,
|
| 382 |
+
custom_metadata_fn: Optional[Callable[[str], str]] = None,
|
| 383 |
+
profile: Optional[str] = None,
|
| 384 |
+
):
|
| 385 |
+
self.id = id
|
| 386 |
+
self.path = s3_path
|
| 387 |
+
self.custom_metadata_fn = custom_metadata_fn
|
| 388 |
+
self.profile = profile
|
| 389 |
+
self.urls = []
|
| 390 |
+
|
| 391 |
+
def load_data_urls(self):
|
| 392 |
+
self.urls = get_all_s3_urls(
|
| 393 |
+
names=[self.path],
|
| 394 |
+
s3_url_prefix=None,
|
| 395 |
+
recursive=True,
|
| 396 |
+
profiles={self.path: self.profile} if self.profile else {},
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
return self.urls
|
| 400 |
+
|
| 401 |
+
class LocalWebDatasetConfig:
|
| 402 |
+
def __init__(
|
| 403 |
+
self,
|
| 404 |
+
id: str,
|
| 405 |
+
path: str,
|
| 406 |
+
custom_metadata_fn: Optional[Callable[[str], str]] = None,
|
| 407 |
+
profile: Optional[str] = None,
|
| 408 |
+
):
|
| 409 |
+
self.id = id
|
| 410 |
+
self.path = path
|
| 411 |
+
self.custom_metadata_fn = custom_metadata_fn
|
| 412 |
+
self.urls = []
|
| 413 |
+
|
| 414 |
+
def load_data_urls(self):
|
| 415 |
+
|
| 416 |
+
self.urls = fast_scandir(self.path, ["tar"])[1]
|
| 417 |
+
|
| 418 |
+
return self.urls
|
| 419 |
+
|
| 420 |
+
def audio_decoder(key, value):
|
| 421 |
+
# Get file extension from key
|
| 422 |
+
ext = key.split(".")[-1]
|
| 423 |
+
|
| 424 |
+
if ext in AUDIO_KEYS:
|
| 425 |
+
return torchaudio.load(io.BytesIO(value))
|
| 426 |
+
else:
|
| 427 |
+
return None
|
| 428 |
+
|
| 429 |
+
def collation_fn(samples):
|
| 430 |
+
batched = list(zip(*samples))
|
| 431 |
+
result = []
|
| 432 |
+
for b in batched:
|
| 433 |
+
if isinstance(b[0], (int, float)):
|
| 434 |
+
b = np.array(b)
|
| 435 |
+
elif isinstance(b[0], torch.Tensor):
|
| 436 |
+
b = torch.stack(b)
|
| 437 |
+
elif isinstance(b[0], np.ndarray):
|
| 438 |
+
b = np.array(b)
|
| 439 |
+
else:
|
| 440 |
+
b = b
|
| 441 |
+
result.append(b)
|
| 442 |
+
return result
|
| 443 |
+
|
| 444 |
+
class WebDatasetDataLoader():
|
| 445 |
+
def __init__(
|
| 446 |
+
self,
|
| 447 |
+
datasets: List[S3DatasetConfig],
|
| 448 |
+
batch_size,
|
| 449 |
+
sample_size,
|
| 450 |
+
sample_rate=48000,
|
| 451 |
+
num_workers=8,
|
| 452 |
+
epoch_steps=1000,
|
| 453 |
+
random_crop=True,
|
| 454 |
+
force_channels="stereo",
|
| 455 |
+
augment_phase=True,
|
| 456 |
+
**data_loader_kwargs
|
| 457 |
+
):
|
| 458 |
+
|
| 459 |
+
self.datasets = datasets
|
| 460 |
+
|
| 461 |
+
self.sample_size = sample_size
|
| 462 |
+
self.sample_rate = sample_rate
|
| 463 |
+
self.random_crop = random_crop
|
| 464 |
+
self.force_channels = force_channels
|
| 465 |
+
self.augment_phase = augment_phase
|
| 466 |
+
|
| 467 |
+
urls = [dataset.load_data_urls() for dataset in datasets]
|
| 468 |
+
|
| 469 |
+
# Flatten the list of lists of URLs
|
| 470 |
+
urls = [url for dataset_urls in urls for url in dataset_urls]
|
| 471 |
+
|
| 472 |
+
# Shuffle the urls
|
| 473 |
+
random.shuffle(urls)
|
| 474 |
+
|
| 475 |
+
self.dataset = wds.DataPipeline(
|
| 476 |
+
wds.ResampledShards(urls),
|
| 477 |
+
wds.tarfile_to_samples(handler=log_and_continue),
|
| 478 |
+
wds.decode(audio_decoder, handler=log_and_continue),
|
| 479 |
+
wds.map(self.wds_preprocess, handler=log_and_continue),
|
| 480 |
+
wds.select(is_valid_sample),
|
| 481 |
+
wds.to_tuple("audio", "json", handler=log_and_continue),
|
| 482 |
+
#wds.shuffle(bufsize=1000, initial=5000),
|
| 483 |
+
wds.batched(batch_size, partial=False, collation_fn=collation_fn),
|
| 484 |
+
).with_epoch(epoch_steps//num_workers if num_workers > 0 else epoch_steps)
|
| 485 |
+
|
| 486 |
+
self.data_loader = wds.WebLoader(self.dataset, num_workers=num_workers, **data_loader_kwargs)
|
| 487 |
+
|
| 488 |
+
def wds_preprocess(self, sample):
|
| 489 |
+
|
| 490 |
+
found_key, rewrite_key = '', ''
|
| 491 |
+
for k, v in sample.items(): # print the all entries in dict
|
| 492 |
+
for akey in AUDIO_KEYS:
|
| 493 |
+
if k.endswith(akey):
|
| 494 |
+
# to rename long/weird key with its simpler counterpart
|
| 495 |
+
found_key, rewrite_key = k, akey
|
| 496 |
+
break
|
| 497 |
+
if '' != found_key:
|
| 498 |
+
break
|
| 499 |
+
if '' == found_key: # got no audio!
|
| 500 |
+
return None # try returning None to tell WebDataset to skip this one
|
| 501 |
+
|
| 502 |
+
audio, in_sr = sample[found_key]
|
| 503 |
+
if in_sr != self.sample_rate:
|
| 504 |
+
resample_tf = T.Resample(in_sr, self.sample_rate)
|
| 505 |
+
audio = resample_tf(audio)
|
| 506 |
+
|
| 507 |
+
if self.sample_size is not None:
|
| 508 |
+
# Pad/crop and get the relative timestamp
|
| 509 |
+
pad_crop = PadCrop_Normalized_T(
|
| 510 |
+
self.sample_size, randomize=self.random_crop, sample_rate=self.sample_rate)
|
| 511 |
+
audio, t_start, t_end, seconds_start, seconds_total, padding_mask = pad_crop(
|
| 512 |
+
audio)
|
| 513 |
+
sample["json"]["seconds_start"] = seconds_start
|
| 514 |
+
sample["json"]["seconds_total"] = seconds_total
|
| 515 |
+
sample["json"]["padding_mask"] = padding_mask
|
| 516 |
+
else:
|
| 517 |
+
t_start, t_end = 0, 1
|
| 518 |
+
|
| 519 |
+
# Check if audio is length zero, initialize to a single zero if so
|
| 520 |
+
if audio.shape[-1] == 0:
|
| 521 |
+
audio = torch.zeros(1, 1)
|
| 522 |
+
|
| 523 |
+
# Make the audio stereo and augment by randomly inverting phase
|
| 524 |
+
augs = torch.nn.Sequential(
|
| 525 |
+
Stereo() if self.force_channels == "stereo" else torch.nn.Identity(),
|
| 526 |
+
Mono() if self.force_channels == "mono" else torch.nn.Identity(),
|
| 527 |
+
PhaseFlipper() if self.augment_phase else torch.nn.Identity()
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
audio = augs(audio)
|
| 531 |
+
|
| 532 |
+
sample["json"]["timestamps"] = (t_start, t_end)
|
| 533 |
+
|
| 534 |
+
if "text" in sample["json"]:
|
| 535 |
+
sample["json"]["prompt"] = sample["json"]["text"]
|
| 536 |
+
|
| 537 |
+
# Check for custom metadata functions
|
| 538 |
+
for dataset in self.datasets:
|
| 539 |
+
if dataset.custom_metadata_fn is None:
|
| 540 |
+
continue
|
| 541 |
+
|
| 542 |
+
if dataset.path in sample["__url__"]:
|
| 543 |
+
custom_metadata = dataset.custom_metadata_fn(sample["json"], audio)
|
| 544 |
+
sample["json"].update(custom_metadata)
|
| 545 |
+
|
| 546 |
+
if found_key != rewrite_key: # rename long/weird key with its simpler counterpart
|
| 547 |
+
del sample[found_key]
|
| 548 |
+
|
| 549 |
+
sample["audio"] = audio
|
| 550 |
+
|
| 551 |
+
# Add audio to the metadata as well for conditioning
|
| 552 |
+
sample["json"]["audio"] = audio
|
| 553 |
+
|
| 554 |
+
return sample
|
| 555 |
+
|
| 556 |
+
def create_dataloader_from_config(dataset_config, batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4):
|
| 557 |
+
|
| 558 |
+
dataset_type = dataset_config.get("dataset_type", None)
|
| 559 |
+
|
| 560 |
+
assert dataset_type is not None, "Dataset type must be specified in dataset config"
|
| 561 |
+
|
| 562 |
+
if audio_channels == 1:
|
| 563 |
+
force_channels = "mono"
|
| 564 |
+
else:
|
| 565 |
+
force_channels = "stereo"
|
| 566 |
+
|
| 567 |
+
if dataset_type == "audio_dir":
|
| 568 |
+
|
| 569 |
+
audio_dir_configs = dataset_config.get("datasets", None)
|
| 570 |
+
|
| 571 |
+
assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]"
|
| 572 |
+
|
| 573 |
+
configs = []
|
| 574 |
+
|
| 575 |
+
for audio_dir_config in audio_dir_configs:
|
| 576 |
+
audio_dir_path = audio_dir_config.get("path", None)
|
| 577 |
+
assert audio_dir_path is not None, "Path must be set for local audio directory configuration"
|
| 578 |
+
|
| 579 |
+
custom_metadata_fn = None
|
| 580 |
+
custom_metadata_module_path = audio_dir_config.get("custom_metadata_module", None)
|
| 581 |
+
|
| 582 |
+
if custom_metadata_module_path is not None:
|
| 583 |
+
spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path)
|
| 584 |
+
metadata_module = importlib.util.module_from_spec(spec)
|
| 585 |
+
spec.loader.exec_module(metadata_module)
|
| 586 |
+
|
| 587 |
+
custom_metadata_fn = metadata_module.get_custom_metadata
|
| 588 |
+
|
| 589 |
+
configs.append(
|
| 590 |
+
LocalDatasetConfig(
|
| 591 |
+
id=audio_dir_config["id"],
|
| 592 |
+
path=audio_dir_path,
|
| 593 |
+
custom_metadata_fn=custom_metadata_fn
|
| 594 |
+
)
|
| 595 |
+
)
|
| 596 |
+
|
| 597 |
+
train_set = SampleDataset(
|
| 598 |
+
configs,
|
| 599 |
+
sample_rate=sample_rate,
|
| 600 |
+
sample_size=sample_size,
|
| 601 |
+
random_crop=dataset_config.get("random_crop", True),
|
| 602 |
+
force_channels=force_channels
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
return torch.utils.data.DataLoader(train_set, batch_size, shuffle=True,
|
| 606 |
+
num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=True, collate_fn=collation_fn)
|
| 607 |
+
|
| 608 |
+
elif dataset_type in ["s3", "wds"]: # Support "s3" type for backwards compatibility
|
| 609 |
+
wds_configs = []
|
| 610 |
+
|
| 611 |
+
for wds_config in dataset_config["datasets"]:
|
| 612 |
+
|
| 613 |
+
custom_metadata_fn = None
|
| 614 |
+
custom_metadata_module_path = wds_config.get("custom_metadata_module", None)
|
| 615 |
+
|
| 616 |
+
if custom_metadata_module_path is not None:
|
| 617 |
+
spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path)
|
| 618 |
+
metadata_module = importlib.util.module_from_spec(spec)
|
| 619 |
+
spec.loader.exec_module(metadata_module)
|
| 620 |
+
|
| 621 |
+
custom_metadata_fn = metadata_module.get_custom_metadata
|
| 622 |
+
|
| 623 |
+
if "s3_path" in wds_config:
|
| 624 |
+
|
| 625 |
+
wds_configs.append(
|
| 626 |
+
S3DatasetConfig(
|
| 627 |
+
id=wds_config["id"],
|
| 628 |
+
s3_path=wds_config["s3_path"],
|
| 629 |
+
custom_metadata_fn=custom_metadata_fn,
|
| 630 |
+
profile=wds_config.get("profile", None),
|
| 631 |
+
)
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
elif "path" in wds_config:
|
| 635 |
+
|
| 636 |
+
wds_configs.append(
|
| 637 |
+
LocalWebDatasetConfig(
|
| 638 |
+
id=wds_config["id"],
|
| 639 |
+
path=wds_config["path"],
|
| 640 |
+
custom_metadata_fn=custom_metadata_fn
|
| 641 |
+
)
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
return WebDatasetDataLoader(
|
| 645 |
+
wds_configs,
|
| 646 |
+
sample_rate=sample_rate,
|
| 647 |
+
sample_size=sample_size,
|
| 648 |
+
batch_size=batch_size,
|
| 649 |
+
random_crop=dataset_config.get("random_crop", True),
|
| 650 |
+
num_workers=num_workers,
|
| 651 |
+
persistent_workers=True,
|
| 652 |
+
force_channels=force_channels,
|
| 653 |
+
epoch_steps=dataset_config.get("epoch_steps", 2000)
|
| 654 |
+
).data_loader
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/stable_audio_tools/data/utils.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import random
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from torch import nn
|
| 6 |
+
from typing import Tuple
|
| 7 |
+
|
| 8 |
+
class PadCrop(nn.Module):
|
| 9 |
+
def __init__(self, n_samples, randomize=True):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.n_samples = n_samples
|
| 12 |
+
self.randomize = randomize
|
| 13 |
+
|
| 14 |
+
def __call__(self, signal):
|
| 15 |
+
n, s = signal.shape
|
| 16 |
+
start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item()
|
| 17 |
+
end = start + self.n_samples
|
| 18 |
+
output = signal.new_zeros([n, self.n_samples])
|
| 19 |
+
output[:, :min(s, self.n_samples)] = signal[:, start:end]
|
| 20 |
+
return output
|
| 21 |
+
|
| 22 |
+
class PadCrop_Normalized_T(nn.Module):
|
| 23 |
+
|
| 24 |
+
def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True):
|
| 25 |
+
|
| 26 |
+
super().__init__()
|
| 27 |
+
|
| 28 |
+
self.n_samples = n_samples
|
| 29 |
+
self.sample_rate = sample_rate
|
| 30 |
+
self.randomize = randomize
|
| 31 |
+
|
| 32 |
+
def __call__(self, source: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int]:
|
| 33 |
+
|
| 34 |
+
n_channels, n_samples = source.shape
|
| 35 |
+
|
| 36 |
+
# If the audio is shorter than the desired length, pad it
|
| 37 |
+
upper_bound = max(0, n_samples - self.n_samples)
|
| 38 |
+
|
| 39 |
+
# If randomize is False, always start at the beginning of the audio
|
| 40 |
+
offset = 0
|
| 41 |
+
if(self.randomize and n_samples > self.n_samples):
|
| 42 |
+
offset = random.randint(0, upper_bound)
|
| 43 |
+
|
| 44 |
+
# Calculate the start and end times of the chunk
|
| 45 |
+
t_start = offset / (upper_bound + self.n_samples)
|
| 46 |
+
t_end = (offset + self.n_samples) / (upper_bound + self.n_samples)
|
| 47 |
+
|
| 48 |
+
# Create the chunk
|
| 49 |
+
chunk = source.new_zeros([n_channels, self.n_samples])
|
| 50 |
+
|
| 51 |
+
# Copy the audio into the chunk
|
| 52 |
+
chunk[:, :min(n_samples, self.n_samples)] = source[:, offset:offset + self.n_samples]
|
| 53 |
+
|
| 54 |
+
# Calculate the start and end times of the chunk in seconds
|
| 55 |
+
seconds_start = math.floor(offset / self.sample_rate)
|
| 56 |
+
seconds_total = math.ceil(n_samples / self.sample_rate)
|
| 57 |
+
|
| 58 |
+
# Create a mask the same length as the chunk with 1s where the audio is and 0s where it isn't
|
| 59 |
+
padding_mask = torch.zeros([self.n_samples])
|
| 60 |
+
padding_mask[:min(n_samples, self.n_samples)] = 1
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
return (
|
| 64 |
+
chunk,
|
| 65 |
+
t_start,
|
| 66 |
+
t_end,
|
| 67 |
+
seconds_start,
|
| 68 |
+
seconds_total,
|
| 69 |
+
padding_mask
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
class PhaseFlipper(nn.Module):
|
| 73 |
+
"Randomly invert the phase of a signal"
|
| 74 |
+
def __init__(self, p=0.5):
|
| 75 |
+
super().__init__()
|
| 76 |
+
self.p = p
|
| 77 |
+
def __call__(self, signal):
|
| 78 |
+
return -signal if (random.random() < self.p) else signal
|
| 79 |
+
|
| 80 |
+
class Mono(nn.Module):
|
| 81 |
+
def __call__(self, signal):
|
| 82 |
+
return torch.mean(signal, dim=0, keepdims=True) if len(signal.shape) > 1 else signal
|
| 83 |
+
|
| 84 |
+
class Stereo(nn.Module):
|
| 85 |
+
def __call__(self, signal):
|
| 86 |
+
signal_shape = signal.shape
|
| 87 |
+
# Check if it's mono
|
| 88 |
+
if len(signal_shape) == 1: # s -> 2, s
|
| 89 |
+
signal = signal.unsqueeze(0).repeat(2, 1)
|
| 90 |
+
elif len(signal_shape) == 2:
|
| 91 |
+
if signal_shape[0] == 1: #1, s -> 2, s
|
| 92 |
+
signal = signal.repeat(2, 1)
|
| 93 |
+
elif signal_shape[0] > 2: #?, s -> 2,s
|
| 94 |
+
signal = signal[:2, :]
|
| 95 |
+
|
| 96 |
+
return signal
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/stable_audio_tools/inference/__init__.py
ADDED
|
File without changes
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/stable_audio_tools/inference/generation.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import typing as tp
|
| 4 |
+
import math
|
| 5 |
+
from torchaudio import transforms as T
|
| 6 |
+
|
| 7 |
+
from .utils import prepare_audio
|
| 8 |
+
from .sampling import sample, sample_k, sample_rf
|
| 9 |
+
from ..data.utils import PadCrop
|
| 10 |
+
|
| 11 |
+
def generate_diffusion_uncond(
|
| 12 |
+
model,
|
| 13 |
+
steps: int = 250,
|
| 14 |
+
batch_size: int = 1,
|
| 15 |
+
sample_size: int = 2097152,
|
| 16 |
+
seed: int = -1,
|
| 17 |
+
device: str = "cuda",
|
| 18 |
+
init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None,
|
| 19 |
+
init_noise_level: float = 1.0,
|
| 20 |
+
return_latents = False,
|
| 21 |
+
**sampler_kwargs
|
| 22 |
+
) -> torch.Tensor:
|
| 23 |
+
|
| 24 |
+
# The length of the output in audio samples
|
| 25 |
+
audio_sample_size = sample_size
|
| 26 |
+
|
| 27 |
+
# If this is latent diffusion, change sample_size instead to the downsampled latent size
|
| 28 |
+
if model.pretransform is not None:
|
| 29 |
+
sample_size = sample_size // model.pretransform.downsampling_ratio
|
| 30 |
+
|
| 31 |
+
# Seed
|
| 32 |
+
# The user can explicitly set the seed to deterministically generate the same output. Otherwise, use a random seed.
|
| 33 |
+
seed = seed if seed != -1 else np.random.randint(0, 2**32 - 1, dtype=np.uint32)
|
| 34 |
+
print(seed)
|
| 35 |
+
torch.manual_seed(seed)
|
| 36 |
+
# Define the initial noise immediately after setting the seed
|
| 37 |
+
noise = torch.randn([batch_size, model.io_channels, sample_size], device=device)
|
| 38 |
+
|
| 39 |
+
if init_audio is not None:
|
| 40 |
+
# The user supplied some initial audio (for inpainting or variation). Let us prepare the input audio.
|
| 41 |
+
in_sr, init_audio = init_audio
|
| 42 |
+
|
| 43 |
+
io_channels = model.io_channels
|
| 44 |
+
|
| 45 |
+
# For latent models, set the io_channels to the autoencoder's io_channels
|
| 46 |
+
if model.pretransform is not None:
|
| 47 |
+
io_channels = model.pretransform.io_channels
|
| 48 |
+
|
| 49 |
+
# Prepare the initial audio for use by the model
|
| 50 |
+
init_audio = prepare_audio(init_audio, in_sr=in_sr, target_sr=model.sample_rate, target_length=audio_sample_size, target_channels=io_channels, device=device)
|
| 51 |
+
|
| 52 |
+
# For latent models, encode the initial audio into latents
|
| 53 |
+
if model.pretransform is not None:
|
| 54 |
+
init_audio = model.pretransform.encode(init_audio)
|
| 55 |
+
|
| 56 |
+
init_audio = init_audio.repeat(batch_size, 1, 1)
|
| 57 |
+
else:
|
| 58 |
+
# The user did not supply any initial audio for inpainting or variation. Generate new output from scratch.
|
| 59 |
+
init_audio = None
|
| 60 |
+
init_noise_level = None
|
| 61 |
+
|
| 62 |
+
# Inpainting mask
|
| 63 |
+
|
| 64 |
+
if init_audio is not None:
|
| 65 |
+
# variations
|
| 66 |
+
sampler_kwargs["sigma_max"] = init_noise_level
|
| 67 |
+
mask = None
|
| 68 |
+
else:
|
| 69 |
+
mask = None
|
| 70 |
+
|
| 71 |
+
# Now the generative AI part:
|
| 72 |
+
|
| 73 |
+
diff_objective = model.diffusion_objective
|
| 74 |
+
|
| 75 |
+
if diff_objective == "v":
|
| 76 |
+
# k-diffusion denoising process go!
|
| 77 |
+
sampled = sample_k(model.model, noise, init_audio, mask, steps, **sampler_kwargs, device=device)
|
| 78 |
+
elif diff_objective == "rectified_flow":
|
| 79 |
+
sampled = sample_rf(model.model, noise, init_data=init_audio, steps=steps, **sampler_kwargs, device=device)
|
| 80 |
+
|
| 81 |
+
# Denoising process done.
|
| 82 |
+
# If this is latent diffusion, decode latents back into audio
|
| 83 |
+
if model.pretransform is not None and not return_latents:
|
| 84 |
+
sampled = model.pretransform.decode(sampled)
|
| 85 |
+
|
| 86 |
+
# Return audio
|
| 87 |
+
return sampled
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def generate_diffusion_cond(
|
| 91 |
+
model,
|
| 92 |
+
steps: int = 250,
|
| 93 |
+
cfg_scale=6,
|
| 94 |
+
conditioning: dict = None,
|
| 95 |
+
conditioning_tensors: tp.Optional[dict] = None,
|
| 96 |
+
negative_conditioning: dict = None,
|
| 97 |
+
negative_conditioning_tensors: tp.Optional[dict] = None,
|
| 98 |
+
batch_size: int = 1,
|
| 99 |
+
sample_size: int = 2097152,
|
| 100 |
+
sample_rate: int = 48000,
|
| 101 |
+
seed: int = -1,
|
| 102 |
+
device: str = "cuda",
|
| 103 |
+
init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None,
|
| 104 |
+
init_noise_level: float = 1.0,
|
| 105 |
+
mask_args: dict = None,
|
| 106 |
+
return_latents = False,
|
| 107 |
+
**sampler_kwargs
|
| 108 |
+
) -> torch.Tensor:
|
| 109 |
+
"""
|
| 110 |
+
Generate audio from a prompt using a diffusion model.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
model: The diffusion model to use for generation.
|
| 114 |
+
steps: The number of diffusion steps to use.
|
| 115 |
+
cfg_scale: Classifier-free guidance scale
|
| 116 |
+
conditioning: A dictionary of conditioning parameters to use for generation.
|
| 117 |
+
conditioning_tensors: A dictionary of precomputed conditioning tensors to use for generation.
|
| 118 |
+
batch_size: The batch size to use for generation.
|
| 119 |
+
sample_size: The length of the audio to generate, in samples.
|
| 120 |
+
sample_rate: The sample rate of the audio to generate (Deprecated, now pulled from the model directly)
|
| 121 |
+
seed: The random seed to use for generation, or -1 to use a random seed.
|
| 122 |
+
device: The device to use for generation.
|
| 123 |
+
init_audio: A tuple of (sample_rate, audio) to use as the initial audio for generation.
|
| 124 |
+
init_noise_level: The noise level to use when generating from an initial audio sample.
|
| 125 |
+
return_latents: Whether to return the latents used for generation instead of the decoded audio.
|
| 126 |
+
**sampler_kwargs: Additional keyword arguments to pass to the sampler.
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
# The length of the output in audio samples
|
| 130 |
+
audio_sample_size = sample_size
|
| 131 |
+
|
| 132 |
+
# If this is latent diffusion, change sample_size instead to the downsampled latent size
|
| 133 |
+
if model.pretransform is not None:
|
| 134 |
+
sample_size = sample_size // model.pretransform.downsampling_ratio
|
| 135 |
+
|
| 136 |
+
# Seed
|
| 137 |
+
# The user can explicitly set the seed to deterministically generate the same output. Otherwise, use a random seed.
|
| 138 |
+
seed = seed if seed != -1 else np.random.randint(0, 2**32 - 1, dtype=np.uint32)
|
| 139 |
+
print(seed)
|
| 140 |
+
torch.manual_seed(seed)
|
| 141 |
+
# Define the initial noise immediately after setting the seed
|
| 142 |
+
noise = torch.randn([batch_size, model.io_channels, sample_size], device=device)
|
| 143 |
+
|
| 144 |
+
torch.backends.cuda.matmul.allow_tf32 = False
|
| 145 |
+
torch.backends.cudnn.allow_tf32 = False
|
| 146 |
+
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
| 147 |
+
torch.backends.cudnn.benchmark = False
|
| 148 |
+
|
| 149 |
+
# Conditioning
|
| 150 |
+
assert conditioning is not None or conditioning_tensors is not None, "Must provide either conditioning or conditioning_tensors"
|
| 151 |
+
if conditioning_tensors is None:
|
| 152 |
+
conditioning_tensors = model.conditioner(conditioning, device)
|
| 153 |
+
conditioning_inputs = model.get_conditioning_inputs(conditioning_tensors)
|
| 154 |
+
|
| 155 |
+
if negative_conditioning is not None or negative_conditioning_tensors is not None:
|
| 156 |
+
|
| 157 |
+
if negative_conditioning_tensors is None:
|
| 158 |
+
negative_conditioning_tensors = model.conditioner(negative_conditioning, device)
|
| 159 |
+
|
| 160 |
+
negative_conditioning_tensors = model.get_conditioning_inputs(negative_conditioning_tensors, negative=True)
|
| 161 |
+
else:
|
| 162 |
+
negative_conditioning_tensors = {}
|
| 163 |
+
|
| 164 |
+
if init_audio is not None:
|
| 165 |
+
# The user supplied some initial audio (for inpainting or variation). Let us prepare the input audio.
|
| 166 |
+
in_sr, init_audio = init_audio
|
| 167 |
+
|
| 168 |
+
io_channels = model.io_channels
|
| 169 |
+
|
| 170 |
+
# For latent models, set the io_channels to the autoencoder's io_channels
|
| 171 |
+
if model.pretransform is not None:
|
| 172 |
+
io_channels = model.pretransform.io_channels
|
| 173 |
+
|
| 174 |
+
# Prepare the initial audio for use by the model
|
| 175 |
+
init_audio = prepare_audio(init_audio, in_sr=in_sr, target_sr=model.sample_rate, target_length=audio_sample_size, target_channels=io_channels, device=device)
|
| 176 |
+
|
| 177 |
+
# For latent models, encode the initial audio into latents
|
| 178 |
+
if model.pretransform is not None:
|
| 179 |
+
init_audio = model.pretransform.encode(init_audio)
|
| 180 |
+
|
| 181 |
+
init_audio = init_audio.repeat(batch_size, 1, 1)
|
| 182 |
+
else:
|
| 183 |
+
# The user did not supply any initial audio for inpainting or variation. Generate new output from scratch.
|
| 184 |
+
init_audio = None
|
| 185 |
+
init_noise_level = None
|
| 186 |
+
mask_args = None
|
| 187 |
+
|
| 188 |
+
# Inpainting mask
|
| 189 |
+
if init_audio is not None and mask_args is not None:
|
| 190 |
+
# Cut and paste init_audio according to cropfrom, pastefrom, pasteto
|
| 191 |
+
# This is helpful for forward and reverse outpainting
|
| 192 |
+
cropfrom = math.floor(mask_args["cropfrom"]/100.0 * sample_size)
|
| 193 |
+
pastefrom = math.floor(mask_args["pastefrom"]/100.0 * sample_size)
|
| 194 |
+
pasteto = math.ceil(mask_args["pasteto"]/100.0 * sample_size)
|
| 195 |
+
assert pastefrom < pasteto, "Paste From should be less than Paste To"
|
| 196 |
+
croplen = pasteto - pastefrom
|
| 197 |
+
if cropfrom + croplen > sample_size:
|
| 198 |
+
croplen = sample_size - cropfrom
|
| 199 |
+
cropto = cropfrom + croplen
|
| 200 |
+
pasteto = pastefrom + croplen
|
| 201 |
+
cutpaste = init_audio.new_zeros(init_audio.shape)
|
| 202 |
+
cutpaste[:, :, pastefrom:pasteto] = init_audio[:,:,cropfrom:cropto]
|
| 203 |
+
#print(cropfrom, cropto, pastefrom, pasteto)
|
| 204 |
+
init_audio = cutpaste
|
| 205 |
+
# Build a soft mask (list of floats 0 to 1, the size of the latent) from the given args
|
| 206 |
+
mask = build_mask(sample_size, mask_args)
|
| 207 |
+
mask = mask.to(device)
|
| 208 |
+
elif init_audio is not None and mask_args is None:
|
| 209 |
+
# variations
|
| 210 |
+
sampler_kwargs["sigma_max"] = init_noise_level
|
| 211 |
+
mask = None
|
| 212 |
+
else:
|
| 213 |
+
mask = None
|
| 214 |
+
|
| 215 |
+
model_dtype = next(model.model.parameters()).dtype
|
| 216 |
+
noise = noise.type(model_dtype)
|
| 217 |
+
conditioning_inputs = {k: v.type(model_dtype) if v is not None else v for k, v in conditioning_inputs.items()}
|
| 218 |
+
# Now the generative AI part:
|
| 219 |
+
# k-diffusion denoising process go!
|
| 220 |
+
|
| 221 |
+
diff_objective = model.diffusion_objective
|
| 222 |
+
|
| 223 |
+
if diff_objective == "v":
|
| 224 |
+
# k-diffusion denoising process go!
|
| 225 |
+
sampled = sample_k(model.model, noise, init_audio, mask, steps, **sampler_kwargs, **conditioning_inputs, **negative_conditioning_tensors, cfg_scale=cfg_scale, batch_cfg=True, rescale_cfg=True, device=device)
|
| 226 |
+
elif diff_objective == "rectified_flow":
|
| 227 |
+
|
| 228 |
+
if "sigma_min" in sampler_kwargs:
|
| 229 |
+
del sampler_kwargs["sigma_min"]
|
| 230 |
+
|
| 231 |
+
if "sampler_type" in sampler_kwargs:
|
| 232 |
+
del sampler_kwargs["sampler_type"]
|
| 233 |
+
|
| 234 |
+
sampled = sample_rf(model.model, noise, init_data=init_audio, steps=steps, **sampler_kwargs, **conditioning_inputs, **negative_conditioning_tensors, cfg_scale=cfg_scale, batch_cfg=True, rescale_cfg=True, device=device)
|
| 235 |
+
|
| 236 |
+
# v-diffusion:
|
| 237 |
+
#sampled = sample(model.model, noise, steps, 0, **conditioning_tensors, embedding_scale=cfg_scale)
|
| 238 |
+
del noise
|
| 239 |
+
del conditioning_tensors
|
| 240 |
+
del conditioning_inputs
|
| 241 |
+
torch.cuda.empty_cache()
|
| 242 |
+
# Denoising process done.
|
| 243 |
+
# If this is latent diffusion, decode latents back into audio
|
| 244 |
+
if model.pretransform is not None and not return_latents:
|
| 245 |
+
#cast sampled latents to pretransform dtype
|
| 246 |
+
sampled = sampled.to(next(model.pretransform.parameters()).dtype)
|
| 247 |
+
sampled = model.pretransform.decode(sampled)
|
| 248 |
+
|
| 249 |
+
# Return audio
|
| 250 |
+
return sampled
|
| 251 |
+
|
| 252 |
+
# builds a softmask given the parameters
|
| 253 |
+
# returns array of values 0 to 1, size sample_size, where 0 means noise / fresh generation, 1 means keep the input audio,
|
| 254 |
+
# and anything between is a mixture of old/new
|
| 255 |
+
# ideally 0.5 is half/half mixture but i haven't figured this out yet
|
| 256 |
+
def build_mask(sample_size, mask_args):
|
| 257 |
+
maskstart = math.floor(mask_args["maskstart"]/100.0 * sample_size)
|
| 258 |
+
maskend = math.ceil(mask_args["maskend"]/100.0 * sample_size)
|
| 259 |
+
softnessL = round(mask_args["softnessL"]/100.0 * sample_size)
|
| 260 |
+
softnessR = round(mask_args["softnessR"]/100.0 * sample_size)
|
| 261 |
+
marination = mask_args["marination"]
|
| 262 |
+
# use hann windows for softening the transition (i don't know if this is correct)
|
| 263 |
+
hannL = torch.hann_window(softnessL*2, periodic=False)[:softnessL]
|
| 264 |
+
hannR = torch.hann_window(softnessR*2, periodic=False)[softnessR:]
|
| 265 |
+
# build the mask.
|
| 266 |
+
mask = torch.zeros((sample_size))
|
| 267 |
+
mask[maskstart:maskend] = 1
|
| 268 |
+
mask[maskstart:maskstart+softnessL] = hannL
|
| 269 |
+
mask[maskend-softnessR:maskend] = hannR
|
| 270 |
+
# marination finishes the inpainting early in the denoising schedule, and lets audio get changed in the final rounds
|
| 271 |
+
if marination > 0:
|
| 272 |
+
mask = mask * (1-marination)
|
| 273 |
+
#print(mask)
|
| 274 |
+
return mask
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/stable_audio_tools/inference/sampling.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import math
|
| 3 |
+
from tqdm import trange, tqdm
|
| 4 |
+
|
| 5 |
+
import k_diffusion as K
|
| 6 |
+
|
| 7 |
+
# Define the noise schedule and sampling loop
|
| 8 |
+
def get_alphas_sigmas(t):
|
| 9 |
+
"""Returns the scaling factors for the clean image (alpha) and for the
|
| 10 |
+
noise (sigma), given a timestep."""
|
| 11 |
+
return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
|
| 12 |
+
|
| 13 |
+
def alpha_sigma_to_t(alpha, sigma):
|
| 14 |
+
"""Returns a timestep, given the scaling factors for the clean image and for
|
| 15 |
+
the noise."""
|
| 16 |
+
return torch.atan2(sigma, alpha) / math.pi * 2
|
| 17 |
+
|
| 18 |
+
def t_to_alpha_sigma(t):
|
| 19 |
+
"""Returns the scaling factors for the clean image and for the noise, given
|
| 20 |
+
a timestep."""
|
| 21 |
+
return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@torch.no_grad()
|
| 25 |
+
def sample_discrete_euler(model, x, steps, sigma_max=1, **extra_args):
|
| 26 |
+
"""Draws samples from a model given starting noise. Euler method"""
|
| 27 |
+
|
| 28 |
+
# Make tensor of ones to broadcast the single t values
|
| 29 |
+
ts = x.new_ones([x.shape[0]])
|
| 30 |
+
|
| 31 |
+
# Create the noise schedule
|
| 32 |
+
t = torch.linspace(sigma_max, 0, steps + 1)
|
| 33 |
+
|
| 34 |
+
#alphas, sigmas = 1-t, t
|
| 35 |
+
|
| 36 |
+
for t_curr, t_prev in tqdm(zip(t[:-1], t[1:])):
|
| 37 |
+
# Broadcast the current timestep to the correct shape
|
| 38 |
+
t_curr_tensor = t_curr * torch.ones(
|
| 39 |
+
(x.shape[0],), dtype=x.dtype, device=x.device
|
| 40 |
+
)
|
| 41 |
+
dt = t_prev - t_curr # we solve backwards in our formulation
|
| 42 |
+
x = x + dt * model(x, t_curr_tensor, **extra_args) #.denoise(x, denoiser, t_curr_tensor, cond, uc)
|
| 43 |
+
|
| 44 |
+
# If we are on the last timestep, output the denoised image
|
| 45 |
+
return x
|
| 46 |
+
|
| 47 |
+
@torch.no_grad()
|
| 48 |
+
def sample(model, x, steps, eta, **extra_args):
|
| 49 |
+
"""Draws samples from a model given starting noise. v-diffusion"""
|
| 50 |
+
ts = x.new_ones([x.shape[0]])
|
| 51 |
+
|
| 52 |
+
# Create the noise schedule
|
| 53 |
+
t = torch.linspace(1, 0, steps + 1)[:-1]
|
| 54 |
+
|
| 55 |
+
alphas, sigmas = get_alphas_sigmas(t)
|
| 56 |
+
|
| 57 |
+
# The sampling loop
|
| 58 |
+
for i in trange(steps):
|
| 59 |
+
|
| 60 |
+
# Get the model output (v, the predicted velocity)
|
| 61 |
+
with torch.cuda.amp.autocast():
|
| 62 |
+
v = model(x, ts * t[i], **extra_args).float()
|
| 63 |
+
|
| 64 |
+
# Predict the noise and the denoised image
|
| 65 |
+
pred = x * alphas[i] - v * sigmas[i]
|
| 66 |
+
eps = x * sigmas[i] + v * alphas[i]
|
| 67 |
+
|
| 68 |
+
# If we are not on the last timestep, compute the noisy image for the
|
| 69 |
+
# next timestep.
|
| 70 |
+
if i < steps - 1:
|
| 71 |
+
# If eta > 0, adjust the scaling factor for the predicted noise
|
| 72 |
+
# downward according to the amount of additional noise to add
|
| 73 |
+
ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \
|
| 74 |
+
(1 - alphas[i]**2 / alphas[i + 1]**2).sqrt()
|
| 75 |
+
adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt()
|
| 76 |
+
|
| 77 |
+
# Recombine the predicted noise and predicted denoised image in the
|
| 78 |
+
# correct proportions for the next step
|
| 79 |
+
x = pred * alphas[i + 1] + eps * adjusted_sigma
|
| 80 |
+
|
| 81 |
+
# Add the correct amount of fresh noise
|
| 82 |
+
if eta:
|
| 83 |
+
x += torch.randn_like(x) * ddim_sigma
|
| 84 |
+
|
| 85 |
+
# If we are on the last timestep, output the denoised image
|
| 86 |
+
return pred
|
| 87 |
+
|
| 88 |
+
# Soft mask inpainting is just shrinking hard (binary) mask inpainting
|
| 89 |
+
# Given a float-valued soft mask (values between 0 and 1), get the binary mask for this particular step
|
| 90 |
+
def get_bmask(i, steps, mask):
|
| 91 |
+
strength = (i+1)/(steps)
|
| 92 |
+
# convert to binary mask
|
| 93 |
+
bmask = torch.where(mask<=strength,1,0)
|
| 94 |
+
return bmask
|
| 95 |
+
|
| 96 |
+
def make_cond_model_fn(model, cond_fn):
|
| 97 |
+
def cond_model_fn(x, sigma, **kwargs):
|
| 98 |
+
with torch.enable_grad():
|
| 99 |
+
x = x.detach().requires_grad_()
|
| 100 |
+
denoised = model(x, sigma, **kwargs)
|
| 101 |
+
cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach()
|
| 102 |
+
cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim)
|
| 103 |
+
return cond_denoised
|
| 104 |
+
return cond_model_fn
|
| 105 |
+
|
| 106 |
+
# Uses k-diffusion from https://github.com/crowsonkb/k-diffusion
|
| 107 |
+
# init_data is init_audio as latents (if this is latent diffusion)
|
| 108 |
+
# For sampling, set both init_data and mask to None
|
| 109 |
+
# For variations, set init_data
|
| 110 |
+
# For inpainting, set both init_data & mask
|
| 111 |
+
def sample_k(
|
| 112 |
+
model_fn,
|
| 113 |
+
noise,
|
| 114 |
+
init_data=None,
|
| 115 |
+
mask=None,
|
| 116 |
+
steps=100,
|
| 117 |
+
sampler_type="dpmpp-2m-sde",
|
| 118 |
+
sigma_min=0.5,
|
| 119 |
+
sigma_max=50,
|
| 120 |
+
rho=1.0, device="cuda",
|
| 121 |
+
callback=None,
|
| 122 |
+
cond_fn=None,
|
| 123 |
+
**extra_args
|
| 124 |
+
):
|
| 125 |
+
|
| 126 |
+
denoiser = K.external.VDenoiser(model_fn)
|
| 127 |
+
|
| 128 |
+
if cond_fn is not None:
|
| 129 |
+
denoiser = make_cond_model_fn(denoiser, cond_fn)
|
| 130 |
+
|
| 131 |
+
# Make the list of sigmas. Sigma values are scalars related to the amount of noise each denoising step has
|
| 132 |
+
sigmas = K.sampling.get_sigmas_polyexponential(steps, sigma_min, sigma_max, rho, device=device)
|
| 133 |
+
# Scale the initial noise by sigma
|
| 134 |
+
noise = noise * sigmas[0]
|
| 135 |
+
|
| 136 |
+
wrapped_callback = callback
|
| 137 |
+
|
| 138 |
+
if mask is None and init_data is not None:
|
| 139 |
+
# VARIATION (no inpainting)
|
| 140 |
+
# set the initial latent to the init_data, and noise it with initial sigma
|
| 141 |
+
x = init_data + noise
|
| 142 |
+
elif mask is not None and init_data is not None:
|
| 143 |
+
# INPAINTING
|
| 144 |
+
bmask = get_bmask(0, steps, mask)
|
| 145 |
+
# initial noising
|
| 146 |
+
input_noised = init_data + noise
|
| 147 |
+
# set the initial latent to a mix of init_data and noise, based on step 0's binary mask
|
| 148 |
+
x = input_noised * bmask + noise * (1-bmask)
|
| 149 |
+
# define the inpainting callback function (Note: side effects, it mutates x)
|
| 150 |
+
# See https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py#L596C13-L596C105
|
| 151 |
+
# callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
| 152 |
+
# This is called immediately after `denoised = model(x, sigmas[i] * s_in, **extra_args)`
|
| 153 |
+
def inpainting_callback(args):
|
| 154 |
+
i = args["i"]
|
| 155 |
+
x = args["x"]
|
| 156 |
+
sigma = args["sigma"]
|
| 157 |
+
#denoised = args["denoised"]
|
| 158 |
+
# noise the init_data input with this step's appropriate amount of noise
|
| 159 |
+
input_noised = init_data + torch.randn_like(init_data) * sigma
|
| 160 |
+
# shrinking hard mask
|
| 161 |
+
bmask = get_bmask(i, steps, mask)
|
| 162 |
+
# mix input_noise with x, using binary mask
|
| 163 |
+
new_x = input_noised * bmask + x * (1-bmask)
|
| 164 |
+
# mutate x
|
| 165 |
+
x[:,:,:] = new_x[:,:,:]
|
| 166 |
+
# wrap together the inpainting callback and the user-submitted callback.
|
| 167 |
+
if callback is None:
|
| 168 |
+
wrapped_callback = inpainting_callback
|
| 169 |
+
else:
|
| 170 |
+
wrapped_callback = lambda args: (inpainting_callback(args), callback(args))
|
| 171 |
+
else:
|
| 172 |
+
# SAMPLING
|
| 173 |
+
# set the initial latent to noise
|
| 174 |
+
x = noise
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
with torch.cuda.amp.autocast():
|
| 178 |
+
if sampler_type == "k-heun":
|
| 179 |
+
return K.sampling.sample_heun(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
| 180 |
+
elif sampler_type == "k-lms":
|
| 181 |
+
return K.sampling.sample_lms(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
| 182 |
+
elif sampler_type == "k-dpmpp-2s-ancestral":
|
| 183 |
+
return K.sampling.sample_dpmpp_2s_ancestral(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
| 184 |
+
elif sampler_type == "k-dpm-2":
|
| 185 |
+
return K.sampling.sample_dpm_2(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
| 186 |
+
elif sampler_type == "k-dpm-fast":
|
| 187 |
+
return K.sampling.sample_dpm_fast(denoiser, x, sigma_min, sigma_max, steps, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
| 188 |
+
elif sampler_type == "k-dpm-adaptive":
|
| 189 |
+
return K.sampling.sample_dpm_adaptive(denoiser, x, sigma_min, sigma_max, rtol=0.01, atol=0.01, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
| 190 |
+
elif sampler_type == "dpmpp-2m-sde":
|
| 191 |
+
return K.sampling.sample_dpmpp_2m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
| 192 |
+
elif sampler_type == "dpmpp-3m-sde":
|
| 193 |
+
return K.sampling.sample_dpmpp_3m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
| 194 |
+
|
| 195 |
+
# Uses discrete Euler sampling for rectified flow models
|
| 196 |
+
# init_data is init_audio as latents (if this is latent diffusion)
|
| 197 |
+
# For sampling, set both init_data and mask to None
|
| 198 |
+
# For variations, set init_data
|
| 199 |
+
# For inpainting, set both init_data & mask
|
| 200 |
+
def sample_rf(
|
| 201 |
+
model_fn,
|
| 202 |
+
noise,
|
| 203 |
+
init_data=None,
|
| 204 |
+
steps=100,
|
| 205 |
+
sigma_max=1,
|
| 206 |
+
device="cuda",
|
| 207 |
+
callback=None,
|
| 208 |
+
cond_fn=None,
|
| 209 |
+
**extra_args
|
| 210 |
+
):
|
| 211 |
+
|
| 212 |
+
if sigma_max > 1:
|
| 213 |
+
sigma_max = 1
|
| 214 |
+
|
| 215 |
+
if cond_fn is not None:
|
| 216 |
+
denoiser = make_cond_model_fn(denoiser, cond_fn)
|
| 217 |
+
|
| 218 |
+
wrapped_callback = callback
|
| 219 |
+
|
| 220 |
+
if init_data is not None:
|
| 221 |
+
# VARIATION (no inpainting)
|
| 222 |
+
# Interpolate the init data and the noise for init audio
|
| 223 |
+
x = init_data * (1 - sigma_max) + noise * sigma_max
|
| 224 |
+
else:
|
| 225 |
+
# SAMPLING
|
| 226 |
+
# set the initial latent to noise
|
| 227 |
+
x = noise
|
| 228 |
+
|
| 229 |
+
with torch.cuda.amp.autocast():
|
| 230 |
+
# TODO: Add callback support
|
| 231 |
+
#return sample_discrete_euler(model_fn, x, steps, sigma_max, callback=wrapped_callback, **extra_args)
|
| 232 |
+
return sample_discrete_euler(model_fn, x, steps, sigma_max, **extra_args)
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/stable_audio_tools/inference/utils.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ..data.utils import PadCrop
|
| 2 |
+
|
| 3 |
+
from torchaudio import transforms as T
|
| 4 |
+
|
| 5 |
+
def set_audio_channels(audio, target_channels):
|
| 6 |
+
if target_channels == 1:
|
| 7 |
+
# Convert to mono
|
| 8 |
+
audio = audio.mean(1, keepdim=True)
|
| 9 |
+
elif target_channels == 2:
|
| 10 |
+
# Convert to stereo
|
| 11 |
+
if audio.shape[1] == 1:
|
| 12 |
+
audio = audio.repeat(1, 2, 1)
|
| 13 |
+
elif audio.shape[1] > 2:
|
| 14 |
+
audio = audio[:, :2, :]
|
| 15 |
+
return audio
|
| 16 |
+
|
| 17 |
+
def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
|
| 18 |
+
|
| 19 |
+
audio = audio.to(device)
|
| 20 |
+
|
| 21 |
+
if in_sr != target_sr:
|
| 22 |
+
resample_tf = T.Resample(in_sr, target_sr).to(device)
|
| 23 |
+
audio = resample_tf(audio)
|
| 24 |
+
|
| 25 |
+
audio = PadCrop(target_length, randomize=False)(audio)
|
| 26 |
+
|
| 27 |
+
# Add batch dimension
|
| 28 |
+
if audio.dim() == 1:
|
| 29 |
+
audio = audio.unsqueeze(0).unsqueeze(0)
|
| 30 |
+
elif audio.dim() == 2:
|
| 31 |
+
audio = audio.unsqueeze(0)
|
| 32 |
+
|
| 33 |
+
audio = set_audio_channels(audio, target_channels)
|
| 34 |
+
|
| 35 |
+
return audio
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/stable_audio_tools/interface/__init__.py
ADDED
|
File without changes
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/stable_audio_tools/interface/gradio.py
ADDED
|
@@ -0,0 +1,700 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
import platform
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import gradio as gr
|
| 6 |
+
import json
|
| 7 |
+
import torch
|
| 8 |
+
import torchaudio
|
| 9 |
+
|
| 10 |
+
from aeiou.viz import audio_spectrogram_image
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
from safetensors.torch import load_file
|
| 13 |
+
from torch.nn import functional as F
|
| 14 |
+
from torchaudio import transforms as T
|
| 15 |
+
|
| 16 |
+
from ..inference.generation import generate_diffusion_cond, generate_diffusion_uncond
|
| 17 |
+
from ..models.factory import create_model_from_config
|
| 18 |
+
from ..models.pretrained import get_pretrained_model
|
| 19 |
+
from ..models.utils import load_ckpt_state_dict
|
| 20 |
+
from ..inference.utils import prepare_audio
|
| 21 |
+
from ..training.utils import copy_state_dict
|
| 22 |
+
|
| 23 |
+
model = None
|
| 24 |
+
sample_rate = 32000
|
| 25 |
+
sample_size = 1920000
|
| 26 |
+
|
| 27 |
+
def load_model(model_config=None, model_ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, device="cuda", model_half=False):
|
| 28 |
+
global model, sample_rate, sample_size
|
| 29 |
+
|
| 30 |
+
if pretrained_name is not None:
|
| 31 |
+
print(f"Loading pretrained model {pretrained_name}")
|
| 32 |
+
model, model_config = get_pretrained_model(pretrained_name)
|
| 33 |
+
|
| 34 |
+
elif model_config is not None and model_ckpt_path is not None:
|
| 35 |
+
print(f"Creating model from config")
|
| 36 |
+
model = create_model_from_config(model_config)
|
| 37 |
+
|
| 38 |
+
print(f"Loading model checkpoint from {model_ckpt_path}")
|
| 39 |
+
# Load checkpoint
|
| 40 |
+
copy_state_dict(model, load_ckpt_state_dict(model_ckpt_path))
|
| 41 |
+
#model.load_state_dict(load_ckpt_state_dict(model_ckpt_path))
|
| 42 |
+
|
| 43 |
+
sample_rate = model_config["sample_rate"]
|
| 44 |
+
sample_size = model_config["sample_size"]
|
| 45 |
+
|
| 46 |
+
if pretransform_ckpt_path is not None:
|
| 47 |
+
print(f"Loading pretransform checkpoint from {pretransform_ckpt_path}")
|
| 48 |
+
model.pretransform.load_state_dict(load_ckpt_state_dict(pretransform_ckpt_path), strict=False)
|
| 49 |
+
print(f"Done loading pretransform")
|
| 50 |
+
|
| 51 |
+
model.to(device).eval().requires_grad_(False)
|
| 52 |
+
|
| 53 |
+
if model_half:
|
| 54 |
+
model.to(torch.float16)
|
| 55 |
+
|
| 56 |
+
print(f"Done loading model")
|
| 57 |
+
|
| 58 |
+
return model, model_config
|
| 59 |
+
|
| 60 |
+
def generate_cond(
|
| 61 |
+
prompt,
|
| 62 |
+
negative_prompt=None,
|
| 63 |
+
seconds_start=0,
|
| 64 |
+
seconds_total=30,
|
| 65 |
+
cfg_scale=6.0,
|
| 66 |
+
steps=250,
|
| 67 |
+
preview_every=None,
|
| 68 |
+
seed=-1,
|
| 69 |
+
sampler_type="dpmpp-3m-sde",
|
| 70 |
+
sigma_min=0.03,
|
| 71 |
+
sigma_max=1000,
|
| 72 |
+
cfg_rescale=0.0,
|
| 73 |
+
use_init=False,
|
| 74 |
+
init_audio=None,
|
| 75 |
+
init_noise_level=1.0,
|
| 76 |
+
mask_cropfrom=None,
|
| 77 |
+
mask_pastefrom=None,
|
| 78 |
+
mask_pasteto=None,
|
| 79 |
+
mask_maskstart=None,
|
| 80 |
+
mask_maskend=None,
|
| 81 |
+
mask_softnessL=None,
|
| 82 |
+
mask_softnessR=None,
|
| 83 |
+
mask_marination=None,
|
| 84 |
+
batch_size=1
|
| 85 |
+
):
|
| 86 |
+
|
| 87 |
+
if torch.cuda.is_available():
|
| 88 |
+
torch.cuda.empty_cache()
|
| 89 |
+
gc.collect()
|
| 90 |
+
|
| 91 |
+
print(f"Prompt: {prompt}")
|
| 92 |
+
|
| 93 |
+
global preview_images
|
| 94 |
+
preview_images = []
|
| 95 |
+
if preview_every == 0:
|
| 96 |
+
preview_every = None
|
| 97 |
+
|
| 98 |
+
# Return fake stereo audio
|
| 99 |
+
conditioning = [{"prompt": prompt, "seconds_start": seconds_start, "seconds_total": seconds_total}] * batch_size
|
| 100 |
+
|
| 101 |
+
if negative_prompt:
|
| 102 |
+
negative_conditioning = [{"prompt": negative_prompt, "seconds_start": seconds_start, "seconds_total": seconds_total}] * batch_size
|
| 103 |
+
else:
|
| 104 |
+
negative_conditioning = None
|
| 105 |
+
|
| 106 |
+
#Get the device from the model
|
| 107 |
+
device = next(model.parameters()).device
|
| 108 |
+
|
| 109 |
+
seed = int(seed)
|
| 110 |
+
|
| 111 |
+
if not use_init:
|
| 112 |
+
init_audio = None
|
| 113 |
+
|
| 114 |
+
input_sample_size = sample_size
|
| 115 |
+
|
| 116 |
+
if init_audio is not None:
|
| 117 |
+
in_sr, init_audio = init_audio
|
| 118 |
+
# Turn into torch tensor, converting from int16 to float32
|
| 119 |
+
init_audio = torch.from_numpy(init_audio).float().div(32767)
|
| 120 |
+
|
| 121 |
+
if init_audio.dim() == 1:
|
| 122 |
+
init_audio = init_audio.unsqueeze(0) # [1, n]
|
| 123 |
+
elif init_audio.dim() == 2:
|
| 124 |
+
init_audio = init_audio.transpose(0, 1) # [n, 2] -> [2, n]
|
| 125 |
+
|
| 126 |
+
if in_sr != sample_rate:
|
| 127 |
+
resample_tf = T.Resample(in_sr, sample_rate).to(init_audio.device)
|
| 128 |
+
init_audio = resample_tf(init_audio)
|
| 129 |
+
|
| 130 |
+
audio_length = init_audio.shape[-1]
|
| 131 |
+
|
| 132 |
+
if audio_length > sample_size:
|
| 133 |
+
|
| 134 |
+
input_sample_size = audio_length + (model.min_input_length - (audio_length % model.min_input_length)) % model.min_input_length
|
| 135 |
+
|
| 136 |
+
init_audio = (sample_rate, init_audio)
|
| 137 |
+
|
| 138 |
+
def progress_callback(callback_info):
|
| 139 |
+
global preview_images
|
| 140 |
+
denoised = callback_info["denoised"]
|
| 141 |
+
current_step = callback_info["i"]
|
| 142 |
+
sigma = callback_info["sigma"]
|
| 143 |
+
|
| 144 |
+
if (current_step - 1) % preview_every == 0:
|
| 145 |
+
if model.pretransform is not None:
|
| 146 |
+
denoised = model.pretransform.decode(denoised)
|
| 147 |
+
denoised = rearrange(denoised, "b d n -> d (b n)")
|
| 148 |
+
denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu()
|
| 149 |
+
audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate)
|
| 150 |
+
preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})"))
|
| 151 |
+
|
| 152 |
+
# If inpainting, send mask args
|
| 153 |
+
# This will definitely change in the future
|
| 154 |
+
if mask_cropfrom is not None:
|
| 155 |
+
mask_args = {
|
| 156 |
+
"cropfrom": mask_cropfrom,
|
| 157 |
+
"pastefrom": mask_pastefrom,
|
| 158 |
+
"pasteto": mask_pasteto,
|
| 159 |
+
"maskstart": mask_maskstart,
|
| 160 |
+
"maskend": mask_maskend,
|
| 161 |
+
"softnessL": mask_softnessL,
|
| 162 |
+
"softnessR": mask_softnessR,
|
| 163 |
+
"marination": mask_marination,
|
| 164 |
+
}
|
| 165 |
+
else:
|
| 166 |
+
mask_args = None
|
| 167 |
+
|
| 168 |
+
# Do the audio generation
|
| 169 |
+
audio = generate_diffusion_cond(
|
| 170 |
+
model,
|
| 171 |
+
conditioning=conditioning,
|
| 172 |
+
negative_conditioning=negative_conditioning,
|
| 173 |
+
steps=steps,
|
| 174 |
+
cfg_scale=cfg_scale,
|
| 175 |
+
batch_size=batch_size,
|
| 176 |
+
sample_size=input_sample_size,
|
| 177 |
+
sample_rate=sample_rate,
|
| 178 |
+
seed=seed,
|
| 179 |
+
device=device,
|
| 180 |
+
sampler_type=sampler_type,
|
| 181 |
+
sigma_min=sigma_min,
|
| 182 |
+
sigma_max=sigma_max,
|
| 183 |
+
init_audio=init_audio,
|
| 184 |
+
init_noise_level=init_noise_level,
|
| 185 |
+
mask_args = mask_args,
|
| 186 |
+
callback = progress_callback if preview_every is not None else None,
|
| 187 |
+
scale_phi = cfg_rescale
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
# Convert to WAV file
|
| 191 |
+
audio = rearrange(audio, "b d n -> d (b n)")
|
| 192 |
+
audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
|
| 193 |
+
torchaudio.save("output.wav", audio, sample_rate)
|
| 194 |
+
|
| 195 |
+
# Let's look at a nice spectrogram too
|
| 196 |
+
audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate)
|
| 197 |
+
|
| 198 |
+
return ("output.wav", [audio_spectrogram, *preview_images])
|
| 199 |
+
|
| 200 |
+
def generate_uncond(
|
| 201 |
+
steps=250,
|
| 202 |
+
seed=-1,
|
| 203 |
+
sampler_type="dpmpp-3m-sde",
|
| 204 |
+
sigma_min=0.03,
|
| 205 |
+
sigma_max=1000,
|
| 206 |
+
use_init=False,
|
| 207 |
+
init_audio=None,
|
| 208 |
+
init_noise_level=1.0,
|
| 209 |
+
batch_size=1,
|
| 210 |
+
preview_every=None
|
| 211 |
+
):
|
| 212 |
+
|
| 213 |
+
global preview_images
|
| 214 |
+
|
| 215 |
+
preview_images = []
|
| 216 |
+
|
| 217 |
+
if torch.cuda.is_available():
|
| 218 |
+
torch.cuda.empty_cache()
|
| 219 |
+
gc.collect()
|
| 220 |
+
|
| 221 |
+
#Get the device from the model
|
| 222 |
+
device = next(model.parameters()).device
|
| 223 |
+
|
| 224 |
+
seed = int(seed)
|
| 225 |
+
|
| 226 |
+
if not use_init:
|
| 227 |
+
init_audio = None
|
| 228 |
+
|
| 229 |
+
input_sample_size = sample_size
|
| 230 |
+
|
| 231 |
+
if init_audio is not None:
|
| 232 |
+
in_sr, init_audio = init_audio
|
| 233 |
+
# Turn into torch tensor, converting from int16 to float32
|
| 234 |
+
init_audio = torch.from_numpy(init_audio).float().div(32767)
|
| 235 |
+
|
| 236 |
+
if init_audio.dim() == 1:
|
| 237 |
+
init_audio = init_audio.unsqueeze(0) # [1, n]
|
| 238 |
+
elif init_audio.dim() == 2:
|
| 239 |
+
init_audio = init_audio.transpose(0, 1) # [n, 2] -> [2, n]
|
| 240 |
+
|
| 241 |
+
if in_sr != sample_rate:
|
| 242 |
+
resample_tf = T.Resample(in_sr, sample_rate).to(init_audio.device)
|
| 243 |
+
init_audio = resample_tf(init_audio)
|
| 244 |
+
|
| 245 |
+
audio_length = init_audio.shape[-1]
|
| 246 |
+
|
| 247 |
+
if audio_length > sample_size:
|
| 248 |
+
|
| 249 |
+
input_sample_size = audio_length + (model.min_input_length - (audio_length % model.min_input_length)) % model.min_input_length
|
| 250 |
+
|
| 251 |
+
init_audio = (sample_rate, init_audio)
|
| 252 |
+
|
| 253 |
+
def progress_callback(callback_info):
|
| 254 |
+
global preview_images
|
| 255 |
+
denoised = callback_info["denoised"]
|
| 256 |
+
current_step = callback_info["i"]
|
| 257 |
+
sigma = callback_info["sigma"]
|
| 258 |
+
|
| 259 |
+
if (current_step - 1) % preview_every == 0:
|
| 260 |
+
|
| 261 |
+
if model.pretransform is not None:
|
| 262 |
+
denoised = model.pretransform.decode(denoised)
|
| 263 |
+
|
| 264 |
+
denoised = rearrange(denoised, "b d n -> d (b n)")
|
| 265 |
+
|
| 266 |
+
denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu()
|
| 267 |
+
|
| 268 |
+
audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate)
|
| 269 |
+
|
| 270 |
+
preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})"))
|
| 271 |
+
|
| 272 |
+
audio = generate_diffusion_uncond(
|
| 273 |
+
model,
|
| 274 |
+
steps=steps,
|
| 275 |
+
batch_size=batch_size,
|
| 276 |
+
sample_size=input_sample_size,
|
| 277 |
+
seed=seed,
|
| 278 |
+
device=device,
|
| 279 |
+
sampler_type=sampler_type,
|
| 280 |
+
sigma_min=sigma_min,
|
| 281 |
+
sigma_max=sigma_max,
|
| 282 |
+
init_audio=init_audio,
|
| 283 |
+
init_noise_level=init_noise_level,
|
| 284 |
+
callback = progress_callback if preview_every is not None else None
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
audio = rearrange(audio, "b d n -> d (b n)")
|
| 288 |
+
|
| 289 |
+
audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
|
| 290 |
+
|
| 291 |
+
torchaudio.save("output.wav", audio, sample_rate)
|
| 292 |
+
|
| 293 |
+
audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate)
|
| 294 |
+
|
| 295 |
+
return ("output.wav", [audio_spectrogram, *preview_images])
|
| 296 |
+
|
| 297 |
+
def generate_lm(
|
| 298 |
+
temperature=1.0,
|
| 299 |
+
top_p=0.95,
|
| 300 |
+
top_k=0,
|
| 301 |
+
batch_size=1,
|
| 302 |
+
):
|
| 303 |
+
|
| 304 |
+
if torch.cuda.is_available():
|
| 305 |
+
torch.cuda.empty_cache()
|
| 306 |
+
gc.collect()
|
| 307 |
+
|
| 308 |
+
#Get the device from the model
|
| 309 |
+
device = next(model.parameters()).device
|
| 310 |
+
|
| 311 |
+
audio = model.generate_audio(
|
| 312 |
+
batch_size=batch_size,
|
| 313 |
+
max_gen_len = sample_size//model.pretransform.downsampling_ratio,
|
| 314 |
+
conditioning=None,
|
| 315 |
+
temp=temperature,
|
| 316 |
+
top_p=top_p,
|
| 317 |
+
top_k=top_k,
|
| 318 |
+
use_cache=True
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
audio = rearrange(audio, "b d n -> d (b n)")
|
| 322 |
+
|
| 323 |
+
audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
|
| 324 |
+
|
| 325 |
+
torchaudio.save("output.wav", audio, sample_rate)
|
| 326 |
+
|
| 327 |
+
audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate)
|
| 328 |
+
|
| 329 |
+
return ("output.wav", [audio_spectrogram])
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def create_uncond_sampling_ui(model_config):
|
| 333 |
+
generate_button = gr.Button("Generate", variant='primary', scale=1)
|
| 334 |
+
|
| 335 |
+
with gr.Row(equal_height=False):
|
| 336 |
+
with gr.Column():
|
| 337 |
+
with gr.Row():
|
| 338 |
+
# Steps slider
|
| 339 |
+
steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Steps")
|
| 340 |
+
|
| 341 |
+
with gr.Accordion("Sampler params", open=False):
|
| 342 |
+
|
| 343 |
+
# Seed
|
| 344 |
+
seed_textbox = gr.Textbox(label="Seed (set to -1 for random seed)", value="-1")
|
| 345 |
+
|
| 346 |
+
# Sampler params
|
| 347 |
+
with gr.Row():
|
| 348 |
+
sampler_type_dropdown = gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], label="Sampler type", value="dpmpp-3m-sde")
|
| 349 |
+
sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma min")
|
| 350 |
+
sigma_max_slider = gr.Slider(minimum=0.0, maximum=1000.0, step=0.1, value=500, label="Sigma max")
|
| 351 |
+
|
| 352 |
+
with gr.Accordion("Init audio", open=False):
|
| 353 |
+
init_audio_checkbox = gr.Checkbox(label="Use init audio")
|
| 354 |
+
init_audio_input = gr.Audio(label="Init audio")
|
| 355 |
+
init_noise_level_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.01, value=0.1, label="Init noise level")
|
| 356 |
+
|
| 357 |
+
with gr.Column():
|
| 358 |
+
audio_output = gr.Audio(label="Output audio", interactive=False)
|
| 359 |
+
audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False)
|
| 360 |
+
send_to_init_button = gr.Button("Send to init audio", scale=1)
|
| 361 |
+
send_to_init_button.click(fn=lambda audio: audio, inputs=[audio_output], outputs=[init_audio_input])
|
| 362 |
+
|
| 363 |
+
generate_button.click(fn=generate_uncond,
|
| 364 |
+
inputs=[
|
| 365 |
+
steps_slider,
|
| 366 |
+
seed_textbox,
|
| 367 |
+
sampler_type_dropdown,
|
| 368 |
+
sigma_min_slider,
|
| 369 |
+
sigma_max_slider,
|
| 370 |
+
init_audio_checkbox,
|
| 371 |
+
init_audio_input,
|
| 372 |
+
init_noise_level_slider,
|
| 373 |
+
],
|
| 374 |
+
outputs=[
|
| 375 |
+
audio_output,
|
| 376 |
+
audio_spectrogram_output
|
| 377 |
+
],
|
| 378 |
+
api_name="generate")
|
| 379 |
+
|
| 380 |
+
def create_sampling_ui(model_config, inpainting=False):
|
| 381 |
+
with gr.Row():
|
| 382 |
+
with gr.Column(scale=6):
|
| 383 |
+
prompt = gr.Textbox(show_label=False, placeholder="Prompt")
|
| 384 |
+
negative_prompt = gr.Textbox(show_label=False, placeholder="Negative prompt")
|
| 385 |
+
generate_button = gr.Button("Generate", variant='primary', scale=1)
|
| 386 |
+
|
| 387 |
+
model_conditioning_config = model_config["model"].get("conditioning", None)
|
| 388 |
+
|
| 389 |
+
has_seconds_start = False
|
| 390 |
+
has_seconds_total = False
|
| 391 |
+
|
| 392 |
+
if model_conditioning_config is not None:
|
| 393 |
+
for conditioning_config in model_conditioning_config["configs"]:
|
| 394 |
+
if conditioning_config["id"] == "seconds_start":
|
| 395 |
+
has_seconds_start = True
|
| 396 |
+
if conditioning_config["id"] == "seconds_total":
|
| 397 |
+
has_seconds_total = True
|
| 398 |
+
|
| 399 |
+
with gr.Row(equal_height=False):
|
| 400 |
+
with gr.Column():
|
| 401 |
+
with gr.Row(visible = has_seconds_start or has_seconds_total):
|
| 402 |
+
# Timing controls
|
| 403 |
+
seconds_start_slider = gr.Slider(minimum=0, maximum=512, step=1, value=0, label="Seconds start", visible=has_seconds_start)
|
| 404 |
+
seconds_total_slider = gr.Slider(minimum=0, maximum=512, step=1, value=sample_size//sample_rate, label="Seconds total", visible=has_seconds_total)
|
| 405 |
+
|
| 406 |
+
with gr.Row():
|
| 407 |
+
# Steps slider
|
| 408 |
+
steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Steps")
|
| 409 |
+
|
| 410 |
+
# Preview Every slider
|
| 411 |
+
preview_every_slider = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Preview Every")
|
| 412 |
+
|
| 413 |
+
# CFG scale
|
| 414 |
+
cfg_scale_slider = gr.Slider(minimum=0.0, maximum=25.0, step=0.1, value=7.0, label="CFG scale")
|
| 415 |
+
|
| 416 |
+
with gr.Accordion("Sampler params", open=False):
|
| 417 |
+
|
| 418 |
+
# Seed
|
| 419 |
+
seed_textbox = gr.Textbox(label="Seed (set to -1 for random seed)", value="-1")
|
| 420 |
+
|
| 421 |
+
# Sampler params
|
| 422 |
+
with gr.Row():
|
| 423 |
+
sampler_type_dropdown = gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], label="Sampler type", value="dpmpp-3m-sde")
|
| 424 |
+
sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma min")
|
| 425 |
+
sigma_max_slider = gr.Slider(minimum=0.0, maximum=1000.0, step=0.1, value=500, label="Sigma max")
|
| 426 |
+
cfg_rescale_slider = gr.Slider(minimum=0.0, maximum=1, step=0.01, value=0.0, label="CFG rescale amount")
|
| 427 |
+
|
| 428 |
+
if inpainting:
|
| 429 |
+
# Inpainting Tab
|
| 430 |
+
with gr.Accordion("Inpainting", open=False):
|
| 431 |
+
sigma_max_slider.maximum=1000
|
| 432 |
+
|
| 433 |
+
init_audio_checkbox = gr.Checkbox(label="Do inpainting")
|
| 434 |
+
init_audio_input = gr.Audio(label="Init audio")
|
| 435 |
+
init_noise_level_slider = gr.Slider(minimum=0.1, maximum=100.0, step=0.1, value=80, label="Init audio noise level", visible=False) # hide this
|
| 436 |
+
|
| 437 |
+
mask_cropfrom_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=0, label="Crop From %")
|
| 438 |
+
mask_pastefrom_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=0, label="Paste From %")
|
| 439 |
+
mask_pasteto_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=100, label="Paste To %")
|
| 440 |
+
|
| 441 |
+
mask_maskstart_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=50, label="Mask Start %")
|
| 442 |
+
mask_maskend_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=100, label="Mask End %")
|
| 443 |
+
mask_softnessL_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=0, label="Softmask Left Crossfade Length %")
|
| 444 |
+
mask_softnessR_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=0, label="Softmask Right Crossfade Length %")
|
| 445 |
+
mask_marination_slider = gr.Slider(minimum=0.0, maximum=1, step=0.0001, value=0, label="Marination level", visible=False) # still working on the usefulness of this
|
| 446 |
+
|
| 447 |
+
inputs = [prompt,
|
| 448 |
+
negative_prompt,
|
| 449 |
+
seconds_start_slider,
|
| 450 |
+
seconds_total_slider,
|
| 451 |
+
cfg_scale_slider,
|
| 452 |
+
steps_slider,
|
| 453 |
+
preview_every_slider,
|
| 454 |
+
seed_textbox,
|
| 455 |
+
sampler_type_dropdown,
|
| 456 |
+
sigma_min_slider,
|
| 457 |
+
sigma_max_slider,
|
| 458 |
+
cfg_rescale_slider,
|
| 459 |
+
init_audio_checkbox,
|
| 460 |
+
init_audio_input,
|
| 461 |
+
init_noise_level_slider,
|
| 462 |
+
mask_cropfrom_slider,
|
| 463 |
+
mask_pastefrom_slider,
|
| 464 |
+
mask_pasteto_slider,
|
| 465 |
+
mask_maskstart_slider,
|
| 466 |
+
mask_maskend_slider,
|
| 467 |
+
mask_softnessL_slider,
|
| 468 |
+
mask_softnessR_slider,
|
| 469 |
+
mask_marination_slider
|
| 470 |
+
]
|
| 471 |
+
else:
|
| 472 |
+
# Default generation tab
|
| 473 |
+
with gr.Accordion("Init audio", open=False):
|
| 474 |
+
init_audio_checkbox = gr.Checkbox(label="Use init audio")
|
| 475 |
+
init_audio_input = gr.Audio(label="Init audio")
|
| 476 |
+
init_noise_level_slider = gr.Slider(minimum=0.1, maximum=100.0, step=0.01, value=0.1, label="Init noise level")
|
| 477 |
+
|
| 478 |
+
inputs = [prompt,
|
| 479 |
+
negative_prompt,
|
| 480 |
+
seconds_start_slider,
|
| 481 |
+
seconds_total_slider,
|
| 482 |
+
cfg_scale_slider,
|
| 483 |
+
steps_slider,
|
| 484 |
+
preview_every_slider,
|
| 485 |
+
seed_textbox,
|
| 486 |
+
sampler_type_dropdown,
|
| 487 |
+
sigma_min_slider,
|
| 488 |
+
sigma_max_slider,
|
| 489 |
+
cfg_rescale_slider,
|
| 490 |
+
init_audio_checkbox,
|
| 491 |
+
init_audio_input,
|
| 492 |
+
init_noise_level_slider
|
| 493 |
+
]
|
| 494 |
+
|
| 495 |
+
with gr.Column():
|
| 496 |
+
audio_output = gr.Audio(label="Output audio", interactive=False)
|
| 497 |
+
audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False)
|
| 498 |
+
send_to_init_button = gr.Button("Send to init audio", scale=1)
|
| 499 |
+
send_to_init_button.click(fn=lambda audio: audio, inputs=[audio_output], outputs=[init_audio_input])
|
| 500 |
+
|
| 501 |
+
generate_button.click(fn=generate_cond,
|
| 502 |
+
inputs=inputs,
|
| 503 |
+
outputs=[
|
| 504 |
+
audio_output,
|
| 505 |
+
audio_spectrogram_output
|
| 506 |
+
],
|
| 507 |
+
api_name="generate")
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
def create_txt2audio_ui(model_config):
|
| 511 |
+
with gr.Blocks() as ui:
|
| 512 |
+
with gr.Tab("Generation"):
|
| 513 |
+
create_sampling_ui(model_config)
|
| 514 |
+
with gr.Tab("Inpainting"):
|
| 515 |
+
create_sampling_ui(model_config, inpainting=True)
|
| 516 |
+
return ui
|
| 517 |
+
|
| 518 |
+
def create_diffusion_uncond_ui(model_config):
|
| 519 |
+
with gr.Blocks() as ui:
|
| 520 |
+
create_uncond_sampling_ui(model_config)
|
| 521 |
+
|
| 522 |
+
return ui
|
| 523 |
+
|
| 524 |
+
def autoencoder_process(audio, latent_noise, n_quantizers):
|
| 525 |
+
if torch.cuda.is_available():
|
| 526 |
+
torch.cuda.empty_cache()
|
| 527 |
+
gc.collect()
|
| 528 |
+
|
| 529 |
+
#Get the device from the model
|
| 530 |
+
device = next(model.parameters()).device
|
| 531 |
+
|
| 532 |
+
in_sr, audio = audio
|
| 533 |
+
|
| 534 |
+
audio = torch.from_numpy(audio).float().div(32767).to(device)
|
| 535 |
+
|
| 536 |
+
if audio.dim() == 1:
|
| 537 |
+
audio = audio.unsqueeze(0)
|
| 538 |
+
else:
|
| 539 |
+
audio = audio.transpose(0, 1)
|
| 540 |
+
|
| 541 |
+
audio = model.preprocess_audio_for_encoder(audio, in_sr)
|
| 542 |
+
# Note: If you need to do chunked encoding, to reduce VRAM,
|
| 543 |
+
# then add these arguments to encode_audio and decode_audio: chunked=True, overlap=32, chunk_size=128
|
| 544 |
+
# To turn it off, do chunked=False
|
| 545 |
+
# Optimal overlap and chunk_size values will depend on the model.
|
| 546 |
+
# See encode_audio & decode_audio in autoencoders.py for more info
|
| 547 |
+
# Get dtype of model
|
| 548 |
+
dtype = next(model.parameters()).dtype
|
| 549 |
+
|
| 550 |
+
audio = audio.to(dtype)
|
| 551 |
+
|
| 552 |
+
if n_quantizers > 0:
|
| 553 |
+
latents = model.encode_audio(audio, chunked=False, n_quantizers=n_quantizers)
|
| 554 |
+
else:
|
| 555 |
+
latents = model.encode_audio(audio, chunked=False)
|
| 556 |
+
|
| 557 |
+
if latent_noise > 0:
|
| 558 |
+
latents = latents + torch.randn_like(latents) * latent_noise
|
| 559 |
+
|
| 560 |
+
audio = model.decode_audio(latents, chunked=False)
|
| 561 |
+
|
| 562 |
+
audio = rearrange(audio, "b d n -> d (b n)")
|
| 563 |
+
|
| 564 |
+
audio = audio.to(torch.float32).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
|
| 565 |
+
|
| 566 |
+
torchaudio.save("output.wav", audio, sample_rate)
|
| 567 |
+
|
| 568 |
+
return "output.wav"
|
| 569 |
+
|
| 570 |
+
def create_autoencoder_ui(model_config):
|
| 571 |
+
|
| 572 |
+
is_dac_rvq = "model" in model_config and "bottleneck" in model_config["model"] and model_config["model"]["bottleneck"]["type"] in ["dac_rvq","dac_rvq_vae"]
|
| 573 |
+
|
| 574 |
+
if is_dac_rvq:
|
| 575 |
+
n_quantizers = model_config["model"]["bottleneck"]["config"]["n_codebooks"]
|
| 576 |
+
else:
|
| 577 |
+
n_quantizers = 0
|
| 578 |
+
|
| 579 |
+
with gr.Blocks() as ui:
|
| 580 |
+
input_audio = gr.Audio(label="Input audio")
|
| 581 |
+
output_audio = gr.Audio(label="Output audio", interactive=False)
|
| 582 |
+
n_quantizers_slider = gr.Slider(minimum=1, maximum=n_quantizers, step=1, value=n_quantizers, label="# quantizers", visible=is_dac_rvq)
|
| 583 |
+
latent_noise_slider = gr.Slider(minimum=0.0, maximum=10.0, step=0.001, value=0.0, label="Add latent noise")
|
| 584 |
+
process_button = gr.Button("Process", variant='primary', scale=1)
|
| 585 |
+
process_button.click(fn=autoencoder_process, inputs=[input_audio, latent_noise_slider, n_quantizers_slider], outputs=output_audio, api_name="process")
|
| 586 |
+
|
| 587 |
+
return ui
|
| 588 |
+
|
| 589 |
+
def diffusion_prior_process(audio, steps, sampler_type, sigma_min, sigma_max):
|
| 590 |
+
|
| 591 |
+
if torch.cuda.is_available():
|
| 592 |
+
torch.cuda.empty_cache()
|
| 593 |
+
gc.collect()
|
| 594 |
+
|
| 595 |
+
#Get the device from the model
|
| 596 |
+
device = next(model.parameters()).device
|
| 597 |
+
|
| 598 |
+
in_sr, audio = audio
|
| 599 |
+
|
| 600 |
+
audio = torch.from_numpy(audio).float().div(32767).to(device)
|
| 601 |
+
|
| 602 |
+
if audio.dim() == 1:
|
| 603 |
+
audio = audio.unsqueeze(0) # [1, n]
|
| 604 |
+
elif audio.dim() == 2:
|
| 605 |
+
audio = audio.transpose(0, 1) # [n, 2] -> [2, n]
|
| 606 |
+
|
| 607 |
+
audio = audio.unsqueeze(0)
|
| 608 |
+
|
| 609 |
+
audio = model.stereoize(audio, in_sr, steps, sampler_kwargs={"sampler_type": sampler_type, "sigma_min": sigma_min, "sigma_max": sigma_max})
|
| 610 |
+
|
| 611 |
+
audio = rearrange(audio, "b d n -> d (b n)")
|
| 612 |
+
|
| 613 |
+
audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
|
| 614 |
+
|
| 615 |
+
torchaudio.save("output.wav", audio, sample_rate)
|
| 616 |
+
|
| 617 |
+
return "output.wav"
|
| 618 |
+
|
| 619 |
+
def create_diffusion_prior_ui(model_config):
|
| 620 |
+
with gr.Blocks() as ui:
|
| 621 |
+
input_audio = gr.Audio(label="Input audio")
|
| 622 |
+
output_audio = gr.Audio(label="Output audio", interactive=False)
|
| 623 |
+
# Sampler params
|
| 624 |
+
with gr.Row():
|
| 625 |
+
steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Steps")
|
| 626 |
+
sampler_type_dropdown = gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], label="Sampler type", value="dpmpp-3m-sde")
|
| 627 |
+
sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma min")
|
| 628 |
+
sigma_max_slider = gr.Slider(minimum=0.0, maximum=1000.0, step=0.1, value=500, label="Sigma max")
|
| 629 |
+
process_button = gr.Button("Process", variant='primary', scale=1)
|
| 630 |
+
process_button.click(fn=diffusion_prior_process, inputs=[input_audio, steps_slider, sampler_type_dropdown, sigma_min_slider, sigma_max_slider], outputs=output_audio, api_name="process")
|
| 631 |
+
|
| 632 |
+
return ui
|
| 633 |
+
|
| 634 |
+
def create_lm_ui(model_config):
|
| 635 |
+
with gr.Blocks() as ui:
|
| 636 |
+
output_audio = gr.Audio(label="Output audio", interactive=False)
|
| 637 |
+
audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False)
|
| 638 |
+
|
| 639 |
+
# Sampling params
|
| 640 |
+
with gr.Row():
|
| 641 |
+
temperature_slider = gr.Slider(minimum=0, maximum=5, step=0.01, value=1.0, label="Temperature")
|
| 642 |
+
top_p_slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.95, label="Top p")
|
| 643 |
+
top_k_slider = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Top k")
|
| 644 |
+
|
| 645 |
+
generate_button = gr.Button("Generate", variant='primary', scale=1)
|
| 646 |
+
generate_button.click(
|
| 647 |
+
fn=generate_lm,
|
| 648 |
+
inputs=[
|
| 649 |
+
temperature_slider,
|
| 650 |
+
top_p_slider,
|
| 651 |
+
top_k_slider
|
| 652 |
+
],
|
| 653 |
+
outputs=[output_audio, audio_spectrogram_output],
|
| 654 |
+
api_name="generate"
|
| 655 |
+
)
|
| 656 |
+
|
| 657 |
+
return ui
|
| 658 |
+
|
| 659 |
+
def create_ui(model_config_path=None, ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, model_half=False):
|
| 660 |
+
|
| 661 |
+
assert (pretrained_name is not None) ^ (model_config_path is not None and ckpt_path is not None), "Must specify either pretrained name or provide a model config and checkpoint, but not both"
|
| 662 |
+
|
| 663 |
+
if model_config_path is not None:
|
| 664 |
+
# Load config from json file
|
| 665 |
+
with open(model_config_path) as f:
|
| 666 |
+
model_config = json.load(f)
|
| 667 |
+
else:
|
| 668 |
+
model_config = None
|
| 669 |
+
|
| 670 |
+
try:
|
| 671 |
+
has_mps = platform.system() == "Darwin" and torch.backends.mps.is_available()
|
| 672 |
+
except Exception:
|
| 673 |
+
# In case this version of Torch doesn't even have `torch.backends.mps`...
|
| 674 |
+
has_mps = False
|
| 675 |
+
|
| 676 |
+
if has_mps:
|
| 677 |
+
device = torch.device("mps")
|
| 678 |
+
elif torch.cuda.is_available():
|
| 679 |
+
device = torch.device("cuda")
|
| 680 |
+
else:
|
| 681 |
+
device = torch.device("cpu")
|
| 682 |
+
|
| 683 |
+
print("Using device:", device)
|
| 684 |
+
|
| 685 |
+
_, model_config = load_model(model_config, ckpt_path, pretrained_name=pretrained_name, pretransform_ckpt_path=pretransform_ckpt_path, model_half=model_half, device=device)
|
| 686 |
+
|
| 687 |
+
model_type = model_config["model_type"]
|
| 688 |
+
|
| 689 |
+
if model_type == "diffusion_cond":
|
| 690 |
+
ui = create_txt2audio_ui(model_config)
|
| 691 |
+
elif model_type == "diffusion_uncond":
|
| 692 |
+
ui = create_diffusion_uncond_ui(model_config)
|
| 693 |
+
elif model_type == "autoencoder" or model_type == "diffusion_autoencoder":
|
| 694 |
+
ui = create_autoencoder_ui(model_config)
|
| 695 |
+
elif model_type == "diffusion_prior":
|
| 696 |
+
ui = create_diffusion_prior_ui(model_config)
|
| 697 |
+
elif model_type == "lm":
|
| 698 |
+
ui = create_lm_ui(model_config)
|
| 699 |
+
|
| 700 |
+
return ui
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/stable_audio_tools/models/autoencoders.py
ADDED
|
@@ -0,0 +1,794 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import math
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
from torchaudio import transforms as T
|
| 8 |
+
from alias_free_torch import Activation1d
|
| 9 |
+
from dac.nn.layers import WNConv1d, WNConvTranspose1d
|
| 10 |
+
from typing import Literal, Dict, Any
|
| 11 |
+
|
| 12 |
+
from ..inference.sampling import sample
|
| 13 |
+
from ..inference.utils import prepare_audio
|
| 14 |
+
from .blocks import SnakeBeta
|
| 15 |
+
from .bottleneck import Bottleneck, DiscreteBottleneck
|
| 16 |
+
from .diffusion import ConditionedDiffusionModel, DAU1DCondWrapper, UNet1DCondWrapper, DiTWrapper
|
| 17 |
+
from .factory import create_pretransform_from_config, create_bottleneck_from_config
|
| 18 |
+
from .pretransforms import Pretransform
|
| 19 |
+
|
| 20 |
+
def checkpoint(function, *args, **kwargs):
|
| 21 |
+
kwargs.setdefault("use_reentrant", False)
|
| 22 |
+
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
|
| 23 |
+
|
| 24 |
+
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
|
| 25 |
+
if activation == "elu":
|
| 26 |
+
act = nn.ELU()
|
| 27 |
+
elif activation == "snake":
|
| 28 |
+
act = SnakeBeta(channels)
|
| 29 |
+
elif activation == "none":
|
| 30 |
+
act = nn.Identity()
|
| 31 |
+
else:
|
| 32 |
+
raise ValueError(f"Unknown activation {activation}")
|
| 33 |
+
|
| 34 |
+
if antialias:
|
| 35 |
+
act = Activation1d(act)
|
| 36 |
+
|
| 37 |
+
return act
|
| 38 |
+
|
| 39 |
+
class ResidualUnit(nn.Module):
|
| 40 |
+
def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
|
| 41 |
+
super().__init__()
|
| 42 |
+
|
| 43 |
+
self.dilation = dilation
|
| 44 |
+
|
| 45 |
+
padding = (dilation * (7-1)) // 2
|
| 46 |
+
|
| 47 |
+
self.layers = nn.Sequential(
|
| 48 |
+
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
|
| 49 |
+
WNConv1d(in_channels=in_channels, out_channels=out_channels,
|
| 50 |
+
kernel_size=7, dilation=dilation, padding=padding),
|
| 51 |
+
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
|
| 52 |
+
WNConv1d(in_channels=out_channels, out_channels=out_channels,
|
| 53 |
+
kernel_size=1)
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
def forward(self, x):
|
| 57 |
+
res = x
|
| 58 |
+
|
| 59 |
+
#x = checkpoint(self.layers, x)
|
| 60 |
+
x = self.layers(x)
|
| 61 |
+
|
| 62 |
+
return x + res
|
| 63 |
+
|
| 64 |
+
class EncoderBlock(nn.Module):
|
| 65 |
+
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
|
| 66 |
+
super().__init__()
|
| 67 |
+
|
| 68 |
+
self.layers = nn.Sequential(
|
| 69 |
+
ResidualUnit(in_channels=in_channels,
|
| 70 |
+
out_channels=in_channels, dilation=1, use_snake=use_snake),
|
| 71 |
+
ResidualUnit(in_channels=in_channels,
|
| 72 |
+
out_channels=in_channels, dilation=3, use_snake=use_snake),
|
| 73 |
+
ResidualUnit(in_channels=in_channels,
|
| 74 |
+
out_channels=in_channels, dilation=9, use_snake=use_snake),
|
| 75 |
+
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
|
| 76 |
+
WNConv1d(in_channels=in_channels, out_channels=out_channels,
|
| 77 |
+
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
def forward(self, x):
|
| 81 |
+
return self.layers(x)
|
| 82 |
+
|
| 83 |
+
class DecoderBlock(nn.Module):
|
| 84 |
+
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
|
| 85 |
+
super().__init__()
|
| 86 |
+
|
| 87 |
+
if use_nearest_upsample:
|
| 88 |
+
upsample_layer = nn.Sequential(
|
| 89 |
+
nn.Upsample(scale_factor=stride, mode="nearest"),
|
| 90 |
+
WNConv1d(in_channels=in_channels,
|
| 91 |
+
out_channels=out_channels,
|
| 92 |
+
kernel_size=2*stride,
|
| 93 |
+
stride=1,
|
| 94 |
+
bias=False,
|
| 95 |
+
padding='same')
|
| 96 |
+
)
|
| 97 |
+
else:
|
| 98 |
+
upsample_layer = WNConvTranspose1d(in_channels=in_channels,
|
| 99 |
+
out_channels=out_channels,
|
| 100 |
+
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
|
| 101 |
+
|
| 102 |
+
self.layers = nn.Sequential(
|
| 103 |
+
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
|
| 104 |
+
upsample_layer,
|
| 105 |
+
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
| 106 |
+
dilation=1, use_snake=use_snake),
|
| 107 |
+
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
| 108 |
+
dilation=3, use_snake=use_snake),
|
| 109 |
+
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
| 110 |
+
dilation=9, use_snake=use_snake),
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
def forward(self, x):
|
| 114 |
+
return self.layers(x)
|
| 115 |
+
|
| 116 |
+
class OobleckEncoder(nn.Module):
|
| 117 |
+
def __init__(self,
|
| 118 |
+
in_channels=2,
|
| 119 |
+
channels=128,
|
| 120 |
+
latent_dim=32,
|
| 121 |
+
c_mults = [1, 2, 4, 8],
|
| 122 |
+
strides = [2, 4, 8, 8],
|
| 123 |
+
use_snake=False,
|
| 124 |
+
antialias_activation=False
|
| 125 |
+
):
|
| 126 |
+
super().__init__()
|
| 127 |
+
|
| 128 |
+
c_mults = [1] + c_mults
|
| 129 |
+
|
| 130 |
+
self.depth = len(c_mults)
|
| 131 |
+
|
| 132 |
+
layers = [
|
| 133 |
+
WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
|
| 134 |
+
]
|
| 135 |
+
|
| 136 |
+
for i in range(self.depth-1):
|
| 137 |
+
layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
|
| 138 |
+
|
| 139 |
+
layers += [
|
| 140 |
+
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
|
| 141 |
+
WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
|
| 142 |
+
]
|
| 143 |
+
|
| 144 |
+
self.layers = nn.Sequential(*layers)
|
| 145 |
+
|
| 146 |
+
def forward(self, x):
|
| 147 |
+
return self.layers(x)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class OobleckDecoder(nn.Module):
|
| 151 |
+
def __init__(self,
|
| 152 |
+
out_channels=2,
|
| 153 |
+
channels=128,
|
| 154 |
+
latent_dim=32,
|
| 155 |
+
c_mults = [1, 2, 4, 8],
|
| 156 |
+
strides = [2, 4, 8, 8],
|
| 157 |
+
use_snake=False,
|
| 158 |
+
antialias_activation=False,
|
| 159 |
+
use_nearest_upsample=False,
|
| 160 |
+
final_tanh=True):
|
| 161 |
+
super().__init__()
|
| 162 |
+
|
| 163 |
+
c_mults = [1] + c_mults
|
| 164 |
+
|
| 165 |
+
self.depth = len(c_mults)
|
| 166 |
+
|
| 167 |
+
layers = [
|
| 168 |
+
WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
|
| 169 |
+
]
|
| 170 |
+
|
| 171 |
+
for i in range(self.depth-1, 0, -1):
|
| 172 |
+
layers += [DecoderBlock(
|
| 173 |
+
in_channels=c_mults[i]*channels,
|
| 174 |
+
out_channels=c_mults[i-1]*channels,
|
| 175 |
+
stride=strides[i-1],
|
| 176 |
+
use_snake=use_snake,
|
| 177 |
+
antialias_activation=antialias_activation,
|
| 178 |
+
use_nearest_upsample=use_nearest_upsample
|
| 179 |
+
)
|
| 180 |
+
]
|
| 181 |
+
|
| 182 |
+
layers += [
|
| 183 |
+
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
|
| 184 |
+
WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
|
| 185 |
+
nn.Tanh() if final_tanh else nn.Identity()
|
| 186 |
+
]
|
| 187 |
+
|
| 188 |
+
self.layers = nn.Sequential(*layers)
|
| 189 |
+
|
| 190 |
+
def forward(self, x):
|
| 191 |
+
return self.layers(x)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class DACEncoderWrapper(nn.Module):
|
| 195 |
+
def __init__(self, in_channels=1, **kwargs):
|
| 196 |
+
super().__init__()
|
| 197 |
+
|
| 198 |
+
from dac.model.dac import Encoder as DACEncoder
|
| 199 |
+
|
| 200 |
+
latent_dim = kwargs.pop("latent_dim", None)
|
| 201 |
+
|
| 202 |
+
encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"]))
|
| 203 |
+
self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs)
|
| 204 |
+
self.latent_dim = latent_dim
|
| 205 |
+
|
| 206 |
+
# Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility
|
| 207 |
+
self.proj_out = nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity()
|
| 208 |
+
|
| 209 |
+
if in_channels != 1:
|
| 210 |
+
self.encoder.block[0] = WNConv1d(in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3)
|
| 211 |
+
|
| 212 |
+
def forward(self, x):
|
| 213 |
+
x = self.encoder(x)
|
| 214 |
+
x = self.proj_out(x)
|
| 215 |
+
return x
|
| 216 |
+
|
| 217 |
+
class DACDecoderWrapper(nn.Module):
|
| 218 |
+
def __init__(self, latent_dim, out_channels=1, **kwargs):
|
| 219 |
+
super().__init__()
|
| 220 |
+
|
| 221 |
+
from dac.model.dac import Decoder as DACDecoder
|
| 222 |
+
|
| 223 |
+
self.decoder = DACDecoder(**kwargs, input_channel = latent_dim, d_out=out_channels)
|
| 224 |
+
|
| 225 |
+
self.latent_dim = latent_dim
|
| 226 |
+
|
| 227 |
+
def forward(self, x):
|
| 228 |
+
return self.decoder(x)
|
| 229 |
+
|
| 230 |
+
class AudioAutoencoder(nn.Module):
|
| 231 |
+
def __init__(
|
| 232 |
+
self,
|
| 233 |
+
encoder,
|
| 234 |
+
decoder,
|
| 235 |
+
latent_dim,
|
| 236 |
+
downsampling_ratio,
|
| 237 |
+
sample_rate,
|
| 238 |
+
io_channels=2,
|
| 239 |
+
bottleneck: Bottleneck = None,
|
| 240 |
+
pretransform: Pretransform = None,
|
| 241 |
+
in_channels = None,
|
| 242 |
+
out_channels = None,
|
| 243 |
+
soft_clip = False
|
| 244 |
+
):
|
| 245 |
+
super().__init__()
|
| 246 |
+
|
| 247 |
+
self.downsampling_ratio = downsampling_ratio
|
| 248 |
+
self.sample_rate = sample_rate
|
| 249 |
+
|
| 250 |
+
self.latent_dim = latent_dim
|
| 251 |
+
self.io_channels = io_channels
|
| 252 |
+
self.in_channels = io_channels
|
| 253 |
+
self.out_channels = io_channels
|
| 254 |
+
|
| 255 |
+
self.min_length = self.downsampling_ratio
|
| 256 |
+
|
| 257 |
+
if in_channels is not None:
|
| 258 |
+
self.in_channels = in_channels
|
| 259 |
+
|
| 260 |
+
if out_channels is not None:
|
| 261 |
+
self.out_channels = out_channels
|
| 262 |
+
|
| 263 |
+
self.bottleneck = bottleneck
|
| 264 |
+
|
| 265 |
+
self.encoder = encoder
|
| 266 |
+
|
| 267 |
+
self.decoder = decoder
|
| 268 |
+
|
| 269 |
+
self.pretransform = pretransform
|
| 270 |
+
|
| 271 |
+
self.soft_clip = soft_clip
|
| 272 |
+
|
| 273 |
+
self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete
|
| 274 |
+
|
| 275 |
+
def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs):
|
| 276 |
+
|
| 277 |
+
info = {}
|
| 278 |
+
|
| 279 |
+
if self.pretransform is not None and not skip_pretransform:
|
| 280 |
+
if self.pretransform.enable_grad:
|
| 281 |
+
if iterate_batch:
|
| 282 |
+
audios = []
|
| 283 |
+
for i in range(audio.shape[0]):
|
| 284 |
+
audios.append(self.pretransform.encode(audio[i:i+1]))
|
| 285 |
+
audio = torch.cat(audios, dim=0)
|
| 286 |
+
else:
|
| 287 |
+
audio = self.pretransform.encode(audio)
|
| 288 |
+
else:
|
| 289 |
+
with torch.no_grad():
|
| 290 |
+
if iterate_batch:
|
| 291 |
+
audios = []
|
| 292 |
+
for i in range(audio.shape[0]):
|
| 293 |
+
audios.append(self.pretransform.encode(audio[i:i+1]))
|
| 294 |
+
audio = torch.cat(audios, dim=0)
|
| 295 |
+
else:
|
| 296 |
+
audio = self.pretransform.encode(audio)
|
| 297 |
+
|
| 298 |
+
if self.encoder is not None:
|
| 299 |
+
if iterate_batch:
|
| 300 |
+
latents = []
|
| 301 |
+
for i in range(audio.shape[0]):
|
| 302 |
+
latents.append(self.encoder(audio[i:i+1]))
|
| 303 |
+
latents = torch.cat(latents, dim=0)
|
| 304 |
+
else:
|
| 305 |
+
latents = self.encoder(audio)
|
| 306 |
+
else:
|
| 307 |
+
latents = audio
|
| 308 |
+
|
| 309 |
+
if self.bottleneck is not None:
|
| 310 |
+
# TODO: Add iterate batch logic, needs to merge the info dicts
|
| 311 |
+
latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs)
|
| 312 |
+
|
| 313 |
+
info.update(bottleneck_info)
|
| 314 |
+
|
| 315 |
+
if return_info:
|
| 316 |
+
return latents, info
|
| 317 |
+
|
| 318 |
+
return latents
|
| 319 |
+
|
| 320 |
+
def decode(self, latents, iterate_batch=False, **kwargs):
|
| 321 |
+
|
| 322 |
+
if self.bottleneck is not None:
|
| 323 |
+
if iterate_batch:
|
| 324 |
+
decoded = []
|
| 325 |
+
for i in range(latents.shape[0]):
|
| 326 |
+
decoded.append(self.bottleneck.decode(latents[i:i+1]))
|
| 327 |
+
latents = torch.cat(decoded, dim=0)
|
| 328 |
+
else:
|
| 329 |
+
latents = self.bottleneck.decode(latents)
|
| 330 |
+
|
| 331 |
+
if iterate_batch:
|
| 332 |
+
decoded = []
|
| 333 |
+
for i in range(latents.shape[0]):
|
| 334 |
+
decoded.append(self.decoder(latents[i:i+1]))
|
| 335 |
+
decoded = torch.cat(decoded, dim=0)
|
| 336 |
+
else:
|
| 337 |
+
decoded = self.decoder(latents, **kwargs)
|
| 338 |
+
|
| 339 |
+
if self.pretransform is not None:
|
| 340 |
+
if self.pretransform.enable_grad:
|
| 341 |
+
if iterate_batch:
|
| 342 |
+
decodeds = []
|
| 343 |
+
for i in range(decoded.shape[0]):
|
| 344 |
+
decodeds.append(self.pretransform.decode(decoded[i:i+1]))
|
| 345 |
+
decoded = torch.cat(decodeds, dim=0)
|
| 346 |
+
else:
|
| 347 |
+
decoded = self.pretransform.decode(decoded)
|
| 348 |
+
else:
|
| 349 |
+
with torch.no_grad():
|
| 350 |
+
if iterate_batch:
|
| 351 |
+
decodeds = []
|
| 352 |
+
for i in range(latents.shape[0]):
|
| 353 |
+
decodeds.append(self.pretransform.decode(decoded[i:i+1]))
|
| 354 |
+
decoded = torch.cat(decodeds, dim=0)
|
| 355 |
+
else:
|
| 356 |
+
decoded = self.pretransform.decode(decoded)
|
| 357 |
+
|
| 358 |
+
if self.soft_clip:
|
| 359 |
+
decoded = torch.tanh(decoded)
|
| 360 |
+
|
| 361 |
+
return decoded
|
| 362 |
+
|
| 363 |
+
def decode_tokens(self, tokens, **kwargs):
|
| 364 |
+
'''
|
| 365 |
+
Decode discrete tokens to audio
|
| 366 |
+
Only works with discrete autoencoders
|
| 367 |
+
'''
|
| 368 |
+
|
| 369 |
+
assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders"
|
| 370 |
+
|
| 371 |
+
latents = self.bottleneck.decode_tokens(tokens, **kwargs)
|
| 372 |
+
|
| 373 |
+
return self.decode(latents, **kwargs)
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def preprocess_audio_for_encoder(self, audio, in_sr):
|
| 377 |
+
'''
|
| 378 |
+
Preprocess single audio tensor (Channels x Length) to be compatible with the encoder.
|
| 379 |
+
If the model is mono, stereo audio will be converted to mono.
|
| 380 |
+
Audio will be silence-padded to be a multiple of the model's downsampling ratio.
|
| 381 |
+
Audio will be resampled to the model's sample rate.
|
| 382 |
+
The output will have batch size 1 and be shape (1 x Channels x Length)
|
| 383 |
+
'''
|
| 384 |
+
return self.preprocess_audio_list_for_encoder([audio], [in_sr])
|
| 385 |
+
|
| 386 |
+
def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list):
|
| 387 |
+
'''
|
| 388 |
+
Preprocess a [list] of audio (Channels x Length) into a batch tensor to be compatable with the encoder.
|
| 389 |
+
The audio in that list can be of different lengths and channels.
|
| 390 |
+
in_sr can be an integer or list. If it's an integer it will be assumed it is the input sample_rate for every audio.
|
| 391 |
+
All audio will be resampled to the model's sample rate.
|
| 392 |
+
Audio will be silence-padded to the longest length, and further padded to be a multiple of the model's downsampling ratio.
|
| 393 |
+
If the model is mono, all audio will be converted to mono.
|
| 394 |
+
The output will be a tensor of shape (Batch x Channels x Length)
|
| 395 |
+
'''
|
| 396 |
+
batch_size = len(audio_list)
|
| 397 |
+
if isinstance(in_sr_list, int):
|
| 398 |
+
in_sr_list = [in_sr_list]*batch_size
|
| 399 |
+
assert len(in_sr_list) == batch_size, "list of sample rates must be the same length of audio_list"
|
| 400 |
+
new_audio = []
|
| 401 |
+
max_length = 0
|
| 402 |
+
# resample & find the max length
|
| 403 |
+
for i in range(batch_size):
|
| 404 |
+
audio = audio_list[i]
|
| 405 |
+
in_sr = in_sr_list[i]
|
| 406 |
+
if len(audio.shape) == 3 and audio.shape[0] == 1:
|
| 407 |
+
# batchsize 1 was given by accident. Just squeeze it.
|
| 408 |
+
audio = audio.squeeze(0)
|
| 409 |
+
elif len(audio.shape) == 1:
|
| 410 |
+
# Mono signal, channel dimension is missing, unsqueeze it in
|
| 411 |
+
audio = audio.unsqueeze(0)
|
| 412 |
+
assert len(audio.shape)==2, "Audio should be shape (Channels x Length) with no batch dimension"
|
| 413 |
+
# Resample audio
|
| 414 |
+
if in_sr != self.sample_rate:
|
| 415 |
+
resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device)
|
| 416 |
+
audio = resample_tf(audio)
|
| 417 |
+
new_audio.append(audio)
|
| 418 |
+
if audio.shape[-1] > max_length:
|
| 419 |
+
max_length = audio.shape[-1]
|
| 420 |
+
# Pad every audio to the same length, multiple of model's downsampling ratio
|
| 421 |
+
padded_audio_length = max_length + (self.min_length - (max_length % self.min_length)) % self.min_length
|
| 422 |
+
for i in range(batch_size):
|
| 423 |
+
# Pad it & if necessary, mixdown/duplicate stereo/mono channels to support model
|
| 424 |
+
new_audio[i] = prepare_audio(new_audio[i], in_sr=in_sr, target_sr=in_sr, target_length=padded_audio_length,
|
| 425 |
+
target_channels=self.in_channels, device=new_audio[i].device).squeeze(0)
|
| 426 |
+
# convert to tensor
|
| 427 |
+
return torch.stack(new_audio)
|
| 428 |
+
|
| 429 |
+
def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs):
|
| 430 |
+
'''
|
| 431 |
+
Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder.
|
| 432 |
+
If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap.
|
| 433 |
+
Overlap and chunk_size params are both measured in number of latents (not audio samples)
|
| 434 |
+
# and therefore you likely could use the same values with decode_audio.
|
| 435 |
+
A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
|
| 436 |
+
Every autoencoder will have a different receptive field size, and thus ideal overlap.
|
| 437 |
+
You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff.
|
| 438 |
+
The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
|
| 439 |
+
Smaller chunk_size uses less memory, but more compute.
|
| 440 |
+
The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
|
| 441 |
+
For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
|
| 442 |
+
'''
|
| 443 |
+
if not chunked:
|
| 444 |
+
# default behavior. Encode the entire audio in parallel
|
| 445 |
+
return self.encode(audio, **kwargs)
|
| 446 |
+
else:
|
| 447 |
+
# CHUNKED ENCODING
|
| 448 |
+
# samples_per_latent is just the downsampling ratio (which is also the upsampling ratio)
|
| 449 |
+
samples_per_latent = self.downsampling_ratio
|
| 450 |
+
total_size = audio.shape[2] # in samples
|
| 451 |
+
batch_size = audio.shape[0]
|
| 452 |
+
chunk_size *= samples_per_latent # converting metric in latents to samples
|
| 453 |
+
overlap *= samples_per_latent # converting metric in latents to samples
|
| 454 |
+
hop_size = chunk_size - overlap
|
| 455 |
+
chunks = []
|
| 456 |
+
for i in range(0, total_size - chunk_size + 1, hop_size):
|
| 457 |
+
chunk = audio[:,:,i:i+chunk_size]
|
| 458 |
+
chunks.append(chunk)
|
| 459 |
+
if i+chunk_size != total_size:
|
| 460 |
+
# Final chunk
|
| 461 |
+
chunk = audio[:,:,-chunk_size:]
|
| 462 |
+
chunks.append(chunk)
|
| 463 |
+
chunks = torch.stack(chunks)
|
| 464 |
+
num_chunks = chunks.shape[0]
|
| 465 |
+
# Note: y_size might be a different value from the latent length used in diffusion training
|
| 466 |
+
# because we can encode audio of varying lengths
|
| 467 |
+
# However, the audio should've been padded to a multiple of samples_per_latent by now.
|
| 468 |
+
y_size = total_size // samples_per_latent
|
| 469 |
+
# Create an empty latent, we will populate it with chunks as we encode them
|
| 470 |
+
y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device)
|
| 471 |
+
for i in range(num_chunks):
|
| 472 |
+
x_chunk = chunks[i,:]
|
| 473 |
+
# encode the chunk
|
| 474 |
+
y_chunk = self.encode(x_chunk)
|
| 475 |
+
# figure out where to put the audio along the time domain
|
| 476 |
+
if i == num_chunks-1:
|
| 477 |
+
# final chunk always goes at the end
|
| 478 |
+
t_end = y_size
|
| 479 |
+
t_start = t_end - y_chunk.shape[2]
|
| 480 |
+
else:
|
| 481 |
+
t_start = i * hop_size // samples_per_latent
|
| 482 |
+
t_end = t_start + chunk_size // samples_per_latent
|
| 483 |
+
# remove the edges of the overlaps
|
| 484 |
+
ol = overlap//samples_per_latent//2
|
| 485 |
+
chunk_start = 0
|
| 486 |
+
chunk_end = y_chunk.shape[2]
|
| 487 |
+
if i > 0:
|
| 488 |
+
# no overlap for the start of the first chunk
|
| 489 |
+
t_start += ol
|
| 490 |
+
chunk_start += ol
|
| 491 |
+
if i < num_chunks-1:
|
| 492 |
+
# no overlap for the end of the last chunk
|
| 493 |
+
t_end -= ol
|
| 494 |
+
chunk_end -= ol
|
| 495 |
+
# paste the chunked audio into our y_final output audio
|
| 496 |
+
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
|
| 497 |
+
return y_final
|
| 498 |
+
|
| 499 |
+
def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs):
|
| 500 |
+
'''
|
| 501 |
+
Decode latents to audio.
|
| 502 |
+
If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents.
|
| 503 |
+
A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
|
| 504 |
+
Every autoencoder will have a different receptive field size, and thus ideal overlap.
|
| 505 |
+
You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff.
|
| 506 |
+
The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
|
| 507 |
+
Smaller chunk_size uses less memory, but more compute.
|
| 508 |
+
The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
|
| 509 |
+
For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
|
| 510 |
+
'''
|
| 511 |
+
if not chunked:
|
| 512 |
+
# default behavior. Decode the entire latent in parallel
|
| 513 |
+
return self.decode(latents, **kwargs)
|
| 514 |
+
else:
|
| 515 |
+
# chunked decoding
|
| 516 |
+
hop_size = chunk_size - overlap
|
| 517 |
+
total_size = latents.shape[2]
|
| 518 |
+
batch_size = latents.shape[0]
|
| 519 |
+
chunks = []
|
| 520 |
+
for i in range(0, total_size - chunk_size + 1, hop_size):
|
| 521 |
+
chunk = latents[:,:,i:i+chunk_size]
|
| 522 |
+
chunks.append(chunk)
|
| 523 |
+
if i+chunk_size != total_size:
|
| 524 |
+
# Final chunk
|
| 525 |
+
chunk = latents[:,:,-chunk_size:]
|
| 526 |
+
chunks.append(chunk)
|
| 527 |
+
chunks = torch.stack(chunks)
|
| 528 |
+
num_chunks = chunks.shape[0]
|
| 529 |
+
# samples_per_latent is just the downsampling ratio
|
| 530 |
+
samples_per_latent = self.downsampling_ratio
|
| 531 |
+
# Create an empty waveform, we will populate it with chunks as decode them
|
| 532 |
+
y_size = total_size * samples_per_latent
|
| 533 |
+
y_final = torch.zeros((batch_size,self.out_channels,y_size)).to(latents.device)
|
| 534 |
+
for i in range(num_chunks):
|
| 535 |
+
x_chunk = chunks[i,:]
|
| 536 |
+
# decode the chunk
|
| 537 |
+
y_chunk = self.decode(x_chunk)
|
| 538 |
+
# figure out where to put the audio along the time domain
|
| 539 |
+
if i == num_chunks-1:
|
| 540 |
+
# final chunk always goes at the end
|
| 541 |
+
t_end = y_size
|
| 542 |
+
t_start = t_end - y_chunk.shape[2]
|
| 543 |
+
else:
|
| 544 |
+
t_start = i * hop_size * samples_per_latent
|
| 545 |
+
t_end = t_start + chunk_size * samples_per_latent
|
| 546 |
+
# remove the edges of the overlaps
|
| 547 |
+
ol = (overlap//2) * samples_per_latent
|
| 548 |
+
chunk_start = 0
|
| 549 |
+
chunk_end = y_chunk.shape[2]
|
| 550 |
+
if i > 0:
|
| 551 |
+
# no overlap for the start of the first chunk
|
| 552 |
+
t_start += ol
|
| 553 |
+
chunk_start += ol
|
| 554 |
+
if i < num_chunks-1:
|
| 555 |
+
# no overlap for the end of the last chunk
|
| 556 |
+
t_end -= ol
|
| 557 |
+
chunk_end -= ol
|
| 558 |
+
# paste the chunked audio into our y_final output audio
|
| 559 |
+
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
|
| 560 |
+
return y_final
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
class DiffusionAutoencoder(AudioAutoencoder):
|
| 564 |
+
def __init__(
|
| 565 |
+
self,
|
| 566 |
+
diffusion: ConditionedDiffusionModel,
|
| 567 |
+
diffusion_downsampling_ratio,
|
| 568 |
+
*args,
|
| 569 |
+
**kwargs
|
| 570 |
+
):
|
| 571 |
+
super().__init__(*args, **kwargs)
|
| 572 |
+
|
| 573 |
+
self.diffusion = diffusion
|
| 574 |
+
|
| 575 |
+
self.min_length = self.downsampling_ratio * diffusion_downsampling_ratio
|
| 576 |
+
|
| 577 |
+
if self.encoder is not None:
|
| 578 |
+
# Shrink the initial encoder parameters to avoid saturated latents
|
| 579 |
+
with torch.no_grad():
|
| 580 |
+
for param in self.encoder.parameters():
|
| 581 |
+
param *= 0.5
|
| 582 |
+
|
| 583 |
+
def decode(self, latents, steps=100):
|
| 584 |
+
|
| 585 |
+
upsampled_length = latents.shape[2] * self.downsampling_ratio
|
| 586 |
+
|
| 587 |
+
if self.bottleneck is not None:
|
| 588 |
+
latents = self.bottleneck.decode(latents)
|
| 589 |
+
|
| 590 |
+
if self.decoder is not None:
|
| 591 |
+
latents = self.decode(latents)
|
| 592 |
+
|
| 593 |
+
# Upsample latents to match diffusion length
|
| 594 |
+
if latents.shape[2] != upsampled_length:
|
| 595 |
+
latents = F.interpolate(latents, size=upsampled_length, mode='nearest')
|
| 596 |
+
|
| 597 |
+
noise = torch.randn(latents.shape[0], self.io_channels, upsampled_length, device=latents.device)
|
| 598 |
+
decoded = sample(self.diffusion, noise, steps, 0, input_concat_cond=latents)
|
| 599 |
+
|
| 600 |
+
if self.pretransform is not None:
|
| 601 |
+
if self.pretransform.enable_grad:
|
| 602 |
+
decoded = self.pretransform.decode(decoded)
|
| 603 |
+
else:
|
| 604 |
+
with torch.no_grad():
|
| 605 |
+
decoded = self.pretransform.decode(decoded)
|
| 606 |
+
|
| 607 |
+
return decoded
|
| 608 |
+
|
| 609 |
+
# AE factories
|
| 610 |
+
|
| 611 |
+
def create_encoder_from_config(encoder_config: Dict[str, Any]):
|
| 612 |
+
encoder_type = encoder_config.get("type", None)
|
| 613 |
+
assert encoder_type is not None, "Encoder type must be specified"
|
| 614 |
+
|
| 615 |
+
if encoder_type == "oobleck":
|
| 616 |
+
encoder = OobleckEncoder(
|
| 617 |
+
**encoder_config["config"]
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
elif encoder_type == "seanet":
|
| 621 |
+
from encodec.modules import SEANetEncoder
|
| 622 |
+
seanet_encoder_config = encoder_config["config"]
|
| 623 |
+
|
| 624 |
+
#SEANet encoder expects strides in reverse order
|
| 625 |
+
seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2])))
|
| 626 |
+
encoder = SEANetEncoder(
|
| 627 |
+
**seanet_encoder_config
|
| 628 |
+
)
|
| 629 |
+
elif encoder_type == "dac":
|
| 630 |
+
dac_config = encoder_config["config"]
|
| 631 |
+
|
| 632 |
+
encoder = DACEncoderWrapper(**dac_config)
|
| 633 |
+
elif encoder_type == "local_attn":
|
| 634 |
+
from .local_attention import TransformerEncoder1D
|
| 635 |
+
|
| 636 |
+
local_attn_config = encoder_config["config"]
|
| 637 |
+
|
| 638 |
+
encoder = TransformerEncoder1D(
|
| 639 |
+
**local_attn_config
|
| 640 |
+
)
|
| 641 |
+
else:
|
| 642 |
+
raise ValueError(f"Unknown encoder type {encoder_type}")
|
| 643 |
+
|
| 644 |
+
requires_grad = encoder_config.get("requires_grad", True)
|
| 645 |
+
if not requires_grad:
|
| 646 |
+
for param in encoder.parameters():
|
| 647 |
+
param.requires_grad = False
|
| 648 |
+
|
| 649 |
+
return encoder
|
| 650 |
+
|
| 651 |
+
def create_decoder_from_config(decoder_config: Dict[str, Any]):
|
| 652 |
+
decoder_type = decoder_config.get("type", None)
|
| 653 |
+
assert decoder_type is not None, "Decoder type must be specified"
|
| 654 |
+
|
| 655 |
+
if decoder_type == "oobleck":
|
| 656 |
+
decoder = OobleckDecoder(
|
| 657 |
+
**decoder_config["config"]
|
| 658 |
+
)
|
| 659 |
+
elif decoder_type == "seanet":
|
| 660 |
+
from encodec.modules import SEANetDecoder
|
| 661 |
+
|
| 662 |
+
decoder = SEANetDecoder(
|
| 663 |
+
**decoder_config["config"]
|
| 664 |
+
)
|
| 665 |
+
elif decoder_type == "dac":
|
| 666 |
+
dac_config = decoder_config["config"]
|
| 667 |
+
|
| 668 |
+
decoder = DACDecoderWrapper(**dac_config)
|
| 669 |
+
elif decoder_type == "local_attn":
|
| 670 |
+
from .local_attention import TransformerDecoder1D
|
| 671 |
+
|
| 672 |
+
local_attn_config = decoder_config["config"]
|
| 673 |
+
|
| 674 |
+
decoder = TransformerDecoder1D(
|
| 675 |
+
**local_attn_config
|
| 676 |
+
)
|
| 677 |
+
else:
|
| 678 |
+
raise ValueError(f"Unknown decoder type {decoder_type}")
|
| 679 |
+
|
| 680 |
+
requires_grad = decoder_config.get("requires_grad", True)
|
| 681 |
+
if not requires_grad:
|
| 682 |
+
for param in decoder.parameters():
|
| 683 |
+
param.requires_grad = False
|
| 684 |
+
|
| 685 |
+
return decoder
|
| 686 |
+
|
| 687 |
+
def create_autoencoder_from_config(config: Dict[str, Any]):
|
| 688 |
+
|
| 689 |
+
ae_config = config["model"]
|
| 690 |
+
|
| 691 |
+
encoder = create_encoder_from_config(ae_config["encoder"])
|
| 692 |
+
decoder = create_decoder_from_config(ae_config["decoder"])
|
| 693 |
+
|
| 694 |
+
bottleneck = ae_config.get("bottleneck", None)
|
| 695 |
+
|
| 696 |
+
latent_dim = ae_config.get("latent_dim", None)
|
| 697 |
+
assert latent_dim is not None, "latent_dim must be specified in model config"
|
| 698 |
+
downsampling_ratio = ae_config.get("downsampling_ratio", None)
|
| 699 |
+
assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
|
| 700 |
+
io_channels = ae_config.get("io_channels", None)
|
| 701 |
+
assert io_channels is not None, "io_channels must be specified in model config"
|
| 702 |
+
sample_rate = config.get("sample_rate", None)
|
| 703 |
+
assert sample_rate is not None, "sample_rate must be specified in model config"
|
| 704 |
+
|
| 705 |
+
in_channels = ae_config.get("in_channels", None)
|
| 706 |
+
out_channels = ae_config.get("out_channels", None)
|
| 707 |
+
|
| 708 |
+
pretransform = ae_config.get("pretransform", None)
|
| 709 |
+
|
| 710 |
+
if pretransform is not None:
|
| 711 |
+
pretransform = create_pretransform_from_config(pretransform, sample_rate)
|
| 712 |
+
|
| 713 |
+
if bottleneck is not None:
|
| 714 |
+
bottleneck = create_bottleneck_from_config(bottleneck)
|
| 715 |
+
|
| 716 |
+
soft_clip = ae_config["decoder"].get("soft_clip", False)
|
| 717 |
+
|
| 718 |
+
return AudioAutoencoder(
|
| 719 |
+
encoder,
|
| 720 |
+
decoder,
|
| 721 |
+
io_channels=io_channels,
|
| 722 |
+
latent_dim=latent_dim,
|
| 723 |
+
downsampling_ratio=downsampling_ratio,
|
| 724 |
+
sample_rate=sample_rate,
|
| 725 |
+
bottleneck=bottleneck,
|
| 726 |
+
pretransform=pretransform,
|
| 727 |
+
in_channels=in_channels,
|
| 728 |
+
out_channels=out_channels,
|
| 729 |
+
soft_clip=soft_clip
|
| 730 |
+
)
|
| 731 |
+
|
| 732 |
+
def create_diffAE_from_config(config: Dict[str, Any]):
|
| 733 |
+
|
| 734 |
+
diffae_config = config["model"]
|
| 735 |
+
|
| 736 |
+
if "encoder" in diffae_config:
|
| 737 |
+
encoder = create_encoder_from_config(diffae_config["encoder"])
|
| 738 |
+
else:
|
| 739 |
+
encoder = None
|
| 740 |
+
|
| 741 |
+
if "decoder" in diffae_config:
|
| 742 |
+
decoder = create_decoder_from_config(diffae_config["decoder"])
|
| 743 |
+
else:
|
| 744 |
+
decoder = None
|
| 745 |
+
|
| 746 |
+
diffusion_model_type = diffae_config["diffusion"]["type"]
|
| 747 |
+
|
| 748 |
+
if diffusion_model_type == "DAU1d":
|
| 749 |
+
diffusion = DAU1DCondWrapper(**diffae_config["diffusion"]["config"])
|
| 750 |
+
elif diffusion_model_type == "adp_1d":
|
| 751 |
+
diffusion = UNet1DCondWrapper(**diffae_config["diffusion"]["config"])
|
| 752 |
+
elif diffusion_model_type == "dit":
|
| 753 |
+
diffusion = DiTWrapper(**diffae_config["diffusion"]["config"])
|
| 754 |
+
|
| 755 |
+
latent_dim = diffae_config.get("latent_dim", None)
|
| 756 |
+
assert latent_dim is not None, "latent_dim must be specified in model config"
|
| 757 |
+
downsampling_ratio = diffae_config.get("downsampling_ratio", None)
|
| 758 |
+
assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
|
| 759 |
+
io_channels = diffae_config.get("io_channels", None)
|
| 760 |
+
assert io_channels is not None, "io_channels must be specified in model config"
|
| 761 |
+
sample_rate = config.get("sample_rate", None)
|
| 762 |
+
assert sample_rate is not None, "sample_rate must be specified in model config"
|
| 763 |
+
|
| 764 |
+
bottleneck = diffae_config.get("bottleneck", None)
|
| 765 |
+
|
| 766 |
+
pretransform = diffae_config.get("pretransform", None)
|
| 767 |
+
|
| 768 |
+
if pretransform is not None:
|
| 769 |
+
pretransform = create_pretransform_from_config(pretransform, sample_rate)
|
| 770 |
+
|
| 771 |
+
if bottleneck is not None:
|
| 772 |
+
bottleneck = create_bottleneck_from_config(bottleneck)
|
| 773 |
+
|
| 774 |
+
diffusion_downsampling_ratio = None,
|
| 775 |
+
|
| 776 |
+
if diffusion_model_type == "DAU1d":
|
| 777 |
+
diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["strides"])
|
| 778 |
+
elif diffusion_model_type == "adp_1d":
|
| 779 |
+
diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["factors"])
|
| 780 |
+
elif diffusion_model_type == "dit":
|
| 781 |
+
diffusion_downsampling_ratio = 1
|
| 782 |
+
|
| 783 |
+
return DiffusionAutoencoder(
|
| 784 |
+
encoder=encoder,
|
| 785 |
+
decoder=decoder,
|
| 786 |
+
diffusion=diffusion,
|
| 787 |
+
io_channels=io_channels,
|
| 788 |
+
sample_rate=sample_rate,
|
| 789 |
+
latent_dim=latent_dim,
|
| 790 |
+
downsampling_ratio=downsampling_ratio,
|
| 791 |
+
diffusion_downsampling_ratio=diffusion_downsampling_ratio,
|
| 792 |
+
bottleneck=bottleneck,
|
| 793 |
+
pretransform=pretransform
|
| 794 |
+
)
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/stable_audio_tools/models/bottleneck.py
ADDED
|
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
+
|
| 6 |
+
from einops import rearrange
|
| 7 |
+
from vector_quantize_pytorch import ResidualVQ, FSQ
|
| 8 |
+
from dac.nn.quantize import ResidualVectorQuantize as DACResidualVQ
|
| 9 |
+
|
| 10 |
+
class Bottleneck(nn.Module):
|
| 11 |
+
def __init__(self, is_discrete: bool = False):
|
| 12 |
+
super().__init__()
|
| 13 |
+
|
| 14 |
+
self.is_discrete = is_discrete
|
| 15 |
+
|
| 16 |
+
def encode(self, x, return_info=False, **kwargs):
|
| 17 |
+
raise NotImplementedError
|
| 18 |
+
|
| 19 |
+
def decode(self, x):
|
| 20 |
+
raise NotImplementedError
|
| 21 |
+
|
| 22 |
+
class DiscreteBottleneck(Bottleneck):
|
| 23 |
+
def __init__(self, num_quantizers, codebook_size, tokens_id):
|
| 24 |
+
super().__init__(is_discrete=True)
|
| 25 |
+
|
| 26 |
+
self.num_quantizers = num_quantizers
|
| 27 |
+
self.codebook_size = codebook_size
|
| 28 |
+
self.tokens_id = tokens_id
|
| 29 |
+
|
| 30 |
+
def decode_tokens(self, codes, **kwargs):
|
| 31 |
+
raise NotImplementedError
|
| 32 |
+
|
| 33 |
+
class TanhBottleneck(Bottleneck):
|
| 34 |
+
def __init__(self):
|
| 35 |
+
super().__init__(is_discrete=False)
|
| 36 |
+
self.tanh = nn.Tanh()
|
| 37 |
+
|
| 38 |
+
def encode(self, x, return_info=False):
|
| 39 |
+
info = {}
|
| 40 |
+
|
| 41 |
+
x = torch.tanh(x)
|
| 42 |
+
|
| 43 |
+
if return_info:
|
| 44 |
+
return x, info
|
| 45 |
+
else:
|
| 46 |
+
return x
|
| 47 |
+
|
| 48 |
+
def decode(self, x):
|
| 49 |
+
return x
|
| 50 |
+
|
| 51 |
+
def vae_sample(mean, scale):
|
| 52 |
+
stdev = nn.functional.softplus(scale) + 1e-4
|
| 53 |
+
var = stdev * stdev
|
| 54 |
+
logvar = torch.log(var)
|
| 55 |
+
latents = torch.randn_like(mean) * stdev + mean
|
| 56 |
+
|
| 57 |
+
kl = (mean * mean + var - logvar - 1).sum(1).mean()
|
| 58 |
+
|
| 59 |
+
return latents, kl
|
| 60 |
+
|
| 61 |
+
class VAEBottleneck(Bottleneck):
|
| 62 |
+
def __init__(self):
|
| 63 |
+
super().__init__(is_discrete=False)
|
| 64 |
+
|
| 65 |
+
def encode(self, x, return_info=False, **kwargs):
|
| 66 |
+
info = {}
|
| 67 |
+
|
| 68 |
+
mean, scale = x.chunk(2, dim=1)
|
| 69 |
+
|
| 70 |
+
x, kl = vae_sample(mean, scale)
|
| 71 |
+
|
| 72 |
+
info["kl"] = kl
|
| 73 |
+
|
| 74 |
+
if return_info:
|
| 75 |
+
return x, info
|
| 76 |
+
else:
|
| 77 |
+
return x
|
| 78 |
+
|
| 79 |
+
def decode(self, x):
|
| 80 |
+
return x
|
| 81 |
+
|
| 82 |
+
def compute_mean_kernel(x, y):
|
| 83 |
+
kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1]
|
| 84 |
+
return torch.exp(-kernel_input).mean()
|
| 85 |
+
|
| 86 |
+
def compute_mmd(latents):
|
| 87 |
+
latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1])
|
| 88 |
+
noise = torch.randn_like(latents_reshaped)
|
| 89 |
+
|
| 90 |
+
latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped)
|
| 91 |
+
noise_kernel = compute_mean_kernel(noise, noise)
|
| 92 |
+
latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise)
|
| 93 |
+
|
| 94 |
+
mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel
|
| 95 |
+
return mmd.mean()
|
| 96 |
+
|
| 97 |
+
class WassersteinBottleneck(Bottleneck):
|
| 98 |
+
def __init__(self, noise_augment_dim: int = 0, bypass_mmd: bool = False):
|
| 99 |
+
super().__init__(is_discrete=False)
|
| 100 |
+
|
| 101 |
+
self.noise_augment_dim = noise_augment_dim
|
| 102 |
+
self.bypass_mmd = bypass_mmd
|
| 103 |
+
|
| 104 |
+
def encode(self, x, return_info=False):
|
| 105 |
+
info = {}
|
| 106 |
+
|
| 107 |
+
if self.training and return_info:
|
| 108 |
+
if self.bypass_mmd:
|
| 109 |
+
mmd = torch.tensor(0.0)
|
| 110 |
+
else:
|
| 111 |
+
mmd = compute_mmd(x)
|
| 112 |
+
|
| 113 |
+
info["mmd"] = mmd
|
| 114 |
+
|
| 115 |
+
if return_info:
|
| 116 |
+
return x, info
|
| 117 |
+
|
| 118 |
+
return x
|
| 119 |
+
|
| 120 |
+
def decode(self, x):
|
| 121 |
+
|
| 122 |
+
if self.noise_augment_dim > 0:
|
| 123 |
+
noise = torch.randn(x.shape[0], self.noise_augment_dim,
|
| 124 |
+
x.shape[-1]).type_as(x)
|
| 125 |
+
x = torch.cat([x, noise], dim=1)
|
| 126 |
+
|
| 127 |
+
return x
|
| 128 |
+
|
| 129 |
+
class L2Bottleneck(Bottleneck):
|
| 130 |
+
def __init__(self):
|
| 131 |
+
super().__init__(is_discrete=False)
|
| 132 |
+
|
| 133 |
+
def encode(self, x, return_info=False):
|
| 134 |
+
info = {}
|
| 135 |
+
|
| 136 |
+
x = F.normalize(x, dim=1)
|
| 137 |
+
|
| 138 |
+
if return_info:
|
| 139 |
+
return x, info
|
| 140 |
+
else:
|
| 141 |
+
return x
|
| 142 |
+
|
| 143 |
+
def decode(self, x):
|
| 144 |
+
return F.normalize(x, dim=1)
|
| 145 |
+
|
| 146 |
+
class RVQBottleneck(DiscreteBottleneck):
|
| 147 |
+
def __init__(self, **quantizer_kwargs):
|
| 148 |
+
super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
|
| 149 |
+
self.quantizer = ResidualVQ(**quantizer_kwargs)
|
| 150 |
+
self.num_quantizers = quantizer_kwargs["num_quantizers"]
|
| 151 |
+
|
| 152 |
+
def encode(self, x, return_info=False, **kwargs):
|
| 153 |
+
info = {}
|
| 154 |
+
|
| 155 |
+
x = rearrange(x, "b c n -> b n c")
|
| 156 |
+
x, indices, loss = self.quantizer(x)
|
| 157 |
+
x = rearrange(x, "b n c -> b c n")
|
| 158 |
+
|
| 159 |
+
info["quantizer_indices"] = indices
|
| 160 |
+
info["quantizer_loss"] = loss.mean()
|
| 161 |
+
|
| 162 |
+
if return_info:
|
| 163 |
+
return x, info
|
| 164 |
+
else:
|
| 165 |
+
return x
|
| 166 |
+
|
| 167 |
+
def decode(self, x):
|
| 168 |
+
return x
|
| 169 |
+
|
| 170 |
+
def decode_tokens(self, codes, **kwargs):
|
| 171 |
+
latents = self.quantizer.get_outputs_from_indices(codes)
|
| 172 |
+
|
| 173 |
+
return self.decode(latents, **kwargs)
|
| 174 |
+
|
| 175 |
+
class RVQVAEBottleneck(DiscreteBottleneck):
|
| 176 |
+
def __init__(self, **quantizer_kwargs):
|
| 177 |
+
super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
|
| 178 |
+
self.quantizer = ResidualVQ(**quantizer_kwargs)
|
| 179 |
+
self.num_quantizers = quantizer_kwargs["num_quantizers"]
|
| 180 |
+
|
| 181 |
+
def encode(self, x, return_info=False):
|
| 182 |
+
info = {}
|
| 183 |
+
|
| 184 |
+
x, kl = vae_sample(*x.chunk(2, dim=1))
|
| 185 |
+
|
| 186 |
+
info["kl"] = kl
|
| 187 |
+
|
| 188 |
+
x = rearrange(x, "b c n -> b n c")
|
| 189 |
+
x, indices, loss = self.quantizer(x)
|
| 190 |
+
x = rearrange(x, "b n c -> b c n")
|
| 191 |
+
|
| 192 |
+
info["quantizer_indices"] = indices
|
| 193 |
+
info["quantizer_loss"] = loss.mean()
|
| 194 |
+
|
| 195 |
+
if return_info:
|
| 196 |
+
return x, info
|
| 197 |
+
else:
|
| 198 |
+
return x
|
| 199 |
+
|
| 200 |
+
def decode(self, x):
|
| 201 |
+
return x
|
| 202 |
+
|
| 203 |
+
def decode_tokens(self, codes, **kwargs):
|
| 204 |
+
latents = self.quantizer.get_outputs_from_indices(codes)
|
| 205 |
+
|
| 206 |
+
return self.decode(latents, **kwargs)
|
| 207 |
+
|
| 208 |
+
class DACRVQBottleneck(DiscreteBottleneck):
|
| 209 |
+
def __init__(self, quantize_on_decode=False, noise_augment_dim=0, **quantizer_kwargs):
|
| 210 |
+
super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
|
| 211 |
+
self.quantizer = DACResidualVQ(**quantizer_kwargs)
|
| 212 |
+
self.num_quantizers = quantizer_kwargs["n_codebooks"]
|
| 213 |
+
self.quantize_on_decode = quantize_on_decode
|
| 214 |
+
self.noise_augment_dim = noise_augment_dim
|
| 215 |
+
|
| 216 |
+
def encode(self, x, return_info=False, **kwargs):
|
| 217 |
+
info = {}
|
| 218 |
+
|
| 219 |
+
info["pre_quantizer"] = x
|
| 220 |
+
|
| 221 |
+
if self.quantize_on_decode:
|
| 222 |
+
return x, info if return_info else x
|
| 223 |
+
|
| 224 |
+
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs)
|
| 225 |
+
|
| 226 |
+
output = {
|
| 227 |
+
"z": z,
|
| 228 |
+
"codes": codes,
|
| 229 |
+
"latents": latents,
|
| 230 |
+
"vq/commitment_loss": commitment_loss,
|
| 231 |
+
"vq/codebook_loss": codebook_loss,
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
output["vq/commitment_loss"] /= self.num_quantizers
|
| 235 |
+
output["vq/codebook_loss"] /= self.num_quantizers
|
| 236 |
+
|
| 237 |
+
info.update(output)
|
| 238 |
+
|
| 239 |
+
if return_info:
|
| 240 |
+
return output["z"], info
|
| 241 |
+
|
| 242 |
+
return output["z"]
|
| 243 |
+
|
| 244 |
+
def decode(self, x):
|
| 245 |
+
|
| 246 |
+
if self.quantize_on_decode:
|
| 247 |
+
x = self.quantizer(x)[0]
|
| 248 |
+
|
| 249 |
+
if self.noise_augment_dim > 0:
|
| 250 |
+
noise = torch.randn(x.shape[0], self.noise_augment_dim,
|
| 251 |
+
x.shape[-1]).type_as(x)
|
| 252 |
+
x = torch.cat([x, noise], dim=1)
|
| 253 |
+
|
| 254 |
+
return x
|
| 255 |
+
|
| 256 |
+
def decode_tokens(self, codes, **kwargs):
|
| 257 |
+
latents, _, _ = self.quantizer.from_codes(codes)
|
| 258 |
+
|
| 259 |
+
return self.decode(latents, **kwargs)
|
| 260 |
+
|
| 261 |
+
class DACRVQVAEBottleneck(DiscreteBottleneck):
|
| 262 |
+
def __init__(self, quantize_on_decode=False, **quantizer_kwargs):
|
| 263 |
+
super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
|
| 264 |
+
self.quantizer = DACResidualVQ(**quantizer_kwargs)
|
| 265 |
+
self.num_quantizers = quantizer_kwargs["n_codebooks"]
|
| 266 |
+
self.quantize_on_decode = quantize_on_decode
|
| 267 |
+
|
| 268 |
+
def encode(self, x, return_info=False, n_quantizers: int = None):
|
| 269 |
+
info = {}
|
| 270 |
+
|
| 271 |
+
mean, scale = x.chunk(2, dim=1)
|
| 272 |
+
|
| 273 |
+
x, kl = vae_sample(mean, scale)
|
| 274 |
+
|
| 275 |
+
info["pre_quantizer"] = x
|
| 276 |
+
info["kl"] = kl
|
| 277 |
+
|
| 278 |
+
if self.quantize_on_decode:
|
| 279 |
+
return x, info if return_info else x
|
| 280 |
+
|
| 281 |
+
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, n_quantizers=n_quantizers)
|
| 282 |
+
|
| 283 |
+
output = {
|
| 284 |
+
"z": z,
|
| 285 |
+
"codes": codes,
|
| 286 |
+
"latents": latents,
|
| 287 |
+
"vq/commitment_loss": commitment_loss,
|
| 288 |
+
"vq/codebook_loss": codebook_loss,
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
output["vq/commitment_loss"] /= self.num_quantizers
|
| 292 |
+
output["vq/codebook_loss"] /= self.num_quantizers
|
| 293 |
+
|
| 294 |
+
info.update(output)
|
| 295 |
+
|
| 296 |
+
if return_info:
|
| 297 |
+
return output["z"], info
|
| 298 |
+
|
| 299 |
+
return output["z"]
|
| 300 |
+
|
| 301 |
+
def decode(self, x):
|
| 302 |
+
|
| 303 |
+
if self.quantize_on_decode:
|
| 304 |
+
x = self.quantizer(x)[0]
|
| 305 |
+
|
| 306 |
+
return x
|
| 307 |
+
|
| 308 |
+
def decode_tokens(self, codes, **kwargs):
|
| 309 |
+
latents, _, _ = self.quantizer.from_codes(codes)
|
| 310 |
+
|
| 311 |
+
return self.decode(latents, **kwargs)
|
| 312 |
+
|
| 313 |
+
class FSQBottleneck(DiscreteBottleneck):
|
| 314 |
+
def __init__(self, noise_augment_dim=0, **kwargs):
|
| 315 |
+
super().__init__(num_quantizers = kwargs.get("num_codebooks", 1), codebook_size = np.prod(kwargs["levels"]), tokens_id = "quantizer_indices")
|
| 316 |
+
|
| 317 |
+
self.noise_augment_dim = noise_augment_dim
|
| 318 |
+
|
| 319 |
+
self.quantizer = FSQ(**kwargs, allowed_dtypes=[torch.float16, torch.float32, torch.float64])
|
| 320 |
+
|
| 321 |
+
def encode(self, x, return_info=False):
|
| 322 |
+
info = {}
|
| 323 |
+
|
| 324 |
+
orig_dtype = x.dtype
|
| 325 |
+
x = x.float()
|
| 326 |
+
|
| 327 |
+
x = rearrange(x, "b c n -> b n c")
|
| 328 |
+
x, indices = self.quantizer(x)
|
| 329 |
+
x = rearrange(x, "b n c -> b c n")
|
| 330 |
+
|
| 331 |
+
x = x.to(orig_dtype)
|
| 332 |
+
|
| 333 |
+
# Reorder indices to match the expected format
|
| 334 |
+
indices = rearrange(indices, "b n q -> b q n")
|
| 335 |
+
|
| 336 |
+
info["quantizer_indices"] = indices
|
| 337 |
+
|
| 338 |
+
if return_info:
|
| 339 |
+
return x, info
|
| 340 |
+
else:
|
| 341 |
+
return x
|
| 342 |
+
|
| 343 |
+
def decode(self, x):
|
| 344 |
+
|
| 345 |
+
if self.noise_augment_dim > 0:
|
| 346 |
+
noise = torch.randn(x.shape[0], self.noise_augment_dim,
|
| 347 |
+
x.shape[-1]).type_as(x)
|
| 348 |
+
x = torch.cat([x, noise], dim=1)
|
| 349 |
+
|
| 350 |
+
return x
|
| 351 |
+
|
| 352 |
+
def decode_tokens(self, tokens, **kwargs):
|
| 353 |
+
latents = self.quantizer.indices_to_codes(tokens)
|
| 354 |
+
|
| 355 |
+
return self.decode(latents, **kwargs)
|
Levo_Song_Generation/SongGeneration-Runtime/third_party/stable_audio_tools/stable_audio_tools/models/conditioners.py
ADDED
|
@@ -0,0 +1,561 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#Heavily influenced by https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conditioners.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import logging, warnings
|
| 5 |
+
import string
|
| 6 |
+
import typing as tp
|
| 7 |
+
import gc
|
| 8 |
+
|
| 9 |
+
from .adp import NumberEmbedder
|
| 10 |
+
from ..inference.utils import set_audio_channels
|
| 11 |
+
from .factory import create_pretransform_from_config
|
| 12 |
+
from .pretransforms import Pretransform
|
| 13 |
+
from ..training.utils import copy_state_dict
|
| 14 |
+
from .utils import load_ckpt_state_dict
|
| 15 |
+
|
| 16 |
+
from torch import nn
|
| 17 |
+
|
| 18 |
+
class Conditioner(nn.Module):
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
dim: int,
|
| 22 |
+
output_dim: int,
|
| 23 |
+
project_out: bool = False
|
| 24 |
+
):
|
| 25 |
+
|
| 26 |
+
super().__init__()
|
| 27 |
+
|
| 28 |
+
self.dim = dim
|
| 29 |
+
self.output_dim = output_dim
|
| 30 |
+
self.proj_out = nn.Linear(dim, output_dim) if (dim != output_dim or project_out) else nn.Identity()
|
| 31 |
+
|
| 32 |
+
def forward(self, x: tp.Any) -> tp.Any:
|
| 33 |
+
raise NotImplementedError()
|
| 34 |
+
|
| 35 |
+
class IntConditioner(Conditioner):
|
| 36 |
+
def __init__(self,
|
| 37 |
+
output_dim: int,
|
| 38 |
+
min_val: int=0,
|
| 39 |
+
max_val: int=512
|
| 40 |
+
):
|
| 41 |
+
super().__init__(output_dim, output_dim)
|
| 42 |
+
|
| 43 |
+
self.min_val = min_val
|
| 44 |
+
self.max_val = max_val
|
| 45 |
+
self.int_embedder = nn.Embedding(max_val - min_val + 1, output_dim).requires_grad_(True)
|
| 46 |
+
|
| 47 |
+
def forward(self, ints: tp.List[int], device=None) -> tp.Any:
|
| 48 |
+
|
| 49 |
+
#self.int_embedder.to(device)
|
| 50 |
+
|
| 51 |
+
ints = torch.tensor(ints).to(device)
|
| 52 |
+
ints = ints.clamp(self.min_val, self.max_val)
|
| 53 |
+
|
| 54 |
+
int_embeds = self.int_embedder(ints).unsqueeze(1)
|
| 55 |
+
|
| 56 |
+
return [int_embeds, torch.ones(int_embeds.shape[0], 1).to(device)]
|
| 57 |
+
|
| 58 |
+
class NumberConditioner(Conditioner):
|
| 59 |
+
'''
|
| 60 |
+
Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings
|
| 61 |
+
'''
|
| 62 |
+
def __init__(self,
|
| 63 |
+
output_dim: int,
|
| 64 |
+
min_val: float=0,
|
| 65 |
+
max_val: float=1
|
| 66 |
+
):
|
| 67 |
+
super().__init__(output_dim, output_dim)
|
| 68 |
+
|
| 69 |
+
self.min_val = min_val
|
| 70 |
+
self.max_val = max_val
|
| 71 |
+
|
| 72 |
+
self.embedder = NumberEmbedder(features=output_dim)
|
| 73 |
+
|
| 74 |
+
def forward(self, floats: tp.List[float], device=None) -> tp.Any:
|
| 75 |
+
|
| 76 |
+
# Cast the inputs to floats
|
| 77 |
+
floats = [float(x) for x in floats]
|
| 78 |
+
|
| 79 |
+
floats = torch.tensor(floats).to(device)
|
| 80 |
+
|
| 81 |
+
floats = floats.clamp(self.min_val, self.max_val)
|
| 82 |
+
|
| 83 |
+
normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val)
|
| 84 |
+
|
| 85 |
+
# Cast floats to same type as embedder
|
| 86 |
+
embedder_dtype = next(self.embedder.parameters()).dtype
|
| 87 |
+
normalized_floats = normalized_floats.to(embedder_dtype)
|
| 88 |
+
|
| 89 |
+
float_embeds = self.embedder(normalized_floats).unsqueeze(1)
|
| 90 |
+
|
| 91 |
+
return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)]
|
| 92 |
+
|
| 93 |
+
class CLAPTextConditioner(Conditioner):
|
| 94 |
+
def __init__(self,
|
| 95 |
+
output_dim: int,
|
| 96 |
+
clap_ckpt_path,
|
| 97 |
+
use_text_features = False,
|
| 98 |
+
feature_layer_ix: int = -1,
|
| 99 |
+
audio_model_type="HTSAT-base",
|
| 100 |
+
enable_fusion=True,
|
| 101 |
+
project_out: bool = False,
|
| 102 |
+
finetune: bool = False):
|
| 103 |
+
super().__init__(768 if use_text_features else 512, output_dim, project_out=project_out)
|
| 104 |
+
|
| 105 |
+
self.use_text_features = use_text_features
|
| 106 |
+
self.feature_layer_ix = feature_layer_ix
|
| 107 |
+
self.finetune = finetune
|
| 108 |
+
|
| 109 |
+
# Suppress logging from transformers
|
| 110 |
+
previous_level = logging.root.manager.disable
|
| 111 |
+
logging.disable(logging.ERROR)
|
| 112 |
+
with warnings.catch_warnings():
|
| 113 |
+
warnings.simplefilter("ignore")
|
| 114 |
+
try:
|
| 115 |
+
import laion_clap
|
| 116 |
+
from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict
|
| 117 |
+
|
| 118 |
+
model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu')
|
| 119 |
+
|
| 120 |
+
if self.finetune:
|
| 121 |
+
self.model = model
|
| 122 |
+
else:
|
| 123 |
+
self.__dict__["model"] = model
|
| 124 |
+
|
| 125 |
+
state_dict = clap_load_state_dict(clap_ckpt_path)
|
| 126 |
+
self.model.model.load_state_dict(state_dict, strict=False)
|
| 127 |
+
|
| 128 |
+
if self.finetune:
|
| 129 |
+
self.model.model.text_branch.requires_grad_(True)
|
| 130 |
+
self.model.model.text_branch.train()
|
| 131 |
+
else:
|
| 132 |
+
self.model.model.text_branch.requires_grad_(False)
|
| 133 |
+
self.model.model.text_branch.eval()
|
| 134 |
+
|
| 135 |
+
finally:
|
| 136 |
+
logging.disable(previous_level)
|
| 137 |
+
|
| 138 |
+
del self.model.model.audio_branch
|
| 139 |
+
|
| 140 |
+
gc.collect()
|
| 141 |
+
torch.cuda.empty_cache()
|
| 142 |
+
|
| 143 |
+
def get_clap_features(self, prompts, layer_ix=-2, device: tp.Any = "cuda"):
|
| 144 |
+
prompt_tokens = self.model.tokenizer(prompts)
|
| 145 |
+
attention_mask = prompt_tokens["attention_mask"].to(device=device, non_blocking=True)
|
| 146 |
+
prompt_features = self.model.model.text_branch(
|
| 147 |
+
input_ids=prompt_tokens["input_ids"].to(device=device, non_blocking=True),
|
| 148 |
+
attention_mask=attention_mask,
|
| 149 |
+
output_hidden_states=True
|
| 150 |
+
)["hidden_states"][layer_ix]
|
| 151 |
+
|
| 152 |
+
return prompt_features, attention_mask
|
| 153 |
+
|
| 154 |
+
def forward(self, texts: tp.List[str], device: tp.Any = "cuda") -> tp.Any:
|
| 155 |
+
self.model.to(device)
|
| 156 |
+
|
| 157 |
+
if self.use_text_features:
|
| 158 |
+
if len(texts) == 1:
|
| 159 |
+
text_features, text_attention_mask = self.get_clap_features([texts[0], ""], layer_ix=self.feature_layer_ix, device=device)
|
| 160 |
+
text_features = text_features[:1, ...]
|
| 161 |
+
text_attention_mask = text_attention_mask[:1, ...]
|
| 162 |
+
else:
|
| 163 |
+
text_features, text_attention_mask = self.get_clap_features(texts, layer_ix=self.feature_layer_ix, device=device)
|
| 164 |
+
return [self.proj_out(text_features), text_attention_mask]
|
| 165 |
+
|
| 166 |
+
# Fix for CLAP bug when only one text is passed
|
| 167 |
+
if len(texts) == 1:
|
| 168 |
+
text_embedding = self.model.get_text_embedding([texts[0], ""], use_tensor=True)[:1, ...]
|
| 169 |
+
else:
|
| 170 |
+
text_embedding = self.model.get_text_embedding(texts, use_tensor=True)
|
| 171 |
+
|
| 172 |
+
text_embedding = text_embedding.unsqueeze(1).to(device)
|
| 173 |
+
|
| 174 |
+
return [self.proj_out(text_embedding), torch.ones(text_embedding.shape[0], 1).to(device)]
|
| 175 |
+
|
| 176 |
+
class CLAPAudioConditioner(Conditioner):
|
| 177 |
+
def __init__(self,
|
| 178 |
+
output_dim: int,
|
| 179 |
+
clap_ckpt_path,
|
| 180 |
+
audio_model_type="HTSAT-base",
|
| 181 |
+
enable_fusion=True,
|
| 182 |
+
project_out: bool = False):
|
| 183 |
+
super().__init__(512, output_dim, project_out=project_out)
|
| 184 |
+
|
| 185 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 186 |
+
|
| 187 |
+
# Suppress logging from transformers
|
| 188 |
+
previous_level = logging.root.manager.disable
|
| 189 |
+
logging.disable(logging.ERROR)
|
| 190 |
+
with warnings.catch_warnings():
|
| 191 |
+
warnings.simplefilter("ignore")
|
| 192 |
+
try:
|
| 193 |
+
import laion_clap
|
| 194 |
+
from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict
|
| 195 |
+
|
| 196 |
+
model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu')
|
| 197 |
+
|
| 198 |
+
if self.finetune:
|
| 199 |
+
self.model = model
|
| 200 |
+
else:
|
| 201 |
+
self.__dict__["model"] = model
|
| 202 |
+
|
| 203 |
+
state_dict = clap_load_state_dict(clap_ckpt_path)
|
| 204 |
+
self.model.model.load_state_dict(state_dict, strict=False)
|
| 205 |
+
|
| 206 |
+
if self.finetune:
|
| 207 |
+
self.model.model.audio_branch.requires_grad_(True)
|
| 208 |
+
self.model.model.audio_branch.train()
|
| 209 |
+
else:
|
| 210 |
+
self.model.model.audio_branch.requires_grad_(False)
|
| 211 |
+
self.model.model.audio_branch.eval()
|
| 212 |
+
|
| 213 |
+
finally:
|
| 214 |
+
logging.disable(previous_level)
|
| 215 |
+
|
| 216 |
+
del self.model.model.text_branch
|
| 217 |
+
|
| 218 |
+
gc.collect()
|
| 219 |
+
torch.cuda.empty_cache()
|
| 220 |
+
|
| 221 |
+
def forward(self, audios: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]] , device: tp.Any = "cuda") -> tp.Any:
|
| 222 |
+
|
| 223 |
+
self.model.to(device)
|
| 224 |
+
|
| 225 |
+
if isinstance(audios, list) or isinstance(audios, tuple):
|
| 226 |
+
audios = torch.cat(audios, dim=0)
|
| 227 |
+
|
| 228 |
+
# Convert to mono
|
| 229 |
+
mono_audios = audios.mean(dim=1)
|
| 230 |
+
|
| 231 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 232 |
+
audio_embedding = self.model.get_audio_embedding_from_data(mono_audios.float(), use_tensor=True)
|
| 233 |
+
|
| 234 |
+
audio_embedding = audio_embedding.unsqueeze(1).to(device)
|
| 235 |
+
|
| 236 |
+
return [self.proj_out(audio_embedding), torch.ones(audio_embedding.shape[0], 1).to(device)]
|
| 237 |
+
|
| 238 |
+
class T5Conditioner(Conditioner):
|
| 239 |
+
|
| 240 |
+
T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
|
| 241 |
+
"google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large",
|
| 242 |
+
"google/flan-t5-xl", "google/flan-t5-xxl"]
|
| 243 |
+
|
| 244 |
+
T5_MODEL_DIMS = {
|
| 245 |
+
"t5-small": 512,
|
| 246 |
+
"t5-base": 768,
|
| 247 |
+
"t5-large": 1024,
|
| 248 |
+
"t5-3b": 1024,
|
| 249 |
+
"t5-11b": 1024,
|
| 250 |
+
"t5-xl": 2048,
|
| 251 |
+
"t5-xxl": 4096,
|
| 252 |
+
"google/flan-t5-small": 512,
|
| 253 |
+
"google/flan-t5-base": 768,
|
| 254 |
+
"google/flan-t5-large": 1024,
|
| 255 |
+
"google/flan-t5-3b": 1024,
|
| 256 |
+
"google/flan-t5-11b": 1024,
|
| 257 |
+
"google/flan-t5-xl": 2048,
|
| 258 |
+
"google/flan-t5-xxl": 4096,
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
def __init__(
|
| 262 |
+
self,
|
| 263 |
+
output_dim: int,
|
| 264 |
+
t5_model_name: str = "t5-base",
|
| 265 |
+
max_length: str = 128,
|
| 266 |
+
enable_grad: bool = False,
|
| 267 |
+
project_out: bool = False
|
| 268 |
+
):
|
| 269 |
+
assert t5_model_name in self.T5_MODELS, f"Unknown T5 model name: {t5_model_name}"
|
| 270 |
+
super().__init__(self.T5_MODEL_DIMS[t5_model_name], output_dim, project_out=project_out)
|
| 271 |
+
|
| 272 |
+
from transformers import T5EncoderModel, AutoTokenizer
|
| 273 |
+
|
| 274 |
+
self.max_length = max_length
|
| 275 |
+
self.enable_grad = enable_grad
|
| 276 |
+
|
| 277 |
+
# Suppress logging from transformers
|
| 278 |
+
previous_level = logging.root.manager.disable
|
| 279 |
+
logging.disable(logging.ERROR)
|
| 280 |
+
with warnings.catch_warnings():
|
| 281 |
+
warnings.simplefilter("ignore")
|
| 282 |
+
try:
|
| 283 |
+
# self.tokenizer = T5Tokenizer.from_pretrained(t5_model_name, model_max_length = max_length)
|
| 284 |
+
# model = T5EncoderModel.from_pretrained(t5_model_name, max_length=max_length).train(enable_grad).requires_grad_(enable_grad)
|
| 285 |
+
self.tokenizer = AutoTokenizer.from_pretrained(t5_model_name)
|
| 286 |
+
model = T5EncoderModel.from_pretrained(t5_model_name).train(enable_grad).requires_grad_(enable_grad).to(torch.float16)
|
| 287 |
+
finally:
|
| 288 |
+
logging.disable(previous_level)
|
| 289 |
+
|
| 290 |
+
if self.enable_grad:
|
| 291 |
+
self.model = model
|
| 292 |
+
else:
|
| 293 |
+
self.__dict__["model"] = model
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
| 297 |
+
|
| 298 |
+
self.model.to(device)
|
| 299 |
+
self.proj_out.to(device)
|
| 300 |
+
|
| 301 |
+
encoded = self.tokenizer(
|
| 302 |
+
texts,
|
| 303 |
+
truncation=True,
|
| 304 |
+
max_length=self.max_length,
|
| 305 |
+
padding="max_length",
|
| 306 |
+
return_tensors="pt",
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
input_ids = encoded["input_ids"].to(device)
|
| 310 |
+
attention_mask = encoded["attention_mask"].to(device).to(torch.bool)
|
| 311 |
+
|
| 312 |
+
self.model.eval()
|
| 313 |
+
|
| 314 |
+
with torch.cuda.amp.autocast(dtype=torch.float16) and torch.set_grad_enabled(self.enable_grad):
|
| 315 |
+
embeddings = self.model(
|
| 316 |
+
input_ids=input_ids, attention_mask=attention_mask
|
| 317 |
+
)["last_hidden_state"]
|
| 318 |
+
|
| 319 |
+
embeddings = self.proj_out(embeddings.float())
|
| 320 |
+
|
| 321 |
+
embeddings = embeddings * attention_mask.unsqueeze(-1).float()
|
| 322 |
+
|
| 323 |
+
return embeddings, attention_mask
|
| 324 |
+
|
| 325 |
+
class PhonemeConditioner(Conditioner):
|
| 326 |
+
"""
|
| 327 |
+
A conditioner that turns text into phonemes and embeds them using a lookup table
|
| 328 |
+
Only works for English text
|
| 329 |
+
|
| 330 |
+
Args:
|
| 331 |
+
output_dim: the dimension of the output embeddings
|
| 332 |
+
max_length: the maximum number of phonemes to embed
|
| 333 |
+
project_out: whether to add another linear projection to the output embeddings
|
| 334 |
+
"""
|
| 335 |
+
|
| 336 |
+
def __init__(
|
| 337 |
+
self,
|
| 338 |
+
output_dim: int,
|
| 339 |
+
max_length: int = 1024,
|
| 340 |
+
project_out: bool = False,
|
| 341 |
+
):
|
| 342 |
+
super().__init__(output_dim, output_dim, project_out=project_out)
|
| 343 |
+
|
| 344 |
+
from g2p_en import G2p
|
| 345 |
+
|
| 346 |
+
self.max_length = max_length
|
| 347 |
+
|
| 348 |
+
self.g2p = G2p()
|
| 349 |
+
|
| 350 |
+
# Reserving 0 for padding, 1 for ignored
|
| 351 |
+
self.phoneme_embedder = nn.Embedding(len(self.g2p.phonemes) + 2, output_dim)
|
| 352 |
+
|
| 353 |
+
def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
| 354 |
+
|
| 355 |
+
self.phoneme_embedder.to(device)
|
| 356 |
+
self.proj_out.to(device)
|
| 357 |
+
|
| 358 |
+
batch_phonemes = [self.g2p(text) for text in texts] # shape [batch_size, length]
|
| 359 |
+
|
| 360 |
+
phoneme_ignore = [" ", *string.punctuation]
|
| 361 |
+
|
| 362 |
+
# Remove ignored phonemes and cut to max length
|
| 363 |
+
batch_phonemes = [[p if p not in phoneme_ignore else "_" for p in phonemes] for phonemes in batch_phonemes]
|
| 364 |
+
|
| 365 |
+
# Convert to ids
|
| 366 |
+
phoneme_ids = [[self.g2p.p2idx[p] + 2 if p in self.g2p.p2idx else 1 for p in phonemes] for phonemes in batch_phonemes]
|
| 367 |
+
|
| 368 |
+
#Pad to match longest and make a mask tensor for the padding
|
| 369 |
+
longest = max([len(ids) for ids in phoneme_ids])
|
| 370 |
+
phoneme_ids = [ids + [0] * (longest - len(ids)) for ids in phoneme_ids]
|
| 371 |
+
|
| 372 |
+
phoneme_ids = torch.tensor(phoneme_ids).to(device)
|
| 373 |
+
|
| 374 |
+
# Convert to embeddings
|
| 375 |
+
phoneme_embeds = self.phoneme_embedder(phoneme_ids)
|
| 376 |
+
|
| 377 |
+
phoneme_embeds = self.proj_out(phoneme_embeds)
|
| 378 |
+
|
| 379 |
+
return phoneme_embeds, torch.ones(phoneme_embeds.shape[0], phoneme_embeds.shape[1]).to(device)
|
| 380 |
+
|
| 381 |
+
class TokenizerLUTConditioner(Conditioner):
|
| 382 |
+
"""
|
| 383 |
+
A conditioner that embeds text using a lookup table on a pretrained tokenizer's vocabulary
|
| 384 |
+
|
| 385 |
+
Args:
|
| 386 |
+
tokenizer_name: the name of the tokenizer from the Hugging Face transformers library
|
| 387 |
+
output_dim: the dimension of the output embeddings
|
| 388 |
+
max_length: the maximum length of the text to embed
|
| 389 |
+
project_out: whether to add another linear projection to the output embeddings
|
| 390 |
+
"""
|
| 391 |
+
|
| 392 |
+
def __init__(
|
| 393 |
+
self,
|
| 394 |
+
tokenizer_name: str, # Name of a tokenizer from the Hugging Face transformers library
|
| 395 |
+
output_dim: int,
|
| 396 |
+
max_length: int = 1024,
|
| 397 |
+
project_out: bool = False,
|
| 398 |
+
):
|
| 399 |
+
super().__init__(output_dim, output_dim, project_out=project_out)
|
| 400 |
+
|
| 401 |
+
from transformers import AutoTokenizer
|
| 402 |
+
|
| 403 |
+
# Suppress logging from transformers
|
| 404 |
+
previous_level = logging.root.manager.disable
|
| 405 |
+
logging.disable(logging.ERROR)
|
| 406 |
+
with warnings.catch_warnings():
|
| 407 |
+
warnings.simplefilter("ignore")
|
| 408 |
+
try:
|
| 409 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
| 410 |
+
finally:
|
| 411 |
+
logging.disable(previous_level)
|
| 412 |
+
|
| 413 |
+
self.max_length = max_length
|
| 414 |
+
|
| 415 |
+
self.token_embedder = nn.Embedding(len(self.tokenizer), output_dim)
|
| 416 |
+
|
| 417 |
+
def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
| 418 |
+
self.proj_out.to(device)
|
| 419 |
+
|
| 420 |
+
encoded = self.tokenizer(
|
| 421 |
+
texts,
|
| 422 |
+
truncation=True,
|
| 423 |
+
max_length=self.max_length,
|
| 424 |
+
padding="max_length",
|
| 425 |
+
return_tensors="pt",
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
input_ids = encoded["input_ids"].to(device)
|
| 429 |
+
attention_mask = encoded["attention_mask"].to(device).to(torch.bool)
|
| 430 |
+
|
| 431 |
+
embeddings = self.token_embedder(input_ids)
|
| 432 |
+
|
| 433 |
+
embeddings = self.proj_out(embeddings)
|
| 434 |
+
|
| 435 |
+
embeddings = embeddings * attention_mask.unsqueeze(-1).float()
|
| 436 |
+
|
| 437 |
+
return embeddings, attention_mask
|
| 438 |
+
|
| 439 |
+
class PretransformConditioner(Conditioner):
|
| 440 |
+
"""
|
| 441 |
+
A conditioner that uses a pretransform's encoder for conditioning
|
| 442 |
+
|
| 443 |
+
Args:
|
| 444 |
+
pretransform: an instantiated pretransform to use for conditioning
|
| 445 |
+
output_dim: the dimension of the output embeddings
|
| 446 |
+
"""
|
| 447 |
+
def __init__(self, pretransform: Pretransform, output_dim: int):
|
| 448 |
+
super().__init__(pretransform.encoded_channels, output_dim)
|
| 449 |
+
|
| 450 |
+
self.pretransform = pretransform
|
| 451 |
+
|
| 452 |
+
def forward(self, audio: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
| 453 |
+
|
| 454 |
+
self.pretransform.to(device)
|
| 455 |
+
self.proj_out.to(device)
|
| 456 |
+
|
| 457 |
+
if isinstance(audio, list) or isinstance(audio, tuple):
|
| 458 |
+
audio = torch.cat(audio, dim=0)
|
| 459 |
+
|
| 460 |
+
# Convert audio to pretransform input channels
|
| 461 |
+
audio = set_audio_channels(audio, self.pretransform.io_channels)
|
| 462 |
+
|
| 463 |
+
latents = self.pretransform.encode(audio)
|
| 464 |
+
|
| 465 |
+
latents = self.proj_out(latents)
|
| 466 |
+
|
| 467 |
+
return [latents, torch.ones(latents.shape[0], latents.shape[2]).to(latents.device)]
|
| 468 |
+
|
| 469 |
+
class MultiConditioner(nn.Module):
|
| 470 |
+
"""
|
| 471 |
+
A module that applies multiple conditioners to an input dictionary based on the keys
|
| 472 |
+
|
| 473 |
+
Args:
|
| 474 |
+
conditioners: a dictionary of conditioners with keys corresponding to the keys of the conditioning input dictionary (e.g. "prompt")
|
| 475 |
+
default_keys: a dictionary of default keys to use if the key is not in the input dictionary (e.g. {"prompt_t5": "prompt"})
|
| 476 |
+
"""
|
| 477 |
+
def __init__(self, conditioners: tp.Dict[str, Conditioner], default_keys: tp.Dict[str, str] = {}):
|
| 478 |
+
super().__init__()
|
| 479 |
+
|
| 480 |
+
self.conditioners = nn.ModuleDict(conditioners)
|
| 481 |
+
self.default_keys = default_keys
|
| 482 |
+
|
| 483 |
+
def forward(self, batch_metadata: tp.List[tp.Dict[str, tp.Any]], device: tp.Union[torch.device, str]) -> tp.Dict[str, tp.Any]:
|
| 484 |
+
output = {}
|
| 485 |
+
|
| 486 |
+
for key, conditioner in self.conditioners.items():
|
| 487 |
+
condition_key = key
|
| 488 |
+
|
| 489 |
+
conditioner_inputs = []
|
| 490 |
+
|
| 491 |
+
for x in batch_metadata:
|
| 492 |
+
|
| 493 |
+
if condition_key not in x:
|
| 494 |
+
if condition_key in self.default_keys:
|
| 495 |
+
condition_key = self.default_keys[condition_key]
|
| 496 |
+
else:
|
| 497 |
+
raise ValueError(f"Conditioner key {condition_key} not found in batch metadata")
|
| 498 |
+
|
| 499 |
+
#Unwrap the condition info if it's a single-element list or tuple, this is to support collation functions that wrap everything in a list
|
| 500 |
+
if isinstance(x[condition_key], list) or isinstance(x[condition_key], tuple) and len(x[condition_key]) == 1:
|
| 501 |
+
conditioner_input = x[condition_key][0]
|
| 502 |
+
|
| 503 |
+
else:
|
| 504 |
+
conditioner_input = x[condition_key]
|
| 505 |
+
|
| 506 |
+
conditioner_inputs.append(conditioner_input)
|
| 507 |
+
|
| 508 |
+
output[key] = conditioner(conditioner_inputs, device)
|
| 509 |
+
|
| 510 |
+
return output
|
| 511 |
+
|
| 512 |
+
def create_multi_conditioner_from_conditioning_config(config: tp.Dict[str, tp.Any]) -> MultiConditioner:
|
| 513 |
+
"""
|
| 514 |
+
Create a MultiConditioner from a conditioning config dictionary
|
| 515 |
+
|
| 516 |
+
Args:
|
| 517 |
+
config: the conditioning config dictionary
|
| 518 |
+
device: the device to put the conditioners on
|
| 519 |
+
"""
|
| 520 |
+
conditioners = {}
|
| 521 |
+
cond_dim = config["cond_dim"]
|
| 522 |
+
|
| 523 |
+
default_keys = config.get("default_keys", {})
|
| 524 |
+
|
| 525 |
+
for conditioner_info in config["configs"]:
|
| 526 |
+
id = conditioner_info["id"]
|
| 527 |
+
|
| 528 |
+
conditioner_type = conditioner_info["type"]
|
| 529 |
+
|
| 530 |
+
conditioner_config = {"output_dim": cond_dim}
|
| 531 |
+
|
| 532 |
+
conditioner_config.update(conditioner_info["config"])
|
| 533 |
+
|
| 534 |
+
if conditioner_type == "t5":
|
| 535 |
+
conditioners[id] = T5Conditioner(**conditioner_config)
|
| 536 |
+
elif conditioner_type == "clap_text":
|
| 537 |
+
conditioners[id] = CLAPTextConditioner(**conditioner_config)
|
| 538 |
+
elif conditioner_type == "clap_audio":
|
| 539 |
+
conditioners[id] = CLAPAudioConditioner(**conditioner_config)
|
| 540 |
+
elif conditioner_type == "int":
|
| 541 |
+
conditioners[id] = IntConditioner(**conditioner_config)
|
| 542 |
+
elif conditioner_type == "number":
|
| 543 |
+
conditioners[id] = NumberConditioner(**conditioner_config)
|
| 544 |
+
elif conditioner_type == "phoneme":
|
| 545 |
+
conditioners[id] = PhonemeConditioner(**conditioner_config)
|
| 546 |
+
elif conditioner_type == "lut":
|
| 547 |
+
conditioners[id] = TokenizerLUTConditioner(**conditioner_config)
|
| 548 |
+
elif conditioner_type == "pretransform":
|
| 549 |
+
sample_rate = conditioner_config.pop("sample_rate", None)
|
| 550 |
+
assert sample_rate is not None, "Sample rate must be specified for pretransform conditioners"
|
| 551 |
+
|
| 552 |
+
pretransform = create_pretransform_from_config(conditioner_config.pop("pretransform_config"), sample_rate=sample_rate)
|
| 553 |
+
|
| 554 |
+
if conditioner_config.get("pretransform_ckpt_path", None) is not None:
|
| 555 |
+
pretransform.load_state_dict(load_ckpt_state_dict(conditioner_config.pop("pretransform_ckpt_path")))
|
| 556 |
+
|
| 557 |
+
conditioners[id] = PretransformConditioner(pretransform, **conditioner_config)
|
| 558 |
+
else:
|
| 559 |
+
raise ValueError(f"Unknown conditioner type: {conditioner_type}")
|
| 560 |
+
|
| 561 |
+
return MultiConditioner(conditioners, default_keys=default_keys)
|