Commit
·
a6026b2
verified
·
0
Parent(s):
Duplicate from tencent/SongGeneration
Browse filesCo-authored-by: way tan <waytan22@users.noreply.huggingface.co>
This view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +36 -0
- LICENSE +211 -0
- README.md +39 -0
- ckpt/encode-s12k.pt +3 -0
- ckpt/model_1rvq/model_2_fixed.safetensors +3 -0
- ckpt/model_septoken/model_2.safetensors +3 -0
- ckpt/models--lengyue233--content-vec-best/.no_exist/c0b9ba13db21beaa4053faae94c102ebe326fd68/model.safetensors +0 -0
- ckpt/models--lengyue233--content-vec-best/.no_exist/c0b9ba13db21beaa4053faae94c102ebe326fd68/model.safetensors.index.json +0 -0
- ckpt/models--lengyue233--content-vec-best/blobs/5186a71b15933aca2d9942db95e1aff02642d1f0 +71 -0
- ckpt/models--lengyue233--content-vec-best/blobs/d8dd400e054ddf4e6be75dab5a2549db748cc99e756a097c496c099f65a4854e +3 -0
- ckpt/models--lengyue233--content-vec-best/refs/main +1 -0
- ckpt/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68/config.json +71 -0
- ckpt/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68/pytorch_model.bin +3 -0
- ckpt/prompt.pt +3 -0
- ckpt/songgeneration_base/config.yaml +141 -0
- ckpt/songgeneration_base/model.pt +3 -0
- ckpt/vae/autoencoder_music_1320k.ckpt +3 -0
- ckpt/vae/stable_audio_1920_vae.json +122 -0
- img/logo.jpg +0 -0
- third_party/Qwen2-7B/LICENSE +202 -0
- third_party/Qwen2-7B/README.md +97 -0
- third_party/Qwen2-7B/config.json +27 -0
- third_party/Qwen2-7B/generation_config.json +7 -0
- third_party/Qwen2-7B/merges.txt +0 -0
- third_party/Qwen2-7B/tokenizer.json +0 -0
- third_party/Qwen2-7B/tokenizer_config.json +40 -0
- third_party/Qwen2-7B/vocab.json +0 -0
- third_party/demucs/__init__.py +0 -0
- third_party/demucs/ckpt/htdemucs.pth +3 -0
- third_party/demucs/ckpt/htdemucs.yaml +1 -0
- third_party/demucs/models/__init__.py +0 -0
- third_party/demucs/models/apply.py +315 -0
- third_party/demucs/models/audio.py +291 -0
- third_party/demucs/models/demucs.py +452 -0
- third_party/demucs/models/htdemucs.py +955 -0
- third_party/demucs/models/pretrained.py +34 -0
- third_party/demucs/models/spec.py +51 -0
- third_party/demucs/models/states.py +102 -0
- third_party/demucs/models/transformer.py +765 -0
- third_party/demucs/models/utils.py +125 -0
- third_party/demucs/run.py +109 -0
- third_party/hub/version.txt +1 -0
- third_party/stable_audio_tools/.gitignore +164 -0
- third_party/stable_audio_tools/LICENSE +21 -0
- third_party/stable_audio_tools/LICENSES/LICENSE_ADP.txt +21 -0
- third_party/stable_audio_tools/LICENSES/LICENSE_AURALOSS.txt +201 -0
- third_party/stable_audio_tools/LICENSES/LICENSE_DESCRIPT.txt +21 -0
- third_party/stable_audio_tools/LICENSES/LICENSE_META.txt +21 -0
- third_party/stable_audio_tools/LICENSES/LICENSE_NVIDIA.txt +21 -0
- third_party/stable_audio_tools/LICENSES/LICENSE_XTRANSFORMERS.txt +21 -0
.gitattributes
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
ckpt/models--lengyue233--content-vec-best/blobs/d8dd400e054ddf4e6be75dab5a2549db748cc99e756a097c496c099f65a4854e filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Tencent is pleased to support the open source community by making SongGeneration available.
|
| 2 |
+
|
| 3 |
+
Copyright (C) 2025 Tencent. All rights reserved.
|
| 4 |
+
|
| 5 |
+
SongGeneration is licensed under the License Terms of SongGeneration except for the third-party components listed below, which is licensed under different terms. SongGeneration does not impose any additional limitations beyond what is outlined in the respective licenses of these third-party components. Users must comply with all terms and conditions of original licenses of these third-party components and must ensure that the usage of the third party components adheres to all relevant laws and regulations.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
License Terms of SongGeneration:
|
| 9 |
+
--------------------------------------------------------------------
|
| 10 |
+
|
| 11 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this Software and associated documentation files, to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, and/or sublicense copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
| 12 |
+
|
| 13 |
+
- You agree to use the SongGeneration only for academic, research and education purposes, and refrain from using it for any commercial or production purposes under any circumstances.
|
| 14 |
+
|
| 15 |
+
- The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
| 16 |
+
|
| 17 |
+
For avoidance of doubts, "Software" means the SongGeneration inference-enabling code and the weights made available under this license excluding any pre-trained data and other AI components.
|
| 18 |
+
|
| 19 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
Other dependencies and licenses:
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
Open Source Software Licensed under the MIT License and Other Licenses of the Third-Party Components therein:
|
| 26 |
+
--------------------------------------------------------------------
|
| 27 |
+
1. stable_audio_tools
|
| 28 |
+
Copyright (c) 2023 Stability AI
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
Terms of the MIT:
|
| 32 |
+
--------------------------------------------------------------------
|
| 33 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
| 34 |
+
|
| 35 |
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
| 36 |
+
|
| 37 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
| 38 |
+
|
| 39 |
+
For the license of other third party components, please refer to the following URL:
|
| 40 |
+
https://github.com/Stability-AI/stable-audio-tools/tree/main/LICENSES
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
Open Source Software Licensed under the MIT License:
|
| 44 |
+
--------------------------------------------------------------------
|
| 45 |
+
1. demucs
|
| 46 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
A copy of the MIT is included in this file.
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
Open Source Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
|
| 54 |
+
--------------------------------------------------------------------
|
| 55 |
+
1. torch
|
| 56 |
+
From PyTorch:
|
| 57 |
+
|
| 58 |
+
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
|
| 59 |
+
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
|
| 60 |
+
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
|
| 61 |
+
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
|
| 62 |
+
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
|
| 63 |
+
Copyright (c) 2011-2013 NYU (Clement Farabet)
|
| 64 |
+
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
|
| 65 |
+
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
|
| 66 |
+
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
|
| 67 |
+
|
| 68 |
+
From Caffe2:
|
| 69 |
+
|
| 70 |
+
Copyright (c) 2016-present, Facebook Inc. All rights reserved.
|
| 71 |
+
|
| 72 |
+
All contributions by Facebook:
|
| 73 |
+
Copyright (c) 2016 Facebook Inc.
|
| 74 |
+
|
| 75 |
+
All contributions by Google:
|
| 76 |
+
Copyright (c) 2015 Google Inc.
|
| 77 |
+
All rights reserved.
|
| 78 |
+
|
| 79 |
+
All contributions by Yangqing Jia:
|
| 80 |
+
Copyright (c) 2015 Yangqing Jia
|
| 81 |
+
All rights reserved.
|
| 82 |
+
|
| 83 |
+
All contributions by Kakao Brain:
|
| 84 |
+
Copyright 2019-2020 Kakao Brain
|
| 85 |
+
|
| 86 |
+
All contributions by Cruise LLC:
|
| 87 |
+
Copyright (c) 2022 Cruise LLC.
|
| 88 |
+
All rights reserved.
|
| 89 |
+
|
| 90 |
+
All contributions from Caffe:
|
| 91 |
+
Copyright(c) 2013, 2014, 2015, the respective contributors
|
| 92 |
+
All rights reserved.
|
| 93 |
+
|
| 94 |
+
All other contributions:
|
| 95 |
+
Copyright(c) 2015, 2016 the respective contributors
|
| 96 |
+
All rights reserved.
|
| 97 |
+
|
| 98 |
+
Caffe2 uses a copyright model similar to Caffe: each contributor holds
|
| 99 |
+
copyright over their contributions to Caffe2. The project versioning records
|
| 100 |
+
all such contribution and copyright details. If a contributor wants to further
|
| 101 |
+
mark their specific copyright on a particular contribution, they should
|
| 102 |
+
indicate their copyright solely in the commit message of the change when it is
|
| 103 |
+
committed.
|
| 104 |
+
|
| 105 |
+
All rights reserved.
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
Terms of the BSD 3-Clause:
|
| 109 |
+
--------------------------------------------------------------------
|
| 110 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
| 111 |
+
|
| 112 |
+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
| 113 |
+
|
| 114 |
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
| 115 |
+
|
| 116 |
+
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
| 117 |
+
|
| 118 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 119 |
+
|
| 120 |
+
For the license of other third party components, please refer to the following URL:
|
| 121 |
+
https://github.com/pytorch/pytorch/blob/v2.0.1/NOTICE
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
Open Source Software Licensed under the BSD 2-Clause License and Other Licenses of the Third-Party Components therein:
|
| 125 |
+
--------------------------------------------------------------------
|
| 126 |
+
1. torchaudio
|
| 127 |
+
Copyright (c) 2017 Facebook Inc. (Soumith Chintala),
|
| 128 |
+
All rights reserved.
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
Terms of the BSD 2-Clause:
|
| 132 |
+
--------------------------------------------------------------------
|
| 133 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
| 134 |
+
|
| 135 |
+
* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
| 136 |
+
|
| 137 |
+
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
| 138 |
+
|
| 139 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 140 |
+
|
| 141 |
+
For the license of other third party components, please refer to the following URL:
|
| 142 |
+
https://github.com/pytorch/audio/blob/v2.0.2/LICENSE
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
Open Source Software License under the Apache License Version 2.0:
|
| 146 |
+
--------------------------------------------------------------------
|
| 147 |
+
1. huggingface-hub
|
| 148 |
+
Copyright (c) huggingface-hub original author and authors
|
| 149 |
+
|
| 150 |
+
2. transformers
|
| 151 |
+
Copyright 2018- The Hugging Face team. All rights reserved.
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
Terms of the Apache License Version 2.0:
|
| 155 |
+
--------------------------------------------------------------------
|
| 156 |
+
Apache License
|
| 157 |
+
|
| 158 |
+
Version 2.0, January 2004
|
| 159 |
+
|
| 160 |
+
http://www.apache.org/licenses/
|
| 161 |
+
|
| 162 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 163 |
+
1. Definitions.
|
| 164 |
+
|
| 165 |
+
"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
|
| 166 |
+
|
| 167 |
+
"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
|
| 168 |
+
|
| 169 |
+
"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
|
| 170 |
+
|
| 171 |
+
"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
|
| 172 |
+
|
| 173 |
+
"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
|
| 174 |
+
|
| 175 |
+
"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
|
| 176 |
+
|
| 177 |
+
"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
|
| 178 |
+
|
| 179 |
+
"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
|
| 180 |
+
|
| 181 |
+
"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
|
| 182 |
+
|
| 183 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
|
| 184 |
+
|
| 185 |
+
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
|
| 186 |
+
|
| 187 |
+
3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
|
| 188 |
+
|
| 189 |
+
4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
|
| 190 |
+
|
| 191 |
+
You must give any other recipients of the Work or Derivative Works a copy of this License; and
|
| 192 |
+
|
| 193 |
+
You must cause any modified files to carry prominent notices stating that You changed the files; and
|
| 194 |
+
|
| 195 |
+
You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
|
| 196 |
+
|
| 197 |
+
If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
|
| 198 |
+
|
| 199 |
+
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
|
| 200 |
+
|
| 201 |
+
5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
|
| 202 |
+
|
| 203 |
+
6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
|
| 204 |
+
|
| 205 |
+
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
|
| 206 |
+
|
| 207 |
+
8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
|
| 208 |
+
|
| 209 |
+
9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
|
| 210 |
+
|
| 211 |
+
END OF TERMS AND CONDITIONS
|
README.md
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language:
|
| 3 |
+
- en
|
| 4 |
+
- zh
|
| 5 |
+
pipeline_tag: text-to-audio
|
| 6 |
+
library_name: tencent-song-generation
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
# SongGeneration
|
| 10 |
+
|
| 11 |
+
<p align="center"><img src="img/logo.jpg" width="40%"></p>
|
| 12 |
+
<p align="center">
|
| 13 |
+
<a href="https://levo-demo.github.io/">Demo</a> | <a href="https://arxiv.org/abs/2506.07520">Paper</a> | <a href="https://github.com/tencent-ailab/songgeneration">Code</a> | <a href="https://huggingface.co/spaces/tencent/SongGeneration">Space Demo</a>
|
| 14 |
+
</p>
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
This repository is the official weight repository for LeVo: High-Quality Song Generation with Multi-Preference Alignment. In this repository, we provide the SongGeneration model, inference scripts, and the checkpoint that has been trained on the Million Song Dataset.
|
| 18 |
+
|
| 19 |
+
## Model Versions
|
| 20 |
+
|
| 21 |
+
| Model | Max Length | Language | GPU Menmory | RFT(A100) | Download Link |
|
| 22 |
+
| ------------------------- | :--------: | :------------------: | :---------: | :-------: | ------------------------------------------------------------ |
|
| 23 |
+
| SongGeneration-base | 2m30s | zh | 10G/16G | 1.26 | You were here |
|
| 24 |
+
| SongGeneration-base-new | 2m30s | zh, en | 10G/16G | 1.26 | [Huggingface](https://huggingface.co/lglg666/SongGeneration-base-new) |
|
| 25 |
+
| SongGeneration-base-full | 4m30s | zh, en | 12G/18G | 1.30 | [Huggingface](https://huggingface.co/lglg666/SongGeneration-base-full) |
|
| 26 |
+
| SongGeneration-large | 4m30s | zh, en | 22G/28G | 1.51 | [Huggingface](https://huggingface.co/lglg666/SongGeneration-large) |
|
| 27 |
+
| SongGeneration-v1.5-small | 2m | zh, en, es, ja, etc. | - | - | Coming soon |
|
| 28 |
+
| SongGeneration-v1.5-base | 4m30s | zh, en, es, ja, etc. | - | - | Coming soon |
|
| 29 |
+
| SongGeneration-v1.5-large | 4m30s | zh, en, es, ja, etc. | - | - | Coming soon |
|
| 30 |
+
|
| 31 |
+
## Overview
|
| 32 |
+
|
| 33 |
+
We develop the SongGeneration model. It is an LM-based framework consisting of **LeLM** and a **music codec**. LeLM is capable of parallelly modeling two types of tokens: mixed tokens, which represent the combined audio of vocals and accompaniment to achieve vocal-instrument harmony, and dual-track tokens, which separately encode vocals and accompaniment for high-quality song generation. The music codec reconstructs the dual-track tokens into highfidelity music audio. SongGeneration significantly improves over the open-source music generation models and performs competitively with current state-of-the-art industry systems. For more details, please refer to our [paper](https://arxiv.org/abs/2506.07520).
|
| 34 |
+
|
| 35 |
+
<img src="https://github.com/tencent-ailab/songgeneration/blob/main/img/over.jpg?raw=true" alt="img" style="zoom:100%;" />
|
| 36 |
+
|
| 37 |
+
## License
|
| 38 |
+
|
| 39 |
+
The code and weights in this repository is released in the [LICENSE](LICENSE) file.
|
ckpt/encode-s12k.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e250df56b035f74c1f66f15133f4c78f664d70fa0b09aa9a752b7871bb58c02f
|
| 3 |
+
size 3957949089
|
ckpt/model_1rvq/model_2_fixed.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:339a16956b859a82defc02bfd32c3744d11ff942065f6ec9306dfd4400d62110
|
| 3 |
+
size 4704507596
|
ckpt/model_septoken/model_2.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:430b7c1c245722fbe3893cd621b3d4a90076404596e9fb1ce987a4a0f2a4fc6f
|
| 3 |
+
size 4808167708
|
ckpt/models--lengyue233--content-vec-best/.no_exist/c0b9ba13db21beaa4053faae94c102ebe326fd68/model.safetensors
ADDED
|
File without changes
|
ckpt/models--lengyue233--content-vec-best/.no_exist/c0b9ba13db21beaa4053faae94c102ebe326fd68/model.safetensors.index.json
ADDED
|
File without changes
|
ckpt/models--lengyue233--content-vec-best/blobs/5186a71b15933aca2d9942db95e1aff02642d1f0
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"activation_dropout": 0.1,
|
| 3 |
+
"apply_spec_augment": true,
|
| 4 |
+
"architectures": [
|
| 5 |
+
"HubertModelWithFinalProj"
|
| 6 |
+
],
|
| 7 |
+
"attention_dropout": 0.1,
|
| 8 |
+
"bos_token_id": 1,
|
| 9 |
+
"classifier_proj_size": 256,
|
| 10 |
+
"conv_bias": false,
|
| 11 |
+
"conv_dim": [
|
| 12 |
+
512,
|
| 13 |
+
512,
|
| 14 |
+
512,
|
| 15 |
+
512,
|
| 16 |
+
512,
|
| 17 |
+
512,
|
| 18 |
+
512
|
| 19 |
+
],
|
| 20 |
+
"conv_kernel": [
|
| 21 |
+
10,
|
| 22 |
+
3,
|
| 23 |
+
3,
|
| 24 |
+
3,
|
| 25 |
+
3,
|
| 26 |
+
2,
|
| 27 |
+
2
|
| 28 |
+
],
|
| 29 |
+
"conv_stride": [
|
| 30 |
+
5,
|
| 31 |
+
2,
|
| 32 |
+
2,
|
| 33 |
+
2,
|
| 34 |
+
2,
|
| 35 |
+
2,
|
| 36 |
+
2
|
| 37 |
+
],
|
| 38 |
+
"ctc_loss_reduction": "sum",
|
| 39 |
+
"ctc_zero_infinity": false,
|
| 40 |
+
"do_stable_layer_norm": false,
|
| 41 |
+
"eos_token_id": 2,
|
| 42 |
+
"feat_extract_activation": "gelu",
|
| 43 |
+
"feat_extract_norm": "group",
|
| 44 |
+
"feat_proj_dropout": 0.0,
|
| 45 |
+
"feat_proj_layer_norm": true,
|
| 46 |
+
"final_dropout": 0.1,
|
| 47 |
+
"hidden_act": "gelu",
|
| 48 |
+
"hidden_dropout": 0.1,
|
| 49 |
+
"hidden_size": 768,
|
| 50 |
+
"initializer_range": 0.02,
|
| 51 |
+
"intermediate_size": 3072,
|
| 52 |
+
"layer_norm_eps": 1e-05,
|
| 53 |
+
"layerdrop": 0.1,
|
| 54 |
+
"mask_feature_length": 10,
|
| 55 |
+
"mask_feature_min_masks": 0,
|
| 56 |
+
"mask_feature_prob": 0.0,
|
| 57 |
+
"mask_time_length": 10,
|
| 58 |
+
"mask_time_min_masks": 2,
|
| 59 |
+
"mask_time_prob": 0.05,
|
| 60 |
+
"model_type": "hubert",
|
| 61 |
+
"num_attention_heads": 12,
|
| 62 |
+
"num_conv_pos_embedding_groups": 16,
|
| 63 |
+
"num_conv_pos_embeddings": 128,
|
| 64 |
+
"num_feat_extract_layers": 7,
|
| 65 |
+
"num_hidden_layers": 12,
|
| 66 |
+
"pad_token_id": 0,
|
| 67 |
+
"torch_dtype": "float32",
|
| 68 |
+
"transformers_version": "4.27.3",
|
| 69 |
+
"use_weighted_layer_sum": false,
|
| 70 |
+
"vocab_size": 32
|
| 71 |
+
}
|
ckpt/models--lengyue233--content-vec-best/blobs/d8dd400e054ddf4e6be75dab5a2549db748cc99e756a097c496c099f65a4854e
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d8dd400e054ddf4e6be75dab5a2549db748cc99e756a097c496c099f65a4854e
|
| 3 |
+
size 378342945
|
ckpt/models--lengyue233--content-vec-best/refs/main
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
c0b9ba13db21beaa4053faae94c102ebe326fd68
|
ckpt/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68/config.json
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"activation_dropout": 0.1,
|
| 3 |
+
"apply_spec_augment": true,
|
| 4 |
+
"architectures": [
|
| 5 |
+
"HubertModelWithFinalProj"
|
| 6 |
+
],
|
| 7 |
+
"attention_dropout": 0.1,
|
| 8 |
+
"bos_token_id": 1,
|
| 9 |
+
"classifier_proj_size": 256,
|
| 10 |
+
"conv_bias": false,
|
| 11 |
+
"conv_dim": [
|
| 12 |
+
512,
|
| 13 |
+
512,
|
| 14 |
+
512,
|
| 15 |
+
512,
|
| 16 |
+
512,
|
| 17 |
+
512,
|
| 18 |
+
512
|
| 19 |
+
],
|
| 20 |
+
"conv_kernel": [
|
| 21 |
+
10,
|
| 22 |
+
3,
|
| 23 |
+
3,
|
| 24 |
+
3,
|
| 25 |
+
3,
|
| 26 |
+
2,
|
| 27 |
+
2
|
| 28 |
+
],
|
| 29 |
+
"conv_stride": [
|
| 30 |
+
5,
|
| 31 |
+
2,
|
| 32 |
+
2,
|
| 33 |
+
2,
|
| 34 |
+
2,
|
| 35 |
+
2,
|
| 36 |
+
2
|
| 37 |
+
],
|
| 38 |
+
"ctc_loss_reduction": "sum",
|
| 39 |
+
"ctc_zero_infinity": false,
|
| 40 |
+
"do_stable_layer_norm": false,
|
| 41 |
+
"eos_token_id": 2,
|
| 42 |
+
"feat_extract_activation": "gelu",
|
| 43 |
+
"feat_extract_norm": "group",
|
| 44 |
+
"feat_proj_dropout": 0.0,
|
| 45 |
+
"feat_proj_layer_norm": true,
|
| 46 |
+
"final_dropout": 0.1,
|
| 47 |
+
"hidden_act": "gelu",
|
| 48 |
+
"hidden_dropout": 0.1,
|
| 49 |
+
"hidden_size": 768,
|
| 50 |
+
"initializer_range": 0.02,
|
| 51 |
+
"intermediate_size": 3072,
|
| 52 |
+
"layer_norm_eps": 1e-05,
|
| 53 |
+
"layerdrop": 0.1,
|
| 54 |
+
"mask_feature_length": 10,
|
| 55 |
+
"mask_feature_min_masks": 0,
|
| 56 |
+
"mask_feature_prob": 0.0,
|
| 57 |
+
"mask_time_length": 10,
|
| 58 |
+
"mask_time_min_masks": 2,
|
| 59 |
+
"mask_time_prob": 0.05,
|
| 60 |
+
"model_type": "hubert",
|
| 61 |
+
"num_attention_heads": 12,
|
| 62 |
+
"num_conv_pos_embedding_groups": 16,
|
| 63 |
+
"num_conv_pos_embeddings": 128,
|
| 64 |
+
"num_feat_extract_layers": 7,
|
| 65 |
+
"num_hidden_layers": 12,
|
| 66 |
+
"pad_token_id": 0,
|
| 67 |
+
"torch_dtype": "float32",
|
| 68 |
+
"transformers_version": "4.27.3",
|
| 69 |
+
"use_weighted_layer_sum": false,
|
| 70 |
+
"vocab_size": 32
|
| 71 |
+
}
|
ckpt/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68/pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d8dd400e054ddf4e6be75dab5a2549db748cc99e756a097c496c099f65a4854e
|
| 3 |
+
size 378342945
|
ckpt/prompt.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:93fbcc4b88050a0ac9180beee723eb3600229ebcd22afa70aaeb450a622b9f49
|
| 3 |
+
size 3133236
|
ckpt/songgeneration_base/config.yaml
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ================ Train Config ================ #
|
| 2 |
+
lyric_processor:
|
| 3 |
+
max_dur: 150
|
| 4 |
+
min_dur: 30
|
| 5 |
+
prompt_len: 10
|
| 6 |
+
pad_to_max: true
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# ================ Audio tokenzier ================ #
|
| 10 |
+
audio_tokenizer_checkpoint: Flow1dVAE1rvq_./ckpt/model_1rvq/model_2_fixed.safetensors
|
| 11 |
+
audio_tokenizer_frame_rate: 25
|
| 12 |
+
audio_tokenizer_code_depth: 1
|
| 13 |
+
sample_rate: 48000
|
| 14 |
+
|
| 15 |
+
audio_tokenizer_checkpoint_sep: Flow1dVAESeparate_./ckpt/model_septoken/model_2.safetensors
|
| 16 |
+
audio_tokenizer_frame_rate_sep: 25
|
| 17 |
+
audio_tokenizer_code_depth_sep: 2
|
| 18 |
+
sample_rate_sep: 48000
|
| 19 |
+
|
| 20 |
+
# ================ VAE ================ #
|
| 21 |
+
vae_config: ./ckpt/vae/stable_audio_1920_vae.json
|
| 22 |
+
vae_model: ./ckpt/vae/autoencoder_music_1320k.ckpt
|
| 23 |
+
|
| 24 |
+
# ================== LM =========================== #
|
| 25 |
+
lm:
|
| 26 |
+
lm_type: Llama # [Llama]
|
| 27 |
+
dim: 1536
|
| 28 |
+
intermediate_size: 8960
|
| 29 |
+
num_heads: 12
|
| 30 |
+
num_layers: 28
|
| 31 |
+
num_layers_sub: 12
|
| 32 |
+
code_depth: 3
|
| 33 |
+
code_size: 16384
|
| 34 |
+
max_position_embeddings: 8196
|
| 35 |
+
max_position_embeddings_sub: 10000
|
| 36 |
+
rope_theta: 100000.0
|
| 37 |
+
rope_theta_sub: 500000.0
|
| 38 |
+
dropout: 0.0
|
| 39 |
+
use_flash_attn_2: true
|
| 40 |
+
activation: gelu
|
| 41 |
+
norm_first: true
|
| 42 |
+
bias_ff: false
|
| 43 |
+
bias_attn: false
|
| 44 |
+
causal: true
|
| 45 |
+
custom: false
|
| 46 |
+
memory_efficient: true
|
| 47 |
+
attention_as_float32: false
|
| 48 |
+
layer_scale: null
|
| 49 |
+
positional_embedding: sin
|
| 50 |
+
xpos: false
|
| 51 |
+
checkpointing: torch
|
| 52 |
+
weight_init: gaussian
|
| 53 |
+
depthwise_init: current
|
| 54 |
+
zero_bias_init: true
|
| 55 |
+
norm: layer_norm
|
| 56 |
+
cross_attention: false
|
| 57 |
+
qk_layer_norm: false
|
| 58 |
+
qk_layer_norm_cross: false
|
| 59 |
+
attention_dropout: null
|
| 60 |
+
kv_repeat: 1
|
| 61 |
+
|
| 62 |
+
codebooks_pattern:
|
| 63 |
+
modeling: delay
|
| 64 |
+
delay:
|
| 65 |
+
delays: [ 0, 250, 250 ]
|
| 66 |
+
flatten_first: 0
|
| 67 |
+
empty_initial: 0
|
| 68 |
+
|
| 69 |
+
# ================ Conditioners ===================== #
|
| 70 |
+
classifier_free_guidance:
|
| 71 |
+
# drop all conditions simultaneously
|
| 72 |
+
training_dropout: 0.15
|
| 73 |
+
inference_coef: 1.5
|
| 74 |
+
|
| 75 |
+
attribute_dropout:
|
| 76 |
+
# drop each condition separately
|
| 77 |
+
args:
|
| 78 |
+
active_on_eval: false
|
| 79 |
+
text:
|
| 80 |
+
description: 0.0
|
| 81 |
+
type_info: 0.5
|
| 82 |
+
audio:
|
| 83 |
+
prompt_audio: 0.0
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
use_text_training: True
|
| 87 |
+
fuser:
|
| 88 |
+
sum: []
|
| 89 |
+
prepend: [ description, prompt_audio, type_info ] # this order is the SAME with the input concatenation order
|
| 90 |
+
|
| 91 |
+
conditioners:
|
| 92 |
+
prompt_audio:
|
| 93 |
+
model: qt_embedding
|
| 94 |
+
qt_embedding:
|
| 95 |
+
code_size: 16384
|
| 96 |
+
code_depth: 3
|
| 97 |
+
max_len: ${eval:${prompt_len}*${audio_tokenizer_frame_rate}+2} # 25*10+2+1
|
| 98 |
+
description:
|
| 99 |
+
model: QwTokenizer
|
| 100 |
+
QwTokenizer:
|
| 101 |
+
token_path: third_party/Qwen2-7B
|
| 102 |
+
max_len: 300
|
| 103 |
+
add_token_list: ${load_yaml:conf/vocab.yaml}
|
| 104 |
+
type_info:
|
| 105 |
+
model: QwTextTokenizer
|
| 106 |
+
QwTextTokenizer:
|
| 107 |
+
token_path: third_party/Qwen2-7B
|
| 108 |
+
max_len: 50
|
| 109 |
+
|
| 110 |
+
offload:
|
| 111 |
+
audiolm:
|
| 112 |
+
offload_module: self
|
| 113 |
+
cpu_mem_gb: 0
|
| 114 |
+
pre_copy_step: 1
|
| 115 |
+
clean_cache_after_forward: false
|
| 116 |
+
dtype: torch.float16
|
| 117 |
+
offload_layer_dict:
|
| 118 |
+
transformer: 4
|
| 119 |
+
transformer2: 4
|
| 120 |
+
ignore_layer_list: []
|
| 121 |
+
clean_cache_wrapper:
|
| 122 |
+
module: self
|
| 123 |
+
method_name: _sample_next_token
|
| 124 |
+
diff_mem_gb_thre: 2
|
| 125 |
+
debug: false
|
| 126 |
+
|
| 127 |
+
wav_tokenizer_diffusion:
|
| 128 |
+
offload_module: self.model.model
|
| 129 |
+
pre_copy_step: 1
|
| 130 |
+
clean_cache_after_forward: false
|
| 131 |
+
cpu_mem_gb: -1
|
| 132 |
+
dtype: null
|
| 133 |
+
offload_layer_dict:
|
| 134 |
+
cfm_wrapper: 5
|
| 135 |
+
hubert: 4
|
| 136 |
+
ignore_layer_list: []
|
| 137 |
+
clean_cache_wrapper:
|
| 138 |
+
module: self.model.model.cfm_wrapper.estimator
|
| 139 |
+
method_name: forward
|
| 140 |
+
diff_mem_gb_thre: 1
|
| 141 |
+
debug: false
|
ckpt/songgeneration_base/model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8763fc75e5db768c334a9fbadd08e2004eccb6e15156c76b4c2a3984f8fbb884
|
| 3 |
+
size 11318365872
|
ckpt/vae/autoencoder_music_1320k.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:10ccb6c83613781ad32e998a90597ba7eb9292911a224598da1fd53728eb4cd3
|
| 3 |
+
size 674920616
|
ckpt/vae/stable_audio_1920_vae.json
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "autoencoder",
|
| 3 |
+
"sample_size": 403200,
|
| 4 |
+
"sample_rate": 48000,
|
| 5 |
+
"audio_channels": 2,
|
| 6 |
+
"model": {
|
| 7 |
+
"encoder": {
|
| 8 |
+
"type": "oobleck",
|
| 9 |
+
"config": {
|
| 10 |
+
"in_channels": 2,
|
| 11 |
+
"channels": 128,
|
| 12 |
+
"c_mults": [1, 2, 4, 8, 16],
|
| 13 |
+
"strides": [2, 4, 4, 6, 10],
|
| 14 |
+
"latent_dim": 128,
|
| 15 |
+
"use_snake": true
|
| 16 |
+
}
|
| 17 |
+
},
|
| 18 |
+
"decoder": {
|
| 19 |
+
"type": "oobleck",
|
| 20 |
+
"config": {
|
| 21 |
+
"out_channels": 2,
|
| 22 |
+
"channels": 128,
|
| 23 |
+
"c_mults": [1, 2, 4, 8, 16],
|
| 24 |
+
"strides": [2, 4, 4, 6, 10],
|
| 25 |
+
"latent_dim": 64,
|
| 26 |
+
"use_snake": true,
|
| 27 |
+
"final_tanh": false
|
| 28 |
+
}
|
| 29 |
+
},
|
| 30 |
+
"bottleneck": {
|
| 31 |
+
"type": "vae"
|
| 32 |
+
},
|
| 33 |
+
"latent_dim": 64,
|
| 34 |
+
"downsampling_ratio": 1920,
|
| 35 |
+
"io_channels": 2
|
| 36 |
+
},
|
| 37 |
+
"training": {
|
| 38 |
+
"learning_rate": 1.5e-4,
|
| 39 |
+
"warmup_steps": 0,
|
| 40 |
+
"use_ema": true,
|
| 41 |
+
"optimizer_configs": {
|
| 42 |
+
"autoencoder": {
|
| 43 |
+
"optimizer": {
|
| 44 |
+
"type": "AdamW",
|
| 45 |
+
"config": {
|
| 46 |
+
"betas": [0.8, 0.99],
|
| 47 |
+
"lr": 1.5e-4,
|
| 48 |
+
"weight_decay": 1e-3
|
| 49 |
+
}
|
| 50 |
+
},
|
| 51 |
+
"scheduler": {
|
| 52 |
+
"type": "InverseLR",
|
| 53 |
+
"config": {
|
| 54 |
+
"inv_gamma": 200000,
|
| 55 |
+
"power": 0.5,
|
| 56 |
+
"warmup": 0.999
|
| 57 |
+
}
|
| 58 |
+
}
|
| 59 |
+
},
|
| 60 |
+
"discriminator": {
|
| 61 |
+
"optimizer": {
|
| 62 |
+
"type": "AdamW",
|
| 63 |
+
"config": {
|
| 64 |
+
"betas": [0.8, 0.99],
|
| 65 |
+
"lr": 3e-4,
|
| 66 |
+
"weight_decay": 1e-3
|
| 67 |
+
}
|
| 68 |
+
},
|
| 69 |
+
"scheduler": {
|
| 70 |
+
"type": "InverseLR",
|
| 71 |
+
"config": {
|
| 72 |
+
"inv_gamma": 200000,
|
| 73 |
+
"power": 0.5,
|
| 74 |
+
"warmup": 0.999
|
| 75 |
+
}
|
| 76 |
+
}
|
| 77 |
+
}
|
| 78 |
+
},
|
| 79 |
+
"loss_configs": {
|
| 80 |
+
"discriminator": {
|
| 81 |
+
"type": "encodec",
|
| 82 |
+
"config": {
|
| 83 |
+
"filters": 64,
|
| 84 |
+
"n_ffts": [2048, 1024, 512, 256, 128],
|
| 85 |
+
"hop_lengths": [512, 256, 128, 64, 32],
|
| 86 |
+
"win_lengths": [2048, 1024, 512, 256, 128]
|
| 87 |
+
},
|
| 88 |
+
"weights": {
|
| 89 |
+
"adversarial": 0.1,
|
| 90 |
+
"feature_matching": 5.0
|
| 91 |
+
}
|
| 92 |
+
},
|
| 93 |
+
"spectral": {
|
| 94 |
+
"type": "mrstft",
|
| 95 |
+
"config": {
|
| 96 |
+
"fft_sizes": [2048, 1024, 512, 256, 128, 64, 32],
|
| 97 |
+
"hop_sizes": [512, 256, 128, 64, 32, 16, 8],
|
| 98 |
+
"win_lengths": [2048, 1024, 512, 256, 128, 64, 32],
|
| 99 |
+
"perceptual_weighting": true
|
| 100 |
+
},
|
| 101 |
+
"weights": {
|
| 102 |
+
"mrstft": 1.0
|
| 103 |
+
}
|
| 104 |
+
},
|
| 105 |
+
"time": {
|
| 106 |
+
"type": "l1",
|
| 107 |
+
"weights": {
|
| 108 |
+
"l1": 0.0
|
| 109 |
+
}
|
| 110 |
+
},
|
| 111 |
+
"bottleneck": {
|
| 112 |
+
"type": "kl",
|
| 113 |
+
"weights": {
|
| 114 |
+
"kl": 1e-4
|
| 115 |
+
}
|
| 116 |
+
}
|
| 117 |
+
},
|
| 118 |
+
"demo": {
|
| 119 |
+
"demo_every": 2000
|
| 120 |
+
}
|
| 121 |
+
}
|
| 122 |
+
}
|
img/logo.jpg
ADDED
|
third_party/Qwen2-7B/LICENSE
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
Apache License
|
| 3 |
+
Version 2.0, January 2004
|
| 4 |
+
http://www.apache.org/licenses/
|
| 5 |
+
|
| 6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 7 |
+
|
| 8 |
+
1. Definitions.
|
| 9 |
+
|
| 10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 12 |
+
|
| 13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 14 |
+
the copyright owner that is granting the License.
|
| 15 |
+
|
| 16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 17 |
+
other entities that control, are controlled by, or are under common
|
| 18 |
+
control with that entity. For the purposes of this definition,
|
| 19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 20 |
+
direction or management of such entity, whether by contract or
|
| 21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 23 |
+
|
| 24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 25 |
+
exercising permissions granted by this License.
|
| 26 |
+
|
| 27 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 28 |
+
including but not limited to software source code, documentation
|
| 29 |
+
source, and configuration files.
|
| 30 |
+
|
| 31 |
+
"Object" form shall mean any form resulting from mechanical
|
| 32 |
+
transformation or translation of a Source form, including but
|
| 33 |
+
not limited to compiled object code, generated documentation,
|
| 34 |
+
and conversions to other media types.
|
| 35 |
+
|
| 36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 37 |
+
Object form, made available under the License, as indicated by a
|
| 38 |
+
copyright notice that is included in or attached to the work
|
| 39 |
+
(an example is provided in the Appendix below).
|
| 40 |
+
|
| 41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 42 |
+
form, that is based on (or derived from) the Work and for which the
|
| 43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 45 |
+
of this License, Derivative Works shall not include works that remain
|
| 46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 47 |
+
the Work and Derivative Works thereof.
|
| 48 |
+
|
| 49 |
+
"Contribution" shall mean any work of authorship, including
|
| 50 |
+
the original version of the Work and any modifications or additions
|
| 51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 55 |
+
means any form of electronic, verbal, or written communication sent
|
| 56 |
+
to the Licensor or its representatives, including but not limited to
|
| 57 |
+
communication on electronic mailing lists, source code control systems,
|
| 58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 60 |
+
excluding communication that is conspicuously marked or otherwise
|
| 61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 62 |
+
|
| 63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 65 |
+
subsequently incorporated within the Work.
|
| 66 |
+
|
| 67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 72 |
+
Work and such Derivative Works in Source or Object form.
|
| 73 |
+
|
| 74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 77 |
+
(except as stated in this section) patent license to make, have made,
|
| 78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 79 |
+
where such license applies only to those patent claims licensable
|
| 80 |
+
by such Contributor that are necessarily infringed by their
|
| 81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 83 |
+
institute patent litigation against any entity (including a
|
| 84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 85 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 86 |
+
or contributory patent infringement, then any patent licenses
|
| 87 |
+
granted to You under this License for that Work shall terminate
|
| 88 |
+
as of the date such litigation is filed.
|
| 89 |
+
|
| 90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 91 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 92 |
+
modifications, and in Source or Object form, provided that You
|
| 93 |
+
meet the following conditions:
|
| 94 |
+
|
| 95 |
+
(a) You must give any other recipients of the Work or
|
| 96 |
+
Derivative Works a copy of this License; and
|
| 97 |
+
|
| 98 |
+
(b) You must cause any modified files to carry prominent notices
|
| 99 |
+
stating that You changed the files; and
|
| 100 |
+
|
| 101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 102 |
+
that You distribute, all copyright, patent, trademark, and
|
| 103 |
+
attribution notices from the Source form of the Work,
|
| 104 |
+
excluding those notices that do not pertain to any part of
|
| 105 |
+
the Derivative Works; and
|
| 106 |
+
|
| 107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 108 |
+
distribution, then any Derivative Works that You distribute must
|
| 109 |
+
include a readable copy of the attribution notices contained
|
| 110 |
+
within such NOTICE file, excluding those notices that do not
|
| 111 |
+
pertain to any part of the Derivative Works, in at least one
|
| 112 |
+
of the following places: within a NOTICE text file distributed
|
| 113 |
+
as part of the Derivative Works; within the Source form or
|
| 114 |
+
documentation, if provided along with the Derivative Works; or,
|
| 115 |
+
within a display generated by the Derivative Works, if and
|
| 116 |
+
wherever such third-party notices normally appear. The contents
|
| 117 |
+
of the NOTICE file are for informational purposes only and
|
| 118 |
+
do not modify the License. You may add Your own attribution
|
| 119 |
+
notices within Derivative Works that You distribute, alongside
|
| 120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 121 |
+
that such additional attribution notices cannot be construed
|
| 122 |
+
as modifying the License.
|
| 123 |
+
|
| 124 |
+
You may add Your own copyright statement to Your modifications and
|
| 125 |
+
may provide additional or different license terms and conditions
|
| 126 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 127 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 128 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 129 |
+
the conditions stated in this License.
|
| 130 |
+
|
| 131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 133 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 134 |
+
this License, without any additional terms or conditions.
|
| 135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 136 |
+
the terms of any separate license agreement you may have executed
|
| 137 |
+
with Licensor regarding such Contributions.
|
| 138 |
+
|
| 139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 141 |
+
except as required for reasonable and customary use in describing the
|
| 142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 143 |
+
|
| 144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 145 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 148 |
+
implied, including, without limitation, any warranties or conditions
|
| 149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 151 |
+
appropriateness of using or redistributing the Work and assume any
|
| 152 |
+
risks associated with Your exercise of permissions under this License.
|
| 153 |
+
|
| 154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 155 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 156 |
+
unless required by applicable law (such as deliberate and grossly
|
| 157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 158 |
+
liable to You for damages, including any direct, indirect, special,
|
| 159 |
+
incidental, or consequential damages of any character arising as a
|
| 160 |
+
result of this License or out of the use or inability to use the
|
| 161 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 162 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 163 |
+
other commercial damages or losses), even if such Contributor
|
| 164 |
+
has been advised of the possibility of such damages.
|
| 165 |
+
|
| 166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 169 |
+
or other liability obligations and/or rights consistent with this
|
| 170 |
+
License. However, in accepting such obligations, You may act only
|
| 171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 172 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 173 |
+
defend, and hold each Contributor harmless for any liability
|
| 174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 175 |
+
of your accepting any such warranty or additional liability.
|
| 176 |
+
|
| 177 |
+
END OF TERMS AND CONDITIONS
|
| 178 |
+
|
| 179 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 180 |
+
|
| 181 |
+
To apply the Apache License to your work, attach the following
|
| 182 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 183 |
+
replaced with your own identifying information. (Don't include
|
| 184 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 185 |
+
comment syntax for the file format. We also recommend that a
|
| 186 |
+
file or class name and description of purpose be included on the
|
| 187 |
+
same "printed page" as the copyright notice for easier
|
| 188 |
+
identification within third-party archives.
|
| 189 |
+
|
| 190 |
+
Copyright 2024 Alibaba Cloud
|
| 191 |
+
|
| 192 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 193 |
+
you may not use this file except in compliance with the License.
|
| 194 |
+
You may obtain a copy of the License at
|
| 195 |
+
|
| 196 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 197 |
+
|
| 198 |
+
Unless required by applicable law or agreed to in writing, software
|
| 199 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 200 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 201 |
+
See the License for the specific language governing permissions and
|
| 202 |
+
limitations under the License.
|
third_party/Qwen2-7B/README.md
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language:
|
| 3 |
+
- en
|
| 4 |
+
pipeline_tag: text-generation
|
| 5 |
+
tags:
|
| 6 |
+
- pretrained
|
| 7 |
+
license: apache-2.0
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# Qwen2-7B
|
| 11 |
+
|
| 12 |
+
## Introduction
|
| 13 |
+
|
| 14 |
+
Qwen2 is the new series of Qwen large language models. For Qwen2, we release a number of base language models and instruction-tuned language models ranging from 0.5 to 72 billion parameters, including a Mixture-of-Experts model. This repo contains the 7B Qwen2 base language model.
|
| 15 |
+
|
| 16 |
+
Compared with the state-of-the-art opensource language models, including the previous released Qwen1.5, Qwen2 has generally surpassed most opensource models and demonstrated competitiveness against proprietary models across a series of benchmarks targeting for language understanding, language generation, multilingual capability, coding, mathematics, reasoning, etc.
|
| 17 |
+
|
| 18 |
+
For more details, please refer to our [blog](https://qwenlm.github.io/blog/qwen2/), [GitHub](https://github.com/QwenLM/Qwen2), and [Documentation](https://qwen.readthedocs.io/en/latest/).
|
| 19 |
+
<br>
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
## Model Details
|
| 23 |
+
Qwen2 is a language model series including decoder language models of different model sizes. For each size, we release the base language model and the aligned chat model. It is based on the Transformer architecture with SwiGLU activation, attention QKV bias, group query attention, etc. Additionally, we have an improved tokenizer adaptive to multiple natural languages and codes.
|
| 24 |
+
|
| 25 |
+
## Requirements
|
| 26 |
+
The code of Qwen2 has been in the latest Hugging face transformers and we advise you to install `transformers>=4.37.0`, or you might encounter the following error:
|
| 27 |
+
```
|
| 28 |
+
KeyError: 'qwen2'
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
## Usage
|
| 33 |
+
|
| 34 |
+
We do not advise you to use base language models for text generation. Instead, you can apply post-training, e.g., SFT, RLHF, continued pretraining, etc., on this model.
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
### Performance
|
| 38 |
+
|
| 39 |
+
The evaluation of base models mainly focuses on the model performance of natural language understanding, general question answering, coding, mathematics, scientific knowledge, reasoning, multilingual capability, etc.
|
| 40 |
+
|
| 41 |
+
The datasets for evaluation include:
|
| 42 |
+
|
| 43 |
+
**English Tasks**: MMLU (5-shot), MMLU-Pro (5-shot), GPQA (5shot), Theorem QA (5-shot), BBH (3-shot), HellaSwag (10-shot), Winogrande (5-shot), TruthfulQA (0-shot), ARC-C (25-shot)
|
| 44 |
+
|
| 45 |
+
**Coding Tasks**: EvalPlus (0-shot) (HumanEval, MBPP, HumanEval+, MBPP+), MultiPL-E (0-shot) (Python, C++, JAVA, PHP, TypeScript, C#, Bash, JavaScript)
|
| 46 |
+
|
| 47 |
+
**Math Tasks**: GSM8K (4-shot), MATH (4-shot)
|
| 48 |
+
|
| 49 |
+
**Chinese Tasks**: C-Eval(5-shot), CMMLU (5-shot)
|
| 50 |
+
|
| 51 |
+
**Multilingual Tasks**: Multi-Exam (M3Exam 5-shot, IndoMMLU 3-shot, ruMMLU 5-shot, mMMLU 5-shot), Multi-Understanding (BELEBELE 5-shot, XCOPA 5-shot, XWinograd 5-shot, XStoryCloze 0-shot, PAWS-X 5-shot), Multi-Mathematics (MGSM 8-shot), Multi-Translation (Flores-101 5-shot)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
#### Qwen2-7B performance
|
| 56 |
+
| Datasets | Mistral-7B | Gemma-7B | Llama-3-8B | Qwen1.5-7B | Qwen2-7B |
|
| 57 |
+
| :--------| :---------: | :------------: | :------------: | :------------: | :------------: |
|
| 58 |
+
|# Params | 7.2B | 8.5B | 8.0B | 7.7B | 7.6B |
|
| 59 |
+
|# Non-emb Params | 7.0B | 7.8B | 7.0B | 6.5B | 6.5B |
|
| 60 |
+
| ***English*** | | | | | |
|
| 61 |
+
|MMLU | 64.2 | 64.6 | 66.6 | 61.0 | **70.3** |
|
| 62 |
+
|MMLU-Pro | 30.9 | 33.7 | 35.4 | 29.9 | **40.0** |
|
| 63 |
+
|GPQA | 24.7 | 25.7 | 25.8 | 26.7 | **31.8** |
|
| 64 |
+
|Theorem QA | 19.2 | 21.5 | 22.1 | 14.2 | **31.1** |
|
| 65 |
+
|BBH | 56.1 | 55.1 | 57.7 | 40.2 | **62.6** |
|
| 66 |
+
|HellaSwag | **83.2** | 82.2 | 82.1 | 78.5 | 80.7 |
|
| 67 |
+
|Winogrande | 78.4 | **79.0** | 77.4 | 71.3 | 77.0 |
|
| 68 |
+
|ARC-C | 60.0 | **61.1** | 59.3 | 54.2 | 60.6 |
|
| 69 |
+
|TruthfulQA | 42.2 | 44.8 | 44.0 | 51.1 | **54.2** |
|
| 70 |
+
| ***Coding*** | | | | | |
|
| 71 |
+
|HumanEval | 29.3 | 37.2 | 33.5 | 36.0 | **51.2** |
|
| 72 |
+
|MBPP | 51.1 | 50.6 | 53.9 | 51.6 | **65.9** |
|
| 73 |
+
|EvalPlus | 36.4 | 39.6 | 40.3 | 40.0 | **54.2** |
|
| 74 |
+
|MultiPL-E | 29.4 | 29.7 | 22.6 | 28.1 | **46.3** |
|
| 75 |
+
| ***Mathematics*** | | | | | |
|
| 76 |
+
|GSM8K | 52.2 | 46.4 | 56.0 | 62.5 | **79.9** |
|
| 77 |
+
|MATH | 13.1 | 24.3 | 20.5 | 20.3 | **44.2** |
|
| 78 |
+
| ***Chinese*** | | | | | |
|
| 79 |
+
|C-Eval | 47.4 | 43.6 | 49.5 | 74.1 | **83.2** |
|
| 80 |
+
|CMMLU | - | - | 50.8 | 73.1 | **83.9** |
|
| 81 |
+
| ***Multilingual*** | | | | | |
|
| 82 |
+
|Multi-Exam | 47.1 | 42.7 | 52.3 | 47.7 | **59.2** |
|
| 83 |
+
|Multi-Understanding | 63.3 | 58.3 | 68.6 | 67.6 | **72.0** |
|
| 84 |
+
|Multi-Mathematics | 26.3 | 39.1 | 36.3 | 37.3 | **57.5** |
|
| 85 |
+
|Multi-Translation | 23.3 | 31.2 | **31.9** | 28.4 | 31.5 |
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
## Citation
|
| 89 |
+
|
| 90 |
+
If you find our work helpful, feel free to give us a cite.
|
| 91 |
+
|
| 92 |
+
```
|
| 93 |
+
@article{qwen2,
|
| 94 |
+
title={Qwen2 Technical Report},
|
| 95 |
+
year={2024}
|
| 96 |
+
}
|
| 97 |
+
```
|
third_party/Qwen2-7B/config.json
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"Qwen2ForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"attention_dropout": 0.0,
|
| 6 |
+
"bos_token_id": 151643,
|
| 7 |
+
"eos_token_id": 151643,
|
| 8 |
+
"hidden_act": "silu",
|
| 9 |
+
"hidden_size": 3584,
|
| 10 |
+
"initializer_range": 0.02,
|
| 11 |
+
"intermediate_size": 18944,
|
| 12 |
+
"max_position_embeddings": 131072,
|
| 13 |
+
"max_window_layers": 28,
|
| 14 |
+
"model_type": "qwen2",
|
| 15 |
+
"num_attention_heads": 28,
|
| 16 |
+
"num_hidden_layers": 28,
|
| 17 |
+
"num_key_value_heads": 4,
|
| 18 |
+
"rms_norm_eps": 1e-06,
|
| 19 |
+
"rope_theta": 1000000.0,
|
| 20 |
+
"sliding_window": 131072,
|
| 21 |
+
"tie_word_embeddings": false,
|
| 22 |
+
"torch_dtype": "bfloat16",
|
| 23 |
+
"transformers_version": "4.37.2",
|
| 24 |
+
"use_cache": true,
|
| 25 |
+
"use_sliding_window": false,
|
| 26 |
+
"vocab_size": 152064
|
| 27 |
+
}
|
third_party/Qwen2-7B/generation_config.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token_id": 151643,
|
| 3 |
+
"do_sample": false,
|
| 4 |
+
"eos_token_id": 151643,
|
| 5 |
+
"max_new_tokens": 2048,
|
| 6 |
+
"transformers_version": "4.37.0"
|
| 7 |
+
}
|
third_party/Qwen2-7B/merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
third_party/Qwen2-7B/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
third_party/Qwen2-7B/tokenizer_config.json
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_prefix_space": false,
|
| 3 |
+
"added_tokens_decoder": {
|
| 4 |
+
"151643": {
|
| 5 |
+
"content": "<|endoftext|>",
|
| 6 |
+
"lstrip": false,
|
| 7 |
+
"normalized": false,
|
| 8 |
+
"rstrip": false,
|
| 9 |
+
"single_word": false,
|
| 10 |
+
"special": true
|
| 11 |
+
},
|
| 12 |
+
"151644": {
|
| 13 |
+
"content": "<|im_start|>",
|
| 14 |
+
"lstrip": false,
|
| 15 |
+
"normalized": false,
|
| 16 |
+
"rstrip": false,
|
| 17 |
+
"single_word": false,
|
| 18 |
+
"special": true
|
| 19 |
+
},
|
| 20 |
+
"151645": {
|
| 21 |
+
"content": "<|im_end|>",
|
| 22 |
+
"lstrip": false,
|
| 23 |
+
"normalized": false,
|
| 24 |
+
"rstrip": false,
|
| 25 |
+
"single_word": false,
|
| 26 |
+
"special": true
|
| 27 |
+
}
|
| 28 |
+
},
|
| 29 |
+
"additional_special_tokens": ["<|im_start|>", "<|im_end|>"],
|
| 30 |
+
"bos_token": null,
|
| 31 |
+
"chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
| 32 |
+
"clean_up_tokenization_spaces": false,
|
| 33 |
+
"eos_token": "<|endoftext|>",
|
| 34 |
+
"errors": "replace",
|
| 35 |
+
"model_max_length": 32768,
|
| 36 |
+
"pad_token": "<|endoftext|>",
|
| 37 |
+
"split_special_tokens": false,
|
| 38 |
+
"tokenizer_class": "Qwen2Tokenizer",
|
| 39 |
+
"unk_token": null
|
| 40 |
+
}
|
third_party/Qwen2-7B/vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
third_party/demucs/__init__.py
ADDED
|
File without changes
|
third_party/demucs/ckpt/htdemucs.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e4378974c3df2fbcf872d2aeb32218e4de376799494579655775a375d09931c2
|
| 3 |
+
size 168138881
|
third_party/demucs/ckpt/htdemucs.yaml
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
models: ['htdemucs']
|
third_party/demucs/models/__init__.py
ADDED
|
File without changes
|
third_party/demucs/models/apply.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
@File : apply.py
|
| 5 |
+
@Time : 2023/8/8 下午4:22
|
| 6 |
+
@Author : waytan
|
| 7 |
+
@Contact : waytan@tencent.com
|
| 8 |
+
@License : (C)Copyright 2023, Tencent
|
| 9 |
+
@Desc : Apply
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 13 |
+
import torch
|
| 14 |
+
import os
|
| 15 |
+
import random
|
| 16 |
+
import typing as tp
|
| 17 |
+
|
| 18 |
+
import torch as th
|
| 19 |
+
from torch import nn
|
| 20 |
+
from torch.nn import functional as F
|
| 21 |
+
import tqdm
|
| 22 |
+
|
| 23 |
+
from .htdemucs import HTDemucs
|
| 24 |
+
from .audio import load_track, save_audio
|
| 25 |
+
from .utils import center_trim, DummyPoolExecutor
|
| 26 |
+
|
| 27 |
+
Model = tp.Union[HTDemucs]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class BagOfModels(nn.Module):
|
| 31 |
+
def __init__(self, models: tp.List[Model],
|
| 32 |
+
weights: tp.Optional[tp.List[tp.List[float]]] = None,
|
| 33 |
+
segment: tp.Optional[float] = None):
|
| 34 |
+
"""
|
| 35 |
+
Represents a bag of models with specific weights.
|
| 36 |
+
You should call `apply_model` rather than calling directly the forward here for
|
| 37 |
+
optimal performance.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
models (list[nn.Module]): list of Demucs/HDemucs models.
|
| 41 |
+
weights (list[list[float]]): list of weights. If None, assumed to
|
| 42 |
+
be all ones, otherwise it should be a list of N list (N number of models),
|
| 43 |
+
each containing S floats (S number of sources).
|
| 44 |
+
segment (None or float): overrides the `segment` attribute of each model
|
| 45 |
+
(this is performed inplace, be careful is you reuse the models passed).
|
| 46 |
+
"""
|
| 47 |
+
super().__init__()
|
| 48 |
+
assert len(models) > 0
|
| 49 |
+
first = models[0]
|
| 50 |
+
for other in models:
|
| 51 |
+
assert other.sources == first.sources
|
| 52 |
+
assert other.samplerate == first.samplerate
|
| 53 |
+
assert other.audio_channels == first.audio_channels
|
| 54 |
+
if segment is not None:
|
| 55 |
+
other.segment = segment
|
| 56 |
+
|
| 57 |
+
self.audio_channels = first.audio_channels
|
| 58 |
+
self.samplerate = first.samplerate
|
| 59 |
+
self.sources = first.sources
|
| 60 |
+
self.models = nn.ModuleList(models)
|
| 61 |
+
|
| 62 |
+
if weights is None:
|
| 63 |
+
weights = [[1. for _ in first.sources] for _ in models]
|
| 64 |
+
else:
|
| 65 |
+
assert len(weights) == len(models)
|
| 66 |
+
for weight in weights:
|
| 67 |
+
assert len(weight) == len(first.sources)
|
| 68 |
+
self.weights = weights
|
| 69 |
+
|
| 70 |
+
@property
|
| 71 |
+
def max_allowed_segment(self) -> float:
|
| 72 |
+
max_allowed_segment = float('inf')
|
| 73 |
+
for model in self.models:
|
| 74 |
+
if isinstance(model, HTDemucs):
|
| 75 |
+
max_allowed_segment = min(max_allowed_segment, float(model.segment))
|
| 76 |
+
return max_allowed_segment
|
| 77 |
+
|
| 78 |
+
def forward(self, x):
|
| 79 |
+
raise NotImplementedError("Call `apply_model` on this.")
|
| 80 |
+
|
| 81 |
+
def separate(self, source_file, output_dir, stem=None, device=None):
|
| 82 |
+
wav, _ = load_track(source_file, self.audio_channels, self.samplerate)
|
| 83 |
+
ref = wav.mean(0)
|
| 84 |
+
wav -= ref.mean()
|
| 85 |
+
wav /= ref.std()
|
| 86 |
+
sources = apply_model(self, wav[None], device=device, shifts=1, split=True, overlap=0.25,
|
| 87 |
+
progress=True, num_workers=0, segment=None)[0]
|
| 88 |
+
sources *= ref.std()
|
| 89 |
+
sources += ref.mean()
|
| 90 |
+
|
| 91 |
+
output_paths = []
|
| 92 |
+
name, ext = os.path.splitext(os.path.split(source_file)[-1])
|
| 93 |
+
if ext != ".flac":
|
| 94 |
+
ext = ".flac"
|
| 95 |
+
kwargs = {
|
| 96 |
+
'samplerate': self.samplerate,
|
| 97 |
+
'bitrate': 320,
|
| 98 |
+
'clip': "rescale",
|
| 99 |
+
'as_float': False,
|
| 100 |
+
'bits_per_sample': 16,
|
| 101 |
+
}
|
| 102 |
+
if stem is None:
|
| 103 |
+
for source, stem in zip(sources, self.sources):
|
| 104 |
+
output_stem_path = os.path.join(output_dir, f"{name}_{stem}{ext}")
|
| 105 |
+
save_audio(source, output_stem_path, **kwargs)
|
| 106 |
+
output_paths.append(output_stem_path)
|
| 107 |
+
else:
|
| 108 |
+
sources = list(sources)
|
| 109 |
+
output_stem_path = os.path.join(output_dir, f"{name}_{stem}{ext}")
|
| 110 |
+
save_audio(sources.pop(self.sources.index(stem)), output_stem_path, **kwargs)
|
| 111 |
+
other_stem = torch.zeros_like(sources[0])
|
| 112 |
+
for i in sources:
|
| 113 |
+
other_stem += i
|
| 114 |
+
output_no_stem_path = os.path.join(output_dir, f"{name}_no_{stem}{ext}")
|
| 115 |
+
save_audio(other_stem, output_no_stem_path, **kwargs)
|
| 116 |
+
output_paths = [output_stem_path, output_no_stem_path]
|
| 117 |
+
|
| 118 |
+
return output_paths
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class TensorChunk:
|
| 122 |
+
def __init__(self, tensor, offset=0, length=None):
|
| 123 |
+
total_length = tensor.shape[-1]
|
| 124 |
+
assert offset >= 0
|
| 125 |
+
assert offset < total_length
|
| 126 |
+
|
| 127 |
+
if length is None:
|
| 128 |
+
length = total_length - offset
|
| 129 |
+
else:
|
| 130 |
+
length = min(total_length - offset, length)
|
| 131 |
+
|
| 132 |
+
if isinstance(tensor, TensorChunk):
|
| 133 |
+
self.tensor = tensor.tensor
|
| 134 |
+
self.offset = offset + tensor.offset
|
| 135 |
+
else:
|
| 136 |
+
self.tensor = tensor
|
| 137 |
+
self.offset = offset
|
| 138 |
+
self.length = length
|
| 139 |
+
self.device = tensor.device
|
| 140 |
+
|
| 141 |
+
@property
|
| 142 |
+
def shape(self):
|
| 143 |
+
shape = list(self.tensor.shape)
|
| 144 |
+
shape[-1] = self.length
|
| 145 |
+
return shape
|
| 146 |
+
|
| 147 |
+
def padded(self, target_length):
|
| 148 |
+
delta = target_length - self.length
|
| 149 |
+
total_length = self.tensor.shape[-1]
|
| 150 |
+
assert delta >= 0
|
| 151 |
+
|
| 152 |
+
start = self.offset - delta // 2
|
| 153 |
+
end = start + target_length
|
| 154 |
+
|
| 155 |
+
correct_start = max(0, start)
|
| 156 |
+
correct_end = min(total_length, end)
|
| 157 |
+
|
| 158 |
+
pad_left = correct_start - start
|
| 159 |
+
pad_right = end - correct_end
|
| 160 |
+
|
| 161 |
+
out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right))
|
| 162 |
+
assert out.shape[-1] == target_length
|
| 163 |
+
return out
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def tensor_chunk(tensor_or_chunk):
|
| 167 |
+
if isinstance(tensor_or_chunk, TensorChunk):
|
| 168 |
+
return tensor_or_chunk
|
| 169 |
+
else:
|
| 170 |
+
assert isinstance(tensor_or_chunk, th.Tensor)
|
| 171 |
+
return TensorChunk(tensor_or_chunk)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def apply_model(model: tp.Union[BagOfModels, Model],
|
| 175 |
+
mix: tp.Union[th.Tensor, TensorChunk],
|
| 176 |
+
shifts: int = 1, split: bool = True,
|
| 177 |
+
overlap: float = 0.25, transition_power: float = 1.,
|
| 178 |
+
progress: bool = False, device=None,
|
| 179 |
+
num_workers: int = 0, segment: tp.Optional[float] = None,
|
| 180 |
+
pool=None) -> th.Tensor:
|
| 181 |
+
"""
|
| 182 |
+
Apply model to a given mixture.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
|
| 186 |
+
and apply the oppositve shift to the output. This is repeated `shifts` time and
|
| 187 |
+
all predictions are averaged. This effectively makes the model time equivariant
|
| 188 |
+
and improves SDR by up to 0.2 points.
|
| 189 |
+
split (bool): if True, the input will be broken down in 8 seconds extracts
|
| 190 |
+
and predictions will be performed individually on each and concatenated.
|
| 191 |
+
Useful for model with large memory footprint like Tasnet.
|
| 192 |
+
progress (bool): if True, show a progress bar (requires split=True)
|
| 193 |
+
device (torch.device, str, or None): if provided, device on which to
|
| 194 |
+
execute the computation, otherwise `mix.device` is assumed.
|
| 195 |
+
When `device` is different from `mix.device`, only local computations will
|
| 196 |
+
be on `device`, while the entire tracks will be stored on `mix.device`.
|
| 197 |
+
num_workers (int): if non zero, device is 'cpu', how many threads to
|
| 198 |
+
use in parallel.
|
| 199 |
+
segment (float or None): override the model segment parameter.
|
| 200 |
+
"""
|
| 201 |
+
if device is None:
|
| 202 |
+
device = mix.device
|
| 203 |
+
else:
|
| 204 |
+
device = th.device(device)
|
| 205 |
+
if pool is None:
|
| 206 |
+
if num_workers > 0 and device.type == 'cpu':
|
| 207 |
+
pool = ThreadPoolExecutor(num_workers)
|
| 208 |
+
else:
|
| 209 |
+
pool = DummyPoolExecutor()
|
| 210 |
+
kwargs: tp.Dict[str, tp.Any] = {
|
| 211 |
+
'shifts': shifts,
|
| 212 |
+
'split': split,
|
| 213 |
+
'overlap': overlap,
|
| 214 |
+
'transition_power': transition_power,
|
| 215 |
+
'progress': progress,
|
| 216 |
+
'device': device,
|
| 217 |
+
'pool': pool,
|
| 218 |
+
'segment': segment,
|
| 219 |
+
}
|
| 220 |
+
out: tp.Union[float, th.Tensor]
|
| 221 |
+
if isinstance(model, BagOfModels):
|
| 222 |
+
# Special treatment for bag of model.
|
| 223 |
+
# We explicitely apply multiple times `apply_model` so that the random shifts
|
| 224 |
+
# are different for each model.
|
| 225 |
+
estimates: tp.Union[float, th.Tensor] = 0.
|
| 226 |
+
totals = [0.] * len(model.sources)
|
| 227 |
+
for sub_model, model_weights in zip(model.models, model.weights):
|
| 228 |
+
original_model_device = next(iter(sub_model.parameters())).device
|
| 229 |
+
sub_model.to(device)
|
| 230 |
+
|
| 231 |
+
out = apply_model(sub_model, mix, **kwargs)
|
| 232 |
+
sub_model.to(original_model_device)
|
| 233 |
+
for k, inst_weight in enumerate(model_weights):
|
| 234 |
+
out[:, k, :, :] *= inst_weight
|
| 235 |
+
totals[k] += inst_weight
|
| 236 |
+
estimates += out
|
| 237 |
+
del out
|
| 238 |
+
|
| 239 |
+
assert isinstance(estimates, th.Tensor)
|
| 240 |
+
for k in range(estimates.shape[1]):
|
| 241 |
+
estimates[:, k, :, :] /= totals[k]
|
| 242 |
+
return estimates
|
| 243 |
+
|
| 244 |
+
model.to(device)
|
| 245 |
+
model.eval()
|
| 246 |
+
assert transition_power >= 1, "transition_power < 1 leads to weird behavior."
|
| 247 |
+
batch, channels, length = mix.shape
|
| 248 |
+
if shifts:
|
| 249 |
+
kwargs['shifts'] = 0
|
| 250 |
+
max_shift = int(0.5 * model.samplerate)
|
| 251 |
+
mix = tensor_chunk(mix)
|
| 252 |
+
assert isinstance(mix, TensorChunk)
|
| 253 |
+
padded_mix = mix.padded(length + 2 * max_shift)
|
| 254 |
+
out = 0.
|
| 255 |
+
for _ in range(shifts):
|
| 256 |
+
offset = random.randint(0, max_shift)
|
| 257 |
+
shifted = TensorChunk(padded_mix, offset, length + max_shift - offset)
|
| 258 |
+
shifted_out = apply_model(model, shifted, **kwargs)
|
| 259 |
+
out += shifted_out[..., max_shift - offset:]
|
| 260 |
+
out /= shifts
|
| 261 |
+
assert isinstance(out, th.Tensor)
|
| 262 |
+
return out
|
| 263 |
+
elif split:
|
| 264 |
+
kwargs['split'] = False
|
| 265 |
+
out = th.zeros(batch, len(model.sources), channels, length, device=mix.device)
|
| 266 |
+
sum_weight = th.zeros(length, device=mix.device)
|
| 267 |
+
if segment is None:
|
| 268 |
+
segment = model.segment
|
| 269 |
+
assert segment is not None and segment > 0.
|
| 270 |
+
segment_length: int = int(model.samplerate * segment)
|
| 271 |
+
stride = int((1 - overlap) * segment_length)
|
| 272 |
+
offsets = range(0, length, stride)
|
| 273 |
+
scale = float(format(stride / model.samplerate, ".2f"))
|
| 274 |
+
# We start from a triangle shaped weight, with maximal weight in the middle
|
| 275 |
+
# of the segment. Then we normalize and take to the power `transition_power`.
|
| 276 |
+
# Large values of transition power will lead to sharper transitions.
|
| 277 |
+
weight = th.cat([th.arange(1, segment_length // 2 + 1, device=device),
|
| 278 |
+
th.arange(segment_length - segment_length // 2, 0, -1, device=device)])
|
| 279 |
+
assert len(weight) == segment_length
|
| 280 |
+
# If the overlap < 50%, this will translate to linear transition when
|
| 281 |
+
# transition_power is 1.
|
| 282 |
+
weight = (weight / weight.max())**transition_power
|
| 283 |
+
futures = []
|
| 284 |
+
for offset in offsets:
|
| 285 |
+
chunk = TensorChunk(mix, offset, segment_length)
|
| 286 |
+
future = pool.submit(apply_model, model, chunk, **kwargs)
|
| 287 |
+
futures.append((future, offset))
|
| 288 |
+
offset += segment_length
|
| 289 |
+
if progress:
|
| 290 |
+
futures = tqdm.tqdm(futures, unit_scale=scale, ncols=120, unit='seconds')
|
| 291 |
+
for future, offset in futures:
|
| 292 |
+
chunk_out = future.result()
|
| 293 |
+
chunk_length = chunk_out.shape[-1]
|
| 294 |
+
out[..., offset:offset + segment_length] += (
|
| 295 |
+
weight[:chunk_length] * chunk_out).to(mix.device)
|
| 296 |
+
sum_weight[offset:offset + segment_length] += weight[:chunk_length].to(mix.device)
|
| 297 |
+
assert sum_weight.min() > 0
|
| 298 |
+
out /= sum_weight
|
| 299 |
+
assert isinstance(out, th.Tensor)
|
| 300 |
+
return out
|
| 301 |
+
else:
|
| 302 |
+
valid_length: int
|
| 303 |
+
if isinstance(model, HTDemucs) and segment is not None:
|
| 304 |
+
valid_length = int(segment * model.samplerate)
|
| 305 |
+
elif hasattr(model, 'valid_length'):
|
| 306 |
+
valid_length = model.valid_length(length) # type: ignore
|
| 307 |
+
else:
|
| 308 |
+
valid_length = length
|
| 309 |
+
mix = tensor_chunk(mix)
|
| 310 |
+
assert isinstance(mix, TensorChunk)
|
| 311 |
+
padded_mix = mix.padded(valid_length).to(device)
|
| 312 |
+
with th.no_grad():
|
| 313 |
+
out = model(padded_mix)
|
| 314 |
+
assert isinstance(out, th.Tensor)
|
| 315 |
+
return center_trim(out, length)
|
third_party/demucs/models/audio.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
@File : audio.py
|
| 5 |
+
@Time : 2023/8/8 下午7:18
|
| 6 |
+
@Author : waytan
|
| 7 |
+
@Contact : waytan@tencent.com
|
| 8 |
+
@License : (C)Copyright 2023, Tencent
|
| 9 |
+
@Desc : Audio
|
| 10 |
+
"""
|
| 11 |
+
import json
|
| 12 |
+
import subprocess as sp
|
| 13 |
+
import typing as tp
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
import lameenc
|
| 17 |
+
import julius
|
| 18 |
+
import torch
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torchaudio as ta
|
| 21 |
+
|
| 22 |
+
from .utils import temp_filenames
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _read_info(path):
|
| 26 |
+
stdout_data = sp.check_output([
|
| 27 |
+
'ffprobe', "-loglevel", "panic",
|
| 28 |
+
str(path), '-print_format', 'json', '-show_format', '-show_streams'
|
| 29 |
+
])
|
| 30 |
+
return json.loads(stdout_data.decode('utf-8'))
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class AudioFile:
|
| 34 |
+
"""
|
| 35 |
+
Allows to read audio from any format supported by ffmpeg, as well as resampling or
|
| 36 |
+
converting to mono on the fly. See :method:`read` for more details.
|
| 37 |
+
"""
|
| 38 |
+
def __init__(self, path: Path):
|
| 39 |
+
self.path = Path(path)
|
| 40 |
+
self._info = None
|
| 41 |
+
|
| 42 |
+
def __repr__(self):
|
| 43 |
+
features = [("path", self.path)]
|
| 44 |
+
features.append(("samplerate", self.samplerate()))
|
| 45 |
+
features.append(("channels", self.channels()))
|
| 46 |
+
features.append(("streams", len(self)))
|
| 47 |
+
features_str = ", ".join(f"{name}={value}" for name, value in features)
|
| 48 |
+
return f"AudioFile({features_str})"
|
| 49 |
+
|
| 50 |
+
@property
|
| 51 |
+
def info(self):
|
| 52 |
+
if self._info is None:
|
| 53 |
+
self._info = _read_info(self.path)
|
| 54 |
+
return self._info
|
| 55 |
+
|
| 56 |
+
@property
|
| 57 |
+
def duration(self):
|
| 58 |
+
return float(self.info['format']['duration'])
|
| 59 |
+
|
| 60 |
+
@property
|
| 61 |
+
def _audio_streams(self):
|
| 62 |
+
return [
|
| 63 |
+
index for index, stream in enumerate(self.info["streams"])
|
| 64 |
+
if stream["codec_type"] == "audio"
|
| 65 |
+
]
|
| 66 |
+
|
| 67 |
+
def __len__(self):
|
| 68 |
+
return len(self._audio_streams)
|
| 69 |
+
|
| 70 |
+
def channels(self, stream=0):
|
| 71 |
+
return int(self.info['streams'][self._audio_streams[stream]]['channels'])
|
| 72 |
+
|
| 73 |
+
def samplerate(self, stream=0):
|
| 74 |
+
return int(self.info['streams'][self._audio_streams[stream]]['sample_rate'])
|
| 75 |
+
|
| 76 |
+
def read(self,
|
| 77 |
+
seek_time=None,
|
| 78 |
+
duration=None,
|
| 79 |
+
streams=slice(None),
|
| 80 |
+
samplerate=None,
|
| 81 |
+
channels=None):
|
| 82 |
+
"""
|
| 83 |
+
Slightly more efficient implementation than stempeg,
|
| 84 |
+
in particular, this will extract all stems at once
|
| 85 |
+
rather than having to loop over one file multiple times
|
| 86 |
+
for each stream.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
seek_time (float): seek time in seconds or None if no seeking is needed.
|
| 90 |
+
duration (float): duration in seconds to extract or None to extract until the end.
|
| 91 |
+
streams (slice, int or list): streams to extract, can be a single int, a list or
|
| 92 |
+
a slice. If it is a slice or list, the output will be of size [S, C, T]
|
| 93 |
+
with S the number of streams, C the number of channels and T the number of samples.
|
| 94 |
+
If it is an int, the output will be [C, T].
|
| 95 |
+
samplerate (int): if provided, will resample on the fly. If None, no resampling will
|
| 96 |
+
be done. Original sampling rate can be obtained with :method:`samplerate`.
|
| 97 |
+
channels (int): if 1, will convert to mono. We do not rely on ffmpeg for that
|
| 98 |
+
as ffmpeg automatically scale by +3dB to conserve volume when playing on speakers.
|
| 99 |
+
See https://sound.stackexchange.com/a/42710.
|
| 100 |
+
Our definition of mono is simply the average of the two channels. Any other
|
| 101 |
+
value will be ignored.
|
| 102 |
+
"""
|
| 103 |
+
streams = np.array(range(len(self)))[streams]
|
| 104 |
+
single = not isinstance(streams, np.ndarray)
|
| 105 |
+
if single:
|
| 106 |
+
streams = [streams]
|
| 107 |
+
|
| 108 |
+
if duration is None:
|
| 109 |
+
target_size = None
|
| 110 |
+
query_duration = None
|
| 111 |
+
else:
|
| 112 |
+
target_size = int((samplerate or self.samplerate()) * duration)
|
| 113 |
+
query_duration = float((target_size + 1) / (samplerate or self.samplerate()))
|
| 114 |
+
|
| 115 |
+
with temp_filenames(len(streams)) as filenames:
|
| 116 |
+
command = ['ffmpeg', '-y']
|
| 117 |
+
command += ['-loglevel', 'panic']
|
| 118 |
+
if seek_time:
|
| 119 |
+
command += ['-ss', str(seek_time)]
|
| 120 |
+
command += ['-i', str(self.path)]
|
| 121 |
+
for stream, filename in zip(streams, filenames):
|
| 122 |
+
command += ['-map', f'0:{self._audio_streams[stream]}']
|
| 123 |
+
if query_duration is not None:
|
| 124 |
+
command += ['-t', str(query_duration)]
|
| 125 |
+
command += ['-threads', '1']
|
| 126 |
+
command += ['-f', 'f32le']
|
| 127 |
+
if samplerate is not None:
|
| 128 |
+
command += ['-ar', str(samplerate)]
|
| 129 |
+
command += [filename]
|
| 130 |
+
|
| 131 |
+
sp.run(command, check=True)
|
| 132 |
+
wavs = []
|
| 133 |
+
for filename in filenames:
|
| 134 |
+
wav = np.fromfile(filename, dtype=np.float32)
|
| 135 |
+
wav = torch.from_numpy(wav)
|
| 136 |
+
wav = wav.view(-1, self.channels()).t()
|
| 137 |
+
if channels is not None:
|
| 138 |
+
wav = convert_audio_channels(wav, channels)
|
| 139 |
+
if target_size is not None:
|
| 140 |
+
wav = wav[..., :target_size]
|
| 141 |
+
wavs.append(wav)
|
| 142 |
+
wav = torch.stack(wavs, dim=0)
|
| 143 |
+
if single:
|
| 144 |
+
wav = wav[0]
|
| 145 |
+
return wav
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def convert_audio_channels(wav, channels=2):
|
| 149 |
+
"""Convert audio to the given number of channels."""
|
| 150 |
+
*shape, src_channels, length = wav.shape
|
| 151 |
+
if src_channels == channels:
|
| 152 |
+
pass
|
| 153 |
+
elif channels == 1:
|
| 154 |
+
# Case 1:
|
| 155 |
+
# The caller asked 1-channel audio, but the stream have multiple
|
| 156 |
+
# channels, downmix all channels.
|
| 157 |
+
wav = wav.mean(dim=-2, keepdim=True)
|
| 158 |
+
elif src_channels == 1:
|
| 159 |
+
# Case 2:
|
| 160 |
+
# The caller asked for multiple channels, but the input file have
|
| 161 |
+
# one single channel, replicate the audio over all channels.
|
| 162 |
+
wav = wav.expand(*shape, channels, length)
|
| 163 |
+
elif src_channels >= channels:
|
| 164 |
+
# Case 3:
|
| 165 |
+
# The caller asked for multiple channels, and the input file have
|
| 166 |
+
# more channels than requested. In that case return the first channels.
|
| 167 |
+
wav = wav[..., :channels, :]
|
| 168 |
+
else:
|
| 169 |
+
# Case 4: What is a reasonable choice here?
|
| 170 |
+
raise ValueError('The audio file has less channels than requested but is not mono.')
|
| 171 |
+
return wav
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def convert_audio(wav, from_samplerate, to_samplerate, channels):
|
| 175 |
+
"""Convert audio from a given samplerate to a target one and target number of channels."""
|
| 176 |
+
wav = convert_audio_channels(wav, channels)
|
| 177 |
+
return julius.resample_frac(wav, from_samplerate, to_samplerate)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def i16_pcm(wav):
|
| 181 |
+
"""Convert audio to 16 bits integer PCM format."""
|
| 182 |
+
if wav.dtype.is_floating_point:
|
| 183 |
+
return (wav.clamp_(-1, 1) * (2**15 - 1)).short()
|
| 184 |
+
else:
|
| 185 |
+
return wav
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def f32_pcm(wav):
|
| 189 |
+
"""Convert audio to float 32 bits PCM format."""
|
| 190 |
+
if wav.dtype.is_floating_point:
|
| 191 |
+
return wav
|
| 192 |
+
else:
|
| 193 |
+
return wav.float() / (2**15 - 1)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def as_dtype_pcm(wav):
|
| 197 |
+
"""Convert audio to either f32 pcm or i16 pcm depending on the given dtype."""
|
| 198 |
+
if wav.dtype.is_floating_point:
|
| 199 |
+
return f32_pcm(wav)
|
| 200 |
+
else:
|
| 201 |
+
return i16_pcm(wav)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def encode_mp3(wav, path, samplerate=44100, bitrate=320, verbose=False):
|
| 205 |
+
"""Save given audio as mp3. This should work on all OSes."""
|
| 206 |
+
c, _ = wav.shape
|
| 207 |
+
wav = i16_pcm(wav)
|
| 208 |
+
encoder = lameenc.Encoder()
|
| 209 |
+
encoder.set_bit_rate(bitrate)
|
| 210 |
+
encoder.set_in_sample_rate(samplerate)
|
| 211 |
+
encoder.set_channels(c)
|
| 212 |
+
encoder.set_quality(2) # 2-highest, 7-fastest
|
| 213 |
+
if not verbose:
|
| 214 |
+
encoder.silence()
|
| 215 |
+
wav = wav.data.cpu()
|
| 216 |
+
wav = wav.transpose(0, 1).numpy()
|
| 217 |
+
mp3_data = encoder.encode(wav.tobytes())
|
| 218 |
+
mp3_data += encoder.flush()
|
| 219 |
+
with open(path, "wb") as f:
|
| 220 |
+
f.write(mp3_data)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def prevent_clip(wav, mode='rescale'):
|
| 224 |
+
"""
|
| 225 |
+
different strategies for avoiding raw clipping.
|
| 226 |
+
"""
|
| 227 |
+
if mode is None or mode == 'none':
|
| 228 |
+
return wav
|
| 229 |
+
assert wav.dtype.is_floating_point, "too late for clipping"
|
| 230 |
+
if mode == 'rescale':
|
| 231 |
+
wav = wav / max(1.01 * wav.abs().max(), 1)
|
| 232 |
+
elif mode == 'clamp':
|
| 233 |
+
wav = wav.clamp(-0.99, 0.99)
|
| 234 |
+
elif mode == 'tanh':
|
| 235 |
+
wav = torch.tanh(wav)
|
| 236 |
+
else:
|
| 237 |
+
raise ValueError(f"Invalid mode {mode}")
|
| 238 |
+
return wav
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def save_audio(wav: torch.Tensor,
|
| 242 |
+
path: tp.Union[str, Path],
|
| 243 |
+
samplerate: int,
|
| 244 |
+
bitrate: int = 320,
|
| 245 |
+
clip: tp.Union[str] = 'rescale',
|
| 246 |
+
bits_per_sample: tp.Union[int] = 16,
|
| 247 |
+
as_float: bool = False):
|
| 248 |
+
"""Save audio file, automatically preventing clipping if necessary
|
| 249 |
+
based on the given `clip` strategy. If the path ends in `.mp3`, this
|
| 250 |
+
will save as mp3 with the given `bitrate`.
|
| 251 |
+
"""
|
| 252 |
+
wav = prevent_clip(wav, mode=clip)
|
| 253 |
+
path = Path(path)
|
| 254 |
+
suffix = path.suffix.lower()
|
| 255 |
+
if suffix == ".mp3":
|
| 256 |
+
encode_mp3(wav, path, samplerate, bitrate, verbose=True)
|
| 257 |
+
elif suffix == ".wav":
|
| 258 |
+
if as_float:
|
| 259 |
+
bits_per_sample = 32
|
| 260 |
+
encoding = 'PCM_F'
|
| 261 |
+
else:
|
| 262 |
+
encoding = 'PCM_S'
|
| 263 |
+
ta.save(str(path), wav, sample_rate=samplerate,
|
| 264 |
+
encoding=encoding, bits_per_sample=bits_per_sample)
|
| 265 |
+
elif suffix == ".flac":
|
| 266 |
+
ta.save(str(path), wav, sample_rate=samplerate, bits_per_sample=bits_per_sample)
|
| 267 |
+
else:
|
| 268 |
+
raise ValueError(f"Invalid suffix for path: {suffix}")
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def load_track(track, audio_channels, samplerate):
|
| 272 |
+
errors = {}
|
| 273 |
+
wav = None
|
| 274 |
+
|
| 275 |
+
try:
|
| 276 |
+
wav = AudioFile(track).read(
|
| 277 |
+
streams=0,
|
| 278 |
+
samplerate=samplerate,
|
| 279 |
+
channels=audio_channels)
|
| 280 |
+
except sp.CalledProcessError:
|
| 281 |
+
errors['ffmpeg'] = 'FFmpeg could not read the file.'
|
| 282 |
+
|
| 283 |
+
if wav is None:
|
| 284 |
+
try:
|
| 285 |
+
wav, sr = ta.load(str(track))
|
| 286 |
+
except RuntimeError as err:
|
| 287 |
+
errors['torchaudio'] = err.args[0]
|
| 288 |
+
else:
|
| 289 |
+
wav = convert_audio(wav, sr, samplerate, audio_channels)
|
| 290 |
+
|
| 291 |
+
return wav, errors
|
third_party/demucs/models/demucs.py
ADDED
|
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
@File : demucs.py
|
| 5 |
+
@Time : 2023/8/8 下午4:36
|
| 6 |
+
@Author : waytan
|
| 7 |
+
@Contact : waytan@tencent.com
|
| 8 |
+
@License : (C)Copyright 2023, Tencent
|
| 9 |
+
@Desc : Demucs
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import math
|
| 13 |
+
import typing as tp
|
| 14 |
+
|
| 15 |
+
import julius
|
| 16 |
+
import torch
|
| 17 |
+
from torch import nn
|
| 18 |
+
from torch.nn import functional as F
|
| 19 |
+
|
| 20 |
+
from .states import capture_init
|
| 21 |
+
from .utils import center_trim, unfold
|
| 22 |
+
from .transformer import LayerScale
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class BLSTM(nn.Module):
|
| 26 |
+
"""
|
| 27 |
+
BiLSTM with same hidden units as input dim.
|
| 28 |
+
If `max_steps` is not None, input will be splitting in overlapping
|
| 29 |
+
chunks and the LSTM applied separately on each chunk.
|
| 30 |
+
"""
|
| 31 |
+
def __init__(self, dim, layers=1, max_steps=None, skip=False):
|
| 32 |
+
super().__init__()
|
| 33 |
+
assert max_steps is None or max_steps % 4 == 0
|
| 34 |
+
self.max_steps = max_steps
|
| 35 |
+
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
|
| 36 |
+
self.linear = nn.Linear(2 * dim, dim)
|
| 37 |
+
self.skip = skip
|
| 38 |
+
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
b, c, t = x.shape
|
| 41 |
+
y = x
|
| 42 |
+
framed = False
|
| 43 |
+
if self.max_steps is not None and t > self.max_steps:
|
| 44 |
+
width = self.max_steps
|
| 45 |
+
stride = width // 2
|
| 46 |
+
frames = unfold(x, width, stride)
|
| 47 |
+
nframes = frames.shape[2]
|
| 48 |
+
framed = True
|
| 49 |
+
x = frames.permute(0, 2, 1, 3).reshape(-1, c, width)
|
| 50 |
+
|
| 51 |
+
x = x.permute(2, 0, 1)
|
| 52 |
+
|
| 53 |
+
x = self.lstm(x)[0]
|
| 54 |
+
x = self.linear(x)
|
| 55 |
+
x = x.permute(1, 2, 0)
|
| 56 |
+
if framed:
|
| 57 |
+
out = []
|
| 58 |
+
frames = x.reshape(b, -1, c, width)
|
| 59 |
+
limit = stride // 2
|
| 60 |
+
for k in range(nframes):
|
| 61 |
+
if k == 0:
|
| 62 |
+
out.append(frames[:, k, :, :-limit])
|
| 63 |
+
elif k == nframes - 1:
|
| 64 |
+
out.append(frames[:, k, :, limit:])
|
| 65 |
+
else:
|
| 66 |
+
out.append(frames[:, k, :, limit:-limit])
|
| 67 |
+
out = torch.cat(out, -1)
|
| 68 |
+
out = out[..., :t]
|
| 69 |
+
x = out
|
| 70 |
+
if self.skip:
|
| 71 |
+
x = x + y
|
| 72 |
+
return x
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def rescale_conv(conv, reference):
|
| 76 |
+
"""Rescale initial weight scale. It is unclear why it helps but it certainly does.
|
| 77 |
+
"""
|
| 78 |
+
std = conv.weight.std().detach()
|
| 79 |
+
scale = (std / reference)**0.5
|
| 80 |
+
conv.weight.data /= scale
|
| 81 |
+
if conv.bias is not None:
|
| 82 |
+
conv.bias.data /= scale
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def rescale_module(module, reference):
|
| 86 |
+
for sub in module.modules():
|
| 87 |
+
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)):
|
| 88 |
+
rescale_conv(sub, reference)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class DConv(nn.Module):
|
| 92 |
+
"""
|
| 93 |
+
New residual branches in each encoder layer.
|
| 94 |
+
This alternates dilated convolutions, potentially with LSTMs and attention.
|
| 95 |
+
Also before entering each residual branch, dimension is projected on a smaller subspace,
|
| 96 |
+
e.g. of dim `channels // compress`.
|
| 97 |
+
"""
|
| 98 |
+
def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4,
|
| 99 |
+
norm=True, attn=False, heads=4, ndecay=4, lstm=False, gelu=True,
|
| 100 |
+
kernel=3, dilate=True):
|
| 101 |
+
"""
|
| 102 |
+
Args:
|
| 103 |
+
channels: input/output channels for residual branch.
|
| 104 |
+
compress: amount of channel compression inside the branch.
|
| 105 |
+
depth: number of layers in the residual branch. Each layer has its own
|
| 106 |
+
projection, and potentially LSTM and attention.
|
| 107 |
+
init: initial scale for LayerNorm.
|
| 108 |
+
norm: use GroupNorm.
|
| 109 |
+
attn: use LocalAttention.
|
| 110 |
+
heads: number of heads for the LocalAttention.
|
| 111 |
+
ndecay: number of decay controls in the LocalAttention.
|
| 112 |
+
lstm: use LSTM.
|
| 113 |
+
gelu: Use GELU activation.
|
| 114 |
+
kernel: kernel size for the (dilated) convolutions.
|
| 115 |
+
dilate: if true, use dilation, increasing with the depth.
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
super().__init__()
|
| 119 |
+
assert kernel % 2 == 1
|
| 120 |
+
self.channels = channels
|
| 121 |
+
self.compress = compress
|
| 122 |
+
self.depth = abs(depth)
|
| 123 |
+
dilate = depth > 0
|
| 124 |
+
|
| 125 |
+
norm_fn: tp.Callable[[int], nn.Module]
|
| 126 |
+
norm_fn = lambda d: nn.Identity() # noqa
|
| 127 |
+
if norm:
|
| 128 |
+
norm_fn = lambda d: nn.GroupNorm(1, d) # noqa
|
| 129 |
+
|
| 130 |
+
hidden = int(channels / compress)
|
| 131 |
+
|
| 132 |
+
act: tp.Type[nn.Module]
|
| 133 |
+
if gelu:
|
| 134 |
+
act = nn.GELU
|
| 135 |
+
else:
|
| 136 |
+
act = nn.ReLU
|
| 137 |
+
|
| 138 |
+
self.layers = nn.ModuleList([])
|
| 139 |
+
for d in range(self.depth):
|
| 140 |
+
dilation = 2 ** d if dilate else 1
|
| 141 |
+
padding = dilation * (kernel // 2)
|
| 142 |
+
mods = [
|
| 143 |
+
nn.Conv1d(channels, hidden, kernel, dilation=dilation, padding=padding),
|
| 144 |
+
norm_fn(hidden), act(),
|
| 145 |
+
nn.Conv1d(hidden, 2 * channels, 1),
|
| 146 |
+
norm_fn(2 * channels), nn.GLU(1),
|
| 147 |
+
LayerScale(channels, init),
|
| 148 |
+
]
|
| 149 |
+
if attn:
|
| 150 |
+
mods.insert(3, LocalState(hidden, heads=heads, ndecay=ndecay))
|
| 151 |
+
if lstm:
|
| 152 |
+
mods.insert(3, BLSTM(hidden, layers=2, max_steps=200, skip=True))
|
| 153 |
+
layer = nn.Sequential(*mods)
|
| 154 |
+
self.layers.append(layer)
|
| 155 |
+
|
| 156 |
+
def forward(self, x):
|
| 157 |
+
for layer in self.layers:
|
| 158 |
+
x = x + layer(x)
|
| 159 |
+
return x
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class LocalState(nn.Module):
|
| 163 |
+
"""Local state allows to have attention based only on data (no positional embedding),
|
| 164 |
+
but while setting a constraint on the time window (e.g. decaying penalty term).
|
| 165 |
+
|
| 166 |
+
Also a failed experiments with trying to provide some frequency based attention.
|
| 167 |
+
"""
|
| 168 |
+
def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4):
|
| 169 |
+
super().__init__()
|
| 170 |
+
assert channels % heads == 0, (channels, heads)
|
| 171 |
+
self.heads = heads
|
| 172 |
+
self.nfreqs = nfreqs
|
| 173 |
+
self.ndecay = ndecay
|
| 174 |
+
self.content = nn.Conv1d(channels, channels, 1)
|
| 175 |
+
self.query = nn.Conv1d(channels, channels, 1)
|
| 176 |
+
self.key = nn.Conv1d(channels, channels, 1)
|
| 177 |
+
if nfreqs:
|
| 178 |
+
self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1)
|
| 179 |
+
if ndecay:
|
| 180 |
+
self.query_decay = nn.Conv1d(channels, heads * ndecay, 1)
|
| 181 |
+
# Initialize decay close to zero (there is a sigmoid), for maximum initial window.
|
| 182 |
+
self.query_decay.weight.data *= 0.01
|
| 183 |
+
assert self.query_decay.bias is not None # stupid type checker
|
| 184 |
+
self.query_decay.bias.data[:] = -2
|
| 185 |
+
self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1)
|
| 186 |
+
|
| 187 |
+
def forward(self, x):
|
| 188 |
+
b, _, t = x.shape
|
| 189 |
+
heads = self.heads
|
| 190 |
+
indexes = torch.arange(t, device=x.device, dtype=x.dtype)
|
| 191 |
+
# left index are keys, right index are queries
|
| 192 |
+
delta = indexes[:, None] - indexes[None, :]
|
| 193 |
+
|
| 194 |
+
queries = self.query(x).view(b, heads, -1, t)
|
| 195 |
+
keys = self.key(x).view(b, heads, -1, t)
|
| 196 |
+
# t are keys, s are queries
|
| 197 |
+
dots = torch.einsum("bhct,bhcs->bhts", keys, queries)
|
| 198 |
+
dots /= keys.shape[2]**0.5
|
| 199 |
+
if self.nfreqs:
|
| 200 |
+
periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype)
|
| 201 |
+
freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1))
|
| 202 |
+
freq_q = self.query_freqs(x).view(b, heads, -1, t) / self.nfreqs ** 0.5
|
| 203 |
+
dots += torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q)
|
| 204 |
+
if self.ndecay:
|
| 205 |
+
decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype)
|
| 206 |
+
decay_q = self.query_decay(x).view(b, heads, -1, t)
|
| 207 |
+
decay_q = torch.sigmoid(decay_q) / 2
|
| 208 |
+
decay_kernel = - decays.view(-1, 1, 1) * delta.abs() / self.ndecay**0.5
|
| 209 |
+
dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q)
|
| 210 |
+
|
| 211 |
+
# Kill self reference.
|
| 212 |
+
dots.masked_fill_(torch.eye(t, device=dots.device, dtype=torch.bool), -100)
|
| 213 |
+
weights = torch.softmax(dots, dim=2)
|
| 214 |
+
|
| 215 |
+
content = self.content(x).view(b, heads, -1, t)
|
| 216 |
+
result = torch.einsum("bhts,bhct->bhcs", weights, content)
|
| 217 |
+
if self.nfreqs:
|
| 218 |
+
time_sig = torch.einsum("bhts,fts->bhfs", weights, freq_kernel)
|
| 219 |
+
result = torch.cat([result, time_sig], 2)
|
| 220 |
+
result = result.reshape(b, -1, t)
|
| 221 |
+
return x + self.proj(result)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class Demucs(nn.Module):
|
| 225 |
+
@capture_init
|
| 226 |
+
def __init__(self,
|
| 227 |
+
sources,
|
| 228 |
+
# Channels
|
| 229 |
+
audio_channels=2,
|
| 230 |
+
channels=64,
|
| 231 |
+
growth=2.,
|
| 232 |
+
# Main structure
|
| 233 |
+
depth=6,
|
| 234 |
+
rewrite=True,
|
| 235 |
+
lstm_layers=0,
|
| 236 |
+
# Convolutions
|
| 237 |
+
kernel_size=8,
|
| 238 |
+
stride=4,
|
| 239 |
+
context=1,
|
| 240 |
+
# Activations
|
| 241 |
+
gelu=True,
|
| 242 |
+
glu=True,
|
| 243 |
+
# Normalization
|
| 244 |
+
norm_starts=4,
|
| 245 |
+
norm_groups=4,
|
| 246 |
+
# DConv residual branch
|
| 247 |
+
dconv_mode=1,
|
| 248 |
+
dconv_depth=2,
|
| 249 |
+
dconv_comp=4,
|
| 250 |
+
dconv_attn=4,
|
| 251 |
+
dconv_lstm=4,
|
| 252 |
+
dconv_init=1e-4,
|
| 253 |
+
# Pre/post processing
|
| 254 |
+
normalize=True,
|
| 255 |
+
resample=True,
|
| 256 |
+
# Weight init
|
| 257 |
+
rescale=0.1,
|
| 258 |
+
# Metadata
|
| 259 |
+
samplerate=44100,
|
| 260 |
+
segment=4 * 10):
|
| 261 |
+
"""
|
| 262 |
+
Args:
|
| 263 |
+
sources (list[str]): list of source names
|
| 264 |
+
audio_channels (int): stereo or mono
|
| 265 |
+
channels (int): first convolution channels
|
| 266 |
+
depth (int): number of encoder/decoder layers
|
| 267 |
+
growth (float): multiply (resp divide) number of channels by that
|
| 268 |
+
for each layer of the encoder (resp decoder)
|
| 269 |
+
depth (int): number of layers in the encoder and in the decoder.
|
| 270 |
+
rewrite (bool): add 1x1 convolution to each layer.
|
| 271 |
+
lstm_layers (int): number of lstm layers, 0 = no lstm. Deactivated
|
| 272 |
+
by default, as this is now replaced by the smaller and faster small LSTMs
|
| 273 |
+
in the DConv branches.
|
| 274 |
+
kernel_size (int): kernel size for convolutions
|
| 275 |
+
stride (int): stride for convolutions
|
| 276 |
+
context (int): kernel size of the convolution in the
|
| 277 |
+
decoder before the transposed convolution. If > 1,
|
| 278 |
+
will provide some context from neighboring time steps.
|
| 279 |
+
gelu: use GELU activation function.
|
| 280 |
+
glu (bool): use glu instead of ReLU for the 1x1 rewrite conv.
|
| 281 |
+
norm_starts: layer at which group norm starts being used.
|
| 282 |
+
decoder layers are numbered in reverse order.
|
| 283 |
+
norm_groups: number of groups for group norm.
|
| 284 |
+
dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
|
| 285 |
+
dconv_depth: depth of residual DConv branch.
|
| 286 |
+
dconv_comp: compression of DConv branch.
|
| 287 |
+
dconv_attn: adds attention layers in DConv branch starting at this layer.
|
| 288 |
+
dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
|
| 289 |
+
dconv_init: initial scale for the DConv branch LayerScale.
|
| 290 |
+
normalize (bool): normalizes the input audio on the fly, and scales back
|
| 291 |
+
the output by the same amount.
|
| 292 |
+
resample (bool): upsample x2 the input and downsample /2 the output.
|
| 293 |
+
rescale (float): rescale initial weights of convolutions
|
| 294 |
+
to get their standard deviation closer to `rescale`.
|
| 295 |
+
samplerate (int): stored as meta information for easing
|
| 296 |
+
future evaluations of the model.
|
| 297 |
+
segment (float): duration of the chunks of audio to ideally evaluate the model on.
|
| 298 |
+
This is used by `demucs.apply.apply_model`.
|
| 299 |
+
"""
|
| 300 |
+
|
| 301 |
+
super().__init__()
|
| 302 |
+
self.audio_channels = audio_channels
|
| 303 |
+
self.sources = sources
|
| 304 |
+
self.kernel_size = kernel_size
|
| 305 |
+
self.context = context
|
| 306 |
+
self.stride = stride
|
| 307 |
+
self.depth = depth
|
| 308 |
+
self.resample = resample
|
| 309 |
+
self.channels = channels
|
| 310 |
+
self.normalize = normalize
|
| 311 |
+
self.samplerate = samplerate
|
| 312 |
+
self.segment = segment
|
| 313 |
+
self.encoder = nn.ModuleList()
|
| 314 |
+
self.decoder = nn.ModuleList()
|
| 315 |
+
self.skip_scales = nn.ModuleList()
|
| 316 |
+
|
| 317 |
+
if glu:
|
| 318 |
+
activation = nn.GLU(dim=1)
|
| 319 |
+
ch_scale = 2
|
| 320 |
+
else:
|
| 321 |
+
activation = nn.ReLU()
|
| 322 |
+
ch_scale = 1
|
| 323 |
+
if gelu:
|
| 324 |
+
act2 = nn.GELU
|
| 325 |
+
else:
|
| 326 |
+
act2 = nn.ReLU
|
| 327 |
+
|
| 328 |
+
in_channels = audio_channels
|
| 329 |
+
padding = 0
|
| 330 |
+
for index in range(depth):
|
| 331 |
+
norm_fn = lambda d: nn.Identity() # noqa
|
| 332 |
+
if index >= norm_starts:
|
| 333 |
+
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
|
| 334 |
+
|
| 335 |
+
encode = []
|
| 336 |
+
encode += [
|
| 337 |
+
nn.Conv1d(in_channels, channels, kernel_size, stride),
|
| 338 |
+
norm_fn(channels),
|
| 339 |
+
act2(),
|
| 340 |
+
]
|
| 341 |
+
attn = index >= dconv_attn
|
| 342 |
+
lstm = index >= dconv_lstm
|
| 343 |
+
if dconv_mode & 1:
|
| 344 |
+
encode += [DConv(channels, depth=dconv_depth, init=dconv_init,
|
| 345 |
+
compress=dconv_comp, attn=attn, lstm=lstm)]
|
| 346 |
+
if rewrite:
|
| 347 |
+
encode += [
|
| 348 |
+
nn.Conv1d(channels, ch_scale * channels, 1),
|
| 349 |
+
norm_fn(ch_scale * channels), activation]
|
| 350 |
+
self.encoder.append(nn.Sequential(*encode))
|
| 351 |
+
|
| 352 |
+
decode = []
|
| 353 |
+
if index > 0:
|
| 354 |
+
out_channels = in_channels
|
| 355 |
+
else:
|
| 356 |
+
out_channels = len(self.sources) * audio_channels
|
| 357 |
+
if rewrite:
|
| 358 |
+
decode += [
|
| 359 |
+
nn.Conv1d(channels, ch_scale * channels, 2 * context + 1, padding=context),
|
| 360 |
+
norm_fn(ch_scale * channels), activation]
|
| 361 |
+
if dconv_mode & 2:
|
| 362 |
+
decode += [DConv(channels, depth=dconv_depth, init=dconv_init,
|
| 363 |
+
compress=dconv_comp, attn=attn, lstm=lstm)]
|
| 364 |
+
decode += [nn.ConvTranspose1d(channels, out_channels,
|
| 365 |
+
kernel_size, stride, padding=padding)]
|
| 366 |
+
if index > 0:
|
| 367 |
+
decode += [norm_fn(out_channels), act2()]
|
| 368 |
+
self.decoder.insert(0, nn.Sequential(*decode))
|
| 369 |
+
in_channels = channels
|
| 370 |
+
channels = int(growth * channels)
|
| 371 |
+
|
| 372 |
+
channels = in_channels
|
| 373 |
+
if lstm_layers:
|
| 374 |
+
self.lstm = BLSTM(channels, lstm_layers)
|
| 375 |
+
else:
|
| 376 |
+
self.lstm = None
|
| 377 |
+
|
| 378 |
+
if rescale:
|
| 379 |
+
rescale_module(self, reference=rescale)
|
| 380 |
+
|
| 381 |
+
def valid_length(self, length):
|
| 382 |
+
"""
|
| 383 |
+
Return the nearest valid length to use with the model so that
|
| 384 |
+
there is no time steps left over in a convolution, e.g. for all
|
| 385 |
+
layers, size of the input - kernel_size % stride = 0.
|
| 386 |
+
|
| 387 |
+
Note that input are automatically padded if necessary to ensure that the output
|
| 388 |
+
has the same length as the input.
|
| 389 |
+
"""
|
| 390 |
+
if self.resample:
|
| 391 |
+
length *= 2
|
| 392 |
+
|
| 393 |
+
for _ in range(self.depth):
|
| 394 |
+
length = math.ceil((length - self.kernel_size) / self.stride) + 1
|
| 395 |
+
length = max(1, length)
|
| 396 |
+
|
| 397 |
+
for _ in range(self.depth):
|
| 398 |
+
length = (length - 1) * self.stride + self.kernel_size
|
| 399 |
+
|
| 400 |
+
if self.resample:
|
| 401 |
+
length = math.ceil(length / 2)
|
| 402 |
+
return int(length)
|
| 403 |
+
|
| 404 |
+
def forward(self, mix):
|
| 405 |
+
x = mix
|
| 406 |
+
length = x.shape[-1]
|
| 407 |
+
|
| 408 |
+
if self.normalize:
|
| 409 |
+
mono = mix.mean(dim=1, keepdim=True)
|
| 410 |
+
mean = mono.mean(dim=-1, keepdim=True)
|
| 411 |
+
std = mono.std(dim=-1, keepdim=True)
|
| 412 |
+
x = (x - mean) / (1e-5 + std)
|
| 413 |
+
else:
|
| 414 |
+
mean = 0
|
| 415 |
+
std = 1
|
| 416 |
+
|
| 417 |
+
delta = self.valid_length(length) - length
|
| 418 |
+
x = F.pad(x, (delta // 2, delta - delta // 2))
|
| 419 |
+
|
| 420 |
+
if self.resample:
|
| 421 |
+
x = julius.resample_frac(x, 1, 2)
|
| 422 |
+
|
| 423 |
+
saved = []
|
| 424 |
+
for encode in self.encoder:
|
| 425 |
+
x = encode(x)
|
| 426 |
+
saved.append(x)
|
| 427 |
+
|
| 428 |
+
if self.lstm:
|
| 429 |
+
x = self.lstm(x)
|
| 430 |
+
|
| 431 |
+
for decode in self.decoder:
|
| 432 |
+
skip = saved.pop(-1)
|
| 433 |
+
skip = center_trim(skip, x)
|
| 434 |
+
x = decode(x + skip)
|
| 435 |
+
|
| 436 |
+
if self.resample:
|
| 437 |
+
x = julius.resample_frac(x, 2, 1)
|
| 438 |
+
x = x * std + mean
|
| 439 |
+
x = center_trim(x, length)
|
| 440 |
+
x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
|
| 441 |
+
return x
|
| 442 |
+
|
| 443 |
+
def load_state_dict(self, state, strict=True):
|
| 444 |
+
# fix a mismatch with previous generation Demucs models.
|
| 445 |
+
for idx in range(self.depth):
|
| 446 |
+
for a in ['encoder', 'decoder']:
|
| 447 |
+
for b in ['bias', 'weight']:
|
| 448 |
+
new = f'{a}.{idx}.3.{b}'
|
| 449 |
+
old = f'{a}.{idx}.2.{b}'
|
| 450 |
+
if old in state and new not in state:
|
| 451 |
+
state[new] = state.pop(old)
|
| 452 |
+
super().load_state_dict(state, strict=strict)
|
third_party/demucs/models/htdemucs.py
ADDED
|
@@ -0,0 +1,955 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
@File : htdemucs.py
|
| 5 |
+
@Time : 2023/8/8 下午4:27
|
| 6 |
+
@Author : waytan
|
| 7 |
+
@Contact : waytan@tencent.com
|
| 8 |
+
@License : (C)Copyright 2023, Tencent
|
| 9 |
+
@Desc : The spectrogram and Hybrid version of Demucs
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import math
|
| 13 |
+
import typing as tp
|
| 14 |
+
from copy import deepcopy
|
| 15 |
+
from fractions import Fraction
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
from torch import nn
|
| 19 |
+
from torch.nn import functional as F
|
| 20 |
+
from einops import rearrange
|
| 21 |
+
from openunmix.filtering import wiener
|
| 22 |
+
|
| 23 |
+
from .transformer import CrossTransformerEncoder
|
| 24 |
+
from .demucs import DConv, rescale_module
|
| 25 |
+
from .states import capture_init
|
| 26 |
+
from .spec import spectro, ispectro
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
|
| 30 |
+
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
| 31 |
+
If this is the case, we insert extra 0 padding to the right before the reflection happen."""
|
| 32 |
+
x0 = x
|
| 33 |
+
length = x.shape[-1]
|
| 34 |
+
padding_left, padding_right = paddings
|
| 35 |
+
if mode == 'reflect':
|
| 36 |
+
max_pad = max(padding_left, padding_right)
|
| 37 |
+
if length <= max_pad:
|
| 38 |
+
extra_pad = max_pad - length + 1
|
| 39 |
+
extra_pad_right = min(padding_right, extra_pad)
|
| 40 |
+
extra_pad_left = extra_pad - extra_pad_right
|
| 41 |
+
paddings = (padding_left - extra_pad_left, padding_right - extra_pad_right)
|
| 42 |
+
x = F.pad(x, (extra_pad_left, extra_pad_right))
|
| 43 |
+
out = F.pad(x, paddings, mode, value)
|
| 44 |
+
assert out.shape[-1] == length + padding_left + padding_right
|
| 45 |
+
assert (out[..., padding_left: padding_left + length] == x0).all()
|
| 46 |
+
return out
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class ScaledEmbedding(nn.Module):
|
| 50 |
+
"""
|
| 51 |
+
Boost learning rate for embeddings (with `scale`).
|
| 52 |
+
Also, can make embeddings continuous with `smooth`.
|
| 53 |
+
"""
|
| 54 |
+
def __init__(self, num_embeddings: int, embedding_dim: int,
|
| 55 |
+
scale: float = 10., smooth=False):
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
|
| 58 |
+
if smooth:
|
| 59 |
+
weight = torch.cumsum(self.embedding.weight.data, dim=0)
|
| 60 |
+
# when summing gaussian, overscale raises as sqrt(n), so we nornalize by that.
|
| 61 |
+
weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None]
|
| 62 |
+
self.embedding.weight.data[:] = weight
|
| 63 |
+
self.embedding.weight.data /= scale
|
| 64 |
+
self.scale = scale
|
| 65 |
+
|
| 66 |
+
@property
|
| 67 |
+
def weight(self):
|
| 68 |
+
return self.embedding.weight * self.scale
|
| 69 |
+
|
| 70 |
+
def forward(self, x):
|
| 71 |
+
out = self.embedding(x) * self.scale
|
| 72 |
+
return out
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class HEncLayer(nn.Module):
|
| 76 |
+
def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False,
|
| 77 |
+
freq=True, dconv=True, norm=True, context=0, dconv_kw=None, pad=True,
|
| 78 |
+
rewrite=True):
|
| 79 |
+
"""Encoder layer. This used both by the time and the frequency branch.
|
| 80 |
+
"""
|
| 81 |
+
super().__init__()
|
| 82 |
+
norm_fn = lambda d: nn.Identity() # noqa
|
| 83 |
+
if norm:
|
| 84 |
+
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
|
| 85 |
+
if pad:
|
| 86 |
+
pad = kernel_size // 4
|
| 87 |
+
else:
|
| 88 |
+
pad = 0
|
| 89 |
+
klass = nn.Conv1d
|
| 90 |
+
self.freq = freq
|
| 91 |
+
self.kernel_size = kernel_size
|
| 92 |
+
self.stride = stride
|
| 93 |
+
self.empty = empty
|
| 94 |
+
self.norm = norm
|
| 95 |
+
self.pad = pad
|
| 96 |
+
if freq:
|
| 97 |
+
kernel_size = [kernel_size, 1]
|
| 98 |
+
stride = [stride, 1]
|
| 99 |
+
pad = [pad, 0]
|
| 100 |
+
klass = nn.Conv2d
|
| 101 |
+
self.conv = klass(chin, chout, kernel_size, stride, pad)
|
| 102 |
+
if self.empty:
|
| 103 |
+
return
|
| 104 |
+
self.norm1 = norm_fn(chout)
|
| 105 |
+
self.rewrite = None
|
| 106 |
+
if rewrite:
|
| 107 |
+
self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
|
| 108 |
+
self.norm2 = norm_fn(2 * chout)
|
| 109 |
+
|
| 110 |
+
self.dconv = None
|
| 111 |
+
if dconv:
|
| 112 |
+
self.dconv = DConv(chout, **dconv_kw)
|
| 113 |
+
|
| 114 |
+
def forward(self, x, inject=None):
|
| 115 |
+
"""
|
| 116 |
+
`inject` is used to inject the result from the time branch into the frequency branch,
|
| 117 |
+
when both have the same stride.
|
| 118 |
+
"""
|
| 119 |
+
if not self.freq and x.dim() == 4:
|
| 120 |
+
b, c, fr, t = x.shape
|
| 121 |
+
x = x.view(b, -1, t)
|
| 122 |
+
|
| 123 |
+
if not self.freq:
|
| 124 |
+
le = x.shape[-1]
|
| 125 |
+
if not le % self.stride == 0:
|
| 126 |
+
x = F.pad(x, (0, self.stride - (le % self.stride)))
|
| 127 |
+
y = self.conv(x)
|
| 128 |
+
if self.empty:
|
| 129 |
+
return y
|
| 130 |
+
if inject is not None:
|
| 131 |
+
assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape)
|
| 132 |
+
if inject.dim() == 3 and y.dim() == 4:
|
| 133 |
+
inject = inject[:, :, None]
|
| 134 |
+
y = y + inject
|
| 135 |
+
y = F.gelu(self.norm1(y))
|
| 136 |
+
if self.dconv:
|
| 137 |
+
if self.freq:
|
| 138 |
+
b, c, fr, t = y.shape
|
| 139 |
+
y = y.permute(0, 2, 1, 3).reshape(-1, c, t)
|
| 140 |
+
y = self.dconv(y)
|
| 141 |
+
if self.freq:
|
| 142 |
+
y = y.view(b, fr, c, t).permute(0, 2, 1, 3)
|
| 143 |
+
if self.rewrite:
|
| 144 |
+
z = self.norm2(self.rewrite(y))
|
| 145 |
+
z = F.glu(z, dim=1)
|
| 146 |
+
else:
|
| 147 |
+
z = y
|
| 148 |
+
return z
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class MultiWrap(nn.Module):
|
| 152 |
+
"""
|
| 153 |
+
Takes one layer and replicate it N times. each replica will act
|
| 154 |
+
on a frequency band. All is done so that if the N replica have the same weights,
|
| 155 |
+
then this is exactly equivalent to applying the original module on all frequencies.
|
| 156 |
+
"""
|
| 157 |
+
def __init__(self, layer, split_ratios):
|
| 158 |
+
super().__init__()
|
| 159 |
+
self.split_ratios = split_ratios
|
| 160 |
+
self.layers = nn.ModuleList()
|
| 161 |
+
self.conv = isinstance(layer, HEncLayer)
|
| 162 |
+
assert not layer.norm
|
| 163 |
+
assert layer.freq
|
| 164 |
+
assert layer.pad
|
| 165 |
+
if not self.conv:
|
| 166 |
+
assert not layer.context_freq
|
| 167 |
+
for _ in range(len(split_ratios) + 1):
|
| 168 |
+
lay = deepcopy(layer)
|
| 169 |
+
if self.conv:
|
| 170 |
+
lay.conv.padding = (0, 0)
|
| 171 |
+
else:
|
| 172 |
+
lay.pad = False
|
| 173 |
+
for m in lay.modules():
|
| 174 |
+
if hasattr(m, 'reset_parameters'):
|
| 175 |
+
m.reset_parameters()
|
| 176 |
+
self.layers.append(lay)
|
| 177 |
+
|
| 178 |
+
def forward(self, x, skip=None, length=None):
|
| 179 |
+
_, _, fr, _ = x.shape
|
| 180 |
+
|
| 181 |
+
ratios = list(self.split_ratios) + [1]
|
| 182 |
+
start = 0
|
| 183 |
+
outs = []
|
| 184 |
+
for ratio, layer in zip(ratios, self.layers):
|
| 185 |
+
if self.conv:
|
| 186 |
+
pad = layer.kernel_size // 4
|
| 187 |
+
if ratio == 1:
|
| 188 |
+
limit = fr
|
| 189 |
+
frames = -1
|
| 190 |
+
else:
|
| 191 |
+
limit = int(round(fr * ratio))
|
| 192 |
+
le = limit - start
|
| 193 |
+
if start == 0:
|
| 194 |
+
le += pad
|
| 195 |
+
frames = round((le - layer.kernel_size) / layer.stride + 1)
|
| 196 |
+
limit = start + (frames - 1) * layer.stride + layer.kernel_size
|
| 197 |
+
if start == 0:
|
| 198 |
+
limit -= pad
|
| 199 |
+
assert limit - start > 0, (limit, start)
|
| 200 |
+
assert limit <= fr, (limit, fr)
|
| 201 |
+
y = x[:, :, start:limit, :]
|
| 202 |
+
if start == 0:
|
| 203 |
+
y = F.pad(y, (0, 0, pad, 0))
|
| 204 |
+
if ratio == 1:
|
| 205 |
+
y = F.pad(y, (0, 0, 0, pad))
|
| 206 |
+
outs.append(layer(y))
|
| 207 |
+
start = limit - layer.kernel_size + layer.stride
|
| 208 |
+
else:
|
| 209 |
+
if ratio == 1:
|
| 210 |
+
limit = fr
|
| 211 |
+
else:
|
| 212 |
+
limit = int(round(fr * ratio))
|
| 213 |
+
last = layer.last
|
| 214 |
+
layer.last = True
|
| 215 |
+
|
| 216 |
+
y = x[:, :, start:limit]
|
| 217 |
+
s = skip[:, :, start:limit]
|
| 218 |
+
out, _ = layer(y, s, None)
|
| 219 |
+
if outs:
|
| 220 |
+
outs[-1][:, :, -layer.stride:] += (
|
| 221 |
+
out[:, :, :layer.stride] - layer.conv_tr.bias.view(1, -1, 1, 1))
|
| 222 |
+
out = out[:, :, layer.stride:]
|
| 223 |
+
if ratio == 1:
|
| 224 |
+
out = out[:, :, :-layer.stride // 2, :]
|
| 225 |
+
if start == 0:
|
| 226 |
+
out = out[:, :, layer.stride // 2:, :]
|
| 227 |
+
outs.append(out)
|
| 228 |
+
layer.last = last
|
| 229 |
+
start = limit
|
| 230 |
+
out = torch.cat(outs, dim=2)
|
| 231 |
+
if not self.conv and not last:
|
| 232 |
+
out = F.gelu(out)
|
| 233 |
+
if self.conv:
|
| 234 |
+
return out
|
| 235 |
+
else:
|
| 236 |
+
return out, None
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class HDecLayer(nn.Module):
|
| 240 |
+
def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False,
|
| 241 |
+
freq=True, dconv=True, norm=True, context=1, dconv_kw=None, pad=True,
|
| 242 |
+
context_freq=True, rewrite=True):
|
| 243 |
+
"""
|
| 244 |
+
Same as HEncLayer but for decoder. See `HEncLayer` for documentation.
|
| 245 |
+
"""
|
| 246 |
+
super().__init__()
|
| 247 |
+
norm_fn = lambda d: nn.Identity() # noqa
|
| 248 |
+
if norm:
|
| 249 |
+
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
|
| 250 |
+
if pad:
|
| 251 |
+
pad = kernel_size // 4
|
| 252 |
+
else:
|
| 253 |
+
pad = 0
|
| 254 |
+
self.pad = pad
|
| 255 |
+
self.last = last
|
| 256 |
+
self.freq = freq
|
| 257 |
+
self.chin = chin
|
| 258 |
+
self.empty = empty
|
| 259 |
+
self.stride = stride
|
| 260 |
+
self.kernel_size = kernel_size
|
| 261 |
+
self.norm = norm
|
| 262 |
+
self.context_freq = context_freq
|
| 263 |
+
klass = nn.Conv1d
|
| 264 |
+
klass_tr = nn.ConvTranspose1d
|
| 265 |
+
if freq:
|
| 266 |
+
kernel_size = [kernel_size, 1]
|
| 267 |
+
stride = [stride, 1]
|
| 268 |
+
klass = nn.Conv2d
|
| 269 |
+
klass_tr = nn.ConvTranspose2d
|
| 270 |
+
self.conv_tr = klass_tr(chin, chout, kernel_size, stride)
|
| 271 |
+
self.norm2 = norm_fn(chout)
|
| 272 |
+
if self.empty:
|
| 273 |
+
return
|
| 274 |
+
self.rewrite = None
|
| 275 |
+
if rewrite:
|
| 276 |
+
if context_freq:
|
| 277 |
+
self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)
|
| 278 |
+
else:
|
| 279 |
+
self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1,
|
| 280 |
+
[0, context])
|
| 281 |
+
self.norm1 = norm_fn(2 * chin)
|
| 282 |
+
|
| 283 |
+
self.dconv = None
|
| 284 |
+
if dconv:
|
| 285 |
+
self.dconv = DConv(chin, **dconv_kw)
|
| 286 |
+
|
| 287 |
+
def forward(self, x, skip, length):
|
| 288 |
+
if self.freq and x.dim() == 3:
|
| 289 |
+
b, c, t = x.shape
|
| 290 |
+
x = x.view(b, self.chin, -1, t)
|
| 291 |
+
|
| 292 |
+
if not self.empty:
|
| 293 |
+
x = x + skip
|
| 294 |
+
|
| 295 |
+
if self.rewrite:
|
| 296 |
+
y = F.glu(self.norm1(self.rewrite(x)), dim=1)
|
| 297 |
+
else:
|
| 298 |
+
y = x
|
| 299 |
+
if self.dconv:
|
| 300 |
+
if self.freq:
|
| 301 |
+
b, c, fr, t = y.shape
|
| 302 |
+
y = y.permute(0, 2, 1, 3).reshape(-1, c, t)
|
| 303 |
+
y = self.dconv(y)
|
| 304 |
+
if self.freq:
|
| 305 |
+
y = y.view(b, fr, c, t).permute(0, 2, 1, 3)
|
| 306 |
+
else:
|
| 307 |
+
y = x
|
| 308 |
+
assert skip is None
|
| 309 |
+
z = self.norm2(self.conv_tr(y))
|
| 310 |
+
if self.freq:
|
| 311 |
+
if self.pad:
|
| 312 |
+
z = z[..., self.pad:-self.pad, :]
|
| 313 |
+
else:
|
| 314 |
+
z = z[..., self.pad:self.pad + length]
|
| 315 |
+
assert z.shape[-1] == length, (z.shape[-1], length)
|
| 316 |
+
if not self.last:
|
| 317 |
+
z = F.gelu(z)
|
| 318 |
+
return z, y
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
class HTDemucs(nn.Module):
|
| 322 |
+
"""
|
| 323 |
+
Spectrogram and hybrid Demucs model.
|
| 324 |
+
The spectrogram model has the same structure as Demucs, except the first few layers are over the
|
| 325 |
+
frequency axis, until there is only 1 frequency, and then it moves to time convolutions.
|
| 326 |
+
Frequency layers can still access information across time steps thanks to the DConv residual.
|
| 327 |
+
|
| 328 |
+
Hybrid model have a parallel time branch. At some layer, the time branch has the same stride
|
| 329 |
+
as the frequency branch and then the two are combined. The opposite happens in the decoder.
|
| 330 |
+
|
| 331 |
+
Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]),
|
| 332 |
+
or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on
|
| 333 |
+
Open Unmix implementation [Stoter et al. 2019].
|
| 334 |
+
|
| 335 |
+
The loss is always on the temporal domain, by backpropagating through the above
|
| 336 |
+
output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks
|
| 337 |
+
a bit Wiener filtering, as doing more iteration at test time will change the spectrogram
|
| 338 |
+
contribution, without changing the one from the waveform, which will lead to worse performance.
|
| 339 |
+
I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve.
|
| 340 |
+
CaC on the other hand provides similar performance for hybrid, and works naturally with
|
| 341 |
+
hybrid models.
|
| 342 |
+
|
| 343 |
+
This model also uses frequency embeddings are used to improve efficiency on convolutions
|
| 344 |
+
over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf).
|
| 345 |
+
|
| 346 |
+
Unlike classic Demucs, there is no resampling here, and normalization is always applied.
|
| 347 |
+
"""
|
| 348 |
+
|
| 349 |
+
@capture_init
|
| 350 |
+
def __init__(
|
| 351 |
+
self,
|
| 352 |
+
sources,
|
| 353 |
+
# Channels
|
| 354 |
+
audio_channels=2,
|
| 355 |
+
channels=48,
|
| 356 |
+
channels_time=None,
|
| 357 |
+
growth=2,
|
| 358 |
+
# STFT
|
| 359 |
+
nfft=4096,
|
| 360 |
+
wiener_iters=0,
|
| 361 |
+
end_iters=0,
|
| 362 |
+
wiener_residual=False,
|
| 363 |
+
cac=True,
|
| 364 |
+
# Main structure
|
| 365 |
+
depth=4,
|
| 366 |
+
rewrite=True,
|
| 367 |
+
# Frequency branch
|
| 368 |
+
multi_freqs=None,
|
| 369 |
+
multi_freqs_depth=3,
|
| 370 |
+
freq_emb=0.2,
|
| 371 |
+
emb_scale=10,
|
| 372 |
+
emb_smooth=True,
|
| 373 |
+
# Convolutions
|
| 374 |
+
kernel_size=8,
|
| 375 |
+
time_stride=2,
|
| 376 |
+
stride=4,
|
| 377 |
+
context=1,
|
| 378 |
+
context_enc=0,
|
| 379 |
+
# Normalization
|
| 380 |
+
norm_starts=4,
|
| 381 |
+
norm_groups=4,
|
| 382 |
+
# DConv residual branch
|
| 383 |
+
dconv_mode=1,
|
| 384 |
+
dconv_depth=2,
|
| 385 |
+
dconv_comp=8,
|
| 386 |
+
dconv_init=1e-3,
|
| 387 |
+
# Before the Transformer
|
| 388 |
+
bottom_channels=0,
|
| 389 |
+
# Transformer
|
| 390 |
+
t_layers=5,
|
| 391 |
+
t_emb="sin",
|
| 392 |
+
t_hidden_scale=4.0,
|
| 393 |
+
t_heads=8,
|
| 394 |
+
t_dropout=0.0,
|
| 395 |
+
t_max_positions=10000,
|
| 396 |
+
t_norm_in=True,
|
| 397 |
+
t_norm_in_group=False,
|
| 398 |
+
t_group_norm=False,
|
| 399 |
+
t_norm_first=True,
|
| 400 |
+
t_norm_out=True,
|
| 401 |
+
t_max_period=10000.0,
|
| 402 |
+
t_weight_decay=0.0,
|
| 403 |
+
t_lr=None,
|
| 404 |
+
t_layer_scale=True,
|
| 405 |
+
t_gelu=True,
|
| 406 |
+
t_weight_pos_embed=1.0,
|
| 407 |
+
t_sin_random_shift=0,
|
| 408 |
+
t_cape_mean_normalize=True,
|
| 409 |
+
t_cape_augment=True,
|
| 410 |
+
t_cape_glob_loc_scale=None,
|
| 411 |
+
t_sparse_self_attn=False,
|
| 412 |
+
t_sparse_cross_attn=False,
|
| 413 |
+
t_mask_type="diag",
|
| 414 |
+
t_mask_random_seed=42,
|
| 415 |
+
t_sparse_attn_window=500,
|
| 416 |
+
t_global_window=100,
|
| 417 |
+
t_sparsity=0.95,
|
| 418 |
+
t_auto_sparsity=False,
|
| 419 |
+
# ------ Particuliar parameters
|
| 420 |
+
t_cross_first=False,
|
| 421 |
+
# Weight init
|
| 422 |
+
rescale=0.1,
|
| 423 |
+
# Metadata
|
| 424 |
+
samplerate=44100,
|
| 425 |
+
segment=10,
|
| 426 |
+
use_train_segment=True,
|
| 427 |
+
):
|
| 428 |
+
"""
|
| 429 |
+
Args:
|
| 430 |
+
sources (list[str]): list of source names.
|
| 431 |
+
audio_channels (int): input/output audio channels.
|
| 432 |
+
channels (int): initial number of hidden channels.
|
| 433 |
+
channels_time: if not None, use a different `channels` value for the time branch.
|
| 434 |
+
growth: increase the number of hidden channels by this factor at each layer.
|
| 435 |
+
nfft: number of fft bins. Note that changing this require careful computation of
|
| 436 |
+
various shape parameters and will not work out of the box for hybrid models.
|
| 437 |
+
wiener_iters: when using Wiener filtering, number of iterations at test time.
|
| 438 |
+
end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
|
| 439 |
+
wiener_residual: add residual source before wiener filtering.
|
| 440 |
+
cac: uses complex as channels, i.e. complex numbers are 2 channels each
|
| 441 |
+
in input and output. no further processing is done before ISTFT.
|
| 442 |
+
depth (int): number of layers in the encoder and in the decoder.
|
| 443 |
+
rewrite (bool): add 1x1 convolution to each layer.
|
| 444 |
+
multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`.
|
| 445 |
+
multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost
|
| 446 |
+
layers will be wrapped.
|
| 447 |
+
freq_emb: add frequency embedding after the first frequency layer if > 0,
|
| 448 |
+
the actual value controls the weight of the embedding.
|
| 449 |
+
emb_scale: equivalent to scaling the embedding learning rate
|
| 450 |
+
emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
|
| 451 |
+
kernel_size: kernel_size for encoder and decoder layers.
|
| 452 |
+
stride: stride for encoder and decoder layers.
|
| 453 |
+
time_stride: stride for the final time layer, after the merge.
|
| 454 |
+
context: context for 1x1 conv in the decoder.
|
| 455 |
+
context_enc: context for 1x1 conv in the encoder.
|
| 456 |
+
norm_starts: layer at which group norm starts being used.
|
| 457 |
+
decoder layers are numbered in reverse order.
|
| 458 |
+
norm_groups: number of groups for group norm.
|
| 459 |
+
dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
|
| 460 |
+
dconv_depth: depth of residual DConv branch.
|
| 461 |
+
dconv_comp: compression of DConv branch.
|
| 462 |
+
dconv_attn: adds attention layers in DConv branch starting at this layer.
|
| 463 |
+
dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
|
| 464 |
+
dconv_init: initial scale for the DConv branch LayerScale.
|
| 465 |
+
bottom_channels: if >0 it adds a linear layer (1x1 Conv) before and after the
|
| 466 |
+
transformer in order to change the number of channels
|
| 467 |
+
t_layers: number of layers in each branch (waveform and spec) of the transformer
|
| 468 |
+
t_emb: "sin", "cape" or "scaled"
|
| 469 |
+
t_hidden_scale: the hidden scale of the Feedforward parts of the transformer
|
| 470 |
+
for instance if C = 384 (the number of channels in the transformer) and
|
| 471 |
+
t_hidden_scale = 4.0 then the intermediate layer of the FFN has dimension
|
| 472 |
+
384 * 4 = 1536
|
| 473 |
+
t_heads: number of heads for the transformer
|
| 474 |
+
t_dropout: dropout in the transformer
|
| 475 |
+
t_max_positions: max_positions for the "scaled" positional embedding, only
|
| 476 |
+
useful if t_emb="scaled"
|
| 477 |
+
t_norm_in: (bool) norm before addinf positional embedding and getting into the
|
| 478 |
+
transformer layers
|
| 479 |
+
t_norm_in_group: (bool) if True while t_norm_in=True, the norm is on all the
|
| 480 |
+
timesteps (GroupNorm with group=1)
|
| 481 |
+
t_group_norm: (bool) if True, the norms of the Encoder Layers are on all the
|
| 482 |
+
timesteps (GroupNorm with group=1)
|
| 483 |
+
t_norm_first: (bool) if True the norm is before the attention and before the FFN
|
| 484 |
+
t_norm_out: (bool) if True, there is a GroupNorm (group=1) at the end of each layer
|
| 485 |
+
t_max_period: (float) denominator in the sinusoidal embedding expression
|
| 486 |
+
t_weight_decay: (float) weight decay for the transformer
|
| 487 |
+
t_lr: (float) specific learning rate for the transformer
|
| 488 |
+
t_layer_scale: (bool) Layer Scale for the transformer
|
| 489 |
+
t_gelu: (bool) activations of the transformer are GeLU if True, ReLU else
|
| 490 |
+
t_weight_pos_embed: (float) weighting of the positional embedding
|
| 491 |
+
t_cape_mean_normalize: (bool) if t_emb="cape", normalisation of positional embeddings
|
| 492 |
+
see: https://arxiv.org/abs/2106.03143
|
| 493 |
+
t_cape_augment: (bool) if t_emb="cape", must be True during training and False
|
| 494 |
+
during the inference, see: https://arxiv.org/abs/2106.03143
|
| 495 |
+
t_cape_glob_loc_scale: (list of 3 floats) if t_emb="cape", CAPE parameters
|
| 496 |
+
see: https://arxiv.org/abs/2106.03143
|
| 497 |
+
t_sparse_self_attn: (bool) if True, the self attentions are sparse
|
| 498 |
+
t_sparse_cross_attn: (bool) if True, the cross-attentions are sparse (don't use it
|
| 499 |
+
unless you designed really specific masks)
|
| 500 |
+
t_mask_type: (str) can be "diag", "jmask", "random", "global" or any combination
|
| 501 |
+
with '_' between: i.e. "diag_jmask_random" (note that this is permutation
|
| 502 |
+
invariant i.e. "diag_jmask_random" is equivalent to "jmask_random_diag")
|
| 503 |
+
t_mask_random_seed: (int) if "random" is in t_mask_type, controls the seed
|
| 504 |
+
that generated the random part of the mask
|
| 505 |
+
t_sparse_attn_window: (int) if "diag" is in t_mask_type, for a query (i), and
|
| 506 |
+
a key (j), the mask is True id |i-j|<=t_sparse_attn_window
|
| 507 |
+
t_global_window: (int) if "global" is in t_mask_type, mask[:t_global_window, :]
|
| 508 |
+
and mask[:, :t_global_window] will be True
|
| 509 |
+
t_sparsity: (float) if "random" is in t_mask_type, t_sparsity is the sparsity
|
| 510 |
+
level of the random part of the mask.
|
| 511 |
+
t_cross_first: (bool) if True cross attention is the first layer of the
|
| 512 |
+
transformer (False seems to be better)
|
| 513 |
+
rescale: weight rescaling trick
|
| 514 |
+
use_train_segment: (bool) if True, the actual size that is used during the
|
| 515 |
+
training is used during inference.
|
| 516 |
+
"""
|
| 517 |
+
super().__init__()
|
| 518 |
+
self.cac = cac
|
| 519 |
+
self.wiener_residual = wiener_residual
|
| 520 |
+
self.audio_channels = audio_channels
|
| 521 |
+
self.sources = sources
|
| 522 |
+
self.kernel_size = kernel_size
|
| 523 |
+
self.context = context
|
| 524 |
+
self.stride = stride
|
| 525 |
+
self.depth = depth
|
| 526 |
+
self.bottom_channels = bottom_channels
|
| 527 |
+
self.channels = channels
|
| 528 |
+
self.samplerate = samplerate
|
| 529 |
+
self.segment = segment
|
| 530 |
+
self.use_train_segment = use_train_segment
|
| 531 |
+
self.nfft = nfft
|
| 532 |
+
self.hop_length = nfft // 4
|
| 533 |
+
self.wiener_iters = wiener_iters
|
| 534 |
+
self.end_iters = end_iters
|
| 535 |
+
self.freq_emb = None
|
| 536 |
+
assert wiener_iters == end_iters
|
| 537 |
+
|
| 538 |
+
self.encoder = nn.ModuleList()
|
| 539 |
+
self.decoder = nn.ModuleList()
|
| 540 |
+
|
| 541 |
+
self.tencoder = nn.ModuleList()
|
| 542 |
+
self.tdecoder = nn.ModuleList()
|
| 543 |
+
|
| 544 |
+
chin = audio_channels
|
| 545 |
+
chin_z = chin # number of channels for the freq branch
|
| 546 |
+
if self.cac:
|
| 547 |
+
chin_z *= 2
|
| 548 |
+
chout = channels_time or channels
|
| 549 |
+
chout_z = channels
|
| 550 |
+
freqs = nfft // 2
|
| 551 |
+
|
| 552 |
+
for index in range(depth):
|
| 553 |
+
norm = index >= norm_starts
|
| 554 |
+
freq = freqs > 1
|
| 555 |
+
stri = stride
|
| 556 |
+
ker = kernel_size
|
| 557 |
+
if not freq:
|
| 558 |
+
assert freqs == 1
|
| 559 |
+
ker = time_stride * 2
|
| 560 |
+
stri = time_stride
|
| 561 |
+
|
| 562 |
+
pad = True
|
| 563 |
+
last_freq = False
|
| 564 |
+
if freq and freqs <= kernel_size:
|
| 565 |
+
ker = freqs
|
| 566 |
+
pad = False
|
| 567 |
+
last_freq = True
|
| 568 |
+
|
| 569 |
+
kw = {
|
| 570 |
+
"kernel_size": ker,
|
| 571 |
+
"stride": stri,
|
| 572 |
+
"freq": freq,
|
| 573 |
+
"pad": pad,
|
| 574 |
+
"norm": norm,
|
| 575 |
+
"rewrite": rewrite,
|
| 576 |
+
"norm_groups": norm_groups,
|
| 577 |
+
"dconv_kw": {
|
| 578 |
+
"depth": dconv_depth,
|
| 579 |
+
"compress": dconv_comp,
|
| 580 |
+
"init": dconv_init,
|
| 581 |
+
"gelu": True,
|
| 582 |
+
},
|
| 583 |
+
}
|
| 584 |
+
kwt = dict(kw)
|
| 585 |
+
kwt["freq"] = 0
|
| 586 |
+
kwt["kernel_size"] = kernel_size
|
| 587 |
+
kwt["stride"] = stride
|
| 588 |
+
kwt["pad"] = True
|
| 589 |
+
kw_dec = dict(kw)
|
| 590 |
+
multi = False
|
| 591 |
+
if multi_freqs and index < multi_freqs_depth:
|
| 592 |
+
multi = True
|
| 593 |
+
kw_dec["context_freq"] = False
|
| 594 |
+
|
| 595 |
+
if last_freq:
|
| 596 |
+
chout_z = max(chout, chout_z)
|
| 597 |
+
chout = chout_z
|
| 598 |
+
|
| 599 |
+
enc = HEncLayer(
|
| 600 |
+
chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw
|
| 601 |
+
)
|
| 602 |
+
if freq:
|
| 603 |
+
tenc = HEncLayer(
|
| 604 |
+
chin,
|
| 605 |
+
chout,
|
| 606 |
+
dconv=dconv_mode & 1,
|
| 607 |
+
context=context_enc,
|
| 608 |
+
empty=last_freq,
|
| 609 |
+
**kwt
|
| 610 |
+
)
|
| 611 |
+
self.tencoder.append(tenc)
|
| 612 |
+
|
| 613 |
+
if multi:
|
| 614 |
+
enc = MultiWrap(enc, multi_freqs)
|
| 615 |
+
self.encoder.append(enc)
|
| 616 |
+
if index == 0:
|
| 617 |
+
chin = self.audio_channels * len(self.sources)
|
| 618 |
+
chin_z = chin
|
| 619 |
+
if self.cac:
|
| 620 |
+
chin_z *= 2
|
| 621 |
+
dec = HDecLayer(
|
| 622 |
+
chout_z,
|
| 623 |
+
chin_z,
|
| 624 |
+
dconv=dconv_mode & 2,
|
| 625 |
+
last=index == 0,
|
| 626 |
+
context=context,
|
| 627 |
+
**kw_dec
|
| 628 |
+
)
|
| 629 |
+
if multi:
|
| 630 |
+
dec = MultiWrap(dec, multi_freqs)
|
| 631 |
+
if freq:
|
| 632 |
+
tdec = HDecLayer(
|
| 633 |
+
chout,
|
| 634 |
+
chin,
|
| 635 |
+
dconv=dconv_mode & 2,
|
| 636 |
+
empty=last_freq,
|
| 637 |
+
last=index == 0,
|
| 638 |
+
context=context,
|
| 639 |
+
**kwt
|
| 640 |
+
)
|
| 641 |
+
self.tdecoder.insert(0, tdec)
|
| 642 |
+
self.decoder.insert(0, dec)
|
| 643 |
+
|
| 644 |
+
chin = chout
|
| 645 |
+
chin_z = chout_z
|
| 646 |
+
chout = int(growth * chout)
|
| 647 |
+
chout_z = int(growth * chout_z)
|
| 648 |
+
if freq:
|
| 649 |
+
if freqs <= kernel_size:
|
| 650 |
+
freqs = 1
|
| 651 |
+
else:
|
| 652 |
+
freqs //= stride
|
| 653 |
+
if index == 0 and freq_emb:
|
| 654 |
+
self.freq_emb = ScaledEmbedding(
|
| 655 |
+
freqs, chin_z, smooth=emb_smooth, scale=emb_scale
|
| 656 |
+
)
|
| 657 |
+
self.freq_emb_scale = freq_emb
|
| 658 |
+
|
| 659 |
+
if rescale:
|
| 660 |
+
rescale_module(self, reference=rescale)
|
| 661 |
+
|
| 662 |
+
transformer_channels = channels * growth ** (depth - 1)
|
| 663 |
+
if bottom_channels:
|
| 664 |
+
self.channel_upsampler = nn.Conv1d(transformer_channels, bottom_channels, 1)
|
| 665 |
+
self.channel_downsampler = nn.Conv1d(
|
| 666 |
+
bottom_channels, transformer_channels, 1
|
| 667 |
+
)
|
| 668 |
+
self.channel_upsampler_t = nn.Conv1d(
|
| 669 |
+
transformer_channels, bottom_channels, 1
|
| 670 |
+
)
|
| 671 |
+
self.channel_downsampler_t = nn.Conv1d(
|
| 672 |
+
bottom_channels, transformer_channels, 1
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
+
transformer_channels = bottom_channels
|
| 676 |
+
|
| 677 |
+
if t_layers > 0:
|
| 678 |
+
if t_cape_glob_loc_scale is None:
|
| 679 |
+
t_cape_glob_loc_scale = [5000.0, 1.0, 1.4]
|
| 680 |
+
self.crosstransformer = CrossTransformerEncoder(
|
| 681 |
+
dim=transformer_channels,
|
| 682 |
+
emb=t_emb,
|
| 683 |
+
hidden_scale=t_hidden_scale,
|
| 684 |
+
num_heads=t_heads,
|
| 685 |
+
num_layers=t_layers,
|
| 686 |
+
cross_first=t_cross_first,
|
| 687 |
+
dropout=t_dropout,
|
| 688 |
+
max_positions=t_max_positions,
|
| 689 |
+
norm_in=t_norm_in,
|
| 690 |
+
norm_in_group=t_norm_in_group,
|
| 691 |
+
group_norm=t_group_norm,
|
| 692 |
+
norm_first=t_norm_first,
|
| 693 |
+
norm_out=t_norm_out,
|
| 694 |
+
max_period=t_max_period,
|
| 695 |
+
weight_decay=t_weight_decay,
|
| 696 |
+
lr=t_lr,
|
| 697 |
+
layer_scale=t_layer_scale,
|
| 698 |
+
gelu=t_gelu,
|
| 699 |
+
sin_random_shift=t_sin_random_shift,
|
| 700 |
+
weight_pos_embed=t_weight_pos_embed,
|
| 701 |
+
cape_mean_normalize=t_cape_mean_normalize,
|
| 702 |
+
cape_augment=t_cape_augment,
|
| 703 |
+
cape_glob_loc_scale=t_cape_glob_loc_scale,
|
| 704 |
+
sparse_self_attn=t_sparse_self_attn,
|
| 705 |
+
sparse_cross_attn=t_sparse_cross_attn,
|
| 706 |
+
mask_type=t_mask_type,
|
| 707 |
+
mask_random_seed=t_mask_random_seed,
|
| 708 |
+
sparse_attn_window=t_sparse_attn_window,
|
| 709 |
+
global_window=t_global_window,
|
| 710 |
+
sparsity=t_sparsity,
|
| 711 |
+
auto_sparsity=t_auto_sparsity,
|
| 712 |
+
)
|
| 713 |
+
else:
|
| 714 |
+
self.crosstransformer = None
|
| 715 |
+
|
| 716 |
+
def _spec(self, x):
|
| 717 |
+
hl = self.hop_length
|
| 718 |
+
nfft = self.nfft
|
| 719 |
+
|
| 720 |
+
# We re-pad the signal in order to keep the property
|
| 721 |
+
# that the size of the output is exactly the size of the input
|
| 722 |
+
# divided by the stride (here hop_length), when divisible.
|
| 723 |
+
# This is achieved by padding by 1/4th of the kernel size (here nfft).
|
| 724 |
+
# which is not supported by torch.stft.
|
| 725 |
+
# Having all convolution operations follow this convention allow to easily
|
| 726 |
+
# align the time and frequency branches later on.
|
| 727 |
+
assert hl == nfft // 4
|
| 728 |
+
le = int(math.ceil(x.shape[-1] / hl))
|
| 729 |
+
pad = hl // 2 * 3
|
| 730 |
+
x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect")
|
| 731 |
+
|
| 732 |
+
z = spectro(x, nfft, hl)[..., :-1, :]
|
| 733 |
+
assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
|
| 734 |
+
z = z[..., 2: 2 + le]
|
| 735 |
+
return z
|
| 736 |
+
|
| 737 |
+
def _ispec(self, z, length=None, scale=0):
|
| 738 |
+
hl = self.hop_length // (4**scale)
|
| 739 |
+
z = F.pad(z, (0, 0, 0, 1))
|
| 740 |
+
z = F.pad(z, (2, 2))
|
| 741 |
+
pad = hl // 2 * 3
|
| 742 |
+
le = hl * int(math.ceil(length / hl)) + 2 * pad
|
| 743 |
+
x = ispectro(z, hl, length=le)
|
| 744 |
+
x = x[..., pad: pad + length]
|
| 745 |
+
return x
|
| 746 |
+
|
| 747 |
+
def _magnitude(self, z):
|
| 748 |
+
# return the magnitude of the spectrogram, except when cac is True,
|
| 749 |
+
# in which case we just move the complex dimension to the channel one.
|
| 750 |
+
if self.cac:
|
| 751 |
+
b, c, fr, t = z.shape
|
| 752 |
+
m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
|
| 753 |
+
m = m.reshape(b, c * 2, fr, t)
|
| 754 |
+
else:
|
| 755 |
+
m = z.abs()
|
| 756 |
+
return m
|
| 757 |
+
|
| 758 |
+
def _mask(self, z, m):
|
| 759 |
+
# Apply masking given the mixture spectrogram `z` and the estimated mask `m`.
|
| 760 |
+
# If `cac` is True, `m` is actually a full spectrogram and `z` is ignored.
|
| 761 |
+
niters = self.wiener_iters
|
| 762 |
+
if self.cac:
|
| 763 |
+
b, s, _, fr, t = m.shape
|
| 764 |
+
out = m.view(b, s, -1, 2, fr, t).permute(0, 1, 2, 4, 5, 3)
|
| 765 |
+
out = torch.view_as_complex(out.contiguous())
|
| 766 |
+
return out
|
| 767 |
+
if self.training:
|
| 768 |
+
niters = self.end_iters
|
| 769 |
+
if niters < 0:
|
| 770 |
+
z = z[:, None]
|
| 771 |
+
return z / (1e-8 + z.abs()) * m
|
| 772 |
+
else:
|
| 773 |
+
return self._wiener(m, z, niters)
|
| 774 |
+
|
| 775 |
+
def _wiener(self, mag_out, mix_stft, niters):
|
| 776 |
+
# apply wiener filtering from OpenUnmix.
|
| 777 |
+
init = mix_stft.dtype
|
| 778 |
+
wiener_win_len = 300
|
| 779 |
+
residual = self.wiener_residual
|
| 780 |
+
|
| 781 |
+
b, s, c, fq, t = mag_out.shape
|
| 782 |
+
mag_out = mag_out.permute(0, 4, 3, 2, 1)
|
| 783 |
+
mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
|
| 784 |
+
|
| 785 |
+
outs = []
|
| 786 |
+
for sample in range(b):
|
| 787 |
+
pos = 0
|
| 788 |
+
out = []
|
| 789 |
+
for pos in range(0, t, wiener_win_len):
|
| 790 |
+
frame = slice(pos, pos + wiener_win_len)
|
| 791 |
+
z_out = wiener(
|
| 792 |
+
mag_out[sample, frame],
|
| 793 |
+
mix_stft[sample, frame],
|
| 794 |
+
niters,
|
| 795 |
+
residual=residual,
|
| 796 |
+
)
|
| 797 |
+
out.append(z_out.transpose(-1, -2))
|
| 798 |
+
outs.append(torch.cat(out, dim=0))
|
| 799 |
+
out = torch.view_as_complex(torch.stack(outs, 0))
|
| 800 |
+
out = out.permute(0, 4, 3, 2, 1).contiguous()
|
| 801 |
+
if residual:
|
| 802 |
+
out = out[:, :-1]
|
| 803 |
+
assert list(out.shape) == [b, s, c, fq, t]
|
| 804 |
+
return out.to(init)
|
| 805 |
+
|
| 806 |
+
def valid_length(self, length: int):
|
| 807 |
+
"""
|
| 808 |
+
Return a length that is appropriate for evaluation.
|
| 809 |
+
In our case, always return the training length, unless
|
| 810 |
+
it is smaller than the given length, in which case this
|
| 811 |
+
raises an error.
|
| 812 |
+
"""
|
| 813 |
+
if not self.use_train_segment:
|
| 814 |
+
return length
|
| 815 |
+
training_length = int(self.segment * self.samplerate)
|
| 816 |
+
if training_length < length:
|
| 817 |
+
raise ValueError(
|
| 818 |
+
f"Given length {length} is longer than "
|
| 819 |
+
f"training length {training_length}")
|
| 820 |
+
return training_length
|
| 821 |
+
|
| 822 |
+
def forward(self, mix):
|
| 823 |
+
length = mix.shape[-1]
|
| 824 |
+
length_pre_pad = None
|
| 825 |
+
if self.use_train_segment:
|
| 826 |
+
if self.training:
|
| 827 |
+
self.segment = Fraction(mix.shape[-1], self.samplerate)
|
| 828 |
+
else:
|
| 829 |
+
training_length = int(self.segment * self.samplerate)
|
| 830 |
+
if mix.shape[-1] < training_length:
|
| 831 |
+
length_pre_pad = mix.shape[-1]
|
| 832 |
+
mix = F.pad(mix, (0, training_length - length_pre_pad))
|
| 833 |
+
z = self._spec(mix)
|
| 834 |
+
mag = self._magnitude(z).to(mix.device)
|
| 835 |
+
x = mag
|
| 836 |
+
|
| 837 |
+
b, _, fq, t = x.shape
|
| 838 |
+
|
| 839 |
+
# unlike previous Demucs, we always normalize because it is easier.
|
| 840 |
+
mean = x.mean(dim=(1, 2, 3), keepdim=True)
|
| 841 |
+
std = x.std(dim=(1, 2, 3), keepdim=True)
|
| 842 |
+
x = (x - mean) / (1e-5 + std)
|
| 843 |
+
# x will be the freq. branch input.
|
| 844 |
+
|
| 845 |
+
# Prepare the time branch input.
|
| 846 |
+
xt = mix
|
| 847 |
+
meant = xt.mean(dim=(1, 2), keepdim=True)
|
| 848 |
+
stdt = xt.std(dim=(1, 2), keepdim=True)
|
| 849 |
+
xt = (xt - meant) / (1e-5 + stdt)
|
| 850 |
+
|
| 851 |
+
# okay, this is a giant mess I know...
|
| 852 |
+
saved = [] # skip connections, freq.
|
| 853 |
+
saved_t = [] # skip connections, time.
|
| 854 |
+
lengths = [] # saved lengths to properly remove padding, freq branch.
|
| 855 |
+
lengths_t = [] # saved lengths for time branch.
|
| 856 |
+
for idx, encode in enumerate(self.encoder):
|
| 857 |
+
lengths.append(x.shape[-1])
|
| 858 |
+
inject = None
|
| 859 |
+
if idx < len(self.tencoder):
|
| 860 |
+
# we have not yet merged branches.
|
| 861 |
+
lengths_t.append(xt.shape[-1])
|
| 862 |
+
tenc = self.tencoder[idx]
|
| 863 |
+
xt = tenc(xt)
|
| 864 |
+
if not tenc.empty:
|
| 865 |
+
# save for skip connection
|
| 866 |
+
saved_t.append(xt)
|
| 867 |
+
else:
|
| 868 |
+
# tenc contains just the first conv., so that now time and freq.
|
| 869 |
+
# branches have the same shape and can be merged.
|
| 870 |
+
inject = xt
|
| 871 |
+
x = encode(x, inject)
|
| 872 |
+
if idx == 0 and self.freq_emb is not None:
|
| 873 |
+
# add frequency embedding to allow for non equivariant convolutions
|
| 874 |
+
# over the frequency axis.
|
| 875 |
+
frs = torch.arange(x.shape[-2], device=x.device)
|
| 876 |
+
emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
|
| 877 |
+
x = x + self.freq_emb_scale * emb
|
| 878 |
+
|
| 879 |
+
saved.append(x)
|
| 880 |
+
if self.crosstransformer:
|
| 881 |
+
if self.bottom_channels:
|
| 882 |
+
_, _, f, _ = x.shape
|
| 883 |
+
x = rearrange(x, "b c f t-> b c (f t)")
|
| 884 |
+
x = self.channel_upsampler(x)
|
| 885 |
+
x = rearrange(x, "b c (f t)-> b c f t", f=f)
|
| 886 |
+
xt = self.channel_upsampler_t(xt)
|
| 887 |
+
|
| 888 |
+
x, xt = self.crosstransformer(x, xt)
|
| 889 |
+
|
| 890 |
+
if self.bottom_channels:
|
| 891 |
+
x = rearrange(x, "b c f t-> b c (f t)")
|
| 892 |
+
x = self.channel_downsampler(x)
|
| 893 |
+
x = rearrange(x, "b c (f t)-> b c f t", f=f)
|
| 894 |
+
xt = self.channel_downsampler_t(xt)
|
| 895 |
+
|
| 896 |
+
for idx, decode in enumerate(self.decoder):
|
| 897 |
+
skip = saved.pop(-1)
|
| 898 |
+
x, pre = decode(x, skip, lengths.pop(-1))
|
| 899 |
+
# `pre` contains the output just before final transposed convolution,
|
| 900 |
+
# which is used when the freq. and time branch separate.
|
| 901 |
+
|
| 902 |
+
offset = self.depth - len(self.tdecoder)
|
| 903 |
+
if idx >= offset:
|
| 904 |
+
tdec = self.tdecoder[idx - offset]
|
| 905 |
+
length_t = lengths_t.pop(-1)
|
| 906 |
+
if tdec.empty:
|
| 907 |
+
assert pre.shape[2] == 1, pre.shape
|
| 908 |
+
pre = pre[:, :, 0]
|
| 909 |
+
xt, _ = tdec(pre, None, length_t)
|
| 910 |
+
else:
|
| 911 |
+
skip = saved_t.pop(-1)
|
| 912 |
+
xt, _ = tdec(xt, skip, length_t)
|
| 913 |
+
|
| 914 |
+
# Let's make sure we used all stored skip connections.
|
| 915 |
+
assert len(saved) == 0
|
| 916 |
+
assert len(lengths_t) == 0
|
| 917 |
+
assert len(saved_t) == 0
|
| 918 |
+
|
| 919 |
+
s = len(self.sources)
|
| 920 |
+
x = x.view(b, s, -1, fq, t)
|
| 921 |
+
x = x * std[:, None] + mean[:, None]
|
| 922 |
+
|
| 923 |
+
# to cpu as mps doesnt support complex numbers
|
| 924 |
+
# demucs issue #435 ##432
|
| 925 |
+
# NOTE: in this case z already is on cpu
|
| 926 |
+
# TODO: remove this when mps supports complex numbers
|
| 927 |
+
x_is_mps = x.device.type == "mps"
|
| 928 |
+
if x_is_mps:
|
| 929 |
+
x = x.cpu()
|
| 930 |
+
|
| 931 |
+
zout = self._mask(z, x)
|
| 932 |
+
if self.use_train_segment:
|
| 933 |
+
if self.training:
|
| 934 |
+
x = self._ispec(zout, length)
|
| 935 |
+
else:
|
| 936 |
+
x = self._ispec(zout, training_length)
|
| 937 |
+
else:
|
| 938 |
+
x = self._ispec(zout, length)
|
| 939 |
+
|
| 940 |
+
# back to mps device
|
| 941 |
+
if x_is_mps:
|
| 942 |
+
x = x.to("mps")
|
| 943 |
+
|
| 944 |
+
if self.use_train_segment:
|
| 945 |
+
if self.training:
|
| 946 |
+
xt = xt.view(b, s, -1, length)
|
| 947 |
+
else:
|
| 948 |
+
xt = xt.view(b, s, -1, training_length)
|
| 949 |
+
else:
|
| 950 |
+
xt = xt.view(b, s, -1, length)
|
| 951 |
+
xt = xt * stdt[:, None] + meant[:, None]
|
| 952 |
+
x = xt + x
|
| 953 |
+
if length_pre_pad:
|
| 954 |
+
x = x[..., :length_pre_pad]
|
| 955 |
+
return x
|
third_party/demucs/models/pretrained.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
@File : pretrained.py
|
| 5 |
+
@Time : 2023/8/8 下午7:22
|
| 6 |
+
@Author : waytan
|
| 7 |
+
@Contact : waytan@tencent.com
|
| 8 |
+
@License : (C)Copyright 2023, Tencent
|
| 9 |
+
@Desc : Loading pretrained models.
|
| 10 |
+
"""
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
import yaml
|
| 14 |
+
|
| 15 |
+
from .apply import BagOfModels
|
| 16 |
+
from .htdemucs import HTDemucs
|
| 17 |
+
from .states import load_state_dict
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def add_model_flags(parser):
|
| 21 |
+
group = parser.add_mutually_exclusive_group(required=False)
|
| 22 |
+
group.add_argument("-s", "--sig", help="Locally trained XP signature.")
|
| 23 |
+
group.add_argument("-n", "--name", default=None,
|
| 24 |
+
help="Pretrained model name or signature. Default is htdemucs.")
|
| 25 |
+
parser.add_argument("--repo", type=Path,
|
| 26 |
+
help="Folder containing all pre-trained models for use with -n.")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_model_from_yaml(yaml_file, model_file):
|
| 30 |
+
bag = yaml.safe_load(open(yaml_file))
|
| 31 |
+
model = load_state_dict(HTDemucs, model_file)
|
| 32 |
+
weights = bag.get('weights')
|
| 33 |
+
segment = bag.get('segment')
|
| 34 |
+
return BagOfModels([model], weights, segment)
|
third_party/demucs/models/spec.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
@File : spec.py
|
| 5 |
+
@Time : 2023/8/8 下午5:10
|
| 6 |
+
@Author : waytan
|
| 7 |
+
@Contact : waytan@tencent.com
|
| 8 |
+
@License : (C)Copyright 2023, Tencent
|
| 9 |
+
@Desc : Spec
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import torch as th
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def spectro(x, n_fft=512, hop_length=None, pad=0):
|
| 16 |
+
*other, length = x.shape
|
| 17 |
+
x = x.reshape(-1, length)
|
| 18 |
+
is_mps = x.device.type == 'mps'
|
| 19 |
+
if is_mps:
|
| 20 |
+
x = x.cpu()
|
| 21 |
+
z = th.stft(x,
|
| 22 |
+
n_fft * (1 + pad),
|
| 23 |
+
hop_length or n_fft // 4,
|
| 24 |
+
window=th.hann_window(n_fft).to(x),
|
| 25 |
+
win_length=n_fft,
|
| 26 |
+
normalized=True,
|
| 27 |
+
center=True,
|
| 28 |
+
return_complex=True,
|
| 29 |
+
pad_mode='reflect')
|
| 30 |
+
_, freqs, frame = z.shape
|
| 31 |
+
return z.view(*other, freqs, frame)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def ispectro(z, hop_length=None, length=None, pad=0):
|
| 35 |
+
*other, freqs, frames = z.shape
|
| 36 |
+
n_fft = 2 * freqs - 2
|
| 37 |
+
z = z.view(-1, freqs, frames)
|
| 38 |
+
win_length = n_fft // (1 + pad)
|
| 39 |
+
is_mps = z.device.type == 'mps'
|
| 40 |
+
if is_mps:
|
| 41 |
+
z = z.cpu()
|
| 42 |
+
x = th.istft(z,
|
| 43 |
+
n_fft,
|
| 44 |
+
hop_length,
|
| 45 |
+
window=th.hann_window(win_length).to(z.real),
|
| 46 |
+
win_length=win_length,
|
| 47 |
+
normalized=True,
|
| 48 |
+
length=length,
|
| 49 |
+
center=True)
|
| 50 |
+
_, length = x.shape
|
| 51 |
+
return x.view(*other, length)
|
third_party/demucs/models/states.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
@File : states.py
|
| 5 |
+
@Time : 2023/8/8 下午7:01
|
| 6 |
+
@Author : waytan
|
| 7 |
+
@Contact : waytan@tencent.com
|
| 8 |
+
@License : (C)Copyright 2023, Tencent
|
| 9 |
+
@Desc : Utilities to save and load models.
|
| 10 |
+
"""
|
| 11 |
+
import functools
|
| 12 |
+
import inspect
|
| 13 |
+
import warnings
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from fractions import Fraction
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def load_state_dict(net, pth_path):
|
| 21 |
+
kwargs = {'sources': ['drums', 'bass', 'other', 'vocal'], 'audio_channels': 2, 'samplerate': 44100,
|
| 22 |
+
'segment': Fraction(39, 5), 'channels': 48, 'channels_time': None, 'growth': 2, 'nfft': 4096,
|
| 23 |
+
'wiener_iters': 0, 'end_iters': 0, 'wiener_residual': False, 'cac': True, 'depth': 4, 'rewrite': True,
|
| 24 |
+
'multi_freqs': [], 'multi_freqs_depth': 3, 'freq_emb': 0.2, 'emb_scale': 10, 'emb_smooth': True,
|
| 25 |
+
'kernel_size': 8, 'stride': 4, 'time_stride': 2, 'context': 1, 'context_enc': 0, 'norm_starts': 4,
|
| 26 |
+
'norm_groups': 4, 'dconv_mode': 3, 'dconv_depth': 2, 'dconv_comp': 8, 'dconv_init': 0.001,
|
| 27 |
+
'bottom_channels': 512, 't_layers': 5, 't_hidden_scale': 4.0, 't_heads': 8, 't_dropout': 0.02,
|
| 28 |
+
't_layer_scale': True, 't_gelu': True, 't_emb': 'sin', 't_max_positions': 10000, 't_max_period': 10000.0,
|
| 29 |
+
't_weight_pos_embed': 1.0, 't_cape_mean_normalize': True, 't_cape_augment': True,
|
| 30 |
+
't_cape_glob_loc_scale': [5000.0, 1.0, 1.4], 't_sin_random_shift': 0, 't_norm_in': True,
|
| 31 |
+
't_norm_in_group': False, 't_group_norm': False, 't_norm_first': True, 't_norm_out': True,
|
| 32 |
+
't_weight_decay': 0.0, 't_lr': None, 't_sparse_self_attn': False, 't_sparse_cross_attn': False,
|
| 33 |
+
't_mask_type': 'diag', 't_mask_random_seed': 42, 't_sparse_attn_window': 400, 't_global_window': 100,
|
| 34 |
+
't_sparsity': 0.95, 't_auto_sparsity': False, 't_cross_first': False, 'rescale': 0.1}
|
| 35 |
+
model = net(**kwargs)
|
| 36 |
+
state_dict = torch.load(pth_path)
|
| 37 |
+
model.load_state_dict(state_dict)
|
| 38 |
+
return model
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def load_model(path_or_package, strict=False):
|
| 42 |
+
"""Load a model from the given serialized model, either given as a dict (already loaded)
|
| 43 |
+
or a path to a file on disk."""
|
| 44 |
+
if isinstance(path_or_package, dict):
|
| 45 |
+
package = path_or_package
|
| 46 |
+
elif isinstance(path_or_package, (str, Path)):
|
| 47 |
+
with warnings.catch_warnings():
|
| 48 |
+
warnings.simplefilter("ignore")
|
| 49 |
+
path = path_or_package
|
| 50 |
+
package = torch.load(path, 'cpu')
|
| 51 |
+
else:
|
| 52 |
+
raise ValueError(f"Invalid type for {path_or_package}.")
|
| 53 |
+
|
| 54 |
+
klass = package["klass"]
|
| 55 |
+
args = package["args"]
|
| 56 |
+
kwargs = package["kwargs"]
|
| 57 |
+
|
| 58 |
+
if strict:
|
| 59 |
+
model = klass(*args, **kwargs)
|
| 60 |
+
else:
|
| 61 |
+
sig = inspect.signature(klass)
|
| 62 |
+
for key in list(kwargs):
|
| 63 |
+
if key not in sig.parameters:
|
| 64 |
+
warnings.warn("Dropping inexistant parameter " + key)
|
| 65 |
+
del kwargs[key]
|
| 66 |
+
model = klass(*args, **kwargs)
|
| 67 |
+
|
| 68 |
+
state = package["state"]
|
| 69 |
+
|
| 70 |
+
set_state(model, state)
|
| 71 |
+
return model
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def get_state(model, quantizer, half=False):
|
| 75 |
+
"""Get the state from a model, potentially with quantization applied.
|
| 76 |
+
If `half` is True, model are stored as half precision, which shouldn't impact performance
|
| 77 |
+
but half the state size."""
|
| 78 |
+
if quantizer is None:
|
| 79 |
+
dtype = torch.half if half else None
|
| 80 |
+
state = {k: p.data.to(device='cpu', dtype=dtype) for k, p in model.state_dict().items()}
|
| 81 |
+
else:
|
| 82 |
+
state = quantizer.get_quantized_state()
|
| 83 |
+
state['__quantized'] = True
|
| 84 |
+
return state
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def set_state(model, state, quantizer=None):
|
| 88 |
+
"""Set the state on a given model."""
|
| 89 |
+
if state.get('__quantized'):
|
| 90 |
+
quantizer.restore_quantized_state(model, state['quantized'])
|
| 91 |
+
else:
|
| 92 |
+
model.load_state_dict(state)
|
| 93 |
+
return state
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def capture_init(init):
|
| 97 |
+
@functools.wraps(init)
|
| 98 |
+
def __init__(self, *args, **kwargs):
|
| 99 |
+
self._init_args_kwargs = (args, kwargs)
|
| 100 |
+
init(self, *args, **kwargs)
|
| 101 |
+
|
| 102 |
+
return __init__
|
third_party/demucs/models/transformer.py
ADDED
|
@@ -0,0 +1,765 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
@File : transformer.py
|
| 5 |
+
@Time : 2023/8/8 下午5:05
|
| 6 |
+
@Author : waytan
|
| 7 |
+
@Contact : waytan@tencent.com
|
| 8 |
+
@License : (C)Copyright 2023, Tencent
|
| 9 |
+
@Desc : Transformer
|
| 10 |
+
"""
|
| 11 |
+
import math
|
| 12 |
+
import random
|
| 13 |
+
import typing as tp
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
import numpy as np
|
| 19 |
+
from einops import rearrange
|
| 20 |
+
from torch.nn import TransformerEncoderLayer, MultiheadAttention, Linear, LayerNorm
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def create_sin_embedding(
|
| 24 |
+
length: int, dim: int, shift: int = 0, device="cpu", max_period=10000
|
| 25 |
+
):
|
| 26 |
+
# We aim for TBC format
|
| 27 |
+
assert dim % 2 == 0
|
| 28 |
+
pos = shift + torch.arange(length, device=device).view(-1, 1, 1)
|
| 29 |
+
half_dim = dim // 2
|
| 30 |
+
adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
|
| 31 |
+
phase = pos / (max_period ** (adim / (half_dim - 1)))
|
| 32 |
+
return torch.cat(
|
| 33 |
+
[
|
| 34 |
+
torch.cos(phase),
|
| 35 |
+
torch.sin(phase),
|
| 36 |
+
],
|
| 37 |
+
dim=-1,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def create_2d_sin_embedding(d_model, height, width, device="cpu", max_period=10000):
|
| 42 |
+
"""
|
| 43 |
+
:param d_model: dimension of the model
|
| 44 |
+
:param height: height of the positions
|
| 45 |
+
:param width: width of the positions
|
| 46 |
+
:return: d_model*height*width position matrix
|
| 47 |
+
"""
|
| 48 |
+
if d_model % 4 != 0:
|
| 49 |
+
raise ValueError(
|
| 50 |
+
"Cannot use sin/cos positional encoding with "
|
| 51 |
+
"odd dimension (got dim={:d})".format(d_model)
|
| 52 |
+
)
|
| 53 |
+
pe = torch.zeros(d_model, height, width)
|
| 54 |
+
# Each dimension use half of d_model
|
| 55 |
+
d_model = int(d_model / 2)
|
| 56 |
+
div_term = torch.exp(
|
| 57 |
+
torch.arange(0.0, d_model, 2) * -(math.log(max_period) / d_model)
|
| 58 |
+
)
|
| 59 |
+
pos_w = torch.arange(0.0, width).unsqueeze(1)
|
| 60 |
+
pos_h = torch.arange(0.0, height).unsqueeze(1)
|
| 61 |
+
pe[0:d_model:2, :, :] = (
|
| 62 |
+
torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
|
| 63 |
+
)
|
| 64 |
+
pe[1:d_model:2, :, :] = (
|
| 65 |
+
torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
|
| 66 |
+
)
|
| 67 |
+
pe[d_model::2, :, :] = (
|
| 68 |
+
torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
|
| 69 |
+
)
|
| 70 |
+
pe[d_model + 1:: 2, :, :] = (
|
| 71 |
+
torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
return pe[None, :].to(device)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def create_sin_embedding_cape(
|
| 78 |
+
length: int,
|
| 79 |
+
dim: int,
|
| 80 |
+
batch_size: int,
|
| 81 |
+
mean_normalize: bool,
|
| 82 |
+
augment: bool, # True during training
|
| 83 |
+
max_global_shift: float = 0.0, # delta max
|
| 84 |
+
max_local_shift: float = 0.0, # epsilon max
|
| 85 |
+
max_scale: float = 1.0,
|
| 86 |
+
device: str = "cpu",
|
| 87 |
+
max_period: float = 10000.0,
|
| 88 |
+
):
|
| 89 |
+
# We aim for TBC format
|
| 90 |
+
assert dim % 2 == 0
|
| 91 |
+
pos = 1.0 * torch.arange(length).view(-1, 1, 1) # (length, 1, 1)
|
| 92 |
+
pos = pos.repeat(1, batch_size, 1) # (length, batch_size, 1)
|
| 93 |
+
if mean_normalize:
|
| 94 |
+
pos -= torch.nanmean(pos, dim=0, keepdim=True)
|
| 95 |
+
|
| 96 |
+
if augment:
|
| 97 |
+
delta = np.random.uniform(
|
| 98 |
+
-max_global_shift, +max_global_shift, size=[1, batch_size, 1]
|
| 99 |
+
)
|
| 100 |
+
delta_local = np.random.uniform(
|
| 101 |
+
-max_local_shift, +max_local_shift, size=[length, batch_size, 1]
|
| 102 |
+
)
|
| 103 |
+
log_lambdas = np.random.uniform(
|
| 104 |
+
-np.log(max_scale), +np.log(max_scale), size=[1, batch_size, 1]
|
| 105 |
+
)
|
| 106 |
+
pos = (pos + delta + delta_local) * np.exp(log_lambdas)
|
| 107 |
+
|
| 108 |
+
pos = pos.to(device)
|
| 109 |
+
|
| 110 |
+
half_dim = dim // 2
|
| 111 |
+
adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
|
| 112 |
+
phase = pos / (max_period ** (adim / (half_dim - 1)))
|
| 113 |
+
return torch.cat(
|
| 114 |
+
[
|
| 115 |
+
torch.cos(phase),
|
| 116 |
+
torch.sin(phase),
|
| 117 |
+
],
|
| 118 |
+
dim=-1,
|
| 119 |
+
).float()
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def get_causal_mask(length):
|
| 123 |
+
pos = torch.arange(length)
|
| 124 |
+
return pos > pos[:, None]
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def get_elementary_mask(
|
| 128 |
+
t1,
|
| 129 |
+
t2,
|
| 130 |
+
mask_type,
|
| 131 |
+
sparse_attn_window,
|
| 132 |
+
global_window,
|
| 133 |
+
mask_random_seed,
|
| 134 |
+
sparsity,
|
| 135 |
+
device,
|
| 136 |
+
):
|
| 137 |
+
"""
|
| 138 |
+
When the input of the Decoder has length T1 and the output T2
|
| 139 |
+
The mask matrix has shape (T2, T1)
|
| 140 |
+
"""
|
| 141 |
+
assert mask_type in ["diag", "jmask", "random", "global"]
|
| 142 |
+
|
| 143 |
+
if mask_type == "global":
|
| 144 |
+
mask = torch.zeros(t2, t1, dtype=torch.bool)
|
| 145 |
+
mask[:, :global_window] = True
|
| 146 |
+
line_window = int(global_window * t2 / t1)
|
| 147 |
+
mask[:line_window, :] = True
|
| 148 |
+
|
| 149 |
+
if mask_type == "diag":
|
| 150 |
+
|
| 151 |
+
mask = torch.zeros(t2, t1, dtype=torch.bool)
|
| 152 |
+
rows = torch.arange(t2)[:, None]
|
| 153 |
+
cols = (
|
| 154 |
+
(t1 / t2 * rows + torch.arange(-sparse_attn_window, sparse_attn_window + 1))
|
| 155 |
+
.long()
|
| 156 |
+
.clamp(0, t1 - 1)
|
| 157 |
+
)
|
| 158 |
+
mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols))
|
| 159 |
+
|
| 160 |
+
elif mask_type == "jmask":
|
| 161 |
+
mask = torch.zeros(t2 + 2, t1 + 2, dtype=torch.bool)
|
| 162 |
+
rows = torch.arange(t2 + 2)[:, None]
|
| 163 |
+
t = torch.arange(0, int((2 * t1) ** 0.5 + 1))
|
| 164 |
+
t = (t * (t + 1) / 2).int()
|
| 165 |
+
t = torch.cat([-t.flip(0)[:-1], t])
|
| 166 |
+
cols = (t1 / t2 * rows + t).long().clamp(0, t1 + 1)
|
| 167 |
+
mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols))
|
| 168 |
+
mask = mask[1:-1, 1:-1]
|
| 169 |
+
|
| 170 |
+
elif mask_type == "random":
|
| 171 |
+
gene = torch.Generator(device=device)
|
| 172 |
+
gene.manual_seed(mask_random_seed)
|
| 173 |
+
mask = (
|
| 174 |
+
torch.rand(t1 * t2, generator=gene, device=device).reshape(t2, t1)
|
| 175 |
+
> sparsity
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
mask = mask.to(device)
|
| 179 |
+
return mask
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def get_mask(
|
| 183 |
+
t1,
|
| 184 |
+
t2,
|
| 185 |
+
mask_type,
|
| 186 |
+
sparse_attn_window,
|
| 187 |
+
global_window,
|
| 188 |
+
mask_random_seed,
|
| 189 |
+
sparsity,
|
| 190 |
+
device,
|
| 191 |
+
):
|
| 192 |
+
"""
|
| 193 |
+
Return a SparseCSRTensor mask that is a combination of elementary masks
|
| 194 |
+
mask_type can be a combination of multiple masks: for instance "diag_jmask_random"
|
| 195 |
+
"""
|
| 196 |
+
from xformers.sparse import SparseCSRTensor
|
| 197 |
+
# create a list
|
| 198 |
+
mask_types = mask_type.split("_")
|
| 199 |
+
|
| 200 |
+
all_masks = [
|
| 201 |
+
get_elementary_mask(
|
| 202 |
+
t1,
|
| 203 |
+
t2,
|
| 204 |
+
mask,
|
| 205 |
+
sparse_attn_window,
|
| 206 |
+
global_window,
|
| 207 |
+
mask_random_seed,
|
| 208 |
+
sparsity,
|
| 209 |
+
device,
|
| 210 |
+
)
|
| 211 |
+
for mask in mask_types
|
| 212 |
+
]
|
| 213 |
+
|
| 214 |
+
final_mask = torch.stack(all_masks).sum(axis=0) > 0
|
| 215 |
+
|
| 216 |
+
return SparseCSRTensor.from_dense(final_mask[None])
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class ScaledEmbedding(nn.Module):
|
| 220 |
+
def __init__(
|
| 221 |
+
self,
|
| 222 |
+
num_embeddings: int,
|
| 223 |
+
embedding_dim: int,
|
| 224 |
+
scale: float = 1.0,
|
| 225 |
+
boost: float = 3.0,
|
| 226 |
+
):
|
| 227 |
+
super().__init__()
|
| 228 |
+
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
|
| 229 |
+
self.embedding.weight.data *= scale / boost
|
| 230 |
+
self.boost = boost
|
| 231 |
+
|
| 232 |
+
@property
|
| 233 |
+
def weight(self):
|
| 234 |
+
return self.embedding.weight * self.boost
|
| 235 |
+
|
| 236 |
+
def forward(self, x):
|
| 237 |
+
return self.embedding(x) * self.boost
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
class LayerScale(nn.Module):
|
| 241 |
+
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
|
| 242 |
+
This rescales diagonaly residual outputs close to 0 initially, then learnt.
|
| 243 |
+
"""
|
| 244 |
+
|
| 245 |
+
def __init__(self, channels: int, init: float = 0, channel_last=False):
|
| 246 |
+
"""
|
| 247 |
+
channel_last = False corresponds to (B, C, T) tensors
|
| 248 |
+
channel_last = True corresponds to (T, B, C) tensors
|
| 249 |
+
"""
|
| 250 |
+
super().__init__()
|
| 251 |
+
self.channel_last = channel_last
|
| 252 |
+
self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
|
| 253 |
+
self.scale.data[:] = init
|
| 254 |
+
|
| 255 |
+
def forward(self, x):
|
| 256 |
+
if self.channel_last:
|
| 257 |
+
return self.scale * x
|
| 258 |
+
else:
|
| 259 |
+
return self.scale[:, None] * x
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
class MyGroupNorm(nn.GroupNorm):
|
| 263 |
+
def __init__(self, *args, **kwargs):
|
| 264 |
+
super().__init__(*args, **kwargs)
|
| 265 |
+
|
| 266 |
+
def forward(self, x):
|
| 267 |
+
"""
|
| 268 |
+
x: (B, T, C)
|
| 269 |
+
if num_groups=1: Normalisation on all T and C together for each B
|
| 270 |
+
"""
|
| 271 |
+
x = x.transpose(1, 2)
|
| 272 |
+
return super().forward(x).transpose(1, 2)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class MyTransformerEncoderLayer(TransformerEncoderLayer):
|
| 276 |
+
def __init__(
|
| 277 |
+
self,
|
| 278 |
+
d_model,
|
| 279 |
+
nhead,
|
| 280 |
+
dim_feedforward=2048,
|
| 281 |
+
dropout=0.1,
|
| 282 |
+
activation=F.relu,
|
| 283 |
+
group_norm=0,
|
| 284 |
+
norm_first=False,
|
| 285 |
+
norm_out=False,
|
| 286 |
+
layer_norm_eps=1e-5,
|
| 287 |
+
layer_scale=False,
|
| 288 |
+
init_values=1e-4,
|
| 289 |
+
device=None,
|
| 290 |
+
dtype=None,
|
| 291 |
+
sparse=False,
|
| 292 |
+
mask_type="diag",
|
| 293 |
+
mask_random_seed=42,
|
| 294 |
+
sparse_attn_window=500,
|
| 295 |
+
global_window=50,
|
| 296 |
+
auto_sparsity=False,
|
| 297 |
+
sparsity=0.95,
|
| 298 |
+
batch_first=False,
|
| 299 |
+
):
|
| 300 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 301 |
+
super().__init__(
|
| 302 |
+
d_model=d_model,
|
| 303 |
+
nhead=nhead,
|
| 304 |
+
dim_feedforward=dim_feedforward,
|
| 305 |
+
dropout=dropout,
|
| 306 |
+
activation=activation,
|
| 307 |
+
layer_norm_eps=layer_norm_eps,
|
| 308 |
+
batch_first=batch_first,
|
| 309 |
+
norm_first=norm_first,
|
| 310 |
+
device=device,
|
| 311 |
+
dtype=dtype,
|
| 312 |
+
)
|
| 313 |
+
self.sparse = sparse
|
| 314 |
+
self.auto_sparsity = auto_sparsity
|
| 315 |
+
if sparse:
|
| 316 |
+
if not auto_sparsity:
|
| 317 |
+
self.mask_type = mask_type
|
| 318 |
+
self.sparse_attn_window = sparse_attn_window
|
| 319 |
+
self.global_window = global_window
|
| 320 |
+
self.sparsity = sparsity
|
| 321 |
+
if group_norm:
|
| 322 |
+
self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
|
| 323 |
+
self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
|
| 324 |
+
|
| 325 |
+
self.norm_out = None
|
| 326 |
+
if self.norm_first & norm_out:
|
| 327 |
+
self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
|
| 328 |
+
self.gamma_1 = (
|
| 329 |
+
LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
|
| 330 |
+
)
|
| 331 |
+
self.gamma_2 = (
|
| 332 |
+
LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
if sparse:
|
| 336 |
+
self.self_attn = MultiheadAttention(
|
| 337 |
+
d_model, nhead, dropout=dropout, batch_first=batch_first,
|
| 338 |
+
auto_sparsity=sparsity if auto_sparsity else 0,
|
| 339 |
+
)
|
| 340 |
+
self.__setattr__("src_mask", torch.zeros(1, 1))
|
| 341 |
+
self.mask_random_seed = mask_random_seed
|
| 342 |
+
|
| 343 |
+
def forward(self, src, src_mask=None, src_key_padding_mask=None):
|
| 344 |
+
"""
|
| 345 |
+
if batch_first = False, src shape is (t, b, c)
|
| 346 |
+
the case where batch_first=True is not covered
|
| 347 |
+
"""
|
| 348 |
+
device = src.device
|
| 349 |
+
x = src
|
| 350 |
+
t, _, _ = x.shape
|
| 351 |
+
if self.sparse and not self.auto_sparsity:
|
| 352 |
+
assert src_mask is None
|
| 353 |
+
src_mask = self.src_mask
|
| 354 |
+
if src_mask.shape[-1] != t:
|
| 355 |
+
src_mask = get_mask(
|
| 356 |
+
t,
|
| 357 |
+
t,
|
| 358 |
+
self.mask_type,
|
| 359 |
+
self.sparse_attn_window,
|
| 360 |
+
self.global_window,
|
| 361 |
+
self.mask_random_seed,
|
| 362 |
+
self.sparsity,
|
| 363 |
+
device,
|
| 364 |
+
)
|
| 365 |
+
self.__setattr__("src_mask", src_mask)
|
| 366 |
+
|
| 367 |
+
if self.norm_first:
|
| 368 |
+
x = x + self.gamma_1(
|
| 369 |
+
self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
|
| 370 |
+
)
|
| 371 |
+
x = x + self.gamma_2(self._ff_block(self.norm2(x)))
|
| 372 |
+
|
| 373 |
+
if self.norm_out:
|
| 374 |
+
x = self.norm_out(x)
|
| 375 |
+
else:
|
| 376 |
+
x = self.norm1(
|
| 377 |
+
x + self.gamma_1(self._sa_block(x, src_mask, src_key_padding_mask))
|
| 378 |
+
)
|
| 379 |
+
x = self.norm2(x + self.gamma_2(self._ff_block(x)))
|
| 380 |
+
|
| 381 |
+
return x
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
class CrossTransformerEncoderLayer(nn.Module):
|
| 385 |
+
def __init__(
|
| 386 |
+
self,
|
| 387 |
+
d_model: int,
|
| 388 |
+
nhead: int,
|
| 389 |
+
dim_feedforward: int = 2048,
|
| 390 |
+
dropout: float = 0.1,
|
| 391 |
+
activation=F.relu,
|
| 392 |
+
layer_norm_eps: float = 1e-5,
|
| 393 |
+
layer_scale: bool = False,
|
| 394 |
+
init_values: float = 1e-4,
|
| 395 |
+
norm_first: bool = False,
|
| 396 |
+
group_norm: bool = False,
|
| 397 |
+
norm_out: bool = False,
|
| 398 |
+
sparse=False,
|
| 399 |
+
mask_type="diag",
|
| 400 |
+
mask_random_seed=42,
|
| 401 |
+
sparse_attn_window=500,
|
| 402 |
+
global_window=50,
|
| 403 |
+
sparsity=0.95,
|
| 404 |
+
auto_sparsity=None,
|
| 405 |
+
device=None,
|
| 406 |
+
dtype=None,
|
| 407 |
+
batch_first=False,
|
| 408 |
+
):
|
| 409 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 410 |
+
super().__init__()
|
| 411 |
+
|
| 412 |
+
self.sparse = sparse
|
| 413 |
+
self.auto_sparsity = auto_sparsity
|
| 414 |
+
if sparse:
|
| 415 |
+
if not auto_sparsity:
|
| 416 |
+
self.mask_type = mask_type
|
| 417 |
+
self.sparse_attn_window = sparse_attn_window
|
| 418 |
+
self.global_window = global_window
|
| 419 |
+
self.sparsity = sparsity
|
| 420 |
+
|
| 421 |
+
self.cross_attn: nn.Module
|
| 422 |
+
self.cross_attn = MultiheadAttention(
|
| 423 |
+
d_model, nhead, dropout=dropout, batch_first=batch_first)
|
| 424 |
+
# Implementation of Feedforward model
|
| 425 |
+
self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
|
| 426 |
+
self.dropout = nn.Dropout(dropout)
|
| 427 |
+
self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)
|
| 428 |
+
|
| 429 |
+
self.norm_first = norm_first
|
| 430 |
+
self.norm1: nn.Module
|
| 431 |
+
self.norm2: nn.Module
|
| 432 |
+
self.norm3: nn.Module
|
| 433 |
+
if group_norm:
|
| 434 |
+
self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
|
| 435 |
+
self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
|
| 436 |
+
self.norm3 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
|
| 437 |
+
else:
|
| 438 |
+
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
| 439 |
+
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
| 440 |
+
self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
| 441 |
+
|
| 442 |
+
self.norm_out = None
|
| 443 |
+
if self.norm_first & norm_out:
|
| 444 |
+
self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
|
| 445 |
+
|
| 446 |
+
self.gamma_1 = (
|
| 447 |
+
LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
|
| 448 |
+
)
|
| 449 |
+
self.gamma_2 = (
|
| 450 |
+
LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 454 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 455 |
+
|
| 456 |
+
# Legacy string support for activation function.
|
| 457 |
+
if isinstance(activation, str):
|
| 458 |
+
self.activation = self._get_activation_fn(activation)
|
| 459 |
+
else:
|
| 460 |
+
self.activation = activation
|
| 461 |
+
|
| 462 |
+
if sparse:
|
| 463 |
+
self.cross_attn = MultiheadAttention(
|
| 464 |
+
d_model, nhead, dropout=dropout, batch_first=batch_first,
|
| 465 |
+
auto_sparsity=sparsity if auto_sparsity else 0)
|
| 466 |
+
if not auto_sparsity:
|
| 467 |
+
self.__setattr__("mask", torch.zeros(1, 1))
|
| 468 |
+
self.mask_random_seed = mask_random_seed
|
| 469 |
+
|
| 470 |
+
def forward(self, q, k, mask=None):
|
| 471 |
+
"""
|
| 472 |
+
Args:
|
| 473 |
+
q: tensor of shape (T, B, C)
|
| 474 |
+
k: tensor of shape (S, B, C)
|
| 475 |
+
mask: tensor of shape (T, S)
|
| 476 |
+
|
| 477 |
+
"""
|
| 478 |
+
device = q.device
|
| 479 |
+
t, _, _ = q.shape
|
| 480 |
+
s, _, _ = k.shape
|
| 481 |
+
if self.sparse and not self.auto_sparsity:
|
| 482 |
+
assert mask is None
|
| 483 |
+
mask = self.mask
|
| 484 |
+
if mask.shape[-1] != s or mask.shape[-2] != t:
|
| 485 |
+
mask = get_mask(
|
| 486 |
+
s,
|
| 487 |
+
t,
|
| 488 |
+
self.mask_type,
|
| 489 |
+
self.sparse_attn_window,
|
| 490 |
+
self.global_window,
|
| 491 |
+
self.mask_random_seed,
|
| 492 |
+
self.sparsity,
|
| 493 |
+
device,
|
| 494 |
+
)
|
| 495 |
+
self.__setattr__("mask", mask)
|
| 496 |
+
|
| 497 |
+
if self.norm_first:
|
| 498 |
+
x = q + self.gamma_1(self._ca_block(self.norm1(q), self.norm2(k), mask))
|
| 499 |
+
x = x + self.gamma_2(self._ff_block(self.norm3(x)))
|
| 500 |
+
if self.norm_out:
|
| 501 |
+
x = self.norm_out(x)
|
| 502 |
+
else:
|
| 503 |
+
x = self.norm1(q + self.gamma_1(self._ca_block(q, k, mask)))
|
| 504 |
+
x = self.norm2(x + self.gamma_2(self._ff_block(x)))
|
| 505 |
+
|
| 506 |
+
return x
|
| 507 |
+
|
| 508 |
+
# self-attention block
|
| 509 |
+
def _ca_block(self, q, k, attn_mask=None):
|
| 510 |
+
x = self.cross_attn(q, k, k, attn_mask=attn_mask, need_weights=False)[0]
|
| 511 |
+
return self.dropout1(x)
|
| 512 |
+
|
| 513 |
+
# feed forward block
|
| 514 |
+
def _ff_block(self, x):
|
| 515 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
| 516 |
+
return self.dropout2(x)
|
| 517 |
+
|
| 518 |
+
def _get_activation_fn(self, activation):
|
| 519 |
+
if activation == "relu":
|
| 520 |
+
return F.relu
|
| 521 |
+
elif activation == "gelu":
|
| 522 |
+
return F.gelu
|
| 523 |
+
|
| 524 |
+
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
# ----------------- MULTI-BLOCKS MODELS: -----------------------
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
class CrossTransformerEncoder(nn.Module):
|
| 531 |
+
def __init__(
|
| 532 |
+
self,
|
| 533 |
+
dim: int,
|
| 534 |
+
emb: str = "sin",
|
| 535 |
+
hidden_scale: float = 4.0,
|
| 536 |
+
num_heads: int = 8,
|
| 537 |
+
num_layers: int = 6,
|
| 538 |
+
cross_first: bool = False,
|
| 539 |
+
dropout: float = 0.0,
|
| 540 |
+
max_positions: int = 1000,
|
| 541 |
+
norm_in: bool = True,
|
| 542 |
+
norm_in_group: bool = False,
|
| 543 |
+
group_norm: int = False,
|
| 544 |
+
norm_first: bool = False,
|
| 545 |
+
norm_out: bool = False,
|
| 546 |
+
max_period: float = 10000.0,
|
| 547 |
+
weight_decay: float = 0.0,
|
| 548 |
+
lr: tp.Optional[float] = None,
|
| 549 |
+
layer_scale: bool = False,
|
| 550 |
+
gelu: bool = True,
|
| 551 |
+
sin_random_shift: int = 0,
|
| 552 |
+
weight_pos_embed: float = 1.0,
|
| 553 |
+
cape_mean_normalize: bool = True,
|
| 554 |
+
cape_augment: bool = True,
|
| 555 |
+
cape_glob_loc_scale: list = None,
|
| 556 |
+
sparse_self_attn: bool = False,
|
| 557 |
+
sparse_cross_attn: bool = False,
|
| 558 |
+
mask_type: str = "diag",
|
| 559 |
+
mask_random_seed: int = 42,
|
| 560 |
+
sparse_attn_window: int = 500,
|
| 561 |
+
global_window: int = 50,
|
| 562 |
+
auto_sparsity: bool = False,
|
| 563 |
+
sparsity: float = 0.95,
|
| 564 |
+
):
|
| 565 |
+
super().__init__()
|
| 566 |
+
"""
|
| 567 |
+
"""
|
| 568 |
+
assert dim % num_heads == 0
|
| 569 |
+
|
| 570 |
+
hidden_dim = int(dim * hidden_scale)
|
| 571 |
+
|
| 572 |
+
self.num_layers = num_layers
|
| 573 |
+
# classic parity = 1 means that if idx%2 == 1 there is a
|
| 574 |
+
# classical encoder else there is a cross encoder
|
| 575 |
+
self.classic_parity = 1 if cross_first else 0
|
| 576 |
+
self.emb = emb
|
| 577 |
+
self.max_period = max_period
|
| 578 |
+
self.weight_decay = weight_decay
|
| 579 |
+
self.weight_pos_embed = weight_pos_embed
|
| 580 |
+
self.sin_random_shift = sin_random_shift
|
| 581 |
+
if emb == "cape":
|
| 582 |
+
self.cape_mean_normalize = cape_mean_normalize
|
| 583 |
+
self.cape_augment = cape_augment
|
| 584 |
+
if cape_glob_loc_scale is None:
|
| 585 |
+
self.cape_glob_loc_scale = [5000.0, 1.0, 1.4]
|
| 586 |
+
else:
|
| 587 |
+
self.cape_glob_loc_scale = cape_glob_loc_scale
|
| 588 |
+
if emb == "scaled":
|
| 589 |
+
self.position_embeddings = ScaledEmbedding(max_positions, dim, scale=0.2)
|
| 590 |
+
|
| 591 |
+
self.lr = lr
|
| 592 |
+
|
| 593 |
+
activation: tp.Any = F.gelu if gelu else F.relu
|
| 594 |
+
|
| 595 |
+
self.norm_in: nn.Module
|
| 596 |
+
self.norm_in_t: nn.Module
|
| 597 |
+
if norm_in:
|
| 598 |
+
self.norm_in = LayerNorm(dim)
|
| 599 |
+
self.norm_in_t = LayerNorm(dim)
|
| 600 |
+
elif norm_in_group:
|
| 601 |
+
self.norm_in = MyGroupNorm(int(norm_in_group), dim)
|
| 602 |
+
self.norm_in_t = MyGroupNorm(int(norm_in_group), dim)
|
| 603 |
+
else:
|
| 604 |
+
self.norm_in = nn.Identity()
|
| 605 |
+
self.norm_in_t = nn.Identity()
|
| 606 |
+
|
| 607 |
+
# spectrogram layers
|
| 608 |
+
self.layers = nn.ModuleList()
|
| 609 |
+
# temporal layers
|
| 610 |
+
self.layers_t = nn.ModuleList()
|
| 611 |
+
|
| 612 |
+
kwargs_common = {
|
| 613 |
+
"d_model": dim,
|
| 614 |
+
"nhead": num_heads,
|
| 615 |
+
"dim_feedforward": hidden_dim,
|
| 616 |
+
"dropout": dropout,
|
| 617 |
+
"activation": activation,
|
| 618 |
+
"group_norm": group_norm,
|
| 619 |
+
"norm_first": norm_first,
|
| 620 |
+
"norm_out": norm_out,
|
| 621 |
+
"layer_scale": layer_scale,
|
| 622 |
+
"mask_type": mask_type,
|
| 623 |
+
"mask_random_seed": mask_random_seed,
|
| 624 |
+
"sparse_attn_window": sparse_attn_window,
|
| 625 |
+
"global_window": global_window,
|
| 626 |
+
"sparsity": sparsity,
|
| 627 |
+
"auto_sparsity": auto_sparsity,
|
| 628 |
+
"batch_first": True,
|
| 629 |
+
}
|
| 630 |
+
|
| 631 |
+
kwargs_classic_encoder = dict(kwargs_common)
|
| 632 |
+
kwargs_classic_encoder.update({
|
| 633 |
+
"sparse": sparse_self_attn,
|
| 634 |
+
})
|
| 635 |
+
kwargs_cross_encoder = dict(kwargs_common)
|
| 636 |
+
kwargs_cross_encoder.update({
|
| 637 |
+
"sparse": sparse_cross_attn,
|
| 638 |
+
})
|
| 639 |
+
|
| 640 |
+
for idx in range(num_layers):
|
| 641 |
+
if idx % 2 == self.classic_parity:
|
| 642 |
+
|
| 643 |
+
self.layers.append(MyTransformerEncoderLayer(**kwargs_classic_encoder))
|
| 644 |
+
self.layers_t.append(
|
| 645 |
+
MyTransformerEncoderLayer(**kwargs_classic_encoder)
|
| 646 |
+
)
|
| 647 |
+
|
| 648 |
+
else:
|
| 649 |
+
self.layers.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder))
|
| 650 |
+
|
| 651 |
+
self.layers_t.append(
|
| 652 |
+
CrossTransformerEncoderLayer(**kwargs_cross_encoder)
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
def forward(self, x, xt):
|
| 656 |
+
_, c, fr, t1 = x.shape
|
| 657 |
+
pos_emb_2d = create_2d_sin_embedding(
|
| 658 |
+
c, fr, t1, x.device, self.max_period
|
| 659 |
+
) # (1, C, Fr, T1)
|
| 660 |
+
pos_emb_2d = rearrange(pos_emb_2d, "b c fr t1 -> b (t1 fr) c")
|
| 661 |
+
x = rearrange(x, "b c fr t1 -> b (t1 fr) c")
|
| 662 |
+
x = self.norm_in(x)
|
| 663 |
+
x = x + self.weight_pos_embed * pos_emb_2d
|
| 664 |
+
|
| 665 |
+
b, c, t2 = xt.shape
|
| 666 |
+
xt = rearrange(xt, "b c t2 -> b t2 c") # now T2, B, C
|
| 667 |
+
pos_emb = self._get_pos_embedding(t2, b, c, x.device)
|
| 668 |
+
pos_emb = rearrange(pos_emb, "t2 b c -> b t2 c")
|
| 669 |
+
xt = self.norm_in_t(xt)
|
| 670 |
+
xt = xt + self.weight_pos_embed * pos_emb
|
| 671 |
+
|
| 672 |
+
for idx in range(self.num_layers):
|
| 673 |
+
if idx % 2 == self.classic_parity:
|
| 674 |
+
x = self.layers[idx](x)
|
| 675 |
+
xt = self.layers_t[idx](xt)
|
| 676 |
+
else:
|
| 677 |
+
old_x = x
|
| 678 |
+
x = self.layers[idx](x, xt)
|
| 679 |
+
xt = self.layers_t[idx](xt, old_x)
|
| 680 |
+
|
| 681 |
+
x = rearrange(x, "b (t1 fr) c -> b c fr t1", t1=t1)
|
| 682 |
+
xt = rearrange(xt, "b t2 c -> b c t2")
|
| 683 |
+
return x, xt
|
| 684 |
+
|
| 685 |
+
def _get_pos_embedding(self, t, b, c, device):
|
| 686 |
+
if self.emb == "sin":
|
| 687 |
+
shift = random.randrange(self.sin_random_shift + 1)
|
| 688 |
+
pos_emb = create_sin_embedding(
|
| 689 |
+
t, c, shift=shift, device=device, max_period=self.max_period
|
| 690 |
+
)
|
| 691 |
+
elif self.emb == "cape":
|
| 692 |
+
if self.training:
|
| 693 |
+
pos_emb = create_sin_embedding_cape(
|
| 694 |
+
t,
|
| 695 |
+
c,
|
| 696 |
+
b,
|
| 697 |
+
device=device,
|
| 698 |
+
max_period=self.max_period,
|
| 699 |
+
mean_normalize=self.cape_mean_normalize,
|
| 700 |
+
augment=self.cape_augment,
|
| 701 |
+
max_global_shift=self.cape_glob_loc_scale[0],
|
| 702 |
+
max_local_shift=self.cape_glob_loc_scale[1],
|
| 703 |
+
max_scale=self.cape_glob_loc_scale[2],
|
| 704 |
+
)
|
| 705 |
+
else:
|
| 706 |
+
pos_emb = create_sin_embedding_cape(
|
| 707 |
+
t,
|
| 708 |
+
c,
|
| 709 |
+
b,
|
| 710 |
+
device=device,
|
| 711 |
+
max_period=self.max_period,
|
| 712 |
+
mean_normalize=self.cape_mean_normalize,
|
| 713 |
+
augment=False,
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
elif self.emb == "scaled":
|
| 717 |
+
pos = torch.arange(t, device=device)
|
| 718 |
+
pos_emb = self.position_embeddings(pos)[:, None]
|
| 719 |
+
|
| 720 |
+
return pos_emb
|
| 721 |
+
|
| 722 |
+
def make_optim_group(self):
|
| 723 |
+
group = {"params": list(self.parameters()), "weight_decay": self.weight_decay}
|
| 724 |
+
if self.lr is not None:
|
| 725 |
+
group["lr"] = self.lr
|
| 726 |
+
return group
|
| 727 |
+
|
| 728 |
+
|
| 729 |
+
def scaled_query_key_softmax(q, k, att_mask):
|
| 730 |
+
from xformers.ops import masked_matmul
|
| 731 |
+
q = q / (k.size(-1)) ** 0.5
|
| 732 |
+
att = masked_matmul(q, k.transpose(-2, -1), att_mask)
|
| 733 |
+
att = torch.nn.functional.softmax(att, -1)
|
| 734 |
+
return att
|
| 735 |
+
|
| 736 |
+
|
| 737 |
+
def scaled_dot_product_attention(q, k, v, att_mask, dropout):
|
| 738 |
+
att = scaled_query_key_softmax(q, k, att_mask=att_mask)
|
| 739 |
+
att = dropout(att)
|
| 740 |
+
y = att @ v
|
| 741 |
+
return y
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
def _compute_buckets(x, r):
|
| 745 |
+
qq = torch.einsum('btf,bfhi->bhti', x, r)
|
| 746 |
+
qq = torch.cat([qq, -qq], dim=-1)
|
| 747 |
+
buckets = qq.argmax(dim=-1)
|
| 748 |
+
|
| 749 |
+
return buckets.permute(0, 2, 1).byte().contiguous()
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
def dynamic_sparse_attention(query, key, value, sparsity, infer_sparsity=True, attn_bias=None):
|
| 753 |
+
# assert False, "The code for the custom sparse kernel is not ready for release yet."
|
| 754 |
+
from xformers.ops import find_locations, sparse_memory_efficient_attention
|
| 755 |
+
n_hashes = 32
|
| 756 |
+
proj_size = 4
|
| 757 |
+
query, key, value = [x.contiguous() for x in [query, key, value]]
|
| 758 |
+
with torch.no_grad():
|
| 759 |
+
r = torch.randn(1, query.shape[-1], n_hashes, proj_size // 2, device=query.device)
|
| 760 |
+
bucket_query = _compute_buckets(query, r)
|
| 761 |
+
bucket_key = _compute_buckets(key, r)
|
| 762 |
+
row_offsets, column_indices = find_locations(
|
| 763 |
+
bucket_query, bucket_key, sparsity, infer_sparsity)
|
| 764 |
+
return sparse_memory_efficient_attention(
|
| 765 |
+
query, key, value, row_offsets, column_indices, attn_bias)
|
third_party/demucs/models/utils.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
@File : utils.py
|
| 5 |
+
@Time : 2023/8/8 下午4:26
|
| 6 |
+
@Author : waytan
|
| 7 |
+
@Contact : waytan@tencent.com
|
| 8 |
+
@License : (C)Copyright 2023, Tencent
|
| 9 |
+
@Desc : utils
|
| 10 |
+
"""
|
| 11 |
+
from contextlib import contextmanager
|
| 12 |
+
import math
|
| 13 |
+
import os
|
| 14 |
+
import tempfile
|
| 15 |
+
import typing as tp
|
| 16 |
+
import json
|
| 17 |
+
import subprocess
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from torch.nn import functional as F
|
| 21 |
+
from torch.utils.data import Subset
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def unfold(a, kernel_size, stride):
|
| 25 |
+
"""Given input of size [*OT, T], output Tensor of size [*OT, F, K]
|
| 26 |
+
with K the kernel size, by extracting frames with the given stride.
|
| 27 |
+
|
| 28 |
+
This will pad the input so that `F = ceil(T / K)`.
|
| 29 |
+
|
| 30 |
+
see https://github.com/pytorch/pytorch/issues/60466
|
| 31 |
+
"""
|
| 32 |
+
*shape, length = a.shape
|
| 33 |
+
n_frames = math.ceil(length / stride)
|
| 34 |
+
tgt_length = (n_frames - 1) * stride + kernel_size
|
| 35 |
+
a = F.pad(a, (0, tgt_length - length))
|
| 36 |
+
strides = list(a.stride())
|
| 37 |
+
assert strides[-1] == 1, 'data should be contiguous'
|
| 38 |
+
strides = strides[:-1] + [stride, 1]
|
| 39 |
+
return a.as_strided([*shape, n_frames, kernel_size], strides)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def center_trim(tensor: torch.Tensor, reference: tp.Union[torch.Tensor, int]):
|
| 43 |
+
"""
|
| 44 |
+
Center trim `tensor` with respect to `reference`, along the last dimension.
|
| 45 |
+
`reference` can also be a number, representing the length to trim to.
|
| 46 |
+
If the size difference != 0 mod 2, the extra sample is removed on the right side.
|
| 47 |
+
"""
|
| 48 |
+
ref_size: int
|
| 49 |
+
if isinstance(reference, torch.Tensor):
|
| 50 |
+
ref_size = reference.size(-1)
|
| 51 |
+
else:
|
| 52 |
+
ref_size = reference
|
| 53 |
+
delta = tensor.size(-1) - ref_size
|
| 54 |
+
if delta < 0:
|
| 55 |
+
raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.")
|
| 56 |
+
if delta:
|
| 57 |
+
tensor = tensor[..., delta // 2:-(delta - delta // 2)]
|
| 58 |
+
return tensor
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def pull_metric(history: tp.List[dict], name: str):
|
| 62 |
+
out = []
|
| 63 |
+
for metrics in history:
|
| 64 |
+
metric = metrics
|
| 65 |
+
for part in name.split("."):
|
| 66 |
+
metric = metric[part]
|
| 67 |
+
out.append(metric)
|
| 68 |
+
return out
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def sizeof_fmt(num: float, suffix: str = 'B'):
|
| 72 |
+
"""
|
| 73 |
+
Given `num` bytes, return human readable size.
|
| 74 |
+
Taken from https://stackoverflow.com/a/1094933
|
| 75 |
+
"""
|
| 76 |
+
for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:
|
| 77 |
+
if abs(num) < 1024.0:
|
| 78 |
+
return "%3.1f%s%s" % (num, unit, suffix)
|
| 79 |
+
num /= 1024.0
|
| 80 |
+
return "%.1f%s%s" % (num, 'Yi', suffix)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@contextmanager
|
| 84 |
+
def temp_filenames(count: int, delete=True):
|
| 85 |
+
names = []
|
| 86 |
+
try:
|
| 87 |
+
for _ in range(count):
|
| 88 |
+
names.append(tempfile.NamedTemporaryFile(delete=False).name)
|
| 89 |
+
yield names
|
| 90 |
+
finally:
|
| 91 |
+
if delete:
|
| 92 |
+
for name in names:
|
| 93 |
+
os.unlink(name)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def random_subset(dataset, max_samples: int, seed: int = 42):
|
| 97 |
+
if max_samples >= len(dataset):
|
| 98 |
+
return dataset
|
| 99 |
+
|
| 100 |
+
generator = torch.Generator().manual_seed(seed)
|
| 101 |
+
perm = torch.randperm(len(dataset), generator=generator)
|
| 102 |
+
return Subset(dataset, perm[:max_samples].tolist())
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class DummyPoolExecutor:
|
| 106 |
+
class DummyResult:
|
| 107 |
+
def __init__(self, func, *args, **kwargs):
|
| 108 |
+
self.func = func
|
| 109 |
+
self.args = args
|
| 110 |
+
self.kwargs = kwargs
|
| 111 |
+
|
| 112 |
+
def result(self):
|
| 113 |
+
return self.func(*self.args, **self.kwargs)
|
| 114 |
+
|
| 115 |
+
def __init__(self, workers=0):
|
| 116 |
+
pass
|
| 117 |
+
|
| 118 |
+
def submit(self, func, *args, **kwargs):
|
| 119 |
+
return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
|
| 120 |
+
|
| 121 |
+
def __enter__(self):
|
| 122 |
+
return self
|
| 123 |
+
|
| 124 |
+
def __exit__(self, exc_type, exc_value, exc_tb):
|
| 125 |
+
return
|
third_party/demucs/run.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
@File : layers.py
|
| 5 |
+
@Time : 2024/4/22 下午2:40
|
| 6 |
+
@Author : waytan
|
| 7 |
+
@Contact : waytan@tencent.com
|
| 8 |
+
@License : (C)Copyright 2024, Tencent
|
| 9 |
+
"""
|
| 10 |
+
import os
|
| 11 |
+
import json
|
| 12 |
+
import time
|
| 13 |
+
import logging
|
| 14 |
+
import argparse
|
| 15 |
+
from datetime import datetime
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from models.apply import BagOfModels
|
| 21 |
+
from models.pretrained import get_model_from_yaml
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Separator:
|
| 25 |
+
def __init__(self, dm_model_path, dm_config_path, gpu_id=0) -> None:
|
| 26 |
+
if torch.cuda.is_available() and gpu_id < torch.cuda.device_count():
|
| 27 |
+
self.device = torch.device(f"cuda:{gpu_id}")
|
| 28 |
+
else:
|
| 29 |
+
self.device = torch.device("cpu")
|
| 30 |
+
self.demucs_model = self.init_demucs_model(dm_model_path, dm_config_path)
|
| 31 |
+
|
| 32 |
+
def init_demucs_model(self, model_path, config_path) -> BagOfModels:
|
| 33 |
+
model = get_model_from_yaml(config_path, model_path)
|
| 34 |
+
model.to(self.device)
|
| 35 |
+
model.eval()
|
| 36 |
+
return model
|
| 37 |
+
|
| 38 |
+
def run(self, audio_path, output_dir, ext=".flac"):
|
| 39 |
+
name, _ = os.path.splitext(os.path.split(audio_path)[-1])
|
| 40 |
+
output_paths = []
|
| 41 |
+
for stem in self.demucs_model.sources:
|
| 42 |
+
output_path = os.path.join(output_dir, f"{name}_{stem}{ext}")
|
| 43 |
+
if os.path.exists(output_path):
|
| 44 |
+
output_paths.append(output_path)
|
| 45 |
+
if len(output_paths) == 4:
|
| 46 |
+
drums_path, bass_path, other_path, vocal_path = output_paths
|
| 47 |
+
else:
|
| 48 |
+
drums_path, bass_path, other_path, vocal_path = self.demucs_model.separate(audio_path, output_dir, device=self.device)
|
| 49 |
+
data_dict = {
|
| 50 |
+
"vocal_path": vocal_path,
|
| 51 |
+
"bgm_path": [drums_path, bass_path, other_path]
|
| 52 |
+
}
|
| 53 |
+
return data_dict
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def json_io(input_json, output_json, model_dir, dst_dir, gpu_id=0):
|
| 57 |
+
current_datetime = datetime.now()
|
| 58 |
+
current_datetime_str = current_datetime.strftime('%Y-%m-%d-%H:%M')
|
| 59 |
+
logging.basicConfig(filename=os.path.join(dst_dir, f'logger-separate-{os.path.split(input_json)[1]}-{current_datetime_str}.log'), level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 60 |
+
|
| 61 |
+
sp = Separator(os.path.join(model_dir, "htdemucs.pth"), os.path.join(model_dir, "htdemucs.yaml"), gpu_id=gpu_id)
|
| 62 |
+
with open(input_json, "r") as fp:
|
| 63 |
+
lines = fp.readlines()
|
| 64 |
+
t1 = time.time()
|
| 65 |
+
success_num = 0
|
| 66 |
+
fail_num = 0
|
| 67 |
+
total_num = len(lines)
|
| 68 |
+
sep_items = []
|
| 69 |
+
for line in lines:
|
| 70 |
+
item = json.loads(line)
|
| 71 |
+
flac_file = item["path"]
|
| 72 |
+
try:
|
| 73 |
+
fix_data = sp.run(flac_file, dst_dir)
|
| 74 |
+
except Exception as e:
|
| 75 |
+
fail_num += 1
|
| 76 |
+
logging.error(f"process-{success_num + fail_num}/{total_num}|success-{success_num}|fail-{fail_num}|{item['idx']} process fail for {str(e)}")
|
| 77 |
+
continue
|
| 78 |
+
|
| 79 |
+
item["vocal_path"] = fix_data["vocal_path"]
|
| 80 |
+
item["bgm_path"] = fix_data["bgm_path"]
|
| 81 |
+
sep_items.append(item)
|
| 82 |
+
success_num += 1
|
| 83 |
+
logging.debug(f"process-{success_num + fail_num}/{total_num}|success-{success_num}|fail-{fail_num}|{item['idx']} process success")
|
| 84 |
+
|
| 85 |
+
with open(output_json, "w", encoding='utf-8') as fw:
|
| 86 |
+
for item in sep_items:
|
| 87 |
+
fw.write(json.dumps(item, ensure_ascii=False) + "\n")
|
| 88 |
+
|
| 89 |
+
t2 = time.time()
|
| 90 |
+
logging.debug(f"total cost {round(t2-t1, 3)}s")
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
if __name__ == "__main__":
|
| 94 |
+
parser = argparse.ArgumentParser(description='')
|
| 95 |
+
parser.add_argument("-m", dest="model_dir")
|
| 96 |
+
parser.add_argument("-d", dest="dst_dir")
|
| 97 |
+
parser.add_argument("-j", dest="input_json")
|
| 98 |
+
parser.add_argument("-o", dest="output_json")
|
| 99 |
+
parser.add_argument("-gid", dest="gpu_id", default=0, type=int)
|
| 100 |
+
args = parser.parse_args()
|
| 101 |
+
|
| 102 |
+
if not args.dst_dir:
|
| 103 |
+
dst_dir = os.path.join(os.getcwd(), "separate_result")
|
| 104 |
+
os.makedirs(dst_dir, exist_ok=True)
|
| 105 |
+
else:
|
| 106 |
+
dst_dir = os.path.join(args.dst_dir, "separate_result")
|
| 107 |
+
os.makedirs(dst_dir, exist_ok=True)
|
| 108 |
+
|
| 109 |
+
json_io(args.input_json, args.output_json, args.model_dir, dst_dir, gpu_id=args.gpu_id)
|
third_party/hub/version.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
1
|
third_party/stable_audio_tools/.gitignore
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py,cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# poetry
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 102 |
+
#poetry.lock
|
| 103 |
+
|
| 104 |
+
# pdm
|
| 105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 106 |
+
#pdm.lock
|
| 107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 108 |
+
# in version control.
|
| 109 |
+
# https://pdm.fming.dev/#use-with-ide
|
| 110 |
+
.pdm.toml
|
| 111 |
+
|
| 112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 113 |
+
__pypackages__/
|
| 114 |
+
|
| 115 |
+
# Celery stuff
|
| 116 |
+
celerybeat-schedule
|
| 117 |
+
celerybeat.pid
|
| 118 |
+
|
| 119 |
+
# SageMath parsed files
|
| 120 |
+
*.sage.py
|
| 121 |
+
|
| 122 |
+
# Environments
|
| 123 |
+
.env
|
| 124 |
+
.venv
|
| 125 |
+
env/
|
| 126 |
+
venv/
|
| 127 |
+
ENV/
|
| 128 |
+
env.bak/
|
| 129 |
+
venv.bak/
|
| 130 |
+
|
| 131 |
+
# Spyder project settings
|
| 132 |
+
.spyderproject
|
| 133 |
+
.spyproject
|
| 134 |
+
|
| 135 |
+
# Rope project settings
|
| 136 |
+
.ropeproject
|
| 137 |
+
|
| 138 |
+
# mkdocs documentation
|
| 139 |
+
/site
|
| 140 |
+
|
| 141 |
+
# mypy
|
| 142 |
+
.mypy_cache/
|
| 143 |
+
.dmypy.json
|
| 144 |
+
dmypy.json
|
| 145 |
+
|
| 146 |
+
# Pyre type checker
|
| 147 |
+
.pyre/
|
| 148 |
+
|
| 149 |
+
# pytype static type analyzer
|
| 150 |
+
.pytype/
|
| 151 |
+
|
| 152 |
+
# Cython debug symbols
|
| 153 |
+
cython_debug/
|
| 154 |
+
|
| 155 |
+
# PyCharm
|
| 156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 160 |
+
#.idea/
|
| 161 |
+
|
| 162 |
+
*.ckpt
|
| 163 |
+
*.wav
|
| 164 |
+
wandb/*
|
third_party/stable_audio_tools/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023 Stability AI
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
third_party/stable_audio_tools/LICENSES/LICENSE_ADP.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2022 archinet.ai
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
third_party/stable_audio_tools/LICENSES/LICENSE_AURALOSS.txt
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
third_party/stable_audio_tools/LICENSES/LICENSE_DESCRIPT.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023-present, Descript
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
third_party/stable_audio_tools/LICENSES/LICENSE_META.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
third_party/stable_audio_tools/LICENSES/LICENSE_NVIDIA.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2022 NVIDIA CORPORATION.
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
third_party/stable_audio_tools/LICENSES/LICENSE_XTRANSFORMERS.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2020 Phil Wang
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|