Delete MuCodec
Browse files- MuCodec/.DS_Store +0 -0
- MuCodec/.gitattributes +0 -2
- MuCodec/.gitignore +0 -3
- MuCodec/LICENSE +0 -21
- MuCodec/LICENSE_weights +0 -399
- MuCodec/configs/models/transformer2D.json +0 -25
- MuCodec/configs/scheduler/stable_diffusion_2.1_largenoise_sample.json +0 -14
- MuCodec/generate.py +0 -248
- MuCodec/libs/rvq/descript_quantize3.py +0 -298
- MuCodec/model.py +0 -367
- MuCodec/models/attention.py +0 -682
- MuCodec/models/transformer_2d_flow.py +0 -545
- MuCodec/muq_dev/muq_fairseq/data/__init__.py +0 -1
- MuCodec/muq_dev/muq_fairseq/data/ark_dataset.py +0 -71
- MuCodec/muq_dev/muq_fairseq/data/mert_dataset.py +0 -295
- MuCodec/muq_dev/muq_fairseq/data/utils/data_utils.py +0 -535
- MuCodec/muq_dev/muq_fairseq/models/muq/__init__.py +0 -1
- MuCodec/muq_dev/muq_fairseq/models/muq/model/__init__.py +0 -2
- MuCodec/muq_dev/muq_fairseq/models/muq/model/muq.py +0 -520
- MuCodec/muq_dev/muq_fairseq/models/muq/model/pred_ark_target_with_model.py +0 -151
- MuCodec/muq_dev/muq_fairseq/models/muq/model/rvq.py +0 -459
- MuCodec/muq_dev/muq_fairseq/models/muq/model/rvq_muq.py +0 -394
- MuCodec/muq_dev/muq_fairseq/models/muq/model/w2v2_config.json +0 -113
- MuCodec/muq_dev/muq_fairseq/models/muq/modules/__init__.py +0 -2
- MuCodec/muq_dev/muq_fairseq/models/muq/modules/conv.py +0 -77
- MuCodec/muq_dev/muq_fairseq/models/muq/modules/features.py +0 -67
- MuCodec/muq_dev/muq_fairseq/models/muq/modules/flash_conformer.py +0 -2114
- MuCodec/muq_dev/muq_fairseq/models/muq/modules/random_quantizer.py +0 -68
- MuCodec/muq_dev/muq_fairseq/models/muq/muq_model.py +0 -139
- MuCodec/muq_dev/muq_fairseq/tasks/muq_pretraining.py +0 -354
- MuCodec/muq_dev/test.py +0 -22
- MuCodec/readme.md +0 -67
- MuCodec/reconstructed/test.wav +0 -3
- MuCodec/requirements.txt +0 -335
- MuCodec/test_wav/test.wav +0 -3
- MuCodec/tools/get_melvaehifigan48k.py +0 -1551
- MuCodec/tools/torch_tools.py +0 -100
MuCodec/.DS_Store
DELETED
|
Binary file (8.2 kB)
|
|
|
MuCodec/.gitattributes
DELETED
|
@@ -1,2 +0,0 @@
|
|
| 1 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
MuCodec/.gitignore
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
__pycache__
|
| 2 |
-
*.pt
|
| 3 |
-
*.pth
|
|
|
|
|
|
|
|
|
|
|
|
MuCodec/LICENSE
DELETED
|
@@ -1,21 +0,0 @@
|
|
| 1 |
-
MIT License
|
| 2 |
-
|
| 3 |
-
Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 4 |
-
|
| 5 |
-
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
-
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
-
in the Software without restriction, including without limitation the rights
|
| 8 |
-
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
-
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
-
furnished to do so, subject to the following conditions:
|
| 11 |
-
|
| 12 |
-
The above copyright notice and this permission notice shall be included in all
|
| 13 |
-
copies or substantial portions of the Software.
|
| 14 |
-
|
| 15 |
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
-
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
-
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
-
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
-
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
-
SOFTWARE.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MuCodec/LICENSE_weights
DELETED
|
@@ -1,399 +0,0 @@
|
|
| 1 |
-
Attribution-NonCommercial 4.0 International
|
| 2 |
-
|
| 3 |
-
=======================================================================
|
| 4 |
-
|
| 5 |
-
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
| 6 |
-
does not provide legal services or legal advice. Distribution of
|
| 7 |
-
Creative Commons public licenses does not create a lawyer-client or
|
| 8 |
-
other relationship. Creative Commons makes its licenses and related
|
| 9 |
-
information available on an "as-is" basis. Creative Commons gives no
|
| 10 |
-
warranties regarding its licenses, any material licensed under their
|
| 11 |
-
terms and conditions, or any related information. Creative Commons
|
| 12 |
-
disclaims all liability for damages resulting from their use to the
|
| 13 |
-
fullest extent possible.
|
| 14 |
-
|
| 15 |
-
Using Creative Commons Public Licenses
|
| 16 |
-
|
| 17 |
-
Creative Commons public licenses provide a standard set of terms and
|
| 18 |
-
conditions that creators and other rights holders may use to share
|
| 19 |
-
original works of authorship and other material subject to copyright
|
| 20 |
-
and certain other rights specified in the public license below. The
|
| 21 |
-
following considerations are for informational purposes only, are not
|
| 22 |
-
exhaustive, and do not form part of our licenses.
|
| 23 |
-
|
| 24 |
-
Considerations for licensors: Our public licenses are
|
| 25 |
-
intended for use by those authorized to give the public
|
| 26 |
-
permission to use material in ways otherwise restricted by
|
| 27 |
-
copyright and certain other rights. Our licenses are
|
| 28 |
-
irrevocable. Licensors should read and understand the terms
|
| 29 |
-
and conditions of the license they choose before applying it.
|
| 30 |
-
Licensors should also secure all rights necessary before
|
| 31 |
-
applying our licenses so that the public can reuse the
|
| 32 |
-
material as expected. Licensors should clearly mark any
|
| 33 |
-
material not subject to the license. This includes other CC-
|
| 34 |
-
licensed material, or material used under an exception or
|
| 35 |
-
limitation to copyright. More considerations for licensors:
|
| 36 |
-
wiki.creativecommons.org/Considerations_for_licensors
|
| 37 |
-
|
| 38 |
-
Considerations for the public: By using one of our public
|
| 39 |
-
licenses, a licensor grants the public permission to use the
|
| 40 |
-
licensed material under specified terms and conditions. If
|
| 41 |
-
the licensor's permission is not necessary for any reason--for
|
| 42 |
-
example, because of any applicable exception or limitation to
|
| 43 |
-
copyright--then that use is not regulated by the license. Our
|
| 44 |
-
licenses grant only permissions under copyright and certain
|
| 45 |
-
other rights that a licensor has authority to grant. Use of
|
| 46 |
-
the licensed material may still be restricted for other
|
| 47 |
-
reasons, including because others have copyright or other
|
| 48 |
-
rights in the material. A licensor may make special requests,
|
| 49 |
-
such as asking that all changes be marked or described.
|
| 50 |
-
Although not required by our licenses, you are encouraged to
|
| 51 |
-
respect those requests where reasonable. More_considerations
|
| 52 |
-
for the public:
|
| 53 |
-
wiki.creativecommons.org/Considerations_for_licensees
|
| 54 |
-
|
| 55 |
-
=======================================================================
|
| 56 |
-
|
| 57 |
-
Creative Commons Attribution-NonCommercial 4.0 International Public
|
| 58 |
-
License
|
| 59 |
-
|
| 60 |
-
By exercising the Licensed Rights (defined below), You accept and agree
|
| 61 |
-
to be bound by the terms and conditions of this Creative Commons
|
| 62 |
-
Attribution-NonCommercial 4.0 International Public License ("Public
|
| 63 |
-
License"). To the extent this Public License may be interpreted as a
|
| 64 |
-
contract, You are granted the Licensed Rights in consideration of Your
|
| 65 |
-
acceptance of these terms and conditions, and the Licensor grants You
|
| 66 |
-
such rights in consideration of benefits the Licensor receives from
|
| 67 |
-
making the Licensed Material available under these terms and
|
| 68 |
-
conditions.
|
| 69 |
-
|
| 70 |
-
Section 1 -- Definitions.
|
| 71 |
-
|
| 72 |
-
a. Adapted Material means material subject to Copyright and Similar
|
| 73 |
-
Rights that is derived from or based upon the Licensed Material
|
| 74 |
-
and in which the Licensed Material is translated, altered,
|
| 75 |
-
arranged, transformed, or otherwise modified in a manner requiring
|
| 76 |
-
permission under the Copyright and Similar Rights held by the
|
| 77 |
-
Licensor. For purposes of this Public License, where the Licensed
|
| 78 |
-
Material is a musical work, performance, or sound recording,
|
| 79 |
-
Adapted Material is always produced where the Licensed Material is
|
| 80 |
-
synched in timed relation with a moving image.
|
| 81 |
-
|
| 82 |
-
b. Adapter's License means the license You apply to Your Copyright
|
| 83 |
-
and Similar Rights in Your contributions to Adapted Material in
|
| 84 |
-
accordance with the terms and conditions of this Public License.
|
| 85 |
-
|
| 86 |
-
c. Copyright and Similar Rights means copyright and/or similar rights
|
| 87 |
-
closely related to copyright including, without limitation,
|
| 88 |
-
performance, broadcast, sound recording, and Sui Generis Database
|
| 89 |
-
Rights, without regard to how the rights are labeled or
|
| 90 |
-
categorized. For purposes of this Public License, the rights
|
| 91 |
-
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
| 92 |
-
Rights.
|
| 93 |
-
d. Effective Technological Measures means those measures that, in the
|
| 94 |
-
absence of proper authority, may not be circumvented under laws
|
| 95 |
-
fulfilling obligations under Article 11 of the WIPO Copyright
|
| 96 |
-
Treaty adopted on December 20, 1996, and/or similar international
|
| 97 |
-
agreements.
|
| 98 |
-
|
| 99 |
-
e. Exceptions and Limitations means fair use, fair dealing, and/or
|
| 100 |
-
any other exception or limitation to Copyright and Similar Rights
|
| 101 |
-
that applies to Your use of the Licensed Material.
|
| 102 |
-
|
| 103 |
-
f. Licensed Material means the artistic or literary work, database,
|
| 104 |
-
or other material to which the Licensor applied this Public
|
| 105 |
-
License.
|
| 106 |
-
|
| 107 |
-
g. Licensed Rights means the rights granted to You subject to the
|
| 108 |
-
terms and conditions of this Public License, which are limited to
|
| 109 |
-
all Copyright and Similar Rights that apply to Your use of the
|
| 110 |
-
Licensed Material and that the Licensor has authority to license.
|
| 111 |
-
|
| 112 |
-
h. Licensor means the individual(s) or entity(ies) granting rights
|
| 113 |
-
under this Public License.
|
| 114 |
-
|
| 115 |
-
i. NonCommercial means not primarily intended for or directed towards
|
| 116 |
-
commercial advantage or monetary compensation. For purposes of
|
| 117 |
-
this Public License, the exchange of the Licensed Material for
|
| 118 |
-
other material subject to Copyright and Similar Rights by digital
|
| 119 |
-
file-sharing or similar means is NonCommercial provided there is
|
| 120 |
-
no payment of monetary compensation in connection with the
|
| 121 |
-
exchange.
|
| 122 |
-
|
| 123 |
-
j. Share means to provide material to the public by any means or
|
| 124 |
-
process that requires permission under the Licensed Rights, such
|
| 125 |
-
as reproduction, public display, public performance, distribution,
|
| 126 |
-
dissemination, communication, or importation, and to make material
|
| 127 |
-
available to the public including in ways that members of the
|
| 128 |
-
public may access the material from a place and at a time
|
| 129 |
-
individually chosen by them.
|
| 130 |
-
|
| 131 |
-
k. Sui Generis Database Rights means rights other than copyright
|
| 132 |
-
resulting from Directive 96/9/EC of the European Parliament and of
|
| 133 |
-
the Council of 11 March 1996 on the legal protection of databases,
|
| 134 |
-
as amended and/or succeeded, as well as other essentially
|
| 135 |
-
equivalent rights anywhere in the world.
|
| 136 |
-
|
| 137 |
-
l. You means the individual or entity exercising the Licensed Rights
|
| 138 |
-
under this Public License. Your has a corresponding meaning.
|
| 139 |
-
|
| 140 |
-
Section 2 -- Scope.
|
| 141 |
-
|
| 142 |
-
a. License grant.
|
| 143 |
-
|
| 144 |
-
1. Subject to the terms and conditions of this Public License,
|
| 145 |
-
the Licensor hereby grants You a worldwide, royalty-free,
|
| 146 |
-
non-sublicensable, non-exclusive, irrevocable license to
|
| 147 |
-
exercise the Licensed Rights in the Licensed Material to:
|
| 148 |
-
|
| 149 |
-
a. reproduce and Share the Licensed Material, in whole or
|
| 150 |
-
in part, for NonCommercial purposes only; and
|
| 151 |
-
|
| 152 |
-
b. produce, reproduce, and Share Adapted Material for
|
| 153 |
-
NonCommercial purposes only.
|
| 154 |
-
|
| 155 |
-
2. Exceptions and Limitations. For the avoidance of doubt, where
|
| 156 |
-
Exceptions and Limitations apply to Your use, this Public
|
| 157 |
-
License does not apply, and You do not need to comply with
|
| 158 |
-
its terms and conditions.
|
| 159 |
-
|
| 160 |
-
3. Term. The term of this Public License is specified in Section
|
| 161 |
-
6(a).
|
| 162 |
-
|
| 163 |
-
4. Media and formats; technical modifications allowed. The
|
| 164 |
-
Licensor authorizes You to exercise the Licensed Rights in
|
| 165 |
-
all media and formats whether now known or hereafter created,
|
| 166 |
-
and to make technical modifications necessary to do so. The
|
| 167 |
-
Licensor waives and/or agrees not to assert any right or
|
| 168 |
-
authority to forbid You from making technical modifications
|
| 169 |
-
necessary to exercise the Licensed Rights, including
|
| 170 |
-
technical modifications necessary to circumvent Effective
|
| 171 |
-
Technological Measures. For purposes of this Public License,
|
| 172 |
-
simply making modifications authorized by this Section 2(a)
|
| 173 |
-
(4) never produces Adapted Material.
|
| 174 |
-
|
| 175 |
-
5. Downstream recipients.
|
| 176 |
-
|
| 177 |
-
a. Offer from the Licensor -- Licensed Material. Every
|
| 178 |
-
recipient of the Licensed Material automatically
|
| 179 |
-
receives an offer from the Licensor to exercise the
|
| 180 |
-
Licensed Rights under the terms and conditions of this
|
| 181 |
-
Public License.
|
| 182 |
-
|
| 183 |
-
b. No downstream restrictions. You may not offer or impose
|
| 184 |
-
any additional or different terms or conditions on, or
|
| 185 |
-
apply any Effective Technological Measures to, the
|
| 186 |
-
Licensed Material if doing so restricts exercise of the
|
| 187 |
-
Licensed Rights by any recipient of the Licensed
|
| 188 |
-
Material.
|
| 189 |
-
|
| 190 |
-
6. No endorsement. Nothing in this Public License constitutes or
|
| 191 |
-
may be construed as permission to assert or imply that You
|
| 192 |
-
are, or that Your use of the Licensed Material is, connected
|
| 193 |
-
with, or sponsored, endorsed, or granted official status by,
|
| 194 |
-
the Licensor or others designated to receive attribution as
|
| 195 |
-
provided in Section 3(a)(1)(A)(i).
|
| 196 |
-
|
| 197 |
-
b. Other rights.
|
| 198 |
-
|
| 199 |
-
1. Moral rights, such as the right of integrity, are not
|
| 200 |
-
licensed under this Public License, nor are publicity,
|
| 201 |
-
privacy, and/or other similar personality rights; however, to
|
| 202 |
-
the extent possible, the Licensor waives and/or agrees not to
|
| 203 |
-
assert any such rights held by the Licensor to the limited
|
| 204 |
-
extent necessary to allow You to exercise the Licensed
|
| 205 |
-
Rights, but not otherwise.
|
| 206 |
-
|
| 207 |
-
2. Patent and trademark rights are not licensed under this
|
| 208 |
-
Public License.
|
| 209 |
-
|
| 210 |
-
3. To the extent possible, the Licensor waives any right to
|
| 211 |
-
collect royalties from You for the exercise of the Licensed
|
| 212 |
-
Rights, whether directly or through a collecting society
|
| 213 |
-
under any voluntary or waivable statutory or compulsory
|
| 214 |
-
licensing scheme. In all other cases the Licensor expressly
|
| 215 |
-
reserves any right to collect such royalties, including when
|
| 216 |
-
the Licensed Material is used other than for NonCommercial
|
| 217 |
-
purposes.
|
| 218 |
-
|
| 219 |
-
Section 3 -- License Conditions.
|
| 220 |
-
|
| 221 |
-
Your exercise of the Licensed Rights is expressly made subject to the
|
| 222 |
-
following conditions.
|
| 223 |
-
|
| 224 |
-
a. Attribution.
|
| 225 |
-
|
| 226 |
-
1. If You Share the Licensed Material (including in modified
|
| 227 |
-
form), You must:
|
| 228 |
-
|
| 229 |
-
a. retain the following if it is supplied by the Licensor
|
| 230 |
-
with the Licensed Material:
|
| 231 |
-
|
| 232 |
-
i. identification of the creator(s) of the Licensed
|
| 233 |
-
Material and any others designated to receive
|
| 234 |
-
attribution, in any reasonable manner requested by
|
| 235 |
-
the Licensor (including by pseudonym if
|
| 236 |
-
designated);
|
| 237 |
-
|
| 238 |
-
ii. a copyright notice;
|
| 239 |
-
|
| 240 |
-
iii. a notice that refers to this Public License;
|
| 241 |
-
|
| 242 |
-
iv. a notice that refers to the disclaimer of
|
| 243 |
-
warranties;
|
| 244 |
-
|
| 245 |
-
v. a URI or hyperlink to the Licensed Material to the
|
| 246 |
-
extent reasonably practicable;
|
| 247 |
-
|
| 248 |
-
b. indicate if You modified the Licensed Material and
|
| 249 |
-
retain an indication of any previous modifications; and
|
| 250 |
-
|
| 251 |
-
c. indicate the Licensed Material is licensed under this
|
| 252 |
-
Public License, and include the text of, or the URI or
|
| 253 |
-
hyperlink to, this Public License.
|
| 254 |
-
|
| 255 |
-
2. You may satisfy the conditions in Section 3(a)(1) in any
|
| 256 |
-
reasonable manner based on the medium, means, and context in
|
| 257 |
-
which You Share the Licensed Material. For example, it may be
|
| 258 |
-
reasonable to satisfy the conditions by providing a URI or
|
| 259 |
-
hyperlink to a resource that includes the required
|
| 260 |
-
information.
|
| 261 |
-
|
| 262 |
-
3. If requested by the Licensor, You must remove any of the
|
| 263 |
-
information required by Section 3(a)(1)(A) to the extent
|
| 264 |
-
reasonably practicable.
|
| 265 |
-
|
| 266 |
-
4. If You Share Adapted Material You produce, the Adapter's
|
| 267 |
-
License You apply must not prevent recipients of the Adapted
|
| 268 |
-
Material from complying with this Public License.
|
| 269 |
-
|
| 270 |
-
Section 4 -- Sui Generis Database Rights.
|
| 271 |
-
|
| 272 |
-
Where the Licensed Rights include Sui Generis Database Rights that
|
| 273 |
-
apply to Your use of the Licensed Material:
|
| 274 |
-
|
| 275 |
-
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
| 276 |
-
to extract, reuse, reproduce, and Share all or a substantial
|
| 277 |
-
portion of the contents of the database for NonCommercial purposes
|
| 278 |
-
only;
|
| 279 |
-
|
| 280 |
-
b. if You include all or a substantial portion of the database
|
| 281 |
-
contents in a database in which You have Sui Generis Database
|
| 282 |
-
Rights, then the database in which You have Sui Generis Database
|
| 283 |
-
Rights (but not its individual contents) is Adapted Material; and
|
| 284 |
-
|
| 285 |
-
c. You must comply with the conditions in Section 3(a) if You Share
|
| 286 |
-
all or a substantial portion of the contents of the database.
|
| 287 |
-
|
| 288 |
-
For the avoidance of doubt, this Section 4 supplements and does not
|
| 289 |
-
replace Your obligations under this Public License where the Licensed
|
| 290 |
-
Rights include other Copyright and Similar Rights.
|
| 291 |
-
|
| 292 |
-
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
| 293 |
-
|
| 294 |
-
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
| 295 |
-
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
| 296 |
-
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
| 297 |
-
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
| 298 |
-
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
| 299 |
-
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
| 300 |
-
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
| 301 |
-
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
| 302 |
-
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
| 303 |
-
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
| 304 |
-
|
| 305 |
-
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
| 306 |
-
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
| 307 |
-
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
| 308 |
-
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
| 309 |
-
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
| 310 |
-
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
| 311 |
-
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
| 312 |
-
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
| 313 |
-
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
| 314 |
-
|
| 315 |
-
c. The disclaimer of warranties and limitation of liability provided
|
| 316 |
-
above shall be interpreted in a manner that, to the extent
|
| 317 |
-
possible, most closely approximates an absolute disclaimer and
|
| 318 |
-
waiver of all liability.
|
| 319 |
-
|
| 320 |
-
Section 6 -- Term and Termination.
|
| 321 |
-
|
| 322 |
-
a. This Public License applies for the term of the Copyright and
|
| 323 |
-
Similar Rights licensed here. However, if You fail to comply with
|
| 324 |
-
this Public License, then Your rights under this Public License
|
| 325 |
-
terminate automatically.
|
| 326 |
-
|
| 327 |
-
b. Where Your right to use the Licensed Material has terminated under
|
| 328 |
-
Section 6(a), it reinstates:
|
| 329 |
-
|
| 330 |
-
1. automatically as of the date the violation is cured, provided
|
| 331 |
-
it is cured within 30 days of Your discovery of the
|
| 332 |
-
violation; or
|
| 333 |
-
|
| 334 |
-
2. upon express reinstatement by the Licensor.
|
| 335 |
-
|
| 336 |
-
For the avoidance of doubt, this Section 6(b) does not affect any
|
| 337 |
-
right the Licensor may have to seek remedies for Your violations
|
| 338 |
-
of this Public License.
|
| 339 |
-
|
| 340 |
-
c. For the avoidance of doubt, the Licensor may also offer the
|
| 341 |
-
Licensed Material under separate terms or conditions or stop
|
| 342 |
-
distributing the Licensed Material at any time; however, doing so
|
| 343 |
-
will not terminate this Public License.
|
| 344 |
-
|
| 345 |
-
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
| 346 |
-
License.
|
| 347 |
-
|
| 348 |
-
Section 7 -- Other Terms and Conditions.
|
| 349 |
-
|
| 350 |
-
a. The Licensor shall not be bound by any additional or different
|
| 351 |
-
terms or conditions communicated by You unless expressly agreed.
|
| 352 |
-
|
| 353 |
-
b. Any arrangements, understandings, or agreements regarding the
|
| 354 |
-
Licensed Material not stated herein are separate from and
|
| 355 |
-
independent of the terms and conditions of this Public License.
|
| 356 |
-
|
| 357 |
-
Section 8 -- Interpretation.
|
| 358 |
-
|
| 359 |
-
a. For the avoidance of doubt, this Public License does not, and
|
| 360 |
-
shall not be interpreted to, reduce, limit, restrict, or impose
|
| 361 |
-
conditions on any use of the Licensed Material that could lawfully
|
| 362 |
-
be made without permission under this Public License.
|
| 363 |
-
|
| 364 |
-
b. To the extent possible, if any provision of this Public License is
|
| 365 |
-
deemed unenforceable, it shall be automatically reformed to the
|
| 366 |
-
minimum extent necessary to make it enforceable. If the provision
|
| 367 |
-
cannot be reformed, it shall be severed from this Public License
|
| 368 |
-
without affecting the enforceability of the remaining terms and
|
| 369 |
-
conditions.
|
| 370 |
-
|
| 371 |
-
c. No term or condition of this Public License will be waived and no
|
| 372 |
-
failure to comply consented to unless expressly agreed to by the
|
| 373 |
-
Licensor.
|
| 374 |
-
|
| 375 |
-
d. Nothing in this Public License constitutes or may be interpreted
|
| 376 |
-
as a limitation upon, or waiver of, any privileges and immunities
|
| 377 |
-
that apply to the Licensor or You, including from the legal
|
| 378 |
-
processes of any jurisdiction or authority.
|
| 379 |
-
|
| 380 |
-
=======================================================================
|
| 381 |
-
|
| 382 |
-
Creative Commons is not a party to its public
|
| 383 |
-
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
| 384 |
-
its public licenses to material it publishes and in those instances
|
| 385 |
-
will be considered the “Licensor.” The text of the Creative Commons
|
| 386 |
-
public licenses is dedicated to the public domain under the CC0 Public
|
| 387 |
-
Domain Dedication. Except for the limited purpose of indicating that
|
| 388 |
-
material is shared under a Creative Commons public license or as
|
| 389 |
-
otherwise permitted by the Creative Commons policies published at
|
| 390 |
-
creativecommons.org/policies, Creative Commons does not authorize the
|
| 391 |
-
use of the trademark "Creative Commons" or any other trademark or logo
|
| 392 |
-
of Creative Commons without its prior written consent including,
|
| 393 |
-
without limitation, in connection with any unauthorized modifications
|
| 394 |
-
to any of its public licenses or any other arrangements,
|
| 395 |
-
understandings, or agreements concerning use of licensed material. For
|
| 396 |
-
the avoidance of doubt, this paragraph does not form part of the
|
| 397 |
-
public licenses.
|
| 398 |
-
|
| 399 |
-
Creative Commons may be contacted at creativecommons.org.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MuCodec/configs/models/transformer2D.json
DELETED
|
@@ -1,25 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"_class_name": "Transformer2DModel",
|
| 3 |
-
"activation_fn": "gelu-approximate",
|
| 4 |
-
"attention_bias": true,
|
| 5 |
-
"attention_head_dim": 72,
|
| 6 |
-
"attention_type": "default",
|
| 7 |
-
"cross_attention_dim": null,
|
| 8 |
-
"double_self_attention": false,
|
| 9 |
-
"dropout": 0.0,
|
| 10 |
-
"in_channels": 96,
|
| 11 |
-
"norm_elementwise_affine": false,
|
| 12 |
-
"norm_eps": 1e-06,
|
| 13 |
-
"norm_num_groups": 32,
|
| 14 |
-
"norm_type": "ada_norm_single",
|
| 15 |
-
"num_attention_heads": 22,
|
| 16 |
-
"num_embeds_ada_norm": 1000,
|
| 17 |
-
"num_layers": 24,
|
| 18 |
-
"num_vector_embeds": null,
|
| 19 |
-
"only_cross_attention": false,
|
| 20 |
-
"out_channels": 32,
|
| 21 |
-
"patch_size": 2,
|
| 22 |
-
"sample_size": 384,
|
| 23 |
-
"upcast_attention": false,
|
| 24 |
-
"use_linear_projection": false
|
| 25 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MuCodec/configs/scheduler/stable_diffusion_2.1_largenoise_sample.json
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"_class_name": "DDIMScheduler",
|
| 3 |
-
"_diffusers_version": "0.8.0",
|
| 4 |
-
"beta_end": 0.02,
|
| 5 |
-
"beta_schedule": "scaled_linear",
|
| 6 |
-
"beta_start": 0.0015,
|
| 7 |
-
"clip_sample": false,
|
| 8 |
-
"num_train_timesteps": 1000,
|
| 9 |
-
"prediction_type": "sample",
|
| 10 |
-
"set_alpha_to_one": false,
|
| 11 |
-
"skip_prk_steps": true,
|
| 12 |
-
"steps_offset": 1,
|
| 13 |
-
"trained_betas": null
|
| 14 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MuCodec/generate.py
DELETED
|
@@ -1,248 +0,0 @@
|
|
| 1 |
-
import json
|
| 2 |
-
import torch
|
| 3 |
-
from tqdm import tqdm
|
| 4 |
-
import sys
|
| 5 |
-
from model import PromptCondAudioDiffusion
|
| 6 |
-
from diffusers import DDIMScheduler, DDPMScheduler
|
| 7 |
-
import torchaudio
|
| 8 |
-
import librosa
|
| 9 |
-
import os
|
| 10 |
-
import math
|
| 11 |
-
import numpy as np
|
| 12 |
-
from tools.get_melvaehifigan48k import build_pretrained_models
|
| 13 |
-
import tools.torch_tools as torch_tools
|
| 14 |
-
from safetensors.torch import load_file
|
| 15 |
-
from cached_path import cached_path
|
| 16 |
-
|
| 17 |
-
class MuCodec:
|
| 18 |
-
def __init__(self, \
|
| 19 |
-
model_path, \
|
| 20 |
-
layer_num, \
|
| 21 |
-
load_main_model=True, \
|
| 22 |
-
device="cuda:0"):
|
| 23 |
-
|
| 24 |
-
self.layer_num = layer_num - 1
|
| 25 |
-
self.sample_rate = 48000
|
| 26 |
-
self.device = device
|
| 27 |
-
|
| 28 |
-
self.MAX_DURATION = 360
|
| 29 |
-
if load_main_model:
|
| 30 |
-
audio_ldm_path = str(cached_path("hf://haoheliu/audioldm_48k/audioldm_48k.pth"))
|
| 31 |
-
self.vae, self.stft = build_pretrained_models(audio_ldm_path)
|
| 32 |
-
self.vae, self.stft = self.vae.eval().to(device), self.stft.eval().to(device)
|
| 33 |
-
main_config = {
|
| 34 |
-
"num_channels":32,
|
| 35 |
-
"unet_model_name":None,
|
| 36 |
-
"unet_model_config_path":os.path.dirname(os.path.abspath(__file__)) + "/configs/models/transformer2D.json",
|
| 37 |
-
"snr_gamma":None,
|
| 38 |
-
}
|
| 39 |
-
self.model = PromptCondAudioDiffusion(**main_config)
|
| 40 |
-
if model_path.endswith('.safetensors'):
|
| 41 |
-
main_weights = load_file(model_path)
|
| 42 |
-
else:
|
| 43 |
-
main_weights = torch.load(model_path, map_location='cpu')
|
| 44 |
-
self.model.load_state_dict(main_weights, strict=False)
|
| 45 |
-
self.model = self.model.to(device)
|
| 46 |
-
print ("Successfully loaded checkpoint from:", model_path)
|
| 47 |
-
else:
|
| 48 |
-
main_config = {
|
| 49 |
-
"num_channels":32,
|
| 50 |
-
"unet_model_name":None,
|
| 51 |
-
"unet_model_config_path":None,
|
| 52 |
-
"snr_gamma":None,
|
| 53 |
-
}
|
| 54 |
-
self.model = PromptCondAudioDiffusion(**main_config).to(device)
|
| 55 |
-
main_weights = torch.load(model_path, map_location='cpu')
|
| 56 |
-
self.model.load_state_dict(main_weights, strict=False)
|
| 57 |
-
self.model = self.model.to(device)
|
| 58 |
-
print ("Successfully loaded checkpoint from:", model_path)
|
| 59 |
-
|
| 60 |
-
self.model.eval()
|
| 61 |
-
self.model.init_device_dtype(torch.device(device), torch.float32)
|
| 62 |
-
print("scaling factor: ", self.model.normfeat.std)
|
| 63 |
-
|
| 64 |
-
def file2code(self, fname):
|
| 65 |
-
orig_samples, fs = torchaudio.load(fname)
|
| 66 |
-
if(fs!=self.sample_rate):
|
| 67 |
-
orig_samples = torchaudio.functional.resample(orig_samples, fs, self.sample_rate)
|
| 68 |
-
fs = self.sample_rate
|
| 69 |
-
if orig_samples.shape[0] == 1:
|
| 70 |
-
orig_samples = torch.cat([orig_samples, orig_samples], 0)
|
| 71 |
-
return self.sound2code(orig_samples)
|
| 72 |
-
|
| 73 |
-
@torch.no_grad()
|
| 74 |
-
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
| 75 |
-
def sound2code(self, orig_samples, batch_size=3):
|
| 76 |
-
if(orig_samples.ndim == 2):
|
| 77 |
-
audios = orig_samples.unsqueeze(0).to(self.device)
|
| 78 |
-
elif(orig_samples.ndim == 3):
|
| 79 |
-
audios = orig_samples.to(self.device)
|
| 80 |
-
else:
|
| 81 |
-
assert orig_samples.ndim in (2,3), orig_samples.shape
|
| 82 |
-
audios = self.preprocess_audio(audios)
|
| 83 |
-
audios = audios.squeeze(0)
|
| 84 |
-
orig_length = audios.shape[-1]
|
| 85 |
-
min_samples = int(40.96 * self.sample_rate)
|
| 86 |
-
output_len = int(orig_length / float(self.sample_rate) * 25) + 1
|
| 87 |
-
print("output_len: ", output_len)
|
| 88 |
-
|
| 89 |
-
while(audios.shape[-1] < min_samples + 480):
|
| 90 |
-
audios = torch.cat([audios, audios], -1)
|
| 91 |
-
int_max_len=audios.shape[-1]//min_samples+1
|
| 92 |
-
# print("int_max_len: ", int_max_len)
|
| 93 |
-
audios = torch.cat([audios, audios], -1)
|
| 94 |
-
# print("audios:",audios.shape)
|
| 95 |
-
audios=audios[:,:int(int_max_len*(min_samples+480))]
|
| 96 |
-
codes_list=[]
|
| 97 |
-
|
| 98 |
-
audio_input = audios.reshape(2, -1, min_samples+480).permute(1, 0, 2).reshape(-1, 2, min_samples+480)
|
| 99 |
-
|
| 100 |
-
for audio_inx in range(0, audio_input.shape[0], batch_size):
|
| 101 |
-
# import pdb; pdb.set_trace()
|
| 102 |
-
codes, _, spk_embeds = self.model.fetch_codes_batch((audio_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer=self.layer_num)
|
| 103 |
-
codes_list.append(torch.cat(codes, 1))
|
| 104 |
-
# print("codes_list",codes_list[0].shape)
|
| 105 |
-
|
| 106 |
-
codes = torch.cat(codes_list, 0).permute(1,0,2).reshape(1, -1)[None] # B 3 T -> 3 B T
|
| 107 |
-
codes=codes[:,:,:output_len]
|
| 108 |
-
|
| 109 |
-
return codes
|
| 110 |
-
|
| 111 |
-
@torch.no_grad()
|
| 112 |
-
def code2sound(self, codes, prompt=None, duration=40.96, guidance_scale=1.5, num_steps=20, disable_progress=False):
|
| 113 |
-
codes = codes.to(self.device)
|
| 114 |
-
first_latent = torch.randn(codes.shape[0], 32, 512, 32).to(self.device)
|
| 115 |
-
first_latent_length = 0
|
| 116 |
-
first_latent_codes_length = 0
|
| 117 |
-
if(isinstance(prompt, torch.Tensor)):
|
| 118 |
-
prompt = prompt.to(self.device)
|
| 119 |
-
if(prompt.ndim == 3):
|
| 120 |
-
assert prompt.shape[0] == 1, prompt.shape
|
| 121 |
-
prompt = prompt[0]
|
| 122 |
-
elif(prompt.ndim == 1):
|
| 123 |
-
prompt = prompt.unsqueeze(0).repeat(2,1)
|
| 124 |
-
elif(prompt.ndim == 2):
|
| 125 |
-
if(prompt.shape[0] == 1):
|
| 126 |
-
prompt = prompt.repeat(2,1)
|
| 127 |
-
|
| 128 |
-
if(prompt.shape[-1] < int(30.76 * self.sample_rate)):
|
| 129 |
-
prompt = prompt[:,:int(10.24*self.sample_rate)] # limit max length to 10.24
|
| 130 |
-
else:
|
| 131 |
-
prompt = prompt[:,int(20.48*self.sample_rate):int(30.72*self.sample_rate)] # limit max length to 10.24
|
| 132 |
-
|
| 133 |
-
true_mel , _, _ = torch_tools.wav_to_fbank2(prompt, -1, fn_STFT=self.stft) # maximum 10.24s
|
| 134 |
-
true_mel = true_mel.unsqueeze(1).to(self.device)
|
| 135 |
-
true_latent = torch.cat([self.vae.get_first_stage_encoding(self.vae.encode_first_stage(true_mel[[m]])) for m in range(true_mel.shape[0])],0)
|
| 136 |
-
true_latent = true_latent.reshape(true_latent.shape[0]//2, -1, true_latent.shape[2], true_latent.shape[3]).detach()
|
| 137 |
-
|
| 138 |
-
first_latent[:,:,0:true_latent.shape[2],:] = true_latent
|
| 139 |
-
first_latent_length = true_latent.shape[2]
|
| 140 |
-
first_latent_codes = self.sound2code(prompt)[:,:,0:first_latent_length*2] # B 4 T
|
| 141 |
-
first_latent_codes_length = first_latent_codes.shape[-1]
|
| 142 |
-
codes = torch.cat([first_latent_codes, codes], -1)
|
| 143 |
-
|
| 144 |
-
min_samples = 1024
|
| 145 |
-
hop_samples = min_samples // 4 * 3
|
| 146 |
-
ovlp_samples = min_samples - hop_samples
|
| 147 |
-
hop_frames = hop_samples // 2
|
| 148 |
-
ovlp_frames = ovlp_samples // 2
|
| 149 |
-
|
| 150 |
-
codes_len= codes.shape[-1]
|
| 151 |
-
target_len = int((codes_len - first_latent_codes_length) / 100 * 4 * self.sample_rate)
|
| 152 |
-
|
| 153 |
-
if(codes_len < min_samples):
|
| 154 |
-
while(codes.shape[-1] < min_samples):
|
| 155 |
-
codes = torch.cat([codes, codes], -1)
|
| 156 |
-
codes = codes[:,:,0:min_samples]
|
| 157 |
-
codes_len = codes.shape[-1]
|
| 158 |
-
if((codes_len - ovlp_frames) % hop_samples > 0):
|
| 159 |
-
len_codes=math.ceil((codes_len - ovlp_samples) / float(hop_samples)) * hop_samples + ovlp_samples
|
| 160 |
-
while(codes.shape[-1] < len_codes):
|
| 161 |
-
codes = torch.cat([codes, codes], -1)
|
| 162 |
-
codes = codes[:,:,0:len_codes]
|
| 163 |
-
latent_length = 512
|
| 164 |
-
latent_list = []
|
| 165 |
-
spk_embeds = torch.zeros([1, 32, 1, 32], device=codes.device)
|
| 166 |
-
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
| 167 |
-
for sinx in range(0, codes.shape[-1]-hop_samples, hop_samples):
|
| 168 |
-
codes_input=[]
|
| 169 |
-
codes_input.append(codes[:,:,sinx:sinx+min_samples])
|
| 170 |
-
if(sinx == 0):
|
| 171 |
-
incontext_length = first_latent_length
|
| 172 |
-
latents = self.model.inference_codes(codes_input, spk_embeds, first_latent, latent_length, incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
|
| 173 |
-
latent_list.append(latents)
|
| 174 |
-
else:
|
| 175 |
-
true_latent = latent_list[-1][:,:,-ovlp_frames:,:]
|
| 176 |
-
len_add_to_512 = 512 - true_latent.shape[-2]
|
| 177 |
-
incontext_length = true_latent.shape[-2]
|
| 178 |
-
true_latent = torch.cat([true_latent, torch.randn(true_latent.shape[0], true_latent.shape[1], len_add_to_512, true_latent.shape[-1]).to(self.device)], -2)
|
| 179 |
-
latents = self.model.inference_codes(codes_input, spk_embeds, true_latent, latent_length, incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
|
| 180 |
-
latent_list.append(latents)
|
| 181 |
-
|
| 182 |
-
latent_list = [l.float() for l in latent_list]
|
| 183 |
-
latent_list[0] = latent_list[0][:,:,first_latent_length:,:]
|
| 184 |
-
min_samples = int(duration * self.sample_rate)
|
| 185 |
-
hop_samples = min_samples // 4 * 3
|
| 186 |
-
ovlp_samples = min_samples - hop_samples
|
| 187 |
-
with torch.no_grad():
|
| 188 |
-
output = None
|
| 189 |
-
for i in range(len(latent_list)):
|
| 190 |
-
latent = latent_list[i]
|
| 191 |
-
bsz , ch, t, f = latent.shape
|
| 192 |
-
latent = latent.reshape(bsz*2, ch//2, t, f)
|
| 193 |
-
mel = self.vae.decode_first_stage(latent)
|
| 194 |
-
cur_output = self.vae.decode_to_waveform(mel)
|
| 195 |
-
cur_output = torch.from_numpy(cur_output)[:, 0:min_samples]
|
| 196 |
-
|
| 197 |
-
if output is None:
|
| 198 |
-
output = cur_output
|
| 199 |
-
else:
|
| 200 |
-
ov_win = torch.from_numpy(np.linspace(0, 1, ovlp_samples)[None, :])
|
| 201 |
-
ov_win = torch.cat([ov_win, 1 - ov_win], -1)
|
| 202 |
-
output[:, -ovlp_samples:] = output[:, -ovlp_samples:] * ov_win[:, -ovlp_samples:] + cur_output[:, 0:ovlp_samples] * ov_win[:, 0:ovlp_samples]
|
| 203 |
-
output = torch.cat([output, cur_output[:, ovlp_samples:]], -1)
|
| 204 |
-
output = output[:, 0:target_len]
|
| 205 |
-
return output
|
| 206 |
-
|
| 207 |
-
@torch.no_grad()
|
| 208 |
-
def preprocess_audio(self, input_audios, threshold=0.8):
|
| 209 |
-
assert len(input_audios.shape) == 3, input_audios.shape
|
| 210 |
-
nchan = input_audios.shape[1]
|
| 211 |
-
input_audios = input_audios.reshape(input_audios.shape[0], -1)
|
| 212 |
-
norm_value = torch.ones_like(input_audios[:,0])
|
| 213 |
-
max_volume = input_audios.abs().max(dim=-1)[0]
|
| 214 |
-
norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
|
| 215 |
-
return input_audios.reshape(input_audios.shape[0], nchan, -1)/norm_value.unsqueeze(-1).unsqueeze(-1)
|
| 216 |
-
|
| 217 |
-
@torch.no_grad()
|
| 218 |
-
def sound2sound(self, sound, prompt=None, min_duration=40.96, steps=50, disable_progress=False):
|
| 219 |
-
codes = self.sound2code(sound)
|
| 220 |
-
wave = self.code2sound(codes, prompt, duration=min_duration, guidance_scale=1.5, num_steps=steps, disable_progress=disable_progress)
|
| 221 |
-
return wave
|
| 222 |
-
|
| 223 |
-
if __name__=="__main__":
|
| 224 |
-
ckpt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ckpt/mucodec.pt")
|
| 225 |
-
mucodec = MuCodec(model_path=ckpt_path,layer_num=7,load_main_model=True)
|
| 226 |
-
|
| 227 |
-
filelist = []
|
| 228 |
-
|
| 229 |
-
root_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_wav")
|
| 230 |
-
for f in [os.path.join(root_dir, f) for f in os.listdir(root_dir) if '.flac' in f or '.wav' in f or '.mp3' in f]:
|
| 231 |
-
a, fs = torchaudio.load(f)
|
| 232 |
-
if(fs!=48000):
|
| 233 |
-
a = torchaudio.functional.resample(a, fs, 48000)
|
| 234 |
-
if(a.shape[0]==1):
|
| 235 |
-
a = torch.cat([a,a],0)
|
| 236 |
-
ori_len = a.shape[-1]
|
| 237 |
-
filelist.append([a, '', [0, a.shape[-1]/48000.], f,ori_len])
|
| 238 |
-
|
| 239 |
-
reconstructed_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "reconstructed")
|
| 240 |
-
|
| 241 |
-
os.makedirs(reconstructed_dir, exist_ok=True)
|
| 242 |
-
|
| 243 |
-
for sample_idx, (orig_samples, lyric, st_et, fname,ori_len) in enumerate(filelist):
|
| 244 |
-
print(fname, lyric)
|
| 245 |
-
wave = mucodec.sound2sound(orig_samples,None)
|
| 246 |
-
wave = wave[:,0:ori_len]
|
| 247 |
-
torchaudio.save(os.path.join(reconstructed_dir, os.path.basename(fname)),wave.detach().cpu(), 48000)
|
| 248 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MuCodec/libs/rvq/descript_quantize3.py
DELETED
|
@@ -1,298 +0,0 @@
|
|
| 1 |
-
from typing import Union
|
| 2 |
-
|
| 3 |
-
import numpy as np
|
| 4 |
-
import torch
|
| 5 |
-
import torch.nn as nn
|
| 6 |
-
import torch.nn.functional as F
|
| 7 |
-
from einops import rearrange
|
| 8 |
-
from torch.nn.utils import weight_norm
|
| 9 |
-
|
| 10 |
-
def WNConv1d(*args, **kwargs):
|
| 11 |
-
return weight_norm(nn.Conv1d(*args, **kwargs))
|
| 12 |
-
|
| 13 |
-
class VectorQuantize(nn.Module):
|
| 14 |
-
"""
|
| 15 |
-
Implementation of VQ similar to Karpathy's repo:
|
| 16 |
-
https://github.com/karpathy/deep-vector-quantization
|
| 17 |
-
Additionally uses following tricks from Improved VQGAN
|
| 18 |
-
(https://arxiv.org/pdf/2110.04627.pdf):
|
| 19 |
-
1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
|
| 20 |
-
for improved codebook usage
|
| 21 |
-
2. l2-normalized codes: Converts euclidean distance to cosine similarity which
|
| 22 |
-
improves training stability
|
| 23 |
-
"""
|
| 24 |
-
|
| 25 |
-
def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int, stale_tolerance: int = 100):
|
| 26 |
-
super().__init__()
|
| 27 |
-
self.codebook_size = codebook_size
|
| 28 |
-
self.codebook_dim = codebook_dim
|
| 29 |
-
|
| 30 |
-
self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
|
| 31 |
-
self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
|
| 32 |
-
self.codebook = nn.Embedding(codebook_size, codebook_dim)
|
| 33 |
-
self.register_buffer("stale_counter", torch.zeros(self.codebook_size,))
|
| 34 |
-
self.stale_tolerance = stale_tolerance
|
| 35 |
-
|
| 36 |
-
def forward(self, z):
|
| 37 |
-
"""Quantized the input tensor using a fixed codebook and returns
|
| 38 |
-
the corresponding codebook vectors
|
| 39 |
-
|
| 40 |
-
Parameters
|
| 41 |
-
----------
|
| 42 |
-
z : Tensor[B x D x T]
|
| 43 |
-
|
| 44 |
-
Returns
|
| 45 |
-
-------
|
| 46 |
-
Tensor[B x D x T]
|
| 47 |
-
Quantized continuous representation of input
|
| 48 |
-
Tensor[1]
|
| 49 |
-
Commitment loss to train encoder to predict vectors closer to codebook
|
| 50 |
-
entries
|
| 51 |
-
Tensor[1]
|
| 52 |
-
Codebook loss to update the codebook
|
| 53 |
-
Tensor[B x T]
|
| 54 |
-
Codebook indices (quantized discrete representation of input)
|
| 55 |
-
Tensor[B x D x T]
|
| 56 |
-
Projected latents (continuous representation of input before quantization)
|
| 57 |
-
"""
|
| 58 |
-
|
| 59 |
-
# Factorized codes (ViT-VQGAN) Project input into low-dimensional space
|
| 60 |
-
z_e = self.in_proj(z) # z_e : (B x D x T)
|
| 61 |
-
z_q, indices = self.decode_latents(z_e)
|
| 62 |
-
|
| 63 |
-
commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
| 64 |
-
codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
| 65 |
-
|
| 66 |
-
z_q = (
|
| 67 |
-
z_e + (z_q - z_e).detach()
|
| 68 |
-
) # noop in forward pass, straight-through gradient estimator in backward pass
|
| 69 |
-
|
| 70 |
-
z_q = self.out_proj(z_q)
|
| 71 |
-
|
| 72 |
-
return z_q, commitment_loss, codebook_loss, indices, z_e
|
| 73 |
-
|
| 74 |
-
def embed_code(self, embed_id):
|
| 75 |
-
return F.embedding(embed_id, self.codebook.weight)
|
| 76 |
-
|
| 77 |
-
def decode_code(self, embed_id):
|
| 78 |
-
return self.embed_code(embed_id).transpose(1, 2)
|
| 79 |
-
|
| 80 |
-
def decode_latents(self, latents):
|
| 81 |
-
encodings = rearrange(latents, "b d t -> (b t) d")
|
| 82 |
-
codebook = self.codebook.weight # codebook: (N x D)
|
| 83 |
-
|
| 84 |
-
# L2 normalize encodings and codebook (ViT-VQGAN)
|
| 85 |
-
encodings = F.normalize(encodings)
|
| 86 |
-
codebook = F.normalize(codebook)
|
| 87 |
-
|
| 88 |
-
# Compute euclidean distance with codebook
|
| 89 |
-
dist = (
|
| 90 |
-
encodings.pow(2).sum(1, keepdim=True)
|
| 91 |
-
- 2 * encodings @ codebook.t()
|
| 92 |
-
+ codebook.pow(2).sum(1, keepdim=True).t()
|
| 93 |
-
)
|
| 94 |
-
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
| 95 |
-
z_q = self.decode_code(indices)
|
| 96 |
-
|
| 97 |
-
if(self.training):
|
| 98 |
-
onehots = torch.nn.functional.one_hot(indices, self.codebook_size).float() # B, T, codebook_size
|
| 99 |
-
stale_codes = (onehots.sum(0).sum(0) == 0).float()
|
| 100 |
-
self.stale_counter = self.stale_counter * stale_codes + stale_codes
|
| 101 |
-
|
| 102 |
-
# random replace codes that haven't been used for a while
|
| 103 |
-
replace_code = (self.stale_counter == self.stale_tolerance).float() # codebook_size
|
| 104 |
-
if replace_code.sum(-1) > 0:
|
| 105 |
-
print("Replace {} codes".format(replace_code.sum(-1)))
|
| 106 |
-
random_input_idx = torch.randperm(encodings.shape[0])
|
| 107 |
-
random_input = encodings[random_input_idx].view(encodings.shape)
|
| 108 |
-
if random_input.shape[0] < self.codebook_size:
|
| 109 |
-
random_input = torch.cat([random_input]*(self.codebook_size // random_input.shape[0] + 1), 0)
|
| 110 |
-
random_input = random_input[:self.codebook_size,:].contiguous() # codebook_size, dim
|
| 111 |
-
|
| 112 |
-
self.codebook.weight.data = self.codebook.weight.data * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1)
|
| 113 |
-
self.stale_counter = self.stale_counter * (1 - replace_code)
|
| 114 |
-
|
| 115 |
-
return z_q, indices
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
class ResidualVectorQuantize(nn.Module):
|
| 119 |
-
"""
|
| 120 |
-
Introduced in SoundStream: An end2end neural audio codec
|
| 121 |
-
https://arxiv.org/abs/2107.03312
|
| 122 |
-
"""
|
| 123 |
-
|
| 124 |
-
def __init__(
|
| 125 |
-
self,
|
| 126 |
-
input_dim: int = 512,
|
| 127 |
-
n_codebooks: int = 9,
|
| 128 |
-
codebook_size: int = 1024,
|
| 129 |
-
codebook_dim: Union[int, list] = 8,
|
| 130 |
-
quantizer_dropout: float = 0.0,
|
| 131 |
-
stale_tolerance: int = 100,
|
| 132 |
-
):
|
| 133 |
-
super().__init__()
|
| 134 |
-
if isinstance(codebook_dim, int):
|
| 135 |
-
codebook_dim = [codebook_dim for _ in range(n_codebooks)]
|
| 136 |
-
|
| 137 |
-
self.n_codebooks = n_codebooks
|
| 138 |
-
self.codebook_dim = codebook_dim
|
| 139 |
-
self.codebook_size = codebook_size
|
| 140 |
-
|
| 141 |
-
self.quantizers = nn.ModuleList(
|
| 142 |
-
[
|
| 143 |
-
VectorQuantize(input_dim, codebook_size, codebook_dim[i], stale_tolerance=stale_tolerance)
|
| 144 |
-
for i in range(n_codebooks)
|
| 145 |
-
]
|
| 146 |
-
)
|
| 147 |
-
self.quantizer_dropout = quantizer_dropout
|
| 148 |
-
|
| 149 |
-
def forward(self, z, n_quantizers: int = None):
|
| 150 |
-
"""Quantized the input tensor using a fixed set of `n` codebooks and returns
|
| 151 |
-
the corresponding codebook vectors
|
| 152 |
-
Parameters
|
| 153 |
-
----------
|
| 154 |
-
z : Tensor[B x D x T]
|
| 155 |
-
n_quantizers : int, optional
|
| 156 |
-
No. of quantizers to use
|
| 157 |
-
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
|
| 158 |
-
Note: if `self.quantizer_dropout` is True, this argument is ignored
|
| 159 |
-
when in training mode, and a random number of quantizers is used.
|
| 160 |
-
Returns
|
| 161 |
-
-------
|
| 162 |
-
dict
|
| 163 |
-
A dictionary with the following keys:
|
| 164 |
-
|
| 165 |
-
"z" : Tensor[B x D x T]
|
| 166 |
-
Quantized continuous representation of input
|
| 167 |
-
"codes" : Tensor[B x N x T]
|
| 168 |
-
Codebook indices for each codebook
|
| 169 |
-
(quantized discrete representation of input)
|
| 170 |
-
"latents" : Tensor[B x N*D x T]
|
| 171 |
-
Projected latents (continuous representation of input before quantization)
|
| 172 |
-
"vq/commitment_loss" : Tensor[1]
|
| 173 |
-
Commitment loss to train encoder to predict vectors closer to codebook
|
| 174 |
-
entries
|
| 175 |
-
"vq/codebook_loss" : Tensor[1]
|
| 176 |
-
Codebook loss to update the codebook
|
| 177 |
-
"""
|
| 178 |
-
z_q = 0
|
| 179 |
-
residual = z
|
| 180 |
-
commitment_loss = 0
|
| 181 |
-
codebook_loss = 0
|
| 182 |
-
|
| 183 |
-
codebook_indices = []
|
| 184 |
-
latents = []
|
| 185 |
-
|
| 186 |
-
if n_quantizers is None:
|
| 187 |
-
n_quantizers = self.n_codebooks
|
| 188 |
-
if self.training:
|
| 189 |
-
n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
|
| 190 |
-
dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
|
| 191 |
-
n_dropout = int(z.shape[0] * self.quantizer_dropout)
|
| 192 |
-
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
| 193 |
-
n_quantizers = n_quantizers.to(z.device)
|
| 194 |
-
else:
|
| 195 |
-
n_quantizers = torch.ones((z.shape[0],)) * n_quantizers + 1
|
| 196 |
-
n_quantizers = n_quantizers.to(z.device)
|
| 197 |
-
|
| 198 |
-
for i, quantizer in enumerate(self.quantizers):
|
| 199 |
-
# if self.training is False and i >= n_quantizers:
|
| 200 |
-
# break
|
| 201 |
-
|
| 202 |
-
z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
|
| 203 |
-
residual
|
| 204 |
-
)
|
| 205 |
-
|
| 206 |
-
# Create mask to apply quantizer dropout
|
| 207 |
-
mask = (
|
| 208 |
-
torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
|
| 209 |
-
)
|
| 210 |
-
z_q = z_q + z_q_i * mask[:, None, None]
|
| 211 |
-
residual = residual - z_q_i
|
| 212 |
-
|
| 213 |
-
# Sum losses
|
| 214 |
-
commitment_loss += (commitment_loss_i * mask).mean()
|
| 215 |
-
codebook_loss += (codebook_loss_i * mask).mean()
|
| 216 |
-
|
| 217 |
-
codebook_indices.append(indices_i)
|
| 218 |
-
latents.append(z_e_i)
|
| 219 |
-
|
| 220 |
-
codes = torch.stack(codebook_indices, dim=1)
|
| 221 |
-
latents = torch.cat(latents, dim=1)
|
| 222 |
-
|
| 223 |
-
encodings = F.one_hot(codes, self.codebook_size).float() # B N T 1024
|
| 224 |
-
for n in range(encodings.shape[1]):
|
| 225 |
-
print("Lyaer {}, Ratio of unused vector : {:.1f}".format(n,
|
| 226 |
-
(encodings[:,n,:,:].sum(0).sum(0) < 1.0).sum()/torch.numel(encodings[:,n,:,:].sum(0).sum(0) < 1.0) * 100.
|
| 227 |
-
))
|
| 228 |
-
|
| 229 |
-
return z_q, codes, latents, commitment_loss, codebook_loss, n_quantizers.clamp(max=self.n_codebooks).long() - 1
|
| 230 |
-
|
| 231 |
-
def from_codes(self, codes: torch.Tensor):
|
| 232 |
-
"""Given the quantized codes, reconstruct the continuous representation
|
| 233 |
-
Parameters
|
| 234 |
-
----------
|
| 235 |
-
codes : Tensor[B x N x T]
|
| 236 |
-
Quantized discrete representation of input
|
| 237 |
-
Returns
|
| 238 |
-
-------
|
| 239 |
-
Tensor[B x D x T]
|
| 240 |
-
Quantized continuous representation of input
|
| 241 |
-
"""
|
| 242 |
-
z_q = 0.0
|
| 243 |
-
z_p = []
|
| 244 |
-
n_codebooks = codes.shape[1]
|
| 245 |
-
for i in range(n_codebooks):
|
| 246 |
-
z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
|
| 247 |
-
z_p.append(z_p_i)
|
| 248 |
-
|
| 249 |
-
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
| 250 |
-
z_q = z_q + z_q_i
|
| 251 |
-
return z_q, torch.cat(z_p, dim=1), codes
|
| 252 |
-
|
| 253 |
-
def from_latents(self, latents: torch.Tensor):
|
| 254 |
-
"""Given the unquantized latents, reconstruct the
|
| 255 |
-
continuous representation after quantization.
|
| 256 |
-
|
| 257 |
-
Parameters
|
| 258 |
-
----------
|
| 259 |
-
latents : Tensor[B x N x T]
|
| 260 |
-
Continuous representation of input after projection
|
| 261 |
-
|
| 262 |
-
Returns
|
| 263 |
-
-------
|
| 264 |
-
Tensor[B x D x T]
|
| 265 |
-
Quantized representation of full-projected space
|
| 266 |
-
Tensor[B x D x T]
|
| 267 |
-
Quantized representation of latent space
|
| 268 |
-
"""
|
| 269 |
-
z_q = 0
|
| 270 |
-
z_p = []
|
| 271 |
-
codes = []
|
| 272 |
-
dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
|
| 273 |
-
|
| 274 |
-
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
|
| 275 |
-
0
|
| 276 |
-
]
|
| 277 |
-
for i in range(n_codebooks):
|
| 278 |
-
j, k = dims[i], dims[i + 1]
|
| 279 |
-
z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
|
| 280 |
-
z_p.append(z_p_i)
|
| 281 |
-
codes.append(codes_i)
|
| 282 |
-
|
| 283 |
-
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
| 284 |
-
z_q = z_q + z_q_i
|
| 285 |
-
|
| 286 |
-
return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
if __name__ == "__main__":
|
| 290 |
-
rvq = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 4, codebook_size = 1024, codebook_dim = 32, quantizer_dropout = 0.0)
|
| 291 |
-
x = torch.randn(16, 1024, 80)
|
| 292 |
-
quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = rvq(x)
|
| 293 |
-
print(quantized_prompt_embeds.shape)
|
| 294 |
-
print(codes.shape)
|
| 295 |
-
# w/o reconstruction
|
| 296 |
-
loss = commitment_loss * 0.25 + codebook_loss * 1.0
|
| 297 |
-
# w/ reconstruction
|
| 298 |
-
loss = commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MuCodec/model.py
DELETED
|
@@ -1,367 +0,0 @@
|
|
| 1 |
-
import yaml
|
| 2 |
-
import random
|
| 3 |
-
import inspect
|
| 4 |
-
import numpy as np
|
| 5 |
-
from tqdm import tqdm
|
| 6 |
-
import typing as tp
|
| 7 |
-
from abc import ABC
|
| 8 |
-
|
| 9 |
-
import torch
|
| 10 |
-
import torch.nn as nn
|
| 11 |
-
import torch.nn.functional as F
|
| 12 |
-
import torchaudio
|
| 13 |
-
|
| 14 |
-
from einops import repeat
|
| 15 |
-
from tools.torch_tools import wav_to_fbank
|
| 16 |
-
import os
|
| 17 |
-
import diffusers
|
| 18 |
-
from diffusers.utils.torch_utils import randn_tensor
|
| 19 |
-
from diffusers import DDPMScheduler
|
| 20 |
-
from models.transformer_2d_flow import Transformer2DModel
|
| 21 |
-
from libs.rvq.descript_quantize3 import ResidualVectorQuantize
|
| 22 |
-
from torch.cuda.amp import autocast
|
| 23 |
-
from muq_dev.test import load_model
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
class SampleProcessor(torch.nn.Module):
|
| 29 |
-
def project_sample(self, x: torch.Tensor):
|
| 30 |
-
"""Project the original sample to the 'space' where the diffusion will happen."""
|
| 31 |
-
return x
|
| 32 |
-
|
| 33 |
-
def return_sample(self, z: torch.Tensor):
|
| 34 |
-
"""Project back from diffusion space to the actual sample space."""
|
| 35 |
-
return z
|
| 36 |
-
|
| 37 |
-
class Feature2DProcessor(SampleProcessor):
|
| 38 |
-
def __init__(self, dim: int = 8, power_std: tp.Union[float, tp.List[float], torch.Tensor] = 1., \
|
| 39 |
-
num_samples: int = 100_000):
|
| 40 |
-
super().__init__()
|
| 41 |
-
self.num_samples = num_samples
|
| 42 |
-
self.dim = dim
|
| 43 |
-
self.power_std = power_std
|
| 44 |
-
self.register_buffer('counts', torch.zeros(1))
|
| 45 |
-
self.register_buffer('sum_x', torch.zeros(dim, 32))
|
| 46 |
-
self.register_buffer('sum_x2', torch.zeros(dim, 32))
|
| 47 |
-
self.register_buffer('sum_target_x2', torch.zeros(dim, 32))
|
| 48 |
-
self.counts: torch.Tensor
|
| 49 |
-
self.sum_x: torch.Tensor
|
| 50 |
-
self.sum_x2: torch.Tensor
|
| 51 |
-
|
| 52 |
-
@property
|
| 53 |
-
def mean(self):
|
| 54 |
-
mean = self.sum_x / self.counts
|
| 55 |
-
return mean
|
| 56 |
-
|
| 57 |
-
@property
|
| 58 |
-
def std(self):
|
| 59 |
-
std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
|
| 60 |
-
return std
|
| 61 |
-
|
| 62 |
-
@property
|
| 63 |
-
def target_std(self):
|
| 64 |
-
return 1
|
| 65 |
-
|
| 66 |
-
def project_sample(self, x: torch.Tensor):
|
| 67 |
-
assert x.dim() == 4
|
| 68 |
-
if self.counts.item() < self.num_samples:
|
| 69 |
-
self.counts += len(x)
|
| 70 |
-
self.sum_x += x.mean(dim=(2,)).sum(dim=0)
|
| 71 |
-
self.sum_x2 += x.pow(2).mean(dim=(2,)).sum(dim=0)
|
| 72 |
-
rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
|
| 73 |
-
x = (x - self.mean.view(1, -1, 1, 32).contiguous()) * rescale.view(1, -1, 1, 32).contiguous()
|
| 74 |
-
return x
|
| 75 |
-
|
| 76 |
-
def return_sample(self, x: torch.Tensor):
|
| 77 |
-
assert x.dim() == 4
|
| 78 |
-
rescale = (self.std / self.target_std) ** self.power_std
|
| 79 |
-
x = x * rescale.view(1, -1, 1, 32).contiguous() + self.mean.view(1, -1, 1, 32).contiguous()
|
| 80 |
-
return x
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
class BASECFM(torch.nn.Module, ABC):
|
| 84 |
-
def __init__(
|
| 85 |
-
self,
|
| 86 |
-
estimator,
|
| 87 |
-
):
|
| 88 |
-
super().__init__()
|
| 89 |
-
self.sigma_min = 1e-4
|
| 90 |
-
|
| 91 |
-
self.estimator = estimator
|
| 92 |
-
|
| 93 |
-
@torch.inference_mode()
|
| 94 |
-
def forward(self, mu, n_timesteps, temperature=1.0):
|
| 95 |
-
"""Forward diffusion
|
| 96 |
-
|
| 97 |
-
Args:
|
| 98 |
-
mu (torch.Tensor): output of encoder
|
| 99 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
| 100 |
-
n_timesteps (int): number of diffusion steps
|
| 101 |
-
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
| 102 |
-
|
| 103 |
-
Returns:
|
| 104 |
-
sample: generated mel-spectrogram
|
| 105 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
| 106 |
-
"""
|
| 107 |
-
z = torch.randn_like(mu) * temperature
|
| 108 |
-
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
| 109 |
-
return self.solve_euler(z, t_span=t_span)
|
| 110 |
-
|
| 111 |
-
def solve_euler(self, x, incontext_x, incontext_length, t_span, mu, added_cond_kwargs, guidance_scale):
|
| 112 |
-
"""
|
| 113 |
-
Fixed euler solver for ODEs.
|
| 114 |
-
Args:
|
| 115 |
-
x (torch.Tensor): random noise
|
| 116 |
-
t_span (torch.Tensor): n_timesteps interpolated
|
| 117 |
-
shape: (n_timesteps + 1,)
|
| 118 |
-
mu (torch.Tensor): output of encoder
|
| 119 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
| 120 |
-
"""
|
| 121 |
-
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
| 122 |
-
noise = x.clone()
|
| 123 |
-
|
| 124 |
-
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
| 125 |
-
# Or in future might add like a return_all_steps flag
|
| 126 |
-
sol = []
|
| 127 |
-
|
| 128 |
-
for step in tqdm(range(1, len(t_span))):
|
| 129 |
-
x[:,:,0:incontext_length,:] = (1 - (1 - self.sigma_min) * t) * noise[:,:,0:incontext_length,:] + t * incontext_x[:,:,0:incontext_length,:]
|
| 130 |
-
if(guidance_scale > 1.0):
|
| 131 |
-
dphi_dt = self.estimator( \
|
| 132 |
-
torch.cat([ \
|
| 133 |
-
torch.cat([x, x], 0), \
|
| 134 |
-
torch.cat([incontext_x, incontext_x], 0), \
|
| 135 |
-
torch.cat([torch.zeros_like(mu), mu], 0), \
|
| 136 |
-
], 1), \
|
| 137 |
-
timestep = t.unsqueeze(-1).repeat(2), \
|
| 138 |
-
added_cond_kwargs={k:torch.cat([v,v],0) for k,v in added_cond_kwargs.items()}).sample
|
| 139 |
-
dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0)
|
| 140 |
-
dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond)
|
| 141 |
-
else:
|
| 142 |
-
dphi_dt = self.estimator(torch.cat([x, incontext_x, mu], 1), \
|
| 143 |
-
timestep = t.unsqueeze(-1),
|
| 144 |
-
added_cond_kwargs=added_cond_kwargs).sample
|
| 145 |
-
|
| 146 |
-
x = x + dt * dphi_dt
|
| 147 |
-
t = t + dt
|
| 148 |
-
sol.append(x)
|
| 149 |
-
if step < len(t_span) - 1:
|
| 150 |
-
dt = t_span[step + 1] - t
|
| 151 |
-
|
| 152 |
-
return sol[-1]
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
class PromptCondAudioDiffusion(nn.Module):
|
| 156 |
-
def __init__(
|
| 157 |
-
self,
|
| 158 |
-
num_channels,
|
| 159 |
-
unet_model_name=None,
|
| 160 |
-
unet_model_config_path=None,
|
| 161 |
-
snr_gamma=None,
|
| 162 |
-
uncondition=True,
|
| 163 |
-
out_paint=False,
|
| 164 |
-
):
|
| 165 |
-
super().__init__()
|
| 166 |
-
|
| 167 |
-
assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required"
|
| 168 |
-
|
| 169 |
-
self.unet_model_name = unet_model_name
|
| 170 |
-
self.unet_model_config_path = unet_model_config_path
|
| 171 |
-
self.snr_gamma = snr_gamma
|
| 172 |
-
self.uncondition = uncondition
|
| 173 |
-
self.num_channels = num_channels
|
| 174 |
-
|
| 175 |
-
# https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview
|
| 176 |
-
self.normfeat = Feature2DProcessor(dim=num_channels)
|
| 177 |
-
|
| 178 |
-
self.sample_rate = 48000
|
| 179 |
-
self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000)
|
| 180 |
-
self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000)
|
| 181 |
-
muencoder_dir = "muq_dev/muq_fairseq"
|
| 182 |
-
muencoder_ckpt = "muq_dev/muq.pt"
|
| 183 |
-
|
| 184 |
-
self.muencoder = load_model(
|
| 185 |
-
model_dir=os.path.abspath(muencoder_dir),
|
| 186 |
-
checkpoint_dir=os.path.abspath(muencoder_ckpt),
|
| 187 |
-
)
|
| 188 |
-
self.rsq48tomuencoder = torchaudio.transforms.Resample(48000, 24000)
|
| 189 |
-
for v in self.muencoder.parameters():v.requires_grad = False
|
| 190 |
-
self.rvq_muencoder_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
|
| 191 |
-
self.cond_muencoder_emb = nn.Linear(1024, 16*32)
|
| 192 |
-
self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,))
|
| 193 |
-
|
| 194 |
-
unet = Transformer2DModel.from_config(
|
| 195 |
-
unet_model_config_path,
|
| 196 |
-
)
|
| 197 |
-
self.set_from = "random"
|
| 198 |
-
self.cfm_wrapper = BASECFM(unet)
|
| 199 |
-
print("Transformer initialized from pretrain.")
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
def compute_snr(self, timesteps):
|
| 203 |
-
"""
|
| 204 |
-
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
| 205 |
-
"""
|
| 206 |
-
alphas_cumprod = self.noise_scheduler.alphas_cumprod
|
| 207 |
-
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
| 208 |
-
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
| 209 |
-
|
| 210 |
-
# Expand the tensors.
|
| 211 |
-
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
|
| 212 |
-
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
| 213 |
-
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
| 214 |
-
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
| 215 |
-
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
| 216 |
-
|
| 217 |
-
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
| 218 |
-
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
| 219 |
-
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
| 220 |
-
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
| 221 |
-
|
| 222 |
-
# Compute SNR.
|
| 223 |
-
snr = (alpha / sigma) ** 2
|
| 224 |
-
return snr
|
| 225 |
-
|
| 226 |
-
def preprocess_audio(self, input_audios, threshold=0.9):
|
| 227 |
-
assert len(input_audios.shape) == 2, input_audios.shape
|
| 228 |
-
norm_value = torch.ones_like(input_audios[:,0])
|
| 229 |
-
max_volume = input_audios.abs().max(dim=-1)[0]
|
| 230 |
-
norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
|
| 231 |
-
return input_audios/norm_value.unsqueeze(-1)
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
def extract_muencoder_embeds(self, input_audio_0,input_audio_1,layer):
|
| 237 |
-
input_wav_mean = (input_audio_0 + input_audio_1) / 2.0
|
| 238 |
-
input_wav_mean = self.muencoder(self.rsq48tomuencoder(input_wav_mean), features_only = True)
|
| 239 |
-
layer_results = input_wav_mean['layer_results']
|
| 240 |
-
muencoder_emb = layer_results[layer]
|
| 241 |
-
muencoder_emb = muencoder_emb.permute(0,2,1).contiguous()
|
| 242 |
-
return muencoder_emb
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
def init_device_dtype(self, device, dtype):
|
| 248 |
-
self.device = device
|
| 249 |
-
self.dtype = dtype
|
| 250 |
-
|
| 251 |
-
@torch.no_grad()
|
| 252 |
-
def fetch_codes(self, input_audios, additional_feats,layer):
|
| 253 |
-
input_audio_0 = input_audios[[0],:]
|
| 254 |
-
input_audio_1 = input_audios[[1],:]
|
| 255 |
-
input_audio_0 = self.preprocess_audio(input_audio_0)
|
| 256 |
-
input_audio_1 = self.preprocess_audio(input_audio_1)
|
| 257 |
-
|
| 258 |
-
self.muencoder.eval()
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
muencoder_emb = self.extract_muencoder_embeds(input_audio_0,input_audio_1,layer)
|
| 262 |
-
muencoder_emb = muencoder_emb.detach()
|
| 263 |
-
|
| 264 |
-
self.rvq_muencoder_emb.eval()
|
| 265 |
-
quantized_muencoder_emb, codes_muencoder_emb, *_ = self.rvq_muencoder_emb(muencoder_emb)
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
spk_embeds = None
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
return [codes_muencoder_emb], [muencoder_emb], spk_embeds
|
| 272 |
-
@torch.no_grad()
|
| 273 |
-
def fetch_codes_batch(self, input_audios, additional_feats,layer):
|
| 274 |
-
input_audio_0 = input_audios[:,0,:]
|
| 275 |
-
input_audio_1 = input_audios[:,1,:]
|
| 276 |
-
input_audio_0 = self.preprocess_audio(input_audio_0)
|
| 277 |
-
input_audio_1 = self.preprocess_audio(input_audio_1)
|
| 278 |
-
|
| 279 |
-
self.muencoder.eval()
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
muencoder_emb = self.extract_muencoder_embeds(input_audio_0,input_audio_1,layer)
|
| 283 |
-
muencoder_emb = muencoder_emb.detach()
|
| 284 |
-
|
| 285 |
-
self.rvq_muencoder_emb.eval()
|
| 286 |
-
quantized_muencoder_emb, codes_muencoder_emb, *_ = self.rvq_muencoder_emb(muencoder_emb) # b,d,t
|
| 287 |
-
|
| 288 |
-
spk_embeds = None
|
| 289 |
-
|
| 290 |
-
return [codes_muencoder_emb], [muencoder_emb], spk_embeds
|
| 291 |
-
@torch.no_grad()
|
| 292 |
-
def inference_codes(self, codes, spk_embeds, true_latents, latent_length,incontext_length, additional_feats,
|
| 293 |
-
guidance_scale=2, num_steps=20,
|
| 294 |
-
disable_progress=True, scenario='start_seg'):
|
| 295 |
-
classifier_free_guidance = guidance_scale > 1.0
|
| 296 |
-
device = self.device
|
| 297 |
-
dtype = self.dtype
|
| 298 |
-
codes_muencoder_emb = codes[0]
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
batch_size = codes_muencoder_emb.shape[0]
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
quantized_muencoder_emb,_,_=self.rvq_muencoder_emb.from_codes(codes_muencoder_emb)
|
| 305 |
-
|
| 306 |
-
quantized_muencoder_emb = self.cond_muencoder_emb(quantized_muencoder_emb.permute(0,2,1)) # b t 16*32
|
| 307 |
-
quantized_muencoder_emb = quantized_muencoder_emb.reshape(quantized_muencoder_emb.shape[0], quantized_muencoder_emb.shape[1]//2, 2, 16, 32).reshape(quantized_muencoder_emb.shape[0], quantized_muencoder_emb.shape[1]//2, 2*16, 32).permute(0,2,1,3).contiguous() # b 32 t f
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
num_frames = quantized_muencoder_emb.shape[-2]
|
| 311 |
-
|
| 312 |
-
num_channels_latents = self.num_channels
|
| 313 |
-
latents = self.prepare_latents(batch_size, num_frames, num_channels_latents, dtype, device)
|
| 314 |
-
|
| 315 |
-
bsz, _, height, width = latents.shape
|
| 316 |
-
resolution = torch.tensor([height, width]).repeat(bsz, 1)
|
| 317 |
-
aspect_ratio = torch.tensor([float(height / width)]).repeat(bsz, 1)
|
| 318 |
-
resolution = resolution.to(dtype=quantized_muencoder_emb.dtype, device=device)
|
| 319 |
-
aspect_ratio = aspect_ratio.to(dtype=quantized_muencoder_emb.dtype, device=device)
|
| 320 |
-
if classifier_free_guidance:
|
| 321 |
-
resolution = torch.cat([resolution, resolution], 0)
|
| 322 |
-
aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], 0)
|
| 323 |
-
added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
|
| 324 |
-
|
| 325 |
-
latent_masks = torch.zeros(latents.shape[0], latents.shape[2], dtype=torch.int64, device=latents.device)
|
| 326 |
-
latent_masks[:,0:latent_length] = 2
|
| 327 |
-
if(scenario=='other_seg'):
|
| 328 |
-
latent_masks[:,0:incontext_length] = 1
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
quantized_muencoder_emb = (latent_masks > 0.5).unsqueeze(1).unsqueeze(-1) * quantized_muencoder_emb \
|
| 333 |
-
+ (latent_masks < 0.5).unsqueeze(1).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,32,1,32)
|
| 334 |
-
true_latents = self.normfeat.project_sample(true_latents)
|
| 335 |
-
incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(1).unsqueeze(-1).float()
|
| 336 |
-
incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0]
|
| 337 |
-
|
| 338 |
-
additional_model_input = torch.cat([quantized_muencoder_emb],1)
|
| 339 |
-
|
| 340 |
-
temperature = 1.0
|
| 341 |
-
t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_muencoder_emb.device)
|
| 342 |
-
latents = self.cfm_wrapper.solve_euler(latents * temperature, incontext_latents, incontext_length, t_span, additional_model_input, added_cond_kwargs, guidance_scale)
|
| 343 |
-
|
| 344 |
-
latents[:,:,0:incontext_length,:] = incontext_latents[:,:,0:incontext_length,:]
|
| 345 |
-
latents = self.normfeat.return_sample(latents)
|
| 346 |
-
return latents
|
| 347 |
-
|
| 348 |
-
@torch.no_grad()
|
| 349 |
-
def inference(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
|
| 350 |
-
disable_progress=True,layer=5,scenario='start_seg'):
|
| 351 |
-
codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer)
|
| 352 |
-
|
| 353 |
-
latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
|
| 354 |
-
guidance_scale=guidance_scale, num_steps=num_steps, \
|
| 355 |
-
disable_progress=disable_progress,scenario=scenario)
|
| 356 |
-
return latents
|
| 357 |
-
|
| 358 |
-
def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device):
|
| 359 |
-
divisor = 4
|
| 360 |
-
shape = (batch_size, num_channels_latents, num_frames, 32)
|
| 361 |
-
if(num_frames%divisor>0):
|
| 362 |
-
num_frames = round(num_frames/float(divisor))*divisor
|
| 363 |
-
shape = (batch_size, num_channels_latents, num_frames, 32)
|
| 364 |
-
latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
|
| 365 |
-
return latents
|
| 366 |
-
|
| 367 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MuCodec/models/attention.py
DELETED
|
@@ -1,682 +0,0 @@
|
|
| 1 |
-
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
from typing import Any, Dict, Optional
|
| 15 |
-
|
| 16 |
-
import torch
|
| 17 |
-
import torch.nn.functional as F
|
| 18 |
-
from torch import nn
|
| 19 |
-
|
| 20 |
-
from diffusers.utils import USE_PEFT_BACKEND
|
| 21 |
-
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
| 22 |
-
from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
|
| 23 |
-
from diffusers.models.attention_processor import Attention
|
| 24 |
-
from diffusers.models.embeddings import SinusoidalPositionalEmbedding
|
| 25 |
-
from diffusers.models.lora import LoRACompatibleLinear
|
| 26 |
-
from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def _chunked_feed_forward(
|
| 30 |
-
ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None
|
| 31 |
-
):
|
| 32 |
-
# "feed_forward_chunk_size" can be used to save memory
|
| 33 |
-
if hidden_states.shape[chunk_dim] % chunk_size != 0:
|
| 34 |
-
raise ValueError(
|
| 35 |
-
f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
| 36 |
-
)
|
| 37 |
-
|
| 38 |
-
num_chunks = hidden_states.shape[chunk_dim] // chunk_size
|
| 39 |
-
if lora_scale is None:
|
| 40 |
-
ff_output = torch.cat(
|
| 41 |
-
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
|
| 42 |
-
dim=chunk_dim,
|
| 43 |
-
)
|
| 44 |
-
else:
|
| 45 |
-
# TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete
|
| 46 |
-
ff_output = torch.cat(
|
| 47 |
-
[ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
|
| 48 |
-
dim=chunk_dim,
|
| 49 |
-
)
|
| 50 |
-
|
| 51 |
-
return ff_output
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
@maybe_allow_in_graph
|
| 55 |
-
class GatedSelfAttentionDense(nn.Module):
|
| 56 |
-
r"""
|
| 57 |
-
A gated self-attention dense layer that combines visual features and object features.
|
| 58 |
-
|
| 59 |
-
Parameters:
|
| 60 |
-
query_dim (`int`): The number of channels in the query.
|
| 61 |
-
context_dim (`int`): The number of channels in the context.
|
| 62 |
-
n_heads (`int`): The number of heads to use for attention.
|
| 63 |
-
d_head (`int`): The number of channels in each head.
|
| 64 |
-
"""
|
| 65 |
-
|
| 66 |
-
def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
|
| 67 |
-
super().__init__()
|
| 68 |
-
|
| 69 |
-
# we need a linear projection since we need cat visual feature and obj feature
|
| 70 |
-
self.linear = nn.Linear(context_dim, query_dim)
|
| 71 |
-
|
| 72 |
-
self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
|
| 73 |
-
self.ff = FeedForward(query_dim, activation_fn="geglu")
|
| 74 |
-
|
| 75 |
-
self.norm1 = nn.LayerNorm(query_dim)
|
| 76 |
-
self.norm2 = nn.LayerNorm(query_dim)
|
| 77 |
-
|
| 78 |
-
self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
|
| 79 |
-
self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
|
| 80 |
-
|
| 81 |
-
self.enabled = True
|
| 82 |
-
|
| 83 |
-
def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
|
| 84 |
-
if not self.enabled:
|
| 85 |
-
return x
|
| 86 |
-
|
| 87 |
-
n_visual = x.shape[1]
|
| 88 |
-
objs = self.linear(objs)
|
| 89 |
-
|
| 90 |
-
x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
|
| 91 |
-
x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
|
| 92 |
-
|
| 93 |
-
return x
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
@maybe_allow_in_graph
|
| 97 |
-
class BasicTransformerBlock(nn.Module):
|
| 98 |
-
r"""
|
| 99 |
-
A basic Transformer block.
|
| 100 |
-
|
| 101 |
-
Parameters:
|
| 102 |
-
dim (`int`): The number of channels in the input and output.
|
| 103 |
-
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
| 104 |
-
attention_head_dim (`int`): The number of channels in each head.
|
| 105 |
-
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
| 106 |
-
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
| 107 |
-
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
| 108 |
-
num_embeds_ada_norm (:
|
| 109 |
-
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
| 110 |
-
attention_bias (:
|
| 111 |
-
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
| 112 |
-
only_cross_attention (`bool`, *optional*):
|
| 113 |
-
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
| 114 |
-
double_self_attention (`bool`, *optional*):
|
| 115 |
-
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
| 116 |
-
upcast_attention (`bool`, *optional*):
|
| 117 |
-
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
| 118 |
-
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
| 119 |
-
Whether to use learnable elementwise affine parameters for normalization.
|
| 120 |
-
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
| 121 |
-
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
| 122 |
-
final_dropout (`bool` *optional*, defaults to False):
|
| 123 |
-
Whether to apply a final dropout after the last feed-forward layer.
|
| 124 |
-
attention_type (`str`, *optional*, defaults to `"default"`):
|
| 125 |
-
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
|
| 126 |
-
positional_embeddings (`str`, *optional*, defaults to `None`):
|
| 127 |
-
The type of positional embeddings to apply to.
|
| 128 |
-
num_positional_embeddings (`int`, *optional*, defaults to `None`):
|
| 129 |
-
The maximum number of positional embeddings to apply.
|
| 130 |
-
"""
|
| 131 |
-
|
| 132 |
-
def __init__(
|
| 133 |
-
self,
|
| 134 |
-
dim: int,
|
| 135 |
-
num_attention_heads: int,
|
| 136 |
-
attention_head_dim: int,
|
| 137 |
-
dropout=0.0,
|
| 138 |
-
cross_attention_dim: Optional[int] = None,
|
| 139 |
-
activation_fn: str = "geglu",
|
| 140 |
-
num_embeds_ada_norm: Optional[int] = None,
|
| 141 |
-
attention_bias: bool = False,
|
| 142 |
-
only_cross_attention: bool = False,
|
| 143 |
-
double_self_attention: bool = False,
|
| 144 |
-
upcast_attention: bool = False,
|
| 145 |
-
norm_elementwise_affine: bool = True,
|
| 146 |
-
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
|
| 147 |
-
norm_eps: float = 1e-5,
|
| 148 |
-
final_dropout: bool = False,
|
| 149 |
-
attention_type: str = "default",
|
| 150 |
-
positional_embeddings: Optional[str] = None,
|
| 151 |
-
num_positional_embeddings: Optional[int] = None,
|
| 152 |
-
ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
|
| 153 |
-
ada_norm_bias: Optional[int] = None,
|
| 154 |
-
ff_inner_dim: Optional[int] = None,
|
| 155 |
-
ff_bias: bool = True,
|
| 156 |
-
attention_out_bias: bool = True,
|
| 157 |
-
):
|
| 158 |
-
super().__init__()
|
| 159 |
-
self.only_cross_attention = only_cross_attention
|
| 160 |
-
|
| 161 |
-
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
| 162 |
-
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
| 163 |
-
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
|
| 164 |
-
self.use_layer_norm = norm_type == "layer_norm"
|
| 165 |
-
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
|
| 166 |
-
|
| 167 |
-
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
| 168 |
-
raise ValueError(
|
| 169 |
-
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
| 170 |
-
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
| 171 |
-
)
|
| 172 |
-
|
| 173 |
-
if positional_embeddings and (num_positional_embeddings is None):
|
| 174 |
-
raise ValueError(
|
| 175 |
-
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
|
| 176 |
-
)
|
| 177 |
-
|
| 178 |
-
if positional_embeddings == "sinusoidal":
|
| 179 |
-
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
|
| 180 |
-
else:
|
| 181 |
-
self.pos_embed = None
|
| 182 |
-
|
| 183 |
-
# Define 3 blocks. Each block has its own normalization layer.
|
| 184 |
-
# 1. Self-Attn
|
| 185 |
-
if self.use_ada_layer_norm:
|
| 186 |
-
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
| 187 |
-
elif self.use_ada_layer_norm_zero:
|
| 188 |
-
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
| 189 |
-
elif self.use_ada_layer_norm_continuous:
|
| 190 |
-
self.norm1 = AdaLayerNormContinuous(
|
| 191 |
-
dim,
|
| 192 |
-
ada_norm_continous_conditioning_embedding_dim,
|
| 193 |
-
norm_elementwise_affine,
|
| 194 |
-
norm_eps,
|
| 195 |
-
ada_norm_bias,
|
| 196 |
-
"rms_norm",
|
| 197 |
-
)
|
| 198 |
-
else:
|
| 199 |
-
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
| 200 |
-
|
| 201 |
-
self.attn1 = Attention(
|
| 202 |
-
query_dim=dim,
|
| 203 |
-
heads=num_attention_heads,
|
| 204 |
-
dim_head=attention_head_dim,
|
| 205 |
-
dropout=dropout,
|
| 206 |
-
bias=attention_bias,
|
| 207 |
-
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
| 208 |
-
upcast_attention=upcast_attention,
|
| 209 |
-
out_bias=attention_out_bias,
|
| 210 |
-
)
|
| 211 |
-
|
| 212 |
-
# 2. Cross-Attn
|
| 213 |
-
if cross_attention_dim is not None or double_self_attention:
|
| 214 |
-
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
| 215 |
-
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
| 216 |
-
# the second cross attention block.
|
| 217 |
-
if self.use_ada_layer_norm:
|
| 218 |
-
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
| 219 |
-
elif self.use_ada_layer_norm_continuous:
|
| 220 |
-
self.norm2 = AdaLayerNormContinuous(
|
| 221 |
-
dim,
|
| 222 |
-
ada_norm_continous_conditioning_embedding_dim,
|
| 223 |
-
norm_elementwise_affine,
|
| 224 |
-
norm_eps,
|
| 225 |
-
ada_norm_bias,
|
| 226 |
-
"rms_norm",
|
| 227 |
-
)
|
| 228 |
-
else:
|
| 229 |
-
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
| 230 |
-
|
| 231 |
-
self.attn2 = Attention(
|
| 232 |
-
query_dim=dim,
|
| 233 |
-
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
| 234 |
-
heads=num_attention_heads,
|
| 235 |
-
dim_head=attention_head_dim,
|
| 236 |
-
dropout=dropout,
|
| 237 |
-
bias=attention_bias,
|
| 238 |
-
upcast_attention=upcast_attention,
|
| 239 |
-
out_bias=attention_out_bias,
|
| 240 |
-
) # is self-attn if encoder_hidden_states is none
|
| 241 |
-
else:
|
| 242 |
-
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
| 243 |
-
self.attn2 = None
|
| 244 |
-
|
| 245 |
-
# 3. Feed-forward
|
| 246 |
-
if self.use_ada_layer_norm_continuous:
|
| 247 |
-
self.norm3 = AdaLayerNormContinuous(
|
| 248 |
-
dim,
|
| 249 |
-
ada_norm_continous_conditioning_embedding_dim,
|
| 250 |
-
norm_elementwise_affine,
|
| 251 |
-
norm_eps,
|
| 252 |
-
ada_norm_bias,
|
| 253 |
-
"layer_norm",
|
| 254 |
-
)
|
| 255 |
-
elif not self.use_ada_layer_norm_single:
|
| 256 |
-
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
| 257 |
-
|
| 258 |
-
self.ff = FeedForward(
|
| 259 |
-
dim,
|
| 260 |
-
dropout=dropout,
|
| 261 |
-
activation_fn=activation_fn,
|
| 262 |
-
final_dropout=final_dropout,
|
| 263 |
-
inner_dim=ff_inner_dim,
|
| 264 |
-
bias=ff_bias,
|
| 265 |
-
)
|
| 266 |
-
|
| 267 |
-
# 4. Fuser
|
| 268 |
-
if attention_type == "gated" or attention_type == "gated-text-image":
|
| 269 |
-
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
|
| 270 |
-
|
| 271 |
-
# 5. Scale-shift for PixArt-Alpha.
|
| 272 |
-
if self.use_ada_layer_norm_single:
|
| 273 |
-
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
|
| 274 |
-
|
| 275 |
-
# let chunk size default to None
|
| 276 |
-
self._chunk_size = None
|
| 277 |
-
self._chunk_dim = 0
|
| 278 |
-
|
| 279 |
-
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
|
| 280 |
-
# Sets chunk feed-forward
|
| 281 |
-
self._chunk_size = chunk_size
|
| 282 |
-
self._chunk_dim = dim
|
| 283 |
-
|
| 284 |
-
def forward(
|
| 285 |
-
self,
|
| 286 |
-
hidden_states: torch.FloatTensor,
|
| 287 |
-
attention_mask: Optional[torch.FloatTensor] = None,
|
| 288 |
-
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 289 |
-
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 290 |
-
timestep: Optional[torch.LongTensor] = None,
|
| 291 |
-
cross_attention_kwargs: Dict[str, Any] = None,
|
| 292 |
-
class_labels: Optional[torch.LongTensor] = None,
|
| 293 |
-
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
| 294 |
-
) -> torch.FloatTensor:
|
| 295 |
-
# Notice that normalization is always applied before the real computation in the following blocks.
|
| 296 |
-
# 0. Self-Attention
|
| 297 |
-
batch_size = hidden_states.shape[0]
|
| 298 |
-
|
| 299 |
-
if self.use_ada_layer_norm:
|
| 300 |
-
norm_hidden_states = self.norm1(hidden_states, timestep)
|
| 301 |
-
elif self.use_ada_layer_norm_zero:
|
| 302 |
-
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
| 303 |
-
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
| 304 |
-
)
|
| 305 |
-
elif self.use_layer_norm:
|
| 306 |
-
norm_hidden_states = self.norm1(hidden_states)
|
| 307 |
-
elif self.use_ada_layer_norm_continuous:
|
| 308 |
-
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
| 309 |
-
elif self.use_ada_layer_norm_single:
|
| 310 |
-
# print("Using PixArt-Alpha norm")
|
| 311 |
-
# print("time step: ", timestep.shape)
|
| 312 |
-
# print("self.scale_shift_table: ", self.scale_shift_table.shape)
|
| 313 |
-
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
| 314 |
-
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
|
| 315 |
-
).chunk(6, dim=1)
|
| 316 |
-
norm_hidden_states = self.norm1(hidden_states)
|
| 317 |
-
# print("scale_msa: ", scale_msa.shape)
|
| 318 |
-
# print("shift_msa: ", shift_msa.shape)
|
| 319 |
-
#scale_msa: torch.Size([5, 1, 1152])
|
| 320 |
-
#shift_msa: torch.Size([5, 1, 1152])
|
| 321 |
-
# exit()
|
| 322 |
-
# print("before: ", norm_hidden_states.shape)
|
| 323 |
-
#before: torch.Size([5, 3584, 1152])
|
| 324 |
-
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
| 325 |
-
# print("after: ", norm_hidden_states.shape)
|
| 326 |
-
#before: torch.Size([5, 3584, 1152])
|
| 327 |
-
# exit()
|
| 328 |
-
norm_hidden_states = norm_hidden_states.squeeze(1)
|
| 329 |
-
else:
|
| 330 |
-
raise ValueError("Incorrect norm used")
|
| 331 |
-
|
| 332 |
-
if self.pos_embed is not None:
|
| 333 |
-
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
# 1. Retrieve lora scale.
|
| 337 |
-
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
| 338 |
-
|
| 339 |
-
# 2. Prepare GLIGEN inputs
|
| 340 |
-
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
| 341 |
-
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
|
| 342 |
-
|
| 343 |
-
attn_output = self.attn1(
|
| 344 |
-
norm_hidden_states,
|
| 345 |
-
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
| 346 |
-
attention_mask=attention_mask,
|
| 347 |
-
**cross_attention_kwargs,
|
| 348 |
-
)
|
| 349 |
-
if self.use_ada_layer_norm_zero:
|
| 350 |
-
attn_output = gate_msa.unsqueeze(1) * attn_output
|
| 351 |
-
elif self.use_ada_layer_norm_single:
|
| 352 |
-
attn_output = gate_msa * attn_output
|
| 353 |
-
|
| 354 |
-
hidden_states = attn_output + hidden_states
|
| 355 |
-
if hidden_states.ndim == 4:
|
| 356 |
-
hidden_states = hidden_states.squeeze(1)
|
| 357 |
-
|
| 358 |
-
# 2.5 GLIGEN Control
|
| 359 |
-
if gligen_kwargs is not None:
|
| 360 |
-
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
|
| 361 |
-
|
| 362 |
-
# 3. Cross-Attention
|
| 363 |
-
if self.attn2 is not None:
|
| 364 |
-
if self.use_ada_layer_norm:
|
| 365 |
-
norm_hidden_states = self.norm2(hidden_states, timestep)
|
| 366 |
-
elif self.use_ada_layer_norm_zero or self.use_layer_norm:
|
| 367 |
-
norm_hidden_states = self.norm2(hidden_states)
|
| 368 |
-
elif self.use_ada_layer_norm_single:
|
| 369 |
-
# For PixArt norm2 isn't applied here:
|
| 370 |
-
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
|
| 371 |
-
norm_hidden_states = hidden_states
|
| 372 |
-
elif self.use_ada_layer_norm_continuous:
|
| 373 |
-
norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
| 374 |
-
else:
|
| 375 |
-
raise ValueError("Incorrect norm")
|
| 376 |
-
|
| 377 |
-
if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
|
| 378 |
-
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
| 379 |
-
|
| 380 |
-
attn_output = self.attn2(
|
| 381 |
-
norm_hidden_states,
|
| 382 |
-
encoder_hidden_states=encoder_hidden_states,
|
| 383 |
-
attention_mask=encoder_attention_mask,
|
| 384 |
-
**cross_attention_kwargs,
|
| 385 |
-
)
|
| 386 |
-
hidden_states = attn_output + hidden_states
|
| 387 |
-
|
| 388 |
-
# 4. Feed-forward
|
| 389 |
-
if self.use_ada_layer_norm_continuous:
|
| 390 |
-
norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
| 391 |
-
elif not self.use_ada_layer_norm_single:
|
| 392 |
-
norm_hidden_states = self.norm3(hidden_states)
|
| 393 |
-
|
| 394 |
-
if self.use_ada_layer_norm_zero:
|
| 395 |
-
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| 396 |
-
|
| 397 |
-
if self.use_ada_layer_norm_single:
|
| 398 |
-
norm_hidden_states = self.norm2(hidden_states)
|
| 399 |
-
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
| 400 |
-
|
| 401 |
-
if self._chunk_size is not None:
|
| 402 |
-
# "feed_forward_chunk_size" can be used to save memory
|
| 403 |
-
ff_output = _chunked_feed_forward(
|
| 404 |
-
self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale
|
| 405 |
-
)
|
| 406 |
-
else:
|
| 407 |
-
ff_output = self.ff(norm_hidden_states, scale=lora_scale)
|
| 408 |
-
|
| 409 |
-
if self.use_ada_layer_norm_zero:
|
| 410 |
-
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
| 411 |
-
elif self.use_ada_layer_norm_single:
|
| 412 |
-
ff_output = gate_mlp * ff_output
|
| 413 |
-
|
| 414 |
-
hidden_states = ff_output + hidden_states
|
| 415 |
-
if hidden_states.ndim == 4:
|
| 416 |
-
hidden_states = hidden_states.squeeze(1)
|
| 417 |
-
|
| 418 |
-
return hidden_states
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
@maybe_allow_in_graph
|
| 422 |
-
class TemporalBasicTransformerBlock(nn.Module):
|
| 423 |
-
r"""
|
| 424 |
-
A basic Transformer block for video like data.
|
| 425 |
-
|
| 426 |
-
Parameters:
|
| 427 |
-
dim (`int`): The number of channels in the input and output.
|
| 428 |
-
time_mix_inner_dim (`int`): The number of channels for temporal attention.
|
| 429 |
-
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
| 430 |
-
attention_head_dim (`int`): The number of channels in each head.
|
| 431 |
-
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
| 432 |
-
"""
|
| 433 |
-
|
| 434 |
-
def __init__(
|
| 435 |
-
self,
|
| 436 |
-
dim: int,
|
| 437 |
-
time_mix_inner_dim: int,
|
| 438 |
-
num_attention_heads: int,
|
| 439 |
-
attention_head_dim: int,
|
| 440 |
-
cross_attention_dim: Optional[int] = None,
|
| 441 |
-
):
|
| 442 |
-
super().__init__()
|
| 443 |
-
self.is_res = dim == time_mix_inner_dim
|
| 444 |
-
|
| 445 |
-
self.norm_in = nn.LayerNorm(dim)
|
| 446 |
-
|
| 447 |
-
# Define 3 blocks. Each block has its own normalization layer.
|
| 448 |
-
# 1. Self-Attn
|
| 449 |
-
self.norm_in = nn.LayerNorm(dim)
|
| 450 |
-
self.ff_in = FeedForward(
|
| 451 |
-
dim,
|
| 452 |
-
dim_out=time_mix_inner_dim,
|
| 453 |
-
activation_fn="geglu",
|
| 454 |
-
)
|
| 455 |
-
|
| 456 |
-
self.norm1 = nn.LayerNorm(time_mix_inner_dim)
|
| 457 |
-
self.attn1 = Attention(
|
| 458 |
-
query_dim=time_mix_inner_dim,
|
| 459 |
-
heads=num_attention_heads,
|
| 460 |
-
dim_head=attention_head_dim,
|
| 461 |
-
cross_attention_dim=None,
|
| 462 |
-
)
|
| 463 |
-
|
| 464 |
-
# 2. Cross-Attn
|
| 465 |
-
if cross_attention_dim is not None:
|
| 466 |
-
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
| 467 |
-
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
| 468 |
-
# the second cross attention block.
|
| 469 |
-
self.norm2 = nn.LayerNorm(time_mix_inner_dim)
|
| 470 |
-
self.attn2 = Attention(
|
| 471 |
-
query_dim=time_mix_inner_dim,
|
| 472 |
-
cross_attention_dim=cross_attention_dim,
|
| 473 |
-
heads=num_attention_heads,
|
| 474 |
-
dim_head=attention_head_dim,
|
| 475 |
-
) # is self-attn if encoder_hidden_states is none
|
| 476 |
-
else:
|
| 477 |
-
self.norm2 = None
|
| 478 |
-
self.attn2 = None
|
| 479 |
-
|
| 480 |
-
# 3. Feed-forward
|
| 481 |
-
self.norm3 = nn.LayerNorm(time_mix_inner_dim)
|
| 482 |
-
self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
|
| 483 |
-
|
| 484 |
-
# let chunk size default to None
|
| 485 |
-
self._chunk_size = None
|
| 486 |
-
self._chunk_dim = None
|
| 487 |
-
|
| 488 |
-
def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
|
| 489 |
-
# Sets chunk feed-forward
|
| 490 |
-
self._chunk_size = chunk_size
|
| 491 |
-
# chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
|
| 492 |
-
self._chunk_dim = 1
|
| 493 |
-
|
| 494 |
-
def forward(
|
| 495 |
-
self,
|
| 496 |
-
hidden_states: torch.FloatTensor,
|
| 497 |
-
num_frames: int,
|
| 498 |
-
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 499 |
-
) -> torch.FloatTensor:
|
| 500 |
-
# Notice that normalization is always applied before the real computation in the following blocks.
|
| 501 |
-
# 0. Self-Attention
|
| 502 |
-
batch_size = hidden_states.shape[0]
|
| 503 |
-
|
| 504 |
-
batch_frames, seq_length, channels = hidden_states.shape
|
| 505 |
-
batch_size = batch_frames // num_frames
|
| 506 |
-
|
| 507 |
-
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
|
| 508 |
-
hidden_states = hidden_states.permute(0, 2, 1, 3)
|
| 509 |
-
hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
|
| 510 |
-
|
| 511 |
-
residual = hidden_states
|
| 512 |
-
hidden_states = self.norm_in(hidden_states)
|
| 513 |
-
|
| 514 |
-
if self._chunk_size is not None:
|
| 515 |
-
hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
|
| 516 |
-
else:
|
| 517 |
-
hidden_states = self.ff_in(hidden_states)
|
| 518 |
-
|
| 519 |
-
if self.is_res:
|
| 520 |
-
hidden_states = hidden_states + residual
|
| 521 |
-
|
| 522 |
-
norm_hidden_states = self.norm1(hidden_states)
|
| 523 |
-
attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
|
| 524 |
-
hidden_states = attn_output + hidden_states
|
| 525 |
-
|
| 526 |
-
# 3. Cross-Attention
|
| 527 |
-
if self.attn2 is not None:
|
| 528 |
-
norm_hidden_states = self.norm2(hidden_states)
|
| 529 |
-
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
|
| 530 |
-
hidden_states = attn_output + hidden_states
|
| 531 |
-
|
| 532 |
-
# 4. Feed-forward
|
| 533 |
-
norm_hidden_states = self.norm3(hidden_states)
|
| 534 |
-
|
| 535 |
-
if self._chunk_size is not None:
|
| 536 |
-
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
| 537 |
-
else:
|
| 538 |
-
ff_output = self.ff(norm_hidden_states)
|
| 539 |
-
|
| 540 |
-
if self.is_res:
|
| 541 |
-
hidden_states = ff_output + hidden_states
|
| 542 |
-
else:
|
| 543 |
-
hidden_states = ff_output
|
| 544 |
-
|
| 545 |
-
hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
|
| 546 |
-
hidden_states = hidden_states.permute(0, 2, 1, 3)
|
| 547 |
-
hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
|
| 548 |
-
|
| 549 |
-
return hidden_states
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
class SkipFFTransformerBlock(nn.Module):
|
| 553 |
-
def __init__(
|
| 554 |
-
self,
|
| 555 |
-
dim: int,
|
| 556 |
-
num_attention_heads: int,
|
| 557 |
-
attention_head_dim: int,
|
| 558 |
-
kv_input_dim: int,
|
| 559 |
-
kv_input_dim_proj_use_bias: bool,
|
| 560 |
-
dropout=0.0,
|
| 561 |
-
cross_attention_dim: Optional[int] = None,
|
| 562 |
-
attention_bias: bool = False,
|
| 563 |
-
attention_out_bias: bool = True,
|
| 564 |
-
):
|
| 565 |
-
super().__init__()
|
| 566 |
-
if kv_input_dim != dim:
|
| 567 |
-
self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
|
| 568 |
-
else:
|
| 569 |
-
self.kv_mapper = None
|
| 570 |
-
|
| 571 |
-
self.norm1 = RMSNorm(dim, 1e-06)
|
| 572 |
-
|
| 573 |
-
self.attn1 = Attention(
|
| 574 |
-
query_dim=dim,
|
| 575 |
-
heads=num_attention_heads,
|
| 576 |
-
dim_head=attention_head_dim,
|
| 577 |
-
dropout=dropout,
|
| 578 |
-
bias=attention_bias,
|
| 579 |
-
cross_attention_dim=cross_attention_dim,
|
| 580 |
-
out_bias=attention_out_bias,
|
| 581 |
-
)
|
| 582 |
-
|
| 583 |
-
self.norm2 = RMSNorm(dim, 1e-06)
|
| 584 |
-
|
| 585 |
-
self.attn2 = Attention(
|
| 586 |
-
query_dim=dim,
|
| 587 |
-
cross_attention_dim=cross_attention_dim,
|
| 588 |
-
heads=num_attention_heads,
|
| 589 |
-
dim_head=attention_head_dim,
|
| 590 |
-
dropout=dropout,
|
| 591 |
-
bias=attention_bias,
|
| 592 |
-
out_bias=attention_out_bias,
|
| 593 |
-
)
|
| 594 |
-
|
| 595 |
-
def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
|
| 596 |
-
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
| 597 |
-
|
| 598 |
-
if self.kv_mapper is not None:
|
| 599 |
-
encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
|
| 600 |
-
|
| 601 |
-
norm_hidden_states = self.norm1(hidden_states)
|
| 602 |
-
|
| 603 |
-
attn_output = self.attn1(
|
| 604 |
-
norm_hidden_states,
|
| 605 |
-
encoder_hidden_states=encoder_hidden_states,
|
| 606 |
-
**cross_attention_kwargs,
|
| 607 |
-
)
|
| 608 |
-
|
| 609 |
-
hidden_states = attn_output + hidden_states
|
| 610 |
-
|
| 611 |
-
norm_hidden_states = self.norm2(hidden_states)
|
| 612 |
-
|
| 613 |
-
attn_output = self.attn2(
|
| 614 |
-
norm_hidden_states,
|
| 615 |
-
encoder_hidden_states=encoder_hidden_states,
|
| 616 |
-
**cross_attention_kwargs,
|
| 617 |
-
)
|
| 618 |
-
|
| 619 |
-
hidden_states = attn_output + hidden_states
|
| 620 |
-
|
| 621 |
-
return hidden_states
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
class FeedForward(nn.Module):
|
| 625 |
-
r"""
|
| 626 |
-
A feed-forward layer.
|
| 627 |
-
|
| 628 |
-
Parameters:
|
| 629 |
-
dim (`int`): The number of channels in the input.
|
| 630 |
-
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
| 631 |
-
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
| 632 |
-
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
| 633 |
-
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
| 634 |
-
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
| 635 |
-
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
|
| 636 |
-
"""
|
| 637 |
-
|
| 638 |
-
def __init__(
|
| 639 |
-
self,
|
| 640 |
-
dim: int,
|
| 641 |
-
dim_out: Optional[int] = None,
|
| 642 |
-
mult: int = 4,
|
| 643 |
-
dropout: float = 0.0,
|
| 644 |
-
activation_fn: str = "geglu",
|
| 645 |
-
final_dropout: bool = False,
|
| 646 |
-
inner_dim=None,
|
| 647 |
-
bias: bool = True,
|
| 648 |
-
):
|
| 649 |
-
super().__init__()
|
| 650 |
-
if inner_dim is None:
|
| 651 |
-
inner_dim = int(dim * mult)
|
| 652 |
-
dim_out = dim_out if dim_out is not None else dim
|
| 653 |
-
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
|
| 654 |
-
|
| 655 |
-
if activation_fn == "gelu":
|
| 656 |
-
act_fn = GELU(dim, inner_dim, bias=bias)
|
| 657 |
-
if activation_fn == "gelu-approximate":
|
| 658 |
-
act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
|
| 659 |
-
elif activation_fn == "geglu":
|
| 660 |
-
act_fn = GEGLU(dim, inner_dim, bias=bias)
|
| 661 |
-
elif activation_fn == "geglu-approximate":
|
| 662 |
-
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
|
| 663 |
-
|
| 664 |
-
self.net = nn.ModuleList([])
|
| 665 |
-
# project in
|
| 666 |
-
self.net.append(act_fn)
|
| 667 |
-
# project dropout
|
| 668 |
-
self.net.append(nn.Dropout(dropout))
|
| 669 |
-
# project out
|
| 670 |
-
self.net.append(linear_cls(inner_dim, dim_out, bias=bias))
|
| 671 |
-
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
| 672 |
-
if final_dropout:
|
| 673 |
-
self.net.append(nn.Dropout(dropout))
|
| 674 |
-
|
| 675 |
-
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
|
| 676 |
-
compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
|
| 677 |
-
for module in self.net:
|
| 678 |
-
if isinstance(module, compatible_cls):
|
| 679 |
-
hidden_states = module(hidden_states, scale)
|
| 680 |
-
else:
|
| 681 |
-
hidden_states = module(hidden_states)
|
| 682 |
-
return hidden_states
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MuCodec/models/transformer_2d_flow.py
DELETED
|
@@ -1,545 +0,0 @@
|
|
| 1 |
-
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
from dataclasses import dataclass
|
| 15 |
-
import math
|
| 16 |
-
from typing import Any, Dict, Optional, Tuple
|
| 17 |
-
|
| 18 |
-
import torch
|
| 19 |
-
import torch.nn.functional as F
|
| 20 |
-
from torch import nn
|
| 21 |
-
|
| 22 |
-
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 23 |
-
from diffusers.models.embeddings import ImagePositionalEmbeddings
|
| 24 |
-
from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
|
| 25 |
-
from models.attention import BasicTransformerBlock
|
| 26 |
-
from diffusers.models.embeddings import PatchEmbed, PixArtAlphaTextProjection
|
| 27 |
-
from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
|
| 28 |
-
from diffusers.models.modeling_utils import ModelMixin
|
| 29 |
-
from diffusers.models.embeddings import TimestepEmbedding
|
| 30 |
-
|
| 31 |
-
class PixArtAlphaCombinedFlowEmbeddings(nn.Module):
|
| 32 |
-
"""
|
| 33 |
-
For PixArt-Alpha.
|
| 34 |
-
|
| 35 |
-
Reference:
|
| 36 |
-
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
|
| 37 |
-
"""
|
| 38 |
-
|
| 39 |
-
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
|
| 40 |
-
super().__init__()
|
| 41 |
-
|
| 42 |
-
self.flow_t_size = 512
|
| 43 |
-
self.outdim = size_emb_dim
|
| 44 |
-
self.timestep_embedder = TimestepEmbedding(in_channels=self.flow_t_size, time_embed_dim=embedding_dim)
|
| 45 |
-
|
| 46 |
-
self.use_additional_conditions = use_additional_conditions
|
| 47 |
-
if use_additional_conditions:
|
| 48 |
-
self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
| 49 |
-
self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
|
| 50 |
-
self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
|
| 51 |
-
|
| 52 |
-
# https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/models/unet/nn.py#L87
|
| 53 |
-
def timestep_embedding(self, timesteps, max_period=10000, scale=1000):
|
| 54 |
-
"""Create sinusoidal timestep embeddings.
|
| 55 |
-
|
| 56 |
-
:param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
| 57 |
-
:param dim: the dimension of the output.
|
| 58 |
-
:param max_period: controls the minimum frequency of the embeddings.
|
| 59 |
-
:return: an [N x dim] Tensor of positional embeddings.
|
| 60 |
-
"""
|
| 61 |
-
half = self.flow_t_size // 2
|
| 62 |
-
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, device=timesteps.device) / half).type(timesteps.type())
|
| 63 |
-
args = timesteps[:, None] * freqs[None] * scale
|
| 64 |
-
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 65 |
-
if self.flow_t_size % 2:
|
| 66 |
-
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 67 |
-
return embedding
|
| 68 |
-
|
| 69 |
-
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
|
| 70 |
-
timesteps_proj = self.timestep_embedding(timestep)
|
| 71 |
-
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
| 72 |
-
|
| 73 |
-
if self.use_additional_conditions:
|
| 74 |
-
resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
|
| 75 |
-
resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
|
| 76 |
-
aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype)
|
| 77 |
-
aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1)
|
| 78 |
-
conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1)
|
| 79 |
-
else:
|
| 80 |
-
conditioning = timesteps_emb
|
| 81 |
-
|
| 82 |
-
return conditioning
|
| 83 |
-
|
| 84 |
-
class AdaLayerNormSingleFlow(nn.Module):
|
| 85 |
-
r"""
|
| 86 |
-
Norm layer adaptive layer norm single (adaLN-single).
|
| 87 |
-
|
| 88 |
-
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
|
| 89 |
-
|
| 90 |
-
Parameters:
|
| 91 |
-
embedding_dim (`int`): The size of each embedding vector.
|
| 92 |
-
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
| 93 |
-
"""
|
| 94 |
-
|
| 95 |
-
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
|
| 96 |
-
super().__init__()
|
| 97 |
-
|
| 98 |
-
self.emb = PixArtAlphaCombinedFlowEmbeddings(
|
| 99 |
-
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
|
| 100 |
-
)
|
| 101 |
-
|
| 102 |
-
self.silu = nn.SiLU()
|
| 103 |
-
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
| 104 |
-
|
| 105 |
-
def forward(
|
| 106 |
-
self,
|
| 107 |
-
timestep: torch.Tensor,
|
| 108 |
-
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
| 109 |
-
batch_size: Optional[int] = None,
|
| 110 |
-
hidden_dtype: Optional[torch.dtype] = None,
|
| 111 |
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 112 |
-
# No modulation happening here.
|
| 113 |
-
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
|
| 114 |
-
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
@dataclass
|
| 118 |
-
class Transformer2DModelOutput(BaseOutput):
|
| 119 |
-
"""
|
| 120 |
-
The output of [`Transformer2DModel`].
|
| 121 |
-
|
| 122 |
-
Args:
|
| 123 |
-
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
| 124 |
-
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
|
| 125 |
-
distributions for the unnoised latent pixels.
|
| 126 |
-
"""
|
| 127 |
-
|
| 128 |
-
sample: torch.FloatTensor
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
class Transformer2DModel(ModelMixin, ConfigMixin):
|
| 132 |
-
"""
|
| 133 |
-
A 2D Transformer model for image-like data.
|
| 134 |
-
|
| 135 |
-
Parameters:
|
| 136 |
-
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
| 137 |
-
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
| 138 |
-
in_channels (`int`, *optional*):
|
| 139 |
-
The number of channels in the input and output (specify if the input is **continuous**).
|
| 140 |
-
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
| 141 |
-
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
| 142 |
-
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
| 143 |
-
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
| 144 |
-
This is fixed during training since it is used to learn a number of position embeddings.
|
| 145 |
-
num_vector_embeds (`int`, *optional*):
|
| 146 |
-
The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
|
| 147 |
-
Includes the class for the masked latent pixel.
|
| 148 |
-
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
|
| 149 |
-
num_embeds_ada_norm ( `int`, *optional*):
|
| 150 |
-
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
|
| 151 |
-
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
|
| 152 |
-
added to the hidden states.
|
| 153 |
-
|
| 154 |
-
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
|
| 155 |
-
attention_bias (`bool`, *optional*):
|
| 156 |
-
Configure if the `TransformerBlocks` attention should contain a bias parameter.
|
| 157 |
-
"""
|
| 158 |
-
|
| 159 |
-
_supports_gradient_checkpointing = True
|
| 160 |
-
|
| 161 |
-
@register_to_config
|
| 162 |
-
def __init__(
|
| 163 |
-
self,
|
| 164 |
-
num_attention_heads: int = 16,
|
| 165 |
-
attention_head_dim: int = 88,
|
| 166 |
-
in_channels: Optional[int] = None,
|
| 167 |
-
out_channels: Optional[int] = None,
|
| 168 |
-
num_layers: int = 1,
|
| 169 |
-
dropout: float = 0.0,
|
| 170 |
-
norm_num_groups: int = 32,
|
| 171 |
-
cross_attention_dim: Optional[int] = None,
|
| 172 |
-
attention_bias: bool = False,
|
| 173 |
-
sample_size: Optional[int] = None,
|
| 174 |
-
num_vector_embeds: Optional[int] = None,
|
| 175 |
-
patch_size: Optional[int] = None,
|
| 176 |
-
activation_fn: str = "geglu",
|
| 177 |
-
num_embeds_ada_norm: Optional[int] = None,
|
| 178 |
-
use_linear_projection: bool = False,
|
| 179 |
-
only_cross_attention: bool = False,
|
| 180 |
-
double_self_attention: bool = False,
|
| 181 |
-
upcast_attention: bool = False,
|
| 182 |
-
norm_type: str = "layer_norm",
|
| 183 |
-
norm_elementwise_affine: bool = True,
|
| 184 |
-
norm_eps: float = 1e-5,
|
| 185 |
-
attention_type: str = "default",
|
| 186 |
-
caption_channels: int = None,
|
| 187 |
-
):
|
| 188 |
-
super().__init__()
|
| 189 |
-
self.use_linear_projection = use_linear_projection
|
| 190 |
-
self.num_attention_heads = num_attention_heads
|
| 191 |
-
self.attention_head_dim = attention_head_dim
|
| 192 |
-
inner_dim = num_attention_heads * attention_head_dim
|
| 193 |
-
|
| 194 |
-
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
| 195 |
-
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
| 196 |
-
|
| 197 |
-
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
| 198 |
-
# Define whether input is continuous or discrete depending on configuration
|
| 199 |
-
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
|
| 200 |
-
self.is_input_vectorized = num_vector_embeds is not None
|
| 201 |
-
self.is_input_patches = in_channels is not None and patch_size is not None
|
| 202 |
-
|
| 203 |
-
if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
|
| 204 |
-
deprecation_message = (
|
| 205 |
-
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
|
| 206 |
-
" incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
|
| 207 |
-
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
|
| 208 |
-
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
|
| 209 |
-
" would be very nice if you could open a Pull request for the `transformer/config.json` file"
|
| 210 |
-
)
|
| 211 |
-
deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
|
| 212 |
-
norm_type = "ada_norm"
|
| 213 |
-
|
| 214 |
-
if self.is_input_continuous and self.is_input_vectorized:
|
| 215 |
-
raise ValueError(
|
| 216 |
-
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
|
| 217 |
-
" sure that either `in_channels` or `num_vector_embeds` is None."
|
| 218 |
-
)
|
| 219 |
-
elif self.is_input_vectorized and self.is_input_patches:
|
| 220 |
-
raise ValueError(
|
| 221 |
-
f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
|
| 222 |
-
" sure that either `num_vector_embeds` or `num_patches` is None."
|
| 223 |
-
)
|
| 224 |
-
elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
|
| 225 |
-
raise ValueError(
|
| 226 |
-
f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
|
| 227 |
-
f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
|
| 228 |
-
)
|
| 229 |
-
|
| 230 |
-
# 2. Define input layers
|
| 231 |
-
if self.is_input_continuous:
|
| 232 |
-
self.in_channels = in_channels
|
| 233 |
-
|
| 234 |
-
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
| 235 |
-
if use_linear_projection:
|
| 236 |
-
self.proj_in = linear_cls(in_channels, inner_dim)
|
| 237 |
-
else:
|
| 238 |
-
self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
| 239 |
-
elif self.is_input_vectorized:
|
| 240 |
-
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
|
| 241 |
-
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
|
| 242 |
-
|
| 243 |
-
self.height = sample_size
|
| 244 |
-
self.width = sample_size
|
| 245 |
-
self.num_vector_embeds = num_vector_embeds
|
| 246 |
-
self.num_latent_pixels = self.height * self.width
|
| 247 |
-
|
| 248 |
-
self.latent_image_embedding = ImagePositionalEmbeddings(
|
| 249 |
-
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
|
| 250 |
-
)
|
| 251 |
-
elif self.is_input_patches:
|
| 252 |
-
assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
|
| 253 |
-
|
| 254 |
-
self.height = sample_size
|
| 255 |
-
self.width = sample_size
|
| 256 |
-
|
| 257 |
-
self.patch_size = patch_size
|
| 258 |
-
interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
|
| 259 |
-
interpolation_scale = max(interpolation_scale, 1)
|
| 260 |
-
self.pos_embed = PatchEmbed(
|
| 261 |
-
height=sample_size,
|
| 262 |
-
width=sample_size,
|
| 263 |
-
patch_size=patch_size,
|
| 264 |
-
in_channels=in_channels,
|
| 265 |
-
embed_dim=inner_dim,
|
| 266 |
-
interpolation_scale=interpolation_scale,
|
| 267 |
-
)
|
| 268 |
-
|
| 269 |
-
# 3. Define transformers blocks
|
| 270 |
-
self.transformer_blocks = nn.ModuleList(
|
| 271 |
-
[
|
| 272 |
-
BasicTransformerBlock(
|
| 273 |
-
inner_dim,
|
| 274 |
-
num_attention_heads,
|
| 275 |
-
attention_head_dim,
|
| 276 |
-
dropout=dropout,
|
| 277 |
-
cross_attention_dim=cross_attention_dim,
|
| 278 |
-
activation_fn=activation_fn,
|
| 279 |
-
num_embeds_ada_norm=num_embeds_ada_norm,
|
| 280 |
-
attention_bias=attention_bias,
|
| 281 |
-
only_cross_attention=only_cross_attention,
|
| 282 |
-
double_self_attention=double_self_attention,
|
| 283 |
-
upcast_attention=upcast_attention,
|
| 284 |
-
norm_type=norm_type,
|
| 285 |
-
norm_elementwise_affine=norm_elementwise_affine,
|
| 286 |
-
norm_eps=norm_eps,
|
| 287 |
-
attention_type=attention_type,
|
| 288 |
-
)
|
| 289 |
-
for d in range(num_layers)
|
| 290 |
-
]
|
| 291 |
-
)
|
| 292 |
-
|
| 293 |
-
# 4. Define output layers
|
| 294 |
-
self.out_channels = in_channels if out_channels is None else out_channels
|
| 295 |
-
if self.is_input_continuous:
|
| 296 |
-
# TODO: should use out_channels for continuous projections
|
| 297 |
-
if use_linear_projection:
|
| 298 |
-
self.proj_out = linear_cls(inner_dim, in_channels)
|
| 299 |
-
else:
|
| 300 |
-
self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
| 301 |
-
elif self.is_input_vectorized:
|
| 302 |
-
self.norm_out = nn.LayerNorm(inner_dim)
|
| 303 |
-
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
|
| 304 |
-
elif self.is_input_patches and norm_type != "ada_norm_single":
|
| 305 |
-
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
| 306 |
-
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
|
| 307 |
-
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
|
| 308 |
-
elif self.is_input_patches and norm_type == "ada_norm_single":
|
| 309 |
-
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
| 310 |
-
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
|
| 311 |
-
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
|
| 312 |
-
|
| 313 |
-
# 5. PixArt-Alpha blocks.
|
| 314 |
-
self.adaln_single = None
|
| 315 |
-
self.use_additional_conditions = False
|
| 316 |
-
if norm_type == "ada_norm_single":
|
| 317 |
-
self.use_additional_conditions = self.config.sample_size == 128
|
| 318 |
-
# TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
|
| 319 |
-
# additional conditions until we find better name
|
| 320 |
-
self.adaln_single = AdaLayerNormSingleFlow(inner_dim, use_additional_conditions=self.use_additional_conditions)
|
| 321 |
-
|
| 322 |
-
self.caption_projection = None
|
| 323 |
-
if caption_channels is not None:
|
| 324 |
-
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
| 325 |
-
|
| 326 |
-
self.gradient_checkpointing = False
|
| 327 |
-
|
| 328 |
-
def _set_gradient_checkpointing(self, module, value=False):
|
| 329 |
-
if hasattr(module, "gradient_checkpointing"):
|
| 330 |
-
module.gradient_checkpointing = value
|
| 331 |
-
|
| 332 |
-
def forward(
|
| 333 |
-
self,
|
| 334 |
-
hidden_states: torch.Tensor,
|
| 335 |
-
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 336 |
-
timestep: Optional[torch.LongTensor] = None,
|
| 337 |
-
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
| 338 |
-
class_labels: Optional[torch.LongTensor] = None,
|
| 339 |
-
cross_attention_kwargs: Dict[str, Any] = None,
|
| 340 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 341 |
-
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 342 |
-
return_dict: bool = True,
|
| 343 |
-
):
|
| 344 |
-
"""
|
| 345 |
-
The [`Transformer2DModel`] forward method.
|
| 346 |
-
|
| 347 |
-
Args:
|
| 348 |
-
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
| 349 |
-
Input `hidden_states`.
|
| 350 |
-
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
| 351 |
-
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
| 352 |
-
self-attention.
|
| 353 |
-
timestep ( `torch.LongTensor`, *optional*):
|
| 354 |
-
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
| 355 |
-
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
| 356 |
-
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
| 357 |
-
`AdaLayerZeroNorm`.
|
| 358 |
-
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
|
| 359 |
-
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 360 |
-
`self.processor` in
|
| 361 |
-
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 362 |
-
attention_mask ( `torch.Tensor`, *optional*):
|
| 363 |
-
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
| 364 |
-
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
| 365 |
-
negative values to the attention scores corresponding to "discard" tokens.
|
| 366 |
-
encoder_attention_mask ( `torch.Tensor`, *optional*):
|
| 367 |
-
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
| 368 |
-
|
| 369 |
-
* Mask `(batch, sequence_length)` True = keep, False = discard.
|
| 370 |
-
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
|
| 371 |
-
|
| 372 |
-
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
| 373 |
-
above. This bias will be added to the cross-attention scores.
|
| 374 |
-
return_dict (`bool`, *optional*, defaults to `True`):
|
| 375 |
-
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
| 376 |
-
tuple.
|
| 377 |
-
|
| 378 |
-
Returns:
|
| 379 |
-
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
| 380 |
-
`tuple` where the first element is the sample tensor.
|
| 381 |
-
"""
|
| 382 |
-
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
| 383 |
-
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
| 384 |
-
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
| 385 |
-
# expects mask of shape:
|
| 386 |
-
# [batch, key_tokens]
|
| 387 |
-
# adds singleton query_tokens dimension:
|
| 388 |
-
# [batch, 1, key_tokens]
|
| 389 |
-
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
| 390 |
-
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
| 391 |
-
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
| 392 |
-
if attention_mask is not None and attention_mask.ndim == 2:
|
| 393 |
-
# assume that mask is expressed as:
|
| 394 |
-
# (1 = keep, 0 = discard)
|
| 395 |
-
# convert mask into a bias that can be added to attention scores:
|
| 396 |
-
# (keep = +0, discard = -10000.0)
|
| 397 |
-
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
| 398 |
-
attention_mask = attention_mask.unsqueeze(1)
|
| 399 |
-
|
| 400 |
-
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
| 401 |
-
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
| 402 |
-
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
| 403 |
-
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
| 404 |
-
|
| 405 |
-
# Retrieve lora scale.
|
| 406 |
-
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
| 407 |
-
|
| 408 |
-
# 1. Input
|
| 409 |
-
if self.is_input_continuous:
|
| 410 |
-
batch, _, height, width = hidden_states.shape
|
| 411 |
-
residual = hidden_states
|
| 412 |
-
|
| 413 |
-
hidden_states = self.norm(hidden_states)
|
| 414 |
-
if not self.use_linear_projection:
|
| 415 |
-
hidden_states = (
|
| 416 |
-
self.proj_in(hidden_states, scale=lora_scale)
|
| 417 |
-
if not USE_PEFT_BACKEND
|
| 418 |
-
else self.proj_in(hidden_states)
|
| 419 |
-
)
|
| 420 |
-
inner_dim = hidden_states.shape[1]
|
| 421 |
-
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
| 422 |
-
else:
|
| 423 |
-
inner_dim = hidden_states.shape[1]
|
| 424 |
-
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
| 425 |
-
hidden_states = (
|
| 426 |
-
self.proj_in(hidden_states, scale=lora_scale)
|
| 427 |
-
if not USE_PEFT_BACKEND
|
| 428 |
-
else self.proj_in(hidden_states)
|
| 429 |
-
)
|
| 430 |
-
|
| 431 |
-
elif self.is_input_vectorized:
|
| 432 |
-
hidden_states = self.latent_image_embedding(hidden_states)
|
| 433 |
-
elif self.is_input_patches:
|
| 434 |
-
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
|
| 435 |
-
hidden_states = self.pos_embed(hidden_states)
|
| 436 |
-
|
| 437 |
-
if self.adaln_single is not None:
|
| 438 |
-
if self.use_additional_conditions and added_cond_kwargs is None:
|
| 439 |
-
raise ValueError(
|
| 440 |
-
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
|
| 441 |
-
)
|
| 442 |
-
batch_size = hidden_states.shape[0]
|
| 443 |
-
timestep, embedded_timestep = self.adaln_single(
|
| 444 |
-
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
| 445 |
-
)
|
| 446 |
-
|
| 447 |
-
# 2. Blocks
|
| 448 |
-
if self.caption_projection is not None:
|
| 449 |
-
batch_size = hidden_states.shape[0]
|
| 450 |
-
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
| 451 |
-
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
| 452 |
-
|
| 453 |
-
for block in self.transformer_blocks:
|
| 454 |
-
if self.training and self.gradient_checkpointing:
|
| 455 |
-
|
| 456 |
-
def create_custom_forward(module, return_dict=None):
|
| 457 |
-
def custom_forward(*inputs):
|
| 458 |
-
if return_dict is not None:
|
| 459 |
-
return module(*inputs, return_dict=return_dict)
|
| 460 |
-
else:
|
| 461 |
-
return module(*inputs)
|
| 462 |
-
|
| 463 |
-
return custom_forward
|
| 464 |
-
|
| 465 |
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 466 |
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 467 |
-
create_custom_forward(block),
|
| 468 |
-
hidden_states,
|
| 469 |
-
attention_mask,
|
| 470 |
-
encoder_hidden_states,
|
| 471 |
-
encoder_attention_mask,
|
| 472 |
-
timestep,
|
| 473 |
-
cross_attention_kwargs,
|
| 474 |
-
class_labels,
|
| 475 |
-
**ckpt_kwargs,
|
| 476 |
-
)
|
| 477 |
-
else:
|
| 478 |
-
hidden_states = block(
|
| 479 |
-
hidden_states,
|
| 480 |
-
attention_mask=attention_mask,
|
| 481 |
-
encoder_hidden_states=encoder_hidden_states,
|
| 482 |
-
encoder_attention_mask=encoder_attention_mask,
|
| 483 |
-
timestep=timestep,
|
| 484 |
-
cross_attention_kwargs=cross_attention_kwargs,
|
| 485 |
-
class_labels=class_labels,
|
| 486 |
-
)
|
| 487 |
-
|
| 488 |
-
# 3. Output
|
| 489 |
-
if self.is_input_continuous:
|
| 490 |
-
if not self.use_linear_projection:
|
| 491 |
-
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
| 492 |
-
hidden_states = (
|
| 493 |
-
self.proj_out(hidden_states, scale=lora_scale)
|
| 494 |
-
if not USE_PEFT_BACKEND
|
| 495 |
-
else self.proj_out(hidden_states)
|
| 496 |
-
)
|
| 497 |
-
else:
|
| 498 |
-
hidden_states = (
|
| 499 |
-
self.proj_out(hidden_states, scale=lora_scale)
|
| 500 |
-
if not USE_PEFT_BACKEND
|
| 501 |
-
else self.proj_out(hidden_states)
|
| 502 |
-
)
|
| 503 |
-
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
| 504 |
-
|
| 505 |
-
output = hidden_states + residual
|
| 506 |
-
elif self.is_input_vectorized:
|
| 507 |
-
hidden_states = self.norm_out(hidden_states)
|
| 508 |
-
logits = self.out(hidden_states)
|
| 509 |
-
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
|
| 510 |
-
logits = logits.permute(0, 2, 1)
|
| 511 |
-
|
| 512 |
-
# log(p(x_0))
|
| 513 |
-
output = F.log_softmax(logits.double(), dim=1).float()
|
| 514 |
-
|
| 515 |
-
if self.is_input_patches:
|
| 516 |
-
if self.config.norm_type != "ada_norm_single":
|
| 517 |
-
conditioning = self.transformer_blocks[0].norm1.emb(
|
| 518 |
-
timestep, class_labels, hidden_dtype=hidden_states.dtype
|
| 519 |
-
)
|
| 520 |
-
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
| 521 |
-
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
| 522 |
-
hidden_states = self.proj_out_2(hidden_states)
|
| 523 |
-
elif self.config.norm_type == "ada_norm_single":
|
| 524 |
-
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
|
| 525 |
-
hidden_states = self.norm_out(hidden_states)
|
| 526 |
-
# Modulation
|
| 527 |
-
hidden_states = hidden_states * (1 + scale) + shift
|
| 528 |
-
hidden_states = self.proj_out(hidden_states)
|
| 529 |
-
hidden_states = hidden_states.squeeze(1)
|
| 530 |
-
|
| 531 |
-
# unpatchify
|
| 532 |
-
if self.adaln_single is None:
|
| 533 |
-
height = width = int(hidden_states.shape[1] ** 0.5)
|
| 534 |
-
hidden_states = hidden_states.reshape(
|
| 535 |
-
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
| 536 |
-
)
|
| 537 |
-
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
| 538 |
-
output = hidden_states.reshape(
|
| 539 |
-
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
|
| 540 |
-
)
|
| 541 |
-
|
| 542 |
-
if not return_dict:
|
| 543 |
-
return (output,)
|
| 544 |
-
|
| 545 |
-
return Transformer2DModelOutput(sample=output)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MuCodec/muq_dev/muq_fairseq/data/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
from .mert_dataset import MERTDataset
|
|
|
|
|
|
MuCodec/muq_dev/muq_fairseq/data/ark_dataset.py
DELETED
|
@@ -1,71 +0,0 @@
|
|
| 1 |
-
import logging
|
| 2 |
-
import torch
|
| 3 |
-
import torch.nn.functional as F
|
| 4 |
-
from fairseq.data.audio.raw_audio_dataset import RawAudioDataset
|
| 5 |
-
from typing import Tuple
|
| 6 |
-
try:
|
| 7 |
-
import kaldiio
|
| 8 |
-
except:
|
| 9 |
-
kaldiio = None
|
| 10 |
-
import warnings
|
| 11 |
-
|
| 12 |
-
logger = logging.getLogger(__name__)
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
class ArkDataset(RawAudioDataset):
|
| 16 |
-
def __init__(
|
| 17 |
-
self,
|
| 18 |
-
wav_scp,
|
| 19 |
-
dur_scp,
|
| 20 |
-
sr = 24000,
|
| 21 |
-
max_dur = 20,
|
| 22 |
-
num_buckets=0,
|
| 23 |
-
normalize=False,
|
| 24 |
-
):
|
| 25 |
-
super().__init__(
|
| 26 |
-
sample_rate=sr,
|
| 27 |
-
max_sample_size=max_dur*sr,
|
| 28 |
-
min_sample_size=1200,
|
| 29 |
-
shuffle=True,
|
| 30 |
-
pad=True,
|
| 31 |
-
normalize=normalize,
|
| 32 |
-
compute_mask=False,
|
| 33 |
-
)
|
| 34 |
-
self.sr = sr
|
| 35 |
-
self.max_dur = max_dur
|
| 36 |
-
self.normalize = normalize
|
| 37 |
-
|
| 38 |
-
logger.info("Loading Kaldi scp files from {}".format(wav_scp))
|
| 39 |
-
|
| 40 |
-
self.wav_data = kaldiio.load_scp(wav_scp)
|
| 41 |
-
self.keys = list(self.wav_data.keys())
|
| 42 |
-
dur_data = {}
|
| 43 |
-
keys_set = set(self.keys)
|
| 44 |
-
|
| 45 |
-
with open(dur_scp, 'r') as f:
|
| 46 |
-
for line in f:
|
| 47 |
-
line = line.strip().split()
|
| 48 |
-
if line[0] in keys_set:
|
| 49 |
-
dur_data[line[0]] = float(line[-1])
|
| 50 |
-
self.sizes = [int(dur_data[k]*self.sr/100) for k in self.keys]
|
| 51 |
-
|
| 52 |
-
logger.info("Loading Kaldi scp files done")
|
| 53 |
-
|
| 54 |
-
self.dataset_len = len(self.keys)
|
| 55 |
-
self.set_bucket_info(num_buckets)
|
| 56 |
-
|
| 57 |
-
def __len__(self):
|
| 58 |
-
return self.dataset_len
|
| 59 |
-
|
| 60 |
-
def __getitem__(self, idx):
|
| 61 |
-
pass
|
| 62 |
-
|
| 63 |
-
def size(self, idx):
|
| 64 |
-
pass
|
| 65 |
-
|
| 66 |
-
def postprocess(self, wav):
|
| 67 |
-
pass
|
| 68 |
-
|
| 69 |
-
def collater(self, samples):
|
| 70 |
-
pass
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MuCodec/muq_dev/muq_fairseq/data/mert_dataset.py
DELETED
|
@@ -1,295 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the MIT license found in the
|
| 4 |
-
# LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
import itertools
|
| 7 |
-
import logging
|
| 8 |
-
import os
|
| 9 |
-
import sys
|
| 10 |
-
from typing import Any, List, Optional, Union
|
| 11 |
-
|
| 12 |
-
import numpy as np
|
| 13 |
-
from typing import Tuple
|
| 14 |
-
import torch
|
| 15 |
-
import torch.nn.functional as F
|
| 16 |
-
from fairseq.data import data_utils
|
| 17 |
-
from fairseq.data.fairseq_dataset import FairseqDataset
|
| 18 |
-
from fairseq.data.audio.audio_utils import (
|
| 19 |
-
parse_path,
|
| 20 |
-
read_from_stored_zip,
|
| 21 |
-
)
|
| 22 |
-
|
| 23 |
-
import math
|
| 24 |
-
import io
|
| 25 |
-
import torchaudio
|
| 26 |
-
# this is in the user_dir
|
| 27 |
-
from nnAudio import features as nnAudioFeatures
|
| 28 |
-
|
| 29 |
-
# from tqdm import tqdm
|
| 30 |
-
import tqdm
|
| 31 |
-
import json
|
| 32 |
-
import random
|
| 33 |
-
import traceback
|
| 34 |
-
from einops import rearrange
|
| 35 |
-
# from scripts.prepare_codecs_from_manifest import *
|
| 36 |
-
|
| 37 |
-
logger = logging.getLogger(__name__)
|
| 38 |
-
|
| 39 |
-
class model_cqt_pred(torch.nn.Module):
|
| 40 |
-
def __init__(self, n_bins=84, sr=16000, freq=50):
|
| 41 |
-
super().__init__()
|
| 42 |
-
self.epsilon=1e-10
|
| 43 |
-
# Getting Mel Spectrogram on the fly
|
| 44 |
-
self.spec_layer = nnAudioFeatures.cqt.CQT(sr=sr, hop_length=sr//freq, fmin=32.7,
|
| 45 |
-
fmax=None, n_bins=n_bins, bins_per_octave=n_bins//7,
|
| 46 |
-
filter_scale=1, norm=1, window='hann', center=True,
|
| 47 |
-
pad_mode='constant', trainable=False,
|
| 48 |
-
output_format='Magnitude', verbose=True)
|
| 49 |
-
|
| 50 |
-
# self.fc = nn.Linear(input_dim, n_bins)
|
| 51 |
-
|
| 52 |
-
# self.criterion = nn.MSELoss()
|
| 53 |
-
self.forward_dict = {
|
| 54 |
-
# 'masked_transformer_output': self.plain_forward
|
| 55 |
-
'compute_cqt': self.compute_cqt
|
| 56 |
-
}
|
| 57 |
-
def compute_cqt(self, x):
|
| 58 |
-
'''
|
| 59 |
-
convert waveform to CQT -> [batch, bins, len] -> transpose
|
| 60 |
-
'''
|
| 61 |
-
# align with the padding of HuBERT model,
|
| 62 |
-
# the truncation is calculated by bruteforce search since the nnAudio padding strategy and fairseq models are different
|
| 63 |
-
# x = x[..., :-560]
|
| 64 |
-
return torch.transpose(self.spec_layer(x), -1, -2)
|
| 65 |
-
|
| 66 |
-
def forward(self, x, forward_type='masked_transformer_output'):
|
| 67 |
-
'''
|
| 68 |
-
take input from transformer hidden states: [batch, len_seq, channel]
|
| 69 |
-
output: [batch, len_seq, n_bins]
|
| 70 |
-
'''
|
| 71 |
-
|
| 72 |
-
return self.forward_dict[forward_type](x)
|
| 73 |
-
|
| 74 |
-
def load_audio_by_json(json_path, max_keep, min_keep, tgt_sample_rate, clip_secs=5):
|
| 75 |
-
# read json file
|
| 76 |
-
print(json_path)
|
| 77 |
-
datas = []
|
| 78 |
-
inds = []
|
| 79 |
-
sizes = []
|
| 80 |
-
with open(json_path) as fp:
|
| 81 |
-
for ind,line in enumerate(fp):
|
| 82 |
-
data = json.loads(line)
|
| 83 |
-
if 'duration' in data and min_keep is not None and tgt_sample_rate*data['duration'] < min_keep:
|
| 84 |
-
continue
|
| 85 |
-
datas.append(data)
|
| 86 |
-
inds.append(ind)
|
| 87 |
-
# sz = int(data['duration'] * data['sample_rate'])
|
| 88 |
-
if clip_secs > 0:
|
| 89 |
-
sz = int(tgt_sample_rate * clip_secs)
|
| 90 |
-
else:
|
| 91 |
-
sz = int(tgt_sample_rate * data['duration'])
|
| 92 |
-
sizes.append(sz)
|
| 93 |
-
tot = ind + 1
|
| 94 |
-
return datas,inds,tot,sizes
|
| 95 |
-
def load_audio(manifest_path, max_keep, min_keep):
|
| 96 |
-
pass
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
def load_label(label_path, inds, tot):
|
| 100 |
-
pass
|
| 101 |
-
|
| 102 |
-
def load_numpy_label(label_path, inds, tot):
|
| 103 |
-
labels = np.load(label_path, mmap_mode='r')
|
| 104 |
-
assert (labels.shape[0] == tot), f"number of labels does not match ({labels.shape[0]} != {tot})"
|
| 105 |
-
return labels
|
| 106 |
-
|
| 107 |
-
def verify_label_lengths(
|
| 108 |
-
audio_sizes,
|
| 109 |
-
audio_rate,
|
| 110 |
-
label_path,
|
| 111 |
-
label_rate,
|
| 112 |
-
inds,
|
| 113 |
-
tot,
|
| 114 |
-
tol=0.1, # tolerance in seconds
|
| 115 |
-
):
|
| 116 |
-
pass
|
| 117 |
-
|
| 118 |
-
class Read_and_PadCrop_Normalized_T(torch.nn.Module):
|
| 119 |
-
def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True):
|
| 120 |
-
|
| 121 |
-
super().__init__()
|
| 122 |
-
|
| 123 |
-
self.n_samples = n_samples
|
| 124 |
-
self.sample_rate = sample_rate
|
| 125 |
-
self.randomize = randomize
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
def __call__(self, filename: str, duration: float, cur_sample_rate: int, fixed_offset_duration=None) -> Tuple[torch.Tensor, float, float, int, int]:
|
| 129 |
-
pass
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
class MERTDataset(FairseqDataset):
|
| 133 |
-
def __init__(
|
| 134 |
-
self,
|
| 135 |
-
manifest_path: str,
|
| 136 |
-
sample_rate: float,
|
| 137 |
-
label_paths: List[str],
|
| 138 |
-
label_rates: Union[List[float], float], # -1 for sequence labels
|
| 139 |
-
pad_list: List[str],
|
| 140 |
-
eos_list: List[str],
|
| 141 |
-
label_scp_path: Optional[str] = None,
|
| 142 |
-
label_scp_clip_duration: float = -1,
|
| 143 |
-
label_processors: Optional[List[Any]] = None,
|
| 144 |
-
max_keep_sample_size: Optional[int] = None,
|
| 145 |
-
min_keep_sample_size: Optional[int] = None,
|
| 146 |
-
max_sample_size: Optional[int] = None,
|
| 147 |
-
shuffle: bool = True,
|
| 148 |
-
pad_audio: bool = False,
|
| 149 |
-
normalize: bool = False,
|
| 150 |
-
store_labels: bool = True,
|
| 151 |
-
npmemmap: bool = False,
|
| 152 |
-
random_crop: bool = False,
|
| 153 |
-
single_target: bool = False,
|
| 154 |
-
augmentation_effects: List[str] = [],
|
| 155 |
-
augmentation_probs: List[float] = [],
|
| 156 |
-
inbatch_noise_augment_len_range: List[int] = [8000, 24000],
|
| 157 |
-
inbatch_noise_augment_number_range: List[int] = [1, 3],
|
| 158 |
-
inbatch_noise_augment_volume: float = 1.0,
|
| 159 |
-
cqt_prediction_bin: int = -1,
|
| 160 |
-
dataset_len:int = 128*3000,
|
| 161 |
-
clip_secs = 5,
|
| 162 |
-
):
|
| 163 |
-
self.sample_rate = sample_rate
|
| 164 |
-
self.shuffle = shuffle
|
| 165 |
-
self.random_crop = random_crop
|
| 166 |
-
self.datas,inds,tot,self.sizes = load_audio_by_json(manifest_path,max_keep_sample_size,min_keep_sample_size, self.sample_rate, clip_secs)
|
| 167 |
-
self.inds = inds
|
| 168 |
-
|
| 169 |
-
self.num_labels = len(label_paths)
|
| 170 |
-
self.pad_list = pad_list
|
| 171 |
-
self.eos_list = eos_list
|
| 172 |
-
self.label_processors = label_processors
|
| 173 |
-
self.single_target = single_target
|
| 174 |
-
self.label_rates = (
|
| 175 |
-
[label_rates for _ in range(len(label_paths))]
|
| 176 |
-
if isinstance(label_rates, float)
|
| 177 |
-
else label_rates
|
| 178 |
-
)
|
| 179 |
-
self.store_labels = store_labels
|
| 180 |
-
self.npmemmap = npmemmap
|
| 181 |
-
self.label_scp_path = label_scp_path
|
| 182 |
-
self.label_scp_clip_duration = label_scp_clip_duration
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
if self.label_scp_path is not None:
|
| 186 |
-
from kaldiio import load_scp
|
| 187 |
-
self.label_scp = load_scp(self.label_scp_path)
|
| 188 |
-
|
| 189 |
-
# self.dataset_len = dataset_len
|
| 190 |
-
self.dataset_len = len(self.datas)
|
| 191 |
-
logger.info('preparing labels')
|
| 192 |
-
logger.info('========dataset len: {}=========='.format(self.dataset_len))
|
| 193 |
-
if store_labels:
|
| 194 |
-
if self.npmemmap:
|
| 195 |
-
self.label_list = [load_numpy_label(p+'.npy', inds, tot) for p in label_paths]
|
| 196 |
-
else:
|
| 197 |
-
self.label_list = [load_label(p, inds, tot) for p in label_paths]
|
| 198 |
-
else:
|
| 199 |
-
self.label_paths = label_paths
|
| 200 |
-
# self.label_offsets_list = [
|
| 201 |
-
# load_label_offset(p, inds, tot) for p in label_paths
|
| 202 |
-
# ]
|
| 203 |
-
assert label_processors is None or len(label_processors) == self.num_labels
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
self.max_sample_size = (
|
| 207 |
-
max_sample_size if max_sample_size is not None else sys.maxsize
|
| 208 |
-
)
|
| 209 |
-
self.pad_audio = pad_audio
|
| 210 |
-
self.normalize = normalize
|
| 211 |
-
logger.info(
|
| 212 |
-
f"pad_audio={pad_audio}, random_crop={random_crop}, "
|
| 213 |
-
f"normalize={normalize}, max_sample_size={self.max_sample_size}"
|
| 214 |
-
)
|
| 215 |
-
|
| 216 |
-
self.augmentation_effects = augmentation_effects
|
| 217 |
-
self.augmentation_probs = augmentation_probs
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
self.inbatch_noise_augment_len_range = inbatch_noise_augment_len_range
|
| 221 |
-
self.inbatch_noise_augment_number_range = inbatch_noise_augment_number_range
|
| 222 |
-
self.inbatch_noise_augment_volume = inbatch_noise_augment_volume
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
self.cqt_prediction_bin = cqt_prediction_bin
|
| 226 |
-
if self.cqt_prediction_bin > 0:
|
| 227 |
-
self.encoder_cqt_model = model_cqt_pred(n_bins=self.cqt_prediction_bin)
|
| 228 |
-
logger.info('preparing cqt loss objective in dataloader with cpu')
|
| 229 |
-
|
| 230 |
-
self.epoch = -1
|
| 231 |
-
|
| 232 |
-
self.reader = Read_and_PadCrop_Normalized_T(n_samples=clip_secs*sample_rate if clip_secs>0 else None, sample_rate = self.sample_rate)
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
@property
|
| 237 |
-
def can_reuse_epoch_itr_across_epochs(self):
|
| 238 |
-
pass
|
| 239 |
-
def set_epoch(self, epoch):
|
| 240 |
-
pass
|
| 241 |
-
|
| 242 |
-
def inbatch_noise_augment(self,
|
| 243 |
-
target_audio: torch.Tensor, target_audio_idx: int ,
|
| 244 |
-
batch_audios: torch.Tensor, # [bsz, audio_lengths]
|
| 245 |
-
noise_len_min: int, noise_len_max: int,
|
| 246 |
-
n_noise_min: int, n_noise_max: int,
|
| 247 |
-
noise_vol: float = 1.0):
|
| 248 |
-
pass
|
| 249 |
-
|
| 250 |
-
def get_audio_by_slice(self,index):
|
| 251 |
-
pass
|
| 252 |
-
def get_audio(self, index):
|
| 253 |
-
pass
|
| 254 |
-
|
| 255 |
-
def get_label(self, index, label_idx):
|
| 256 |
-
pass
|
| 257 |
-
|
| 258 |
-
def get_labels(self, index):
|
| 259 |
-
pass
|
| 260 |
-
|
| 261 |
-
def __getitem__(self, i):
|
| 262 |
-
pass
|
| 263 |
-
|
| 264 |
-
def __len__(self):
|
| 265 |
-
return self.dataset_len
|
| 266 |
-
|
| 267 |
-
def crop_to_max_size(self, wav, target_size):
|
| 268 |
-
pass
|
| 269 |
-
|
| 270 |
-
def collater(self, samples):
|
| 271 |
-
pass
|
| 272 |
-
|
| 273 |
-
def collater_audio(self, audios, audio_size):
|
| 274 |
-
pass
|
| 275 |
-
|
| 276 |
-
def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad):
|
| 277 |
-
pass
|
| 278 |
-
|
| 279 |
-
def collater_seq_label(self, targets, pad):
|
| 280 |
-
pass
|
| 281 |
-
|
| 282 |
-
def collater_label(self, targets_by_label, audio_size, audio_starts):
|
| 283 |
-
pass
|
| 284 |
-
|
| 285 |
-
def num_tokens(self, index):
|
| 286 |
-
pass
|
| 287 |
-
|
| 288 |
-
def size(self, index):
|
| 289 |
-
pass
|
| 290 |
-
|
| 291 |
-
def ordered_indices(self):
|
| 292 |
-
pass
|
| 293 |
-
|
| 294 |
-
def postprocess(self, wav, cur_sample_rate):
|
| 295 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MuCodec/muq_dev/muq_fairseq/data/utils/data_utils.py
DELETED
|
@@ -1,535 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
-
#
|
| 3 |
-
# This source code is licensed under the MIT license found in the
|
| 4 |
-
# LICENSE file in the root directory of this source tree.
|
| 5 |
-
|
| 6 |
-
import logging
|
| 7 |
-
import math
|
| 8 |
-
import numpy as np
|
| 9 |
-
import torch
|
| 10 |
-
|
| 11 |
-
from typing import Optional, Tuple
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
logger = logging.getLogger(__name__)
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
def compute_mask_indices(
|
| 20 |
-
shape: Tuple[int, int],
|
| 21 |
-
padding_mask: Optional[torch.Tensor],
|
| 22 |
-
mask_prob: float,
|
| 23 |
-
mask_length: int,
|
| 24 |
-
mask_type: str = "static",
|
| 25 |
-
mask_other: float = 0.0,
|
| 26 |
-
min_masks: int = 0,
|
| 27 |
-
no_overlap: bool = False,
|
| 28 |
-
min_space: int = 0,
|
| 29 |
-
require_same_masks: bool = True,
|
| 30 |
-
mask_dropout: float = 0.0,
|
| 31 |
-
add_masks: bool = False,
|
| 32 |
-
seed: Optional[int] = None,
|
| 33 |
-
epoch: Optional[int] = None,
|
| 34 |
-
indices: Optional[torch.Tensor] = None,
|
| 35 |
-
idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset
|
| 36 |
-
num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset
|
| 37 |
-
) -> np.ndarray:
|
| 38 |
-
"""
|
| 39 |
-
Computes random mask spans for a given shape
|
| 40 |
-
|
| 41 |
-
Args:
|
| 42 |
-
shape: the the shape for which to compute masks.
|
| 43 |
-
should be of size 2 where first element is batch size and 2nd is timesteps
|
| 44 |
-
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
| 45 |
-
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
| 46 |
-
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
| 47 |
-
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
| 48 |
-
mask_type: how to compute mask lengths
|
| 49 |
-
static = fixed size
|
| 50 |
-
uniform = sample from uniform distribution [mask_other, mask_length*2]
|
| 51 |
-
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
|
| 52 |
-
poisson = sample from possion distribution with lambda = mask length
|
| 53 |
-
min_masks: minimum number of masked spans
|
| 54 |
-
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
|
| 55 |
-
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
|
| 56 |
-
require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
|
| 57 |
-
mask_dropout: randomly dropout this percentage of masks in each example
|
| 58 |
-
"""
|
| 59 |
-
|
| 60 |
-
bsz, all_sz = shape
|
| 61 |
-
mask = np.full((bsz, all_sz), False)
|
| 62 |
-
|
| 63 |
-
if num_mask_ver == 1:
|
| 64 |
-
all_num_mask = int(
|
| 65 |
-
# add a random number for probabilistic rounding
|
| 66 |
-
mask_prob * all_sz / float(mask_length)
|
| 67 |
-
+ np.random.rand()
|
| 68 |
-
)
|
| 69 |
-
all_num_mask = max(min_masks, all_num_mask)
|
| 70 |
-
|
| 71 |
-
mask_idcs = []
|
| 72 |
-
for i in range(bsz):
|
| 73 |
-
if seed is not None and epoch is not None and indices is not None:
|
| 74 |
-
seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6)
|
| 75 |
-
else:
|
| 76 |
-
seed_i = None
|
| 77 |
-
|
| 78 |
-
rng = np.random.default_rng(seed_i)
|
| 79 |
-
|
| 80 |
-
if padding_mask is not None:
|
| 81 |
-
sz = all_sz - padding_mask[i].long().sum().item()
|
| 82 |
-
assert sz >= 0, sz
|
| 83 |
-
else:
|
| 84 |
-
sz = all_sz
|
| 85 |
-
|
| 86 |
-
if num_mask_ver == 1:
|
| 87 |
-
if padding_mask is not None:
|
| 88 |
-
num_mask = int(
|
| 89 |
-
# add a random number for probabilistic rounding
|
| 90 |
-
mask_prob * sz / float(mask_length)
|
| 91 |
-
+ np.random.rand()
|
| 92 |
-
)
|
| 93 |
-
num_mask = max(min_masks, num_mask)
|
| 94 |
-
else:
|
| 95 |
-
num_mask = all_num_mask
|
| 96 |
-
elif num_mask_ver == 2:
|
| 97 |
-
num_mask = int(
|
| 98 |
-
# add a random number for probabilistic rounding
|
| 99 |
-
mask_prob * sz / float(mask_length)
|
| 100 |
-
+ rng.random()
|
| 101 |
-
)
|
| 102 |
-
num_mask = max(min_masks, num_mask)
|
| 103 |
-
else:
|
| 104 |
-
raise ValueError()
|
| 105 |
-
|
| 106 |
-
if mask_type == "static":
|
| 107 |
-
lengths = np.full(num_mask, mask_length)
|
| 108 |
-
elif mask_type == "uniform":
|
| 109 |
-
lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
|
| 110 |
-
elif mask_type == "normal":
|
| 111 |
-
lengths = rng.normal(mask_length, mask_other, size=num_mask)
|
| 112 |
-
lengths = [max(1, int(round(x))) for x in lengths]
|
| 113 |
-
elif mask_type == "poisson":
|
| 114 |
-
lengths = rng.poisson(mask_length, size=num_mask)
|
| 115 |
-
lengths = [int(round(x)) for x in lengths]
|
| 116 |
-
else:
|
| 117 |
-
raise Exception("unknown mask selection " + mask_type)
|
| 118 |
-
|
| 119 |
-
if sum(lengths) == 0:
|
| 120 |
-
if mask_type == "static":
|
| 121 |
-
raise ValueError(f"this should never happens")
|
| 122 |
-
else:
|
| 123 |
-
lengths = [min(mask_length, sz - 1)]
|
| 124 |
-
|
| 125 |
-
if no_overlap:
|
| 126 |
-
mask_idc = []
|
| 127 |
-
|
| 128 |
-
def arrange(s, e, length, keep_length):
|
| 129 |
-
span_start = rng.randint(s, e - length)
|
| 130 |
-
mask_idc.extend(span_start + i for i in range(length))
|
| 131 |
-
|
| 132 |
-
new_parts = []
|
| 133 |
-
if span_start - s - min_space >= keep_length:
|
| 134 |
-
new_parts.append((s, span_start - min_space + 1))
|
| 135 |
-
if e - span_start - length - min_space > keep_length:
|
| 136 |
-
new_parts.append((span_start + length + min_space, e))
|
| 137 |
-
return new_parts
|
| 138 |
-
|
| 139 |
-
parts = [(0, sz)]
|
| 140 |
-
min_length = min(lengths)
|
| 141 |
-
for length in sorted(lengths, reverse=True):
|
| 142 |
-
lens = np.fromiter(
|
| 143 |
-
(e - s if e - s >= length + min_space else 0 for s, e in parts),
|
| 144 |
-
np.int,
|
| 145 |
-
)
|
| 146 |
-
l_sum = np.sum(lens)
|
| 147 |
-
if l_sum == 0:
|
| 148 |
-
break
|
| 149 |
-
probs = lens / np.sum(lens)
|
| 150 |
-
c = rng.choice(len(parts), p=probs)
|
| 151 |
-
s, e = parts.pop(c)
|
| 152 |
-
parts.extend(arrange(s, e, length, min_length))
|
| 153 |
-
mask_idc = np.asarray(mask_idc)
|
| 154 |
-
else:
|
| 155 |
-
if idc_select_ver == 1:
|
| 156 |
-
min_len = min(lengths)
|
| 157 |
-
if sz - min_len <= num_mask:
|
| 158 |
-
min_len = sz - num_mask - 1
|
| 159 |
-
mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
|
| 160 |
-
elif idc_select_ver == 2:
|
| 161 |
-
mask_idc = rng.choice(sz, num_mask, replace=False)
|
| 162 |
-
else:
|
| 163 |
-
raise ValueError()
|
| 164 |
-
|
| 165 |
-
mask_idc = np.asarray(
|
| 166 |
-
[
|
| 167 |
-
mask_idc[j] + offset
|
| 168 |
-
for j in range(len(mask_idc))
|
| 169 |
-
for offset in range(lengths[j])
|
| 170 |
-
]
|
| 171 |
-
)
|
| 172 |
-
|
| 173 |
-
mask_idc = np.unique(mask_idc[mask_idc < sz])
|
| 174 |
-
if len(mask_idc) >= sz:
|
| 175 |
-
raise ValueError(
|
| 176 |
-
(
|
| 177 |
-
f"the entire sequence is masked. "
|
| 178 |
-
f"sz={sz}; mask_idc[mask_idc]; "
|
| 179 |
-
f"index={indices[i] if indices is not None else None}"
|
| 180 |
-
)
|
| 181 |
-
)
|
| 182 |
-
mask_idcs.append(mask_idc)
|
| 183 |
-
|
| 184 |
-
target_len = None
|
| 185 |
-
if require_same_masks:
|
| 186 |
-
if add_masks:
|
| 187 |
-
target_len = max([len(m) for m in mask_idcs])
|
| 188 |
-
else:
|
| 189 |
-
target_len = min([len(m) for m in mask_idcs])
|
| 190 |
-
|
| 191 |
-
for i, mask_idc in enumerate(mask_idcs):
|
| 192 |
-
if target_len is not None and len(mask_idc) > target_len:
|
| 193 |
-
mask_idc = rng.choice(mask_idc, target_len, replace=False)
|
| 194 |
-
|
| 195 |
-
mask[i, mask_idc] = True
|
| 196 |
-
|
| 197 |
-
if target_len is not None and len(mask_idc) < target_len:
|
| 198 |
-
unmasked = np.flatnonzero(~mask[i])
|
| 199 |
-
to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False)
|
| 200 |
-
mask[i, to_mask] = True
|
| 201 |
-
|
| 202 |
-
if mask_dropout > 0:
|
| 203 |
-
masked = np.flatnonzero(mask[i])
|
| 204 |
-
num_holes = np.rint(len(masked) * mask_dropout).astype(int)
|
| 205 |
-
to_drop = rng.choice(masked, num_holes, replace=False)
|
| 206 |
-
mask[i, to_drop] = False
|
| 207 |
-
|
| 208 |
-
return mask
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
def compute_block_mask_2d(
|
| 212 |
-
shape: Tuple[int, int],
|
| 213 |
-
mask_prob: float,
|
| 214 |
-
mask_length: int,
|
| 215 |
-
mask_prob_adjust: float = 0,
|
| 216 |
-
inverse_mask: bool = False,
|
| 217 |
-
require_same_masks: bool = True,
|
| 218 |
-
expand_adjcent: bool = False,
|
| 219 |
-
mask_dropout: float = 0,
|
| 220 |
-
non_overlapping: bool = False,
|
| 221 |
-
img_shape: tuple = None, # For the situation when d[0] != d[1], especially in audio spce ways
|
| 222 |
-
flexible_mask: bool = False,
|
| 223 |
-
) -> torch.Tensor:
|
| 224 |
-
|
| 225 |
-
assert mask_length > 1
|
| 226 |
-
|
| 227 |
-
B, L = shape
|
| 228 |
-
|
| 229 |
-
d = (int(L**0.5),int(L**0.5))
|
| 230 |
-
|
| 231 |
-
if img_shape:
|
| 232 |
-
d = (img_shape[0],img_shape[1])
|
| 233 |
-
|
| 234 |
-
if flexible_mask:
|
| 235 |
-
index = np.random.randint(0,3)
|
| 236 |
-
block_size_options = np.array([(6, 4), (5, 5), (8, 3)])
|
| 237 |
-
block_size = block_size_options[index]
|
| 238 |
-
|
| 239 |
-
if inverse_mask:
|
| 240 |
-
mask_prob = 1 - mask_prob
|
| 241 |
-
|
| 242 |
-
if flexible_mask:
|
| 243 |
-
mask = torch.zeros((B, d[0], d[1]))
|
| 244 |
-
mask_inds = torch.randint(
|
| 245 |
-
0,
|
| 246 |
-
L,
|
| 247 |
-
size=(
|
| 248 |
-
B,
|
| 249 |
-
int(
|
| 250 |
-
L
|
| 251 |
-
* ((mask_prob + mask_prob_adjust) / (block_size[0]*block_size[1]))
|
| 252 |
-
* (1 + mask_dropout)
|
| 253 |
-
),
|
| 254 |
-
),
|
| 255 |
-
)
|
| 256 |
-
mask.view(B, -1).scatter_(1, mask_inds, 1)
|
| 257 |
-
centers = mask.nonzero(as_tuple=True)
|
| 258 |
-
|
| 259 |
-
inds = ([], [], [])
|
| 260 |
-
|
| 261 |
-
offset = mask_length // 2
|
| 262 |
-
for i in range(block_size[0]):
|
| 263 |
-
for j in range(block_size[1]):
|
| 264 |
-
k1 = i - offset
|
| 265 |
-
k2 = j - offset
|
| 266 |
-
inds[0].append(centers[0])
|
| 267 |
-
inds[1].append(centers[1] + k1)
|
| 268 |
-
inds[2].append(centers[2] + k2)
|
| 269 |
-
|
| 270 |
-
i0 = torch.cat(inds[0])
|
| 271 |
-
i1 = torch.cat(inds[1]).clamp_(min=0, max=d[0] - 1)
|
| 272 |
-
i2 = torch.cat(inds[2]).clamp_(min=0, max=d[1] - 1)
|
| 273 |
-
|
| 274 |
-
mask[(i0, i1, i2)] = 1
|
| 275 |
-
|
| 276 |
-
elif non_overlapping:
|
| 277 |
-
sz = math.ceil(d[0] / mask_length)
|
| 278 |
-
inp_len = sz * sz
|
| 279 |
-
|
| 280 |
-
inp = torch.zeros((B, 1, sz, sz))
|
| 281 |
-
w = torch.ones((1, 1, mask_length, mask_length))
|
| 282 |
-
|
| 283 |
-
mask_inds = torch.multinomial(
|
| 284 |
-
1 - inp.view(B, -1),
|
| 285 |
-
int(inp_len * (mask_prob + mask_prob_adjust) * (1 + mask_dropout)),
|
| 286 |
-
replacement=False,
|
| 287 |
-
)
|
| 288 |
-
inp.view(B, -1).scatter_(1, mask_inds, 1)
|
| 289 |
-
|
| 290 |
-
mask = torch.nn.functional.conv_transpose2d(inp, w, stride=mask_length).squeeze(
|
| 291 |
-
1
|
| 292 |
-
)
|
| 293 |
-
if mask.size(-1) > d[0]:
|
| 294 |
-
mask = mask[..., :d, :d]
|
| 295 |
-
else:
|
| 296 |
-
mask = torch.zeros((B, d[0], d[1]))
|
| 297 |
-
mask_inds = torch.randint(
|
| 298 |
-
0,
|
| 299 |
-
L,
|
| 300 |
-
size=(
|
| 301 |
-
B,
|
| 302 |
-
int(
|
| 303 |
-
L
|
| 304 |
-
* ((mask_prob + mask_prob_adjust) / mask_length**2)
|
| 305 |
-
* (1 + mask_dropout)
|
| 306 |
-
),
|
| 307 |
-
),
|
| 308 |
-
)
|
| 309 |
-
mask.view(B, -1).scatter_(1, mask_inds, 1)
|
| 310 |
-
centers = mask.nonzero(as_tuple=True)
|
| 311 |
-
|
| 312 |
-
inds = ([], [], [])
|
| 313 |
-
|
| 314 |
-
offset = mask_length // 2
|
| 315 |
-
for i in range(mask_length):
|
| 316 |
-
for j in range(mask_length):
|
| 317 |
-
k1 = i - offset
|
| 318 |
-
k2 = j - offset
|
| 319 |
-
inds[0].append(centers[0])
|
| 320 |
-
inds[1].append(centers[1] + k1)
|
| 321 |
-
inds[2].append(centers[2] + k2)
|
| 322 |
-
|
| 323 |
-
i0 = torch.cat(inds[0])
|
| 324 |
-
i1 = torch.cat(inds[1]).clamp_(min=0, max=d[0] - 1)
|
| 325 |
-
i2 = torch.cat(inds[2]).clamp_(min=0, max=d[1] - 1)
|
| 326 |
-
|
| 327 |
-
mask[(i0, i1, i2)] = 1
|
| 328 |
-
|
| 329 |
-
def get_nbs(b, m, w):
|
| 330 |
-
all_nbs = torch.nn.functional.conv2d(m.unsqueeze(1), w, padding="same")
|
| 331 |
-
all_nbs = all_nbs.clamp_max_(1).view(b, -1)
|
| 332 |
-
return all_nbs
|
| 333 |
-
|
| 334 |
-
if require_same_masks and expand_adjcent:
|
| 335 |
-
w = torch.zeros((1, 1, 3, 3))
|
| 336 |
-
w[..., 0, 1] = 1
|
| 337 |
-
w[..., 2, 1] = 1
|
| 338 |
-
w[..., 1, 0] = 1
|
| 339 |
-
w[..., 1, 2] = 1
|
| 340 |
-
|
| 341 |
-
all_nbs = get_nbs(B, mask, w)
|
| 342 |
-
|
| 343 |
-
mask = mask.reshape(B, -1)
|
| 344 |
-
|
| 345 |
-
if require_same_masks:
|
| 346 |
-
n_masks = mask.sum(dim=-1)
|
| 347 |
-
final_target_len = int(L * (mask_prob))
|
| 348 |
-
target_len = int(final_target_len * (1 + mask_dropout))
|
| 349 |
-
|
| 350 |
-
for i in range(len(mask)):
|
| 351 |
-
n = n_masks[i]
|
| 352 |
-
m = mask[i]
|
| 353 |
-
r = 0
|
| 354 |
-
while expand_adjcent and n < target_len:
|
| 355 |
-
if r == 0:
|
| 356 |
-
nbs = all_nbs[i]
|
| 357 |
-
else:
|
| 358 |
-
nbs = get_nbs(1, m.view(1, d[0], d[1]), w).flatten()
|
| 359 |
-
|
| 360 |
-
cands = (1 - m + nbs) > 1
|
| 361 |
-
cand_sz = int(cands.sum().item())
|
| 362 |
-
|
| 363 |
-
assert cand_sz > 0, f"{nbs} {cand_sz}"
|
| 364 |
-
|
| 365 |
-
to_mask = torch.multinomial(
|
| 366 |
-
cands.float(), min(cand_sz, int(target_len - n)), replacement=False
|
| 367 |
-
)
|
| 368 |
-
m[to_mask] = 1
|
| 369 |
-
assert to_mask.numel() > 0
|
| 370 |
-
n += to_mask.numel()
|
| 371 |
-
r += 1
|
| 372 |
-
|
| 373 |
-
if n > final_target_len:
|
| 374 |
-
to_unmask = torch.multinomial(
|
| 375 |
-
m, int(n - final_target_len), replacement=False
|
| 376 |
-
)
|
| 377 |
-
m[to_unmask] = 0
|
| 378 |
-
elif n < final_target_len:
|
| 379 |
-
to_mask = torch.multinomial(
|
| 380 |
-
(1 - m), int(final_target_len - n), replacement=False
|
| 381 |
-
)
|
| 382 |
-
m[to_mask] = 1
|
| 383 |
-
|
| 384 |
-
if inverse_mask:
|
| 385 |
-
mask = 1 - mask
|
| 386 |
-
|
| 387 |
-
return mask
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
def compute_block_mask_1d(
|
| 391 |
-
shape: Tuple[int, int],
|
| 392 |
-
mask_prob: float,
|
| 393 |
-
mask_length: int,
|
| 394 |
-
mask_prob_adjust: float = 0,
|
| 395 |
-
inverse_mask: bool = False,
|
| 396 |
-
require_same_masks: bool = True,
|
| 397 |
-
expand_adjcent: bool = False,
|
| 398 |
-
mask_dropout: float = 0,
|
| 399 |
-
non_overlapping: bool = False,
|
| 400 |
-
) -> torch.Tensor:
|
| 401 |
-
|
| 402 |
-
B, L = shape
|
| 403 |
-
|
| 404 |
-
if inverse_mask:
|
| 405 |
-
mask_prob = 1 - mask_prob
|
| 406 |
-
|
| 407 |
-
if non_overlapping:
|
| 408 |
-
sz = math.ceil(L / mask_length)
|
| 409 |
-
|
| 410 |
-
inp = torch.zeros((B, 1, sz))
|
| 411 |
-
w = torch.ones((1, 1, mask_length))
|
| 412 |
-
|
| 413 |
-
mask_inds = torch.multinomial(
|
| 414 |
-
1 - inp.view(B, -1),
|
| 415 |
-
int(sz * (mask_prob + mask_prob_adjust) * (1 + mask_dropout)),
|
| 416 |
-
replacement=False,
|
| 417 |
-
)
|
| 418 |
-
inp.view(B, -1).scatter_(1, mask_inds, 1)
|
| 419 |
-
|
| 420 |
-
mask = torch.nn.functional.conv_transpose1d(inp, w, stride=mask_length).squeeze(
|
| 421 |
-
1
|
| 422 |
-
)
|
| 423 |
-
if mask.size(-1) > L:
|
| 424 |
-
mask = mask[..., :L]
|
| 425 |
-
|
| 426 |
-
else:
|
| 427 |
-
mask = torch.zeros((B, L))
|
| 428 |
-
mask_inds = torch.randint(
|
| 429 |
-
0,
|
| 430 |
-
L,
|
| 431 |
-
size=(
|
| 432 |
-
B,
|
| 433 |
-
int(
|
| 434 |
-
L
|
| 435 |
-
* ((mask_prob + mask_prob_adjust) / mask_length)
|
| 436 |
-
* (1 + mask_dropout)
|
| 437 |
-
),
|
| 438 |
-
),
|
| 439 |
-
)
|
| 440 |
-
|
| 441 |
-
mask.view(B, -1).scatter_(1, mask_inds, 1)
|
| 442 |
-
centers = mask.nonzero(as_tuple=True)
|
| 443 |
-
|
| 444 |
-
inds = ([], [])
|
| 445 |
-
|
| 446 |
-
offset = mask_length // 2
|
| 447 |
-
for i in range(mask_length):
|
| 448 |
-
k1 = i - offset
|
| 449 |
-
inds[0].append(centers[0])
|
| 450 |
-
inds[1].append(centers[1] + k1)
|
| 451 |
-
|
| 452 |
-
i0 = torch.cat(inds[0])
|
| 453 |
-
i1 = torch.cat(inds[1]).clamp_(min=0, max=L - 1)
|
| 454 |
-
|
| 455 |
-
mask[(i0, i1)] = 1
|
| 456 |
-
|
| 457 |
-
def get_nbs(b, m, w):
|
| 458 |
-
all_nbs = torch.nn.functional.conv1d(m.unsqueeze(1), w, padding="same")
|
| 459 |
-
all_nbs = all_nbs.clamp_max_(1).view(b, -1)
|
| 460 |
-
return all_nbs
|
| 461 |
-
|
| 462 |
-
if require_same_masks and expand_adjcent:
|
| 463 |
-
w = torch.ones((1, 1, 3))
|
| 464 |
-
w[..., 1] = 0
|
| 465 |
-
all_nbs = get_nbs(B, mask, w)
|
| 466 |
-
|
| 467 |
-
mask = mask.view(B, -1)
|
| 468 |
-
|
| 469 |
-
if require_same_masks:
|
| 470 |
-
n_masks = mask.sum(dim=-1)
|
| 471 |
-
final_target_len = int(L * (mask_prob))
|
| 472 |
-
target_len = int(final_target_len * (1 + mask_dropout))
|
| 473 |
-
|
| 474 |
-
for i in range(len(mask)):
|
| 475 |
-
n = n_masks[i]
|
| 476 |
-
m = mask[i]
|
| 477 |
-
r = 0
|
| 478 |
-
while expand_adjcent and n < target_len:
|
| 479 |
-
if r == 0:
|
| 480 |
-
nbs = all_nbs[i]
|
| 481 |
-
else:
|
| 482 |
-
nbs = get_nbs(1, m.unsqueeze(0), w).squeeze(0)
|
| 483 |
-
|
| 484 |
-
cands = (1 - m + nbs) > 1
|
| 485 |
-
cand_sz = int(cands.sum().item())
|
| 486 |
-
|
| 487 |
-
assert cand_sz > 0, f"{nbs} {cand_sz}"
|
| 488 |
-
|
| 489 |
-
to_mask = torch.multinomial(
|
| 490 |
-
cands.float(), min(cand_sz, int(target_len - n)), replacement=False
|
| 491 |
-
)
|
| 492 |
-
m[to_mask] = 1
|
| 493 |
-
assert to_mask.numel() > 0
|
| 494 |
-
n += to_mask.numel()
|
| 495 |
-
r += 1
|
| 496 |
-
|
| 497 |
-
if n > final_target_len:
|
| 498 |
-
to_unmask = torch.multinomial(
|
| 499 |
-
m, int(n - final_target_len), replacement=False
|
| 500 |
-
)
|
| 501 |
-
m[to_unmask] = 0
|
| 502 |
-
elif n < final_target_len:
|
| 503 |
-
to_mask = torch.multinomial(
|
| 504 |
-
(1 - m), int(final_target_len - n), replacement=False
|
| 505 |
-
)
|
| 506 |
-
m[to_mask] = 1
|
| 507 |
-
|
| 508 |
-
if inverse_mask:
|
| 509 |
-
mask = 1 - mask
|
| 510 |
-
|
| 511 |
-
return mask
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
def get_buckets(sizes, num_buckets):
|
| 515 |
-
buckets = np.unique(
|
| 516 |
-
np.percentile(
|
| 517 |
-
sizes,
|
| 518 |
-
np.linspace(0, 100, num_buckets + 1),
|
| 519 |
-
interpolation="lower",
|
| 520 |
-
)[1:]
|
| 521 |
-
)
|
| 522 |
-
return buckets
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
def get_bucketed_sizes(orig_sizes, buckets):
|
| 526 |
-
sizes = np.copy(orig_sizes)
|
| 527 |
-
assert np.min(sizes) >= 0
|
| 528 |
-
start_val = -1
|
| 529 |
-
for end_val in buckets:
|
| 530 |
-
mask = (sizes > start_val) & (sizes <= end_val)
|
| 531 |
-
sizes[mask] = end_val
|
| 532 |
-
start_val = end_val
|
| 533 |
-
return sizes
|
| 534 |
-
|
| 535 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MuCodec/muq_dev/muq_fairseq/models/muq/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
from .muq_model import *
|
|
|
|
|
|
MuCodec/muq_dev/muq_fairseq/models/muq/model/__init__.py
DELETED
|
@@ -1,2 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
|
|
|
|
|
|
|
|
|
MuCodec/muq_dev/muq_fairseq/models/muq/model/muq.py
DELETED
|
@@ -1,520 +0,0 @@
|
|
| 1 |
-
import json
|
| 2 |
-
import random
|
| 3 |
-
import torch
|
| 4 |
-
from torch import nn
|
| 5 |
-
from einops import rearrange
|
| 6 |
-
import os
|
| 7 |
-
from fairseq.data.data_utils import compute_mask_indices
|
| 8 |
-
from fairseq.models.wav2vec.wav2vec2 import ConvFeatureExtractionModel
|
| 9 |
-
from fairseq.modules import LayerNorm
|
| 10 |
-
|
| 11 |
-
try:
|
| 12 |
-
from ..modules.random_quantizer import RandomProjectionQuantizer
|
| 13 |
-
from ..modules.features import MelSTFT
|
| 14 |
-
from ..modules.conv import Conv2dSubsampling
|
| 15 |
-
except:
|
| 16 |
-
import sys, os
|
| 17 |
-
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
| 18 |
-
from modules.random_quantizer import RandomProjectionQuantizer
|
| 19 |
-
from modules.features import MelSTFT
|
| 20 |
-
from modules.conv import Conv2dSubsampling
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
class MuQ(nn.Module):
|
| 24 |
-
"""
|
| 25 |
-
MuQ
|
| 26 |
-
|
| 27 |
-
Input: 128-band mel spectrogram
|
| 28 |
-
Frontend: 2-layer Residual convolution
|
| 29 |
-
Backend: 12-layer Conformer
|
| 30 |
-
Quantizer: a codebook for mel spectrogram
|
| 31 |
-
"""
|
| 32 |
-
|
| 33 |
-
def __init__(
|
| 34 |
-
self,
|
| 35 |
-
num_codebooks=1,
|
| 36 |
-
codebook_dim=16,
|
| 37 |
-
codebook_size=4096,
|
| 38 |
-
features=["melspec_2048"],
|
| 39 |
-
hop_length=240,
|
| 40 |
-
n_mels=128,
|
| 41 |
-
conv_dim=512,
|
| 42 |
-
encoder_dim=1024,
|
| 43 |
-
encoder_depth=12,
|
| 44 |
-
mask_hop=0.4,
|
| 45 |
-
mask_prob=0.6,
|
| 46 |
-
is_flash=False,
|
| 47 |
-
stat_path=None, #"./data/fma_stats.json",
|
| 48 |
-
model_path=None, #"./data/pretrained_fma.pt",
|
| 49 |
-
w2v2_config_path=None, #"facebook/wav2vec2-conformer-rope-large-960h-ft",
|
| 50 |
-
use_rvq_target=False,
|
| 51 |
-
use_vq_target=False,
|
| 52 |
-
rvq_ckpt_path=None,
|
| 53 |
-
recon_loss_ratio=None,
|
| 54 |
-
label_rate=25,
|
| 55 |
-
use_hubert_masking_strategy=False,
|
| 56 |
-
use_hubert_featurizer=False,
|
| 57 |
-
hubert_conv_feature_layers="[(512,10,5)] + [(512,3,2)] * 3 + [(512,3,3)] + [(512,2,2)] * 2",
|
| 58 |
-
use_hubert_nce_loss=False,
|
| 59 |
-
hubert_final_dim=256,
|
| 60 |
-
rvq_n_codebooks=8,
|
| 61 |
-
rvq_multi_layer_num=1,
|
| 62 |
-
use_encodec_target=False,
|
| 63 |
-
):
|
| 64 |
-
super(MuQ, self).__init__()
|
| 65 |
-
|
| 66 |
-
# global variables
|
| 67 |
-
self.hop_length = hop_length
|
| 68 |
-
self.mask_hop = mask_hop
|
| 69 |
-
self.mask_prob = mask_prob
|
| 70 |
-
self.num_codebooks = num_codebooks
|
| 71 |
-
self.codebook_size = codebook_size
|
| 72 |
-
self.features = features
|
| 73 |
-
self.recon_loss_ratio = recon_loss_ratio
|
| 74 |
-
self.n_fold = int(100//label_rate)
|
| 75 |
-
self.label_rate = label_rate
|
| 76 |
-
self.use_hubert_masking_strategy = use_hubert_masking_strategy
|
| 77 |
-
self.use_hubert_featurizer = use_hubert_featurizer
|
| 78 |
-
self.use_hubert_nce_loss = use_hubert_nce_loss
|
| 79 |
-
|
| 80 |
-
# load feature mean / std stats
|
| 81 |
-
import os
|
| 82 |
-
if stat_path is not None and os.path.exists(stat_path):
|
| 83 |
-
with open(stat_path, "r") as f:
|
| 84 |
-
self.stat = json.load(f)
|
| 85 |
-
else:
|
| 86 |
-
# print("No stats file found at `{}`, use default from msd.".format(stat_path))
|
| 87 |
-
self.stat = {"spec_256_cnt": 14394344256, "spec_256_mean": -23.34296658431829, "spec_256_std": 26.189295587132637, "spec_512_cnt": 28677104448, "spec_512_mean": -21.31267396860235, "spec_512_std": 26.52644536245769, "spec_1024_cnt": 57242624832, "spec_1024_mean": -18.852271129208273, "spec_1024_std": 26.443154583585663, "spec_2048_cnt": 114373665600, "spec_2048_mean": -15.638743433896792, "spec_2048_std": 26.115825961611545, "spec_4096_cnt": 228635747136, "spec_4096_mean": -11.715532502794836, "spec_4096_std": 25.763972210234062, "melspec_256_cnt": 14282760192, "melspec_256_mean": -26.962600400166156, "melspec_256_std": 36.13614100912126, "melspec_512_cnt": 14282760192, "melspec_512_mean": -9.108344167718862, "melspec_512_std": 24.71910937988429, "melspec_1024_cnt": 14282760192, "melspec_1024_mean": 0.37302579246531126, "melspec_1024_std": 18.684082325919388, "melspec_2048_cnt": 14282760192, "melspec_2048_mean": 6.768444971712967, "melspec_2048_std": 18.417922652295623, "melspec_4096_cnt": 14282760192, "melspec_4096_mean": 13.617164614990036, "melspec_4096_std": 18.08552130124525, "cqt_cnt": 9373061376, "cqt_mean": 0.46341379757927165, "cqt_std": 0.9543998080910191, "mfcc_256_cnt": 1339008768, "mfcc_256_mean": -11.681755459447485, "mfcc_256_std": 29.183186444668316, "mfcc_512_cnt": 1339008768, "mfcc_512_mean": -2.540581461792183, "mfcc_512_std": 31.93752185832081, "mfcc_1024_cnt": 1339008768, "mfcc_1024_mean": 6.606636263169779, "mfcc_1024_std": 34.151644801729624, "mfcc_2048_cnt": 1339008768, "mfcc_2048_mean": 5.281600844245184, "mfcc_2048_std": 33.12784541220003, "mfcc_4096_cnt": 1339008768, "mfcc_4096_mean": 4.7616569480166095, "mfcc_4096_std": 32.61458906894133, "chromagram_256_cnt": 1339008768, "chromagram_256_mean": 55.15596556703181, "chromagram_256_std": 73.91858278719991, "chromagram_512_cnt": 1339008768, "chromagram_512_mean": 175.73092252759895, "chromagram_512_std": 248.48485148525953, "chromagram_1024_cnt": 1339008768, "chromagram_1024_mean": 589.2947481634608, "chromagram_1024_std": 913.857929063196, "chromagram_2048_cnt": 1339008768, "chromagram_2048_mean": 2062.286388327397, "chromagram_2048_std": 3458.92657915397, "chromagram_4096_cnt": 1339008768, "chromagram_4096_mean": 7673.039107997085, "chromagram_4096_std": 13009.883158267234}
|
| 88 |
-
|
| 89 |
-
# feature extractor
|
| 90 |
-
self.preprocessor_melspec_2048 = MelSTFT(
|
| 91 |
-
n_fft=2048, hop_length=hop_length, is_db=True
|
| 92 |
-
)
|
| 93 |
-
|
| 94 |
-
# random quantizer
|
| 95 |
-
self.use_rvq_target = use_rvq_target
|
| 96 |
-
self.use_vq_target = use_vq_target
|
| 97 |
-
self.use_encodec_target = use_encodec_target
|
| 98 |
-
|
| 99 |
-
seed = 142
|
| 100 |
-
if self.use_rvq_like_target:
|
| 101 |
-
if use_rvq_target:
|
| 102 |
-
try:
|
| 103 |
-
from .rvq_muq import ResidualVectorQuantize
|
| 104 |
-
except:
|
| 105 |
-
import sys, os
|
| 106 |
-
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 107 |
-
from rvq_muq import ResidualVectorQuantize
|
| 108 |
-
|
| 109 |
-
inp_dim = 128*self.n_fold
|
| 110 |
-
self.rvq = ResidualVectorQuantize(
|
| 111 |
-
input_dim = inp_dim,
|
| 112 |
-
n_codebooks = rvq_n_codebooks,
|
| 113 |
-
codebook_size = 1024,
|
| 114 |
-
codebook_dim = 16,
|
| 115 |
-
quantizer_dropout = 0.0,
|
| 116 |
-
use_multi_layer_num = rvq_multi_layer_num,
|
| 117 |
-
)
|
| 118 |
-
elif use_vq_target:
|
| 119 |
-
try:
|
| 120 |
-
from .rvq_muq import VectorQuantize
|
| 121 |
-
except:
|
| 122 |
-
import sys, os
|
| 123 |
-
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 124 |
-
from rvq_muq import VectorQuantize
|
| 125 |
-
|
| 126 |
-
self.rvq = VectorQuantize(
|
| 127 |
-
input_dim = 128*self.n_fold,
|
| 128 |
-
codebook_size = 1024,
|
| 129 |
-
codebook_dim = 8,
|
| 130 |
-
stale_tolerance = 1000,
|
| 131 |
-
mfcc_clustering = False
|
| 132 |
-
)
|
| 133 |
-
elif use_encodec_target:
|
| 134 |
-
from encodec import EncodecModel
|
| 135 |
-
self.rvq = EncodecModel.encodec_model_24khz()
|
| 136 |
-
self.rvq.set_target_bandwidth(6.0)
|
| 137 |
-
for param in self.rvq.parameters():
|
| 138 |
-
param.requires_grad = False
|
| 139 |
-
|
| 140 |
-
import os
|
| 141 |
-
if rvq_ckpt_path is not None and os.path.exists(rvq_ckpt_path):
|
| 142 |
-
state_dict = torch.load(rvq_ckpt_path, map_location="cpu")
|
| 143 |
-
self.rvq.load_state_dict(state_dict)
|
| 144 |
-
else:
|
| 145 |
-
print(f'Checkpoint for rvq `{rvq_ckpt_path}` not found. Using random initialization.')
|
| 146 |
-
else:
|
| 147 |
-
for feature in self.features:
|
| 148 |
-
for i in range(num_codebooks):
|
| 149 |
-
setattr(
|
| 150 |
-
self,
|
| 151 |
-
f"quantizer_{feature}", # _{i}
|
| 152 |
-
RandomProjectionQuantizer(
|
| 153 |
-
n_mels * self.n_fold, codebook_dim, codebook_size, seed=seed + i
|
| 154 |
-
),
|
| 155 |
-
)
|
| 156 |
-
|
| 157 |
-
if use_hubert_masking_strategy:
|
| 158 |
-
self.mask_emb = nn.Parameter(
|
| 159 |
-
torch.FloatTensor(encoder_dim).uniform_()
|
| 160 |
-
)
|
| 161 |
-
|
| 162 |
-
if use_hubert_featurizer:
|
| 163 |
-
feature_enc_layers = eval(hubert_conv_feature_layers) # noqa
|
| 164 |
-
hubert_feat_embed = feature_enc_layers[-1][0]
|
| 165 |
-
self.hubert_feature_extractor = ConvFeatureExtractionModel(
|
| 166 |
-
conv_layers=feature_enc_layers,
|
| 167 |
-
dropout=0.0,
|
| 168 |
-
mode='default', #cfg.extractor_mode,
|
| 169 |
-
conv_bias=False, #cfg.conv_bias,
|
| 170 |
-
)
|
| 171 |
-
self.post_extract_proj = (
|
| 172 |
-
nn.Linear(hubert_feat_embed, encoder_dim)
|
| 173 |
-
if hubert_feat_embed != encoder_dim
|
| 174 |
-
else None
|
| 175 |
-
)
|
| 176 |
-
self.layer_norm = LayerNorm(hubert_feat_embed)
|
| 177 |
-
else:
|
| 178 |
-
# two residual convolution layers + one projection layer
|
| 179 |
-
strides_factory = {
|
| 180 |
-
4: [2, 2],
|
| 181 |
-
2: [2, 1]
|
| 182 |
-
}
|
| 183 |
-
self.conv = Conv2dSubsampling(
|
| 184 |
-
1, conv_dim, encoder_dim, strides=strides_factory.get(self.n_fold), n_bands=n_mels
|
| 185 |
-
)
|
| 186 |
-
|
| 187 |
-
# Conformer
|
| 188 |
-
if is_flash:
|
| 189 |
-
from modules.flash_conformer import (
|
| 190 |
-
Wav2Vec2ConformerEncoder,
|
| 191 |
-
Wav2Vec2ConformerConfig,
|
| 192 |
-
)
|
| 193 |
-
else:
|
| 194 |
-
from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
|
| 195 |
-
Wav2Vec2ConformerEncoder,
|
| 196 |
-
Wav2Vec2ConformerConfig,
|
| 197 |
-
)
|
| 198 |
-
import os
|
| 199 |
-
if w2v2_config_path is None or not os.path.exists(w2v2_config_path):
|
| 200 |
-
w2v2_config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "w2v2_config.json")
|
| 201 |
-
print("load w2v2 config from:", w2v2_config_path)
|
| 202 |
-
config = Wav2Vec2ConformerConfig.from_pretrained(
|
| 203 |
-
w2v2_config_path
|
| 204 |
-
)
|
| 205 |
-
config.num_hidden_layers = encoder_depth
|
| 206 |
-
config.hidden_size = encoder_dim
|
| 207 |
-
|
| 208 |
-
self.conformer = Wav2Vec2ConformerEncoder(config)
|
| 209 |
-
|
| 210 |
-
if self.use_hubert_nce_loss:
|
| 211 |
-
self.label_embs_concat = nn.Parameter(
|
| 212 |
-
torch.FloatTensor(codebook_size, hubert_final_dim)
|
| 213 |
-
) # embeddings of codes
|
| 214 |
-
nn.init.uniform_(self.label_embs_concat)
|
| 215 |
-
self.linear = nn.Linear(encoder_dim, hubert_final_dim) # final_proj
|
| 216 |
-
else:
|
| 217 |
-
# projection
|
| 218 |
-
self.linear = nn.Linear(encoder_dim, codebook_size) # N_SubSpec=8
|
| 219 |
-
|
| 220 |
-
# reconstruct melspec
|
| 221 |
-
if self.recon_loss_ratio is not None and self.recon_loss_ratio > 0:
|
| 222 |
-
self.recon_proj = nn.Linear(encoder_dim, n_mels * self.n_fold)
|
| 223 |
-
self.recon_loss = nn.MSELoss()
|
| 224 |
-
|
| 225 |
-
# loss function
|
| 226 |
-
self.loss = nn.CrossEntropyLoss()
|
| 227 |
-
|
| 228 |
-
# cls token (used for sequence classification)
|
| 229 |
-
random.seed(seed)
|
| 230 |
-
self.cls_token = nn.Parameter(torch.randn(encoder_dim))
|
| 231 |
-
|
| 232 |
-
# load model
|
| 233 |
-
if model_path:
|
| 234 |
-
S = torch.load(model_path)["state_dict"]
|
| 235 |
-
SS = {k[6:]: v for k, v in S.items()}
|
| 236 |
-
SS['quantizer_melspec_2048.random_projection'] = SS['quantizer_melspec_2048_0.random_projection']
|
| 237 |
-
SS['quantizer_melspec_2048.codebook'] = SS['quantizer_melspec_2048_0.codebook']
|
| 238 |
-
del SS['quantizer_melspec_2048_0.random_projection']
|
| 239 |
-
del SS['quantizer_melspec_2048_0.codebook']
|
| 240 |
-
unmatch = self.load_state_dict(SS, strict=False)
|
| 241 |
-
if len(unmatch.missing_keys) > 0:
|
| 242 |
-
print(f'Missing keys: {unmatch.missing_keys}')
|
| 243 |
-
|
| 244 |
-
@property
|
| 245 |
-
def use_rvq_like_target(self):
|
| 246 |
-
return self.use_rvq_target or self.use_vq_target or self.use_encodec_target
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
def apply_hubert_mask(self, x, padding_mask=None, target_list=None):
|
| 250 |
-
B, T, C = x.shape
|
| 251 |
-
if self.mask_prob > 0:
|
| 252 |
-
mask_length = int(self.mask_hop / (1/self.label_rate))
|
| 253 |
-
mask_indices = compute_mask_indices(
|
| 254 |
-
(B, T),
|
| 255 |
-
padding_mask,
|
| 256 |
-
self.mask_prob,
|
| 257 |
-
mask_length, # self.mask_length,
|
| 258 |
-
"static", #self.mask_selection,
|
| 259 |
-
0, #self.mask_other,
|
| 260 |
-
min_masks=2,
|
| 261 |
-
no_overlap=False, #self.no_mask_overlap,
|
| 262 |
-
min_space=1, #self.mask_min_space,
|
| 263 |
-
)
|
| 264 |
-
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
| 265 |
-
x[mask_indices] = self.mask_emb
|
| 266 |
-
mask_indices = torch.nonzero(mask_indices)
|
| 267 |
-
else:
|
| 268 |
-
mask_indices = None
|
| 269 |
-
|
| 270 |
-
return x, mask_indices
|
| 271 |
-
|
| 272 |
-
def masking(self, x, attention_mask=None):
|
| 273 |
-
"""random masking of 400ms with given probability"""
|
| 274 |
-
if self.use_hubert_masking_strategy:
|
| 275 |
-
return x, None
|
| 276 |
-
mx = x.clone()
|
| 277 |
-
b, t = mx.shape
|
| 278 |
-
len_masking_raw = int(24000 * self.mask_hop) # 9600 = 24000 * 0.4
|
| 279 |
-
len_masking_token = int(24000 / self.hop_length / 2 / 2 * self.mask_hop) # 10 = 25Hz * 0.4
|
| 280 |
-
|
| 281 |
-
# get random mask indices
|
| 282 |
-
start_indices = torch.rand(b, t // len_masking_raw) < self.mask_prob
|
| 283 |
-
time_domain_masked_indices = torch.nonzero(
|
| 284 |
-
start_indices.repeat_interleave(len_masking_raw, dim=1)
|
| 285 |
-
)
|
| 286 |
-
token_domain_masked_indices = torch.nonzero(
|
| 287 |
-
start_indices.repeat_interleave(len_masking_token, dim=1)
|
| 288 |
-
)
|
| 289 |
-
|
| 290 |
-
# mask with random values
|
| 291 |
-
masking_noise = (
|
| 292 |
-
torch.randn(time_domain_masked_indices.shape[0], dtype=x.dtype) * 0.1
|
| 293 |
-
) # 0 mean 0.1 std
|
| 294 |
-
mx[tuple(time_domain_masked_indices.t())] = masking_noise.to(x.device)
|
| 295 |
-
|
| 296 |
-
return mx, token_domain_masked_indices
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
@torch.no_grad()
|
| 300 |
-
def preprocessing(self, x, features):
|
| 301 |
-
"""extract classic audio features"""
|
| 302 |
-
# check precision
|
| 303 |
-
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
|
| 304 |
-
precision = 16
|
| 305 |
-
else:
|
| 306 |
-
precision = 32
|
| 307 |
-
|
| 308 |
-
out = {}
|
| 309 |
-
for key in features:
|
| 310 |
-
layer = getattr(self, "preprocessor_%s" % key)
|
| 311 |
-
layer.to(x.device)
|
| 312 |
-
dtype = x.dtype
|
| 313 |
-
out[key] = layer.float()(x.float())[..., :-1]
|
| 314 |
-
if precision == 16:
|
| 315 |
-
out[key] = out[key].half()
|
| 316 |
-
if out[key].dtype != dtype:
|
| 317 |
-
out[key].to(dtype=dtype)
|
| 318 |
-
return out
|
| 319 |
-
|
| 320 |
-
def encoder(self, x, *, attention_mask=None, is_features_only=False):
|
| 321 |
-
"""2-layer conv + w2v-conformer"""
|
| 322 |
-
if not self.use_hubert_featurizer:
|
| 323 |
-
x = self.conv(x) # [3, 128, 3000] -> [3, 750, 1024]
|
| 324 |
-
if self.training and self.use_hubert_masking_strategy and not is_features_only:
|
| 325 |
-
x, mask_indices = self.apply_hubert_mask(x)
|
| 326 |
-
else:
|
| 327 |
-
mask_indices = None
|
| 328 |
-
if attention_mask is None:
|
| 329 |
-
out = self.conformer(x, output_hidden_states=True)
|
| 330 |
-
else:
|
| 331 |
-
attention_mask = attention_mask.bool()
|
| 332 |
-
skip_n = int(attention_mask.size(-1) / x.size(1))
|
| 333 |
-
attention_mask = attention_mask[:, ::skip_n]
|
| 334 |
-
attention_mask = attention_mask[:, :x.size(1)]
|
| 335 |
-
out = self.conformer(x, attention_mask=attention_mask, output_hidden_states=True)
|
| 336 |
-
hidden_emb = out["hidden_states"]
|
| 337 |
-
last_emb = out["last_hidden_state"]
|
| 338 |
-
logits = self.linear(last_emb)
|
| 339 |
-
interval = self.codebook_size
|
| 340 |
-
logits = {
|
| 341 |
-
key: logits[:, :, i * interval : (i + 1) * interval]
|
| 342 |
-
for i, key in enumerate(self.features)
|
| 343 |
-
}
|
| 344 |
-
return logits, hidden_emb, mask_indices
|
| 345 |
-
|
| 346 |
-
@torch.no_grad()
|
| 347 |
-
def normalize(self, x):
|
| 348 |
-
"""normalize the input audio to have zero mean unit variance"""
|
| 349 |
-
for key in x.keys():
|
| 350 |
-
x[key] = (x[key] - self.stat["%s_mean" % key]) / self.stat["%s_std" % key] # {'melspec_2048_cnt': 14282760192, 'melspec_2048_mean': 6.768444971712967}
|
| 351 |
-
return x
|
| 352 |
-
|
| 353 |
-
@torch.no_grad()
|
| 354 |
-
def rearrange(self, x):
|
| 355 |
-
"""rearrange the batch to flatten every 4 steps"""
|
| 356 |
-
for key in x.keys():
|
| 357 |
-
if key == "chromagram":
|
| 358 |
-
x[key] = rearrange(x[key], "b f t -> b t f")
|
| 359 |
-
else:
|
| 360 |
-
x[key] = rearrange(x[key], "b f (t s) -> b t (s f)", s=self.n_fold)
|
| 361 |
-
return x
|
| 362 |
-
|
| 363 |
-
def get_rvq_codes(self, inp, raw_wav):
|
| 364 |
-
if self.use_rvq_target:
|
| 365 |
-
quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = self.rvq(inp)
|
| 366 |
-
return codes
|
| 367 |
-
if self.use_vq_target:
|
| 368 |
-
quantized_prompt_embeds, commitment_loss, codebook_loss, codes, _ = self.rvq(inp)
|
| 369 |
-
return codes.unsqueeze(1)
|
| 370 |
-
if self.use_encodec_target:
|
| 371 |
-
encoded_frames = self.rvq.encode(raw_wav.unsqueeze(1)) #list, B,[ 8,T ]
|
| 372 |
-
codes = torch.cat([encoded[0].detach() for encoded in encoded_frames], dim=-1)
|
| 373 |
-
if self.label_rate == 25:
|
| 374 |
-
codes = codes[:, :, ::3]
|
| 375 |
-
return codes
|
| 376 |
-
|
| 377 |
-
@torch.no_grad()
|
| 378 |
-
def tokenize(self, x, raw_wav):
|
| 379 |
-
out = {}
|
| 380 |
-
for key in x.keys():
|
| 381 |
-
if self.use_rvq_like_target:
|
| 382 |
-
self.rvq.eval()
|
| 383 |
-
inp = x[key].permute((0, 2, 1))
|
| 384 |
-
codes = self.get_rvq_codes(inp, raw_wav)
|
| 385 |
-
out[key] = torch.cat([codes[:, idx, ...] for idx in range(int(self.codebook_size//1024))], dim=-1) # (when use freq mask)->[Batch, N_SubSpec, SeqLen=8*750]
|
| 386 |
-
else:
|
| 387 |
-
layer = getattr(self, "quantizer_%s" % key)
|
| 388 |
-
out[key] = layer(x[key])
|
| 389 |
-
return out
|
| 390 |
-
|
| 391 |
-
def to_spec_wise_quad(self, x):
|
| 392 |
-
Batch, QuadSpec, Time = x.shape
|
| 393 |
-
SubSpec, N_SubSpec = 16, 8
|
| 394 |
-
assert 4 * SubSpec * N_SubSpec == QuadSpec == 4*128
|
| 395 |
-
x = rearrange(x, "b (q n s) t -> b (q s) (n t)", q=4, n=N_SubSpec, s=SubSpec)
|
| 396 |
-
return x # [Batch, SubSpec=16, N_SubSpec*Time=8*100Hz]
|
| 397 |
-
|
| 398 |
-
def get_targets(self, x, label=None):
|
| 399 |
-
if self.use_encodec_target:
|
| 400 |
-
raw_x = x.clone()
|
| 401 |
-
else:
|
| 402 |
-
raw_x = None
|
| 403 |
-
x = self.preprocessing(x, features=self.features) # -> {'melspec_2048': Tensor{Size([3, 128, 3000]) cuda:0 f32}}
|
| 404 |
-
x = self.normalize(x)
|
| 405 |
-
x = self.rearrange(x) # -> {'melspec_2048': Tensor{Size([3, 750, 512]) cuda:0 f32}}
|
| 406 |
-
melspec = x['melspec_2048']
|
| 407 |
-
if label is None:
|
| 408 |
-
target_tokens = self.tokenize(x, raw_x) # -> {'melspec_2048': Tensor{Size([3, 750]) cuda:0 i64}}
|
| 409 |
-
else:
|
| 410 |
-
# print("use_target from label")
|
| 411 |
-
target_tokens = {'melspec_2048': rearrange(label, "b n s -> b (n s)").long()}
|
| 412 |
-
return target_tokens, melspec
|
| 413 |
-
|
| 414 |
-
def get_predictions(self, x, *, mask=None, attention_mask=None, return_new_mask=False, is_features_only=False):
|
| 415 |
-
# preprocessing
|
| 416 |
-
if not self.use_hubert_featurizer:
|
| 417 |
-
x = self.preprocessing(x, features=["melspec_2048"])
|
| 418 |
-
x = self.normalize(x) # -> {'melspec_2048': Tensor{Size([3, 128, 3000]) cuda:0 f32}}
|
| 419 |
-
else:
|
| 420 |
-
features = self.hubert_feature_extractor(x)
|
| 421 |
-
features = self.layer_norm(features.transpose(1, 2))
|
| 422 |
-
if self.post_extract_proj is not None:
|
| 423 |
-
features = self.post_extract_proj(features)
|
| 424 |
-
x = {"melspec_2048": features}
|
| 425 |
-
|
| 426 |
-
# encoding
|
| 427 |
-
logits, hidden_emb, new_mask = self.encoder(x["melspec_2048"], attention_mask=attention_mask, is_features_only=is_features_only)
|
| 428 |
-
|
| 429 |
-
if return_new_mask:
|
| 430 |
-
return logits, hidden_emb, mask if new_mask is None else new_mask
|
| 431 |
-
else:
|
| 432 |
-
return logits, hidden_emb
|
| 433 |
-
|
| 434 |
-
def get_latent(self, x, layer_ix=12):
|
| 435 |
-
_, hidden_states = self.get_predictions(x)
|
| 436 |
-
emb = hidden_states[layer_ix]
|
| 437 |
-
return emb
|
| 438 |
-
|
| 439 |
-
def compute_nce(self, x, pos, negs):
|
| 440 |
-
neg_is_pos = (pos == negs).all(-1)
|
| 441 |
-
pos = pos.unsqueeze(0)
|
| 442 |
-
targets = torch.cat([pos, negs], dim=0)
|
| 443 |
-
|
| 444 |
-
logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x)
|
| 445 |
-
logits /= 0.1
|
| 446 |
-
if neg_is_pos.any():
|
| 447 |
-
logits[1:][neg_is_pos] = float("-inf")
|
| 448 |
-
logits = logits.transpose(0, 1) # (num_x, num_cls+1)
|
| 449 |
-
return logits
|
| 450 |
-
|
| 451 |
-
def compute_hubert_nce_loss(self, proj_xs, targets):
|
| 452 |
-
|
| 453 |
-
label_embs_list = self.label_embs_concat.split(self.codebook_size, 0) # (self.num_classes, 0)
|
| 454 |
-
|
| 455 |
-
def compute_pred(proj_x, target, label_embs):
|
| 456 |
-
# compute logits for the i-th label set
|
| 457 |
-
y = torch.index_select(label_embs, 0, target.long())
|
| 458 |
-
negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1)
|
| 459 |
-
return self.compute_nce(proj_x, y, negs)
|
| 460 |
-
|
| 461 |
-
logit_list = [
|
| 462 |
-
compute_pred(proj_x, t, label_embs_list[i])
|
| 463 |
-
for i, (proj_x, t) in enumerate(zip(proj_xs, targets))
|
| 464 |
-
]
|
| 465 |
-
|
| 466 |
-
return sum(logit_list)
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
def get_loss(self, logits, target_tokens, masked_indices):
|
| 470 |
-
losses = {}
|
| 471 |
-
accuracies = {}
|
| 472 |
-
for key in logits.keys():
|
| 473 |
-
if not self.use_rvq_like_target:
|
| 474 |
-
masked_logits = logits[key][tuple(masked_indices.t())]
|
| 475 |
-
masked_tokens = target_tokens[key][tuple(masked_indices.t())]
|
| 476 |
-
else:
|
| 477 |
-
Batch, SeqLen, N_Codebook_x_CodebookSize = logits[key].shape # CodebookSize=4096
|
| 478 |
-
Batch, N_Codebook_x_SeqLen = target_tokens[key].shape # N_Codebook*SeqLen=4*750
|
| 479 |
-
N_Codebook = int(N_Codebook_x_SeqLen // SeqLen)
|
| 480 |
-
# print("not use_virtual, n codebook = ", N_Codebook)
|
| 481 |
-
target_tokens[key] = rearrange(target_tokens[key], "b (n s) -> b s n", n=N_Codebook) # Batch, SeqLen=750, N_Codebook=4
|
| 482 |
-
masked_logits = logits[key][tuple(masked_indices.t())]
|
| 483 |
-
masked_tokens = target_tokens[key][tuple(masked_indices.t())]
|
| 484 |
-
masked_logits = rearrange(masked_logits, "b (n c) -> (b n) c", n=N_Codebook)
|
| 485 |
-
masked_tokens = rearrange(masked_tokens, "b n -> (b n)", n=N_Codebook)
|
| 486 |
-
|
| 487 |
-
if self.use_hubert_nce_loss:
|
| 488 |
-
losses[key] = self.compute_hubert_nce_loss(masked_logits, masked_tokens)
|
| 489 |
-
else:
|
| 490 |
-
losses[key] = self.loss(masked_logits, masked_tokens)
|
| 491 |
-
accuracies[key] = (
|
| 492 |
-
torch.sum(masked_logits.argmax(-1) == masked_tokens)
|
| 493 |
-
/ masked_tokens.numel()
|
| 494 |
-
)
|
| 495 |
-
return losses, accuracies
|
| 496 |
-
|
| 497 |
-
def get_recon_loss(self, last_hidden_emb, melspec, masked_indices):
|
| 498 |
-
pred_melspec = self.recon_proj(last_hidden_emb[tuple(masked_indices.t())])
|
| 499 |
-
target_melspec = melspec[tuple(masked_indices.t())]
|
| 500 |
-
recon_loss = self.recon_loss(pred_melspec, target_melspec)
|
| 501 |
-
return recon_loss
|
| 502 |
-
|
| 503 |
-
def forward(self, x, attention_mask=None, label=None):
|
| 504 |
-
dtype = x.dtype
|
| 505 |
-
# get target feature tokens
|
| 506 |
-
target_tokens, melspec = self.get_targets(x, label=label)
|
| 507 |
-
|
| 508 |
-
# masking
|
| 509 |
-
x, masked_indices = self.masking(x, attention_mask=attention_mask)
|
| 510 |
-
|
| 511 |
-
# forward
|
| 512 |
-
logits, hidden_emb, masked_indices = self.get_predictions(x, mask=masked_indices, attention_mask=attention_mask, return_new_mask=True)
|
| 513 |
-
|
| 514 |
-
# get loss
|
| 515 |
-
losses, accuracies = self.get_loss(logits, target_tokens, masked_indices)
|
| 516 |
-
|
| 517 |
-
if self.recon_loss_ratio:
|
| 518 |
-
losses["recon_loss"] = self.get_recon_loss(hidden_emb[-1], melspec, masked_indices) * self.recon_loss_ratio
|
| 519 |
-
|
| 520 |
-
return logits, hidden_emb, losses, accuracies
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MuCodec/muq_dev/muq_fairseq/models/muq/model/pred_ark_target_with_model.py
DELETED
|
@@ -1,151 +0,0 @@
|
|
| 1 |
-
import sys
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
import torch
|
| 4 |
-
import sys, os
|
| 5 |
-
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 6 |
-
from rvq_musicfm import PreprocessorWithModel, ResidualVectorQuantize
|
| 7 |
-
|
| 8 |
-
class RVQ(nn.Module):
|
| 9 |
-
def __init__(self,
|
| 10 |
-
model_config,
|
| 11 |
-
rvq_ckpt_path,
|
| 12 |
-
preprocess,
|
| 13 |
-
):
|
| 14 |
-
super().__init__()
|
| 15 |
-
self.rvq = ResidualVectorQuantize(**model_config)
|
| 16 |
-
if rvq_ckpt_path is not None:
|
| 17 |
-
self.rvq.load_state_dict(torch.load(rvq_ckpt_path, map_location='cpu'))
|
| 18 |
-
self.preprocess = preprocess
|
| 19 |
-
|
| 20 |
-
def get_targets(self, x):
|
| 21 |
-
self.rvq.eval()
|
| 22 |
-
x = self.preprocess(x)
|
| 23 |
-
quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = self.rvq(x)
|
| 24 |
-
return codes.permute(1,0,2)
|
| 25 |
-
|
| 26 |
-
@torch.no_grad()
|
| 27 |
-
def encode_wavs(self, wavs):
|
| 28 |
-
wavs = wavs[..., :int((wavs.shape[-1]//320)*320)]
|
| 29 |
-
return self.get_targets(wavs)
|
| 30 |
-
|
| 31 |
-
def This_Music_ModelTarget_Config():
|
| 32 |
-
config = dict(
|
| 33 |
-
model = dict(
|
| 34 |
-
input_dim = 1024,
|
| 35 |
-
n_codebooks = 8,
|
| 36 |
-
codebook_size = 1024,
|
| 37 |
-
codebook_dim = 16,
|
| 38 |
-
quantizer_dropout = 0.0,
|
| 39 |
-
),
|
| 40 |
-
train = dict(
|
| 41 |
-
batch_size = 32,
|
| 42 |
-
num_workers = 6,
|
| 43 |
-
valid_interval = 10,
|
| 44 |
-
save_interval = 100,
|
| 45 |
-
max_updates = 500000,
|
| 46 |
-
lr = 1e-4,
|
| 47 |
-
# device = 'cuda:1',
|
| 48 |
-
loss = 'commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean()',
|
| 49 |
-
preprocess = PreprocessorWithModel(
|
| 50 |
-
model_dir= 'path/to/muq_fairseq',
|
| 51 |
-
checkpoint_dir='path/to/muq_m4a_75K.pt',
|
| 52 |
-
use_layer_idx=9,
|
| 53 |
-
)
|
| 54 |
-
),
|
| 55 |
-
pred = dict(
|
| 56 |
-
rvq_ckpt_path='path/to/runs/Aug07_18-09-24_ts-828fa13e58384d0bba4144fda78ecc92-launcher/ckpt/RVQ_8100.pth',
|
| 57 |
-
sr=24000,
|
| 58 |
-
data_jsonl_path='path/to/data/music4all/train.json',
|
| 59 |
-
save_target_dir= 'path/to/data/music4all_ark/reiter_musicssl_m4a',
|
| 60 |
-
),
|
| 61 |
-
)
|
| 62 |
-
return config
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
CLEN = 30
|
| 66 |
-
N_GPU_PER = 8
|
| 67 |
-
N_NODE = 4
|
| 68 |
-
|
| 69 |
-
def parse_lr(wave_length, sr):
|
| 70 |
-
n_step = int( wave_length // (sr*CLEN) )
|
| 71 |
-
if n_step == 0:
|
| 72 |
-
n_step = 1
|
| 73 |
-
print('wave_length: ', wave_length, 'sr: ', sr, 'n_step: ', n_step)
|
| 74 |
-
starts = torch.arange(n_step) * CLEN * sr
|
| 75 |
-
left_rights = torch.stack((starts, starts+CLEN*sr)).T
|
| 76 |
-
return left_rights[:10, ...]
|
| 77 |
-
|
| 78 |
-
@torch.no_grad()
|
| 79 |
-
def main(index, rank):
|
| 80 |
-
device = f'cuda:{rank}'
|
| 81 |
-
config = This_Music_ModelTarget_Config()
|
| 82 |
-
preprocess = config['train']['preprocess']
|
| 83 |
-
model = RVQ(
|
| 84 |
-
model_config = config['model'],
|
| 85 |
-
rvq_ckpt_path = config['pred']['rvq_ckpt_path'],
|
| 86 |
-
preprocess = preprocess
|
| 87 |
-
).to(device)
|
| 88 |
-
model.eval()
|
| 89 |
-
sr = config['pred']['sr']
|
| 90 |
-
|
| 91 |
-
fname_nobase = os.path.basename(config['pred']['data_jsonl_path']).split('.')[0]
|
| 92 |
-
scp_dir = os.path.join(config['pred']['save_target_dir'], 'scp')
|
| 93 |
-
ark_dir = os.path.join(config['pred']['save_target_dir'], 'ark')
|
| 94 |
-
os.makedirs(scp_dir, exist_ok=True)
|
| 95 |
-
os.makedirs(ark_dir, exist_ok=True)
|
| 96 |
-
|
| 97 |
-
scp_path = os.path.join(scp_dir, f'{fname_nobase}.{index}_{rank}.scp')
|
| 98 |
-
ark_path = os.path.join(ark_dir, f'{fname_nobase}.{index}_{rank}.ark')
|
| 99 |
-
|
| 100 |
-
from kaldiio import WriteHelper
|
| 101 |
-
|
| 102 |
-
with open(config['pred']['data_jsonl_path']) as f:
|
| 103 |
-
lines = f.readlines()
|
| 104 |
-
|
| 105 |
-
print("Total:", len(lines))
|
| 106 |
-
|
| 107 |
-
from tqdm import tqdm
|
| 108 |
-
import json
|
| 109 |
-
import librosa
|
| 110 |
-
import time
|
| 111 |
-
from einops import rearrange
|
| 112 |
-
import numpy as np
|
| 113 |
-
|
| 114 |
-
# lines = lines[(index*N_GPU_PER+rank)::(N_GPU_PER*N_NODE)]
|
| 115 |
-
|
| 116 |
-
with WriteHelper(f'ark,scp:{ark_path},{scp_path}') as writer:
|
| 117 |
-
for idx, line in tqdm(enumerate(lines)):
|
| 118 |
-
try:
|
| 119 |
-
if idx % (N_GPU_PER*N_NODE) != (index*N_GPU_PER+rank):
|
| 120 |
-
continue
|
| 121 |
-
item = json.loads(line)
|
| 122 |
-
path = item['path']
|
| 123 |
-
wave, _ = librosa.load(path, sr=sr)
|
| 124 |
-
wave = torch.from_numpy(wave)
|
| 125 |
-
wave_length = wave.shape[-1]
|
| 126 |
-
if wave_length < sr*CLEN:
|
| 127 |
-
continue
|
| 128 |
-
left_rights = parse_lr(wave_length, sr)
|
| 129 |
-
lr = left_rights.tolist()
|
| 130 |
-
wavs = torch.stack(
|
| 131 |
-
[wave[l:r] for l,r in lr]
|
| 132 |
-
).to(device)
|
| 133 |
-
targets = model.encode_wavs(wavs) # [Codebook=8, N_Steps, Feature]
|
| 134 |
-
|
| 135 |
-
final_target = rearrange(targets, "c n f -> n (c f)").cpu().numpy().astype(np.int32)
|
| 136 |
-
for j in range(final_target.shape[0]):
|
| 137 |
-
writer(f'{idx}:{j}', final_target[j])
|
| 138 |
-
except Exception as e:
|
| 139 |
-
print(e)
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
if __name__ == '__main__':
|
| 143 |
-
import sys
|
| 144 |
-
index = int(sys.argv[1])
|
| 145 |
-
import multiprocessing
|
| 146 |
-
pool = multiprocessing.Pool(processes=N_GPU_PER)
|
| 147 |
-
for rank in range(8):
|
| 148 |
-
pool.apply_async(main, (index, rank))
|
| 149 |
-
pool.close()
|
| 150 |
-
pool.join()
|
| 151 |
-
print("Done.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MuCodec/muq_dev/muq_fairseq/models/muq/model/rvq.py
DELETED
|
@@ -1,459 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
from typing import Union
|
| 3 |
-
|
| 4 |
-
import numpy as np
|
| 5 |
-
import torch
|
| 6 |
-
import torch.nn as nn
|
| 7 |
-
import torch.nn.functional as F
|
| 8 |
-
from einops import rearrange
|
| 9 |
-
from torch.nn.utils import weight_norm
|
| 10 |
-
|
| 11 |
-
def WNConv1d(*args, **kwargs):
|
| 12 |
-
return weight_norm(nn.Conv1d(*args, **kwargs))
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
class VectorQuantize(nn.Module):
|
| 16 |
-
"""
|
| 17 |
-
Implementation of VQ similar to Karpathy's repo:
|
| 18 |
-
https://github.com/karpathy/deep-vector-quantization
|
| 19 |
-
Additionally uses following tricks from Improved VQGAN
|
| 20 |
-
(https://arxiv.org/pdf/2110.04627.pdf):
|
| 21 |
-
1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
|
| 22 |
-
for improved codebook usage
|
| 23 |
-
2. l2-normalized codes: Converts euclidean distance to cosine similarity which
|
| 24 |
-
improves training stability
|
| 25 |
-
"""
|
| 26 |
-
|
| 27 |
-
def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int, stale_tolerance: int = 1000, mfcc_clustering=False, n_layer=1):
|
| 28 |
-
super().__init__()
|
| 29 |
-
self.codebook_size = codebook_size
|
| 30 |
-
self.codebook_dim = codebook_dim
|
| 31 |
-
self.mfcc_clustering = mfcc_clustering
|
| 32 |
-
|
| 33 |
-
ProjClass = nn.Identity if mfcc_clustering else WNConv1d
|
| 34 |
-
if n_layer==1:
|
| 35 |
-
self.in_proj = ProjClass(input_dim, codebook_dim, kernel_size=1)
|
| 36 |
-
self.out_proj = ProjClass(codebook_dim, input_dim, kernel_size=1)
|
| 37 |
-
elif n_layer >= 2:
|
| 38 |
-
ndim_hidden = 128
|
| 39 |
-
self.in_proj = nn.Sequential(
|
| 40 |
-
ProjClass(input_dim, ndim_hidden, kernel_size=1),
|
| 41 |
-
*[nn.Sequential(nn.ReLU(), ProjClass(ndim_hidden, ndim_hidden, kernel_size=1),) for _ in range(n_layer-2)],
|
| 42 |
-
nn.ReLU(),
|
| 43 |
-
ProjClass(ndim_hidden, codebook_dim, kernel_size=1)
|
| 44 |
-
)
|
| 45 |
-
self.out_proj = nn.Sequential(
|
| 46 |
-
ProjClass(codebook_dim, ndim_hidden, kernel_size=1),
|
| 47 |
-
nn.ReLU(),
|
| 48 |
-
*[nn.Sequential(ProjClass(ndim_hidden, ndim_hidden, kernel_size=1), nn.ReLU()) for _ in range(n_layer-2)],
|
| 49 |
-
ProjClass(ndim_hidden, input_dim, kernel_size=1),
|
| 50 |
-
)
|
| 51 |
-
self.codebook = nn.Embedding(codebook_size, codebook_dim)
|
| 52 |
-
self.register_buffer("stale_counter", torch.zeros(self.codebook_size,))
|
| 53 |
-
self.stale_tolerance = stale_tolerance
|
| 54 |
-
|
| 55 |
-
def forward(self, z):
|
| 56 |
-
"""Quantized the input tensor using a fixed codebook and returns
|
| 57 |
-
the corresponding codebook vectors
|
| 58 |
-
|
| 59 |
-
Parameters
|
| 60 |
-
----------
|
| 61 |
-
z : Tensor[B x D x T]
|
| 62 |
-
|
| 63 |
-
Returns
|
| 64 |
-
-------
|
| 65 |
-
Tensor[B x D x T]
|
| 66 |
-
Quantized continuous representation of input
|
| 67 |
-
Tensor[1]
|
| 68 |
-
Commitment loss to train encoder to predict vectors closer to codebook
|
| 69 |
-
entries
|
| 70 |
-
Tensor[1]
|
| 71 |
-
Codebook loss to update the codebook
|
| 72 |
-
Tensor[B x T]
|
| 73 |
-
Codebook indices (quantized discrete representation of input)
|
| 74 |
-
Tensor[B x D x T]
|
| 75 |
-
Projected latents (continuous representation of input before quantization)
|
| 76 |
-
"""
|
| 77 |
-
|
| 78 |
-
# Factorized codes (ViT-VQGAN) Project input into low-dimensional space
|
| 79 |
-
|
| 80 |
-
z_e = self.in_proj(z) # z_e : (B x D x T)
|
| 81 |
-
z_q, indices = self.decode_latents(z_e)
|
| 82 |
-
|
| 83 |
-
commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
| 84 |
-
codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
| 85 |
-
|
| 86 |
-
z_q = (
|
| 87 |
-
z_e + (z_q - z_e).detach()
|
| 88 |
-
) # noop in forward pass, straight-through gradient estimator in backward pass
|
| 89 |
-
|
| 90 |
-
z_q = self.out_proj(z_q)
|
| 91 |
-
|
| 92 |
-
return z_q, commitment_loss, codebook_loss, indices, z_e
|
| 93 |
-
|
| 94 |
-
def embed_code(self, embed_id):
|
| 95 |
-
return F.embedding(embed_id, self.codebook.weight)
|
| 96 |
-
|
| 97 |
-
def decode_code(self, embed_id):
|
| 98 |
-
return self.embed_code(embed_id).transpose(1, 2)
|
| 99 |
-
|
| 100 |
-
def decode_latents(self, latents):
|
| 101 |
-
encodings = rearrange(latents, "b d t -> (b t) d")
|
| 102 |
-
codebook = self.codebook.weight # codebook: (N x D)
|
| 103 |
-
|
| 104 |
-
# L2 normalize encodings and codebook (ViT-VQGAN)
|
| 105 |
-
encodings = F.normalize(encodings)
|
| 106 |
-
codebook = F.normalize(codebook)
|
| 107 |
-
|
| 108 |
-
# Compute euclidean distance with codebook
|
| 109 |
-
dist = (
|
| 110 |
-
encodings.pow(2).sum(1, keepdim=True)
|
| 111 |
-
- 2 * encodings @ codebook.t()
|
| 112 |
-
+ codebook.pow(2).sum(1, keepdim=True).t()
|
| 113 |
-
)
|
| 114 |
-
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
| 115 |
-
z_q = self.decode_code(indices)
|
| 116 |
-
|
| 117 |
-
if(self.training):
|
| 118 |
-
onehots = torch.nn.functional.one_hot(indices, self.codebook_size).float() # B, T, codebook_size
|
| 119 |
-
stale_codes = (onehots.sum(0).sum(0) == 0).float()
|
| 120 |
-
self.stale_counter = self.stale_counter * stale_codes + stale_codes
|
| 121 |
-
|
| 122 |
-
# random replace codes that haven't been used for a while
|
| 123 |
-
replace_code = (self.stale_counter == self.stale_tolerance).float() # codebook_size
|
| 124 |
-
if replace_code.sum(-1) > 0:
|
| 125 |
-
print("Replace {} codes".format(replace_code.sum(-1)))
|
| 126 |
-
random_input_idx = torch.randperm(encodings.shape[0])
|
| 127 |
-
random_input = encodings[random_input_idx].view(encodings.shape)
|
| 128 |
-
if random_input.shape[0] < self.codebook_size:
|
| 129 |
-
random_input = torch.cat([random_input]*(self.codebook_size // random_input.shape[0] + 1), 0)
|
| 130 |
-
random_input = random_input[:self.codebook_size,:].contiguous() # codebook_size, dim
|
| 131 |
-
|
| 132 |
-
self.codebook.weight.data = self.codebook.weight.data * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1)
|
| 133 |
-
self.stale_counter = self.stale_counter * (1 - replace_code)
|
| 134 |
-
|
| 135 |
-
return z_q, indices
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
class ResidualVectorQuantize(nn.Module):
|
| 139 |
-
"""
|
| 140 |
-
Introduced in SoundStream: An end2end neural audio codec
|
| 141 |
-
https://arxiv.org/abs/2107.03312
|
| 142 |
-
"""
|
| 143 |
-
|
| 144 |
-
def __init__(
|
| 145 |
-
self,
|
| 146 |
-
input_dim: int = 512,
|
| 147 |
-
n_codebooks: int = 9,
|
| 148 |
-
codebook_size: int = 1024,
|
| 149 |
-
codebook_dim: Union[int, list] = 8,
|
| 150 |
-
quantizer_dropout: float = 0.0,
|
| 151 |
-
stale_tolerance: int = 100,
|
| 152 |
-
use_multi_layer_num:int = 1,
|
| 153 |
-
):
|
| 154 |
-
super().__init__()
|
| 155 |
-
if isinstance(codebook_dim, int):
|
| 156 |
-
codebook_dim = [codebook_dim for _ in range(n_codebooks)]
|
| 157 |
-
|
| 158 |
-
self.n_codebooks = n_codebooks
|
| 159 |
-
self.codebook_dim = codebook_dim
|
| 160 |
-
self.codebook_size = codebook_size
|
| 161 |
-
|
| 162 |
-
self.quantizers = nn.ModuleList(
|
| 163 |
-
[
|
| 164 |
-
VectorQuantize(input_dim, codebook_size, codebook_dim[i], stale_tolerance=stale_tolerance, n_layer=use_multi_layer_num)
|
| 165 |
-
for i in range(n_codebooks)
|
| 166 |
-
]
|
| 167 |
-
)
|
| 168 |
-
self.quantizer_dropout = quantizer_dropout
|
| 169 |
-
|
| 170 |
-
def forward(self, z, n_quantizers: int = None):
|
| 171 |
-
"""Quantized the input tensor using a fixed set of `n` codebooks and returns
|
| 172 |
-
the corresponding codebook vectors
|
| 173 |
-
Parameters
|
| 174 |
-
----------
|
| 175 |
-
z : Tensor[B x D x T]
|
| 176 |
-
n_quantizers : int, optional
|
| 177 |
-
No. of quantizers to use
|
| 178 |
-
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
|
| 179 |
-
Note: if `self.quantizer_dropout` is True, this argument is ignored
|
| 180 |
-
when in training mode, and a random number of quantizers is used.
|
| 181 |
-
Returns
|
| 182 |
-
-------
|
| 183 |
-
dict
|
| 184 |
-
A dictionary with the following keys:
|
| 185 |
-
|
| 186 |
-
"z" : Tensor[B x D x T]
|
| 187 |
-
Quantized continuous representation of input
|
| 188 |
-
"codes" : Tensor[B x N x T]
|
| 189 |
-
Codebook indices for each codebook
|
| 190 |
-
(quantized discrete representation of input)
|
| 191 |
-
"latents" : Tensor[B x N*D x T]
|
| 192 |
-
Projected latents (continuous representation of input before quantization)
|
| 193 |
-
"vq/commitment_loss" : Tensor[1]
|
| 194 |
-
Commitment loss to train encoder to predict vectors closer to codebook
|
| 195 |
-
entries
|
| 196 |
-
"vq/codebook_loss" : Tensor[1]
|
| 197 |
-
Codebook loss to update the codebook
|
| 198 |
-
"""
|
| 199 |
-
z_q = 0
|
| 200 |
-
residual = z
|
| 201 |
-
commitment_loss = 0
|
| 202 |
-
codebook_loss = 0
|
| 203 |
-
|
| 204 |
-
codebook_indices = []
|
| 205 |
-
latents = []
|
| 206 |
-
|
| 207 |
-
if n_quantizers is None:
|
| 208 |
-
n_quantizers = self.n_codebooks
|
| 209 |
-
if self.training:
|
| 210 |
-
n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
|
| 211 |
-
dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
|
| 212 |
-
n_dropout = int(z.shape[0] * self.quantizer_dropout)
|
| 213 |
-
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
| 214 |
-
n_quantizers = n_quantizers.to(z.device)
|
| 215 |
-
else:
|
| 216 |
-
n_quantizers = torch.ones((z.shape[0],)) * n_quantizers + 1
|
| 217 |
-
n_quantizers = n_quantizers.to(z.device)
|
| 218 |
-
|
| 219 |
-
for i, quantizer in enumerate(self.quantizers):
|
| 220 |
-
# if self.training is False and i >= n_quantizers:
|
| 221 |
-
# break
|
| 222 |
-
|
| 223 |
-
z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
|
| 224 |
-
residual
|
| 225 |
-
)
|
| 226 |
-
|
| 227 |
-
# Create mask to apply quantizer dropout
|
| 228 |
-
mask = (
|
| 229 |
-
torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
|
| 230 |
-
)
|
| 231 |
-
z_q = z_q + z_q_i * mask[:, None, None]
|
| 232 |
-
residual = residual - z_q_i
|
| 233 |
-
|
| 234 |
-
# Sum losses
|
| 235 |
-
commitment_loss += (commitment_loss_i * mask).mean()
|
| 236 |
-
codebook_loss += (codebook_loss_i * mask).mean()
|
| 237 |
-
|
| 238 |
-
codebook_indices.append(indices_i)
|
| 239 |
-
latents.append(z_e_i)
|
| 240 |
-
|
| 241 |
-
codes = torch.stack(codebook_indices, dim=1)
|
| 242 |
-
latents = torch.cat(latents, dim=1)
|
| 243 |
-
|
| 244 |
-
encodings = F.one_hot(codes, self.codebook_size).float() # B N T 1024
|
| 245 |
-
|
| 246 |
-
return z_q, codes, latents, commitment_loss, codebook_loss, n_quantizers.clamp(max=self.n_codebooks).long() - 1
|
| 247 |
-
|
| 248 |
-
def from_codes(self, codes: torch.Tensor):
|
| 249 |
-
"""Given the quantized codes, reconstruct the continuous representation
|
| 250 |
-
Parameters
|
| 251 |
-
----------
|
| 252 |
-
codes : Tensor[B x N x T]
|
| 253 |
-
Quantized discrete representation of input
|
| 254 |
-
Returns
|
| 255 |
-
-------
|
| 256 |
-
Tensor[B x D x T]
|
| 257 |
-
Quantized continuous representation of input
|
| 258 |
-
"""
|
| 259 |
-
z_q = 0.0
|
| 260 |
-
z_p = []
|
| 261 |
-
n_codebooks = codes.shape[1]
|
| 262 |
-
for i in range(n_codebooks):
|
| 263 |
-
z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
|
| 264 |
-
z_p.append(z_p_i)
|
| 265 |
-
|
| 266 |
-
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
| 267 |
-
z_q = z_q + z_q_i
|
| 268 |
-
return z_q, torch.cat(z_p, dim=1), codes
|
| 269 |
-
|
| 270 |
-
def from_latents(self, latents: torch.Tensor):
|
| 271 |
-
"""Given the unquantized latents, reconstruct the
|
| 272 |
-
continuous representation after quantization.
|
| 273 |
-
|
| 274 |
-
Parameters
|
| 275 |
-
----------
|
| 276 |
-
latents : Tensor[B x N x T]
|
| 277 |
-
Continuous representation of input after projection
|
| 278 |
-
|
| 279 |
-
Returns
|
| 280 |
-
-------
|
| 281 |
-
Tensor[B x D x T]
|
| 282 |
-
Quantized representation of full-projected space
|
| 283 |
-
Tensor[B x D x T]
|
| 284 |
-
Quantized representation of latent space
|
| 285 |
-
"""
|
| 286 |
-
z_q = 0
|
| 287 |
-
z_p = []
|
| 288 |
-
codes = []
|
| 289 |
-
dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
|
| 290 |
-
|
| 291 |
-
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
|
| 292 |
-
0
|
| 293 |
-
]
|
| 294 |
-
for i in range(n_codebooks):
|
| 295 |
-
j, k = dims[i], dims[i + 1]
|
| 296 |
-
z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
|
| 297 |
-
z_p.append(z_p_i)
|
| 298 |
-
codes.append(codes_i)
|
| 299 |
-
|
| 300 |
-
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
| 301 |
-
z_q = z_q + z_q_i
|
| 302 |
-
|
| 303 |
-
return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
|
| 304 |
-
|
| 305 |
-
from torch.utils.data import Dataset, DataLoader
|
| 306 |
-
import json, traceback
|
| 307 |
-
import torchaudio
|
| 308 |
-
import math
|
| 309 |
-
|
| 310 |
-
from typing import List, Tuple, Dict, Any
|
| 311 |
-
|
| 312 |
-
CLIPSECS = 5
|
| 313 |
-
def load_audio_by_json(json_path, max_keep, min_keep, tgt_sample_rate):
|
| 314 |
-
# read json file
|
| 315 |
-
print(json_path)
|
| 316 |
-
datas = []
|
| 317 |
-
inds = []
|
| 318 |
-
sizes = []
|
| 319 |
-
with open(json_path) as fp:
|
| 320 |
-
for ind,line in enumerate(fp):
|
| 321 |
-
data = json.loads(line)
|
| 322 |
-
datas.append(data)
|
| 323 |
-
inds.append(ind)
|
| 324 |
-
# sz = int(data['duration'] * data['sample_rate'])
|
| 325 |
-
sz = int(tgt_sample_rate * CLIPSECS)
|
| 326 |
-
sizes.append(sz)
|
| 327 |
-
tot = ind + 1
|
| 328 |
-
return datas,inds,tot,sizes
|
| 329 |
-
|
| 330 |
-
class Read_and_PadCrop_Normalized_T(torch.nn.Module):
|
| 331 |
-
def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True):
|
| 332 |
-
|
| 333 |
-
super().__init__()
|
| 334 |
-
|
| 335 |
-
self.n_samples = n_samples
|
| 336 |
-
self.sample_rate = sample_rate
|
| 337 |
-
self.randomize = randomize
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
def __call__(self, filename: str, duration: float, cur_sample_rate: int) -> Tuple[torch.Tensor, float, float, int, int]:
|
| 341 |
-
if(duration<(float(self.n_samples)/self.sample_rate+1)):
|
| 342 |
-
# print(duration,(float(self.n_samples)/self.sample_rate+1))
|
| 343 |
-
chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1)
|
| 344 |
-
t_start = 0.
|
| 345 |
-
t_end = min(1.0, float(self.n_samples) / float(self.sample_rate) / duration)
|
| 346 |
-
offset = 0
|
| 347 |
-
# print('c1:',chunk.shape)
|
| 348 |
-
else:
|
| 349 |
-
offset = np.random.randint(0,int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
|
| 350 |
-
t_start = offset / float(cur_sample_rate) / duration
|
| 351 |
-
t_end = t_start + float(self.n_samples) / float(self.sample_rate) / duration
|
| 352 |
-
chunk, _ = torchaudio.load(filename, frame_offset=offset, num_frames=int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
|
| 353 |
-
# print('offset:',offset)
|
| 354 |
-
# print('c0:',chunk.shape)
|
| 355 |
-
# Pad with silence if necessary.
|
| 356 |
-
if(chunk.shape[0]>1):
|
| 357 |
-
chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
|
| 358 |
-
else:
|
| 359 |
-
chunk = chunk[[0],:].float()
|
| 360 |
-
if(cur_sample_rate!=self.sample_rate):
|
| 361 |
-
# print('a:',cur_sample_rate,chunk.shape)
|
| 362 |
-
chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sample_rate)
|
| 363 |
-
# print('b:',self.sample_rate,chunk.shape)
|
| 364 |
-
if chunk.shape[-1] < self.n_samples:
|
| 365 |
-
chunk = torch.cat([chunk, torch.zeros((1, self.n_samples - chunk.shape[-1],))],-1)
|
| 366 |
-
else:
|
| 367 |
-
chunk = chunk[:,0:self.n_samples]
|
| 368 |
-
seconds_start = math.floor(offset / cur_sample_rate)
|
| 369 |
-
seconds_total = math.floor(duration)
|
| 370 |
-
|
| 371 |
-
return (
|
| 372 |
-
chunk,
|
| 373 |
-
t_start,
|
| 374 |
-
t_end,
|
| 375 |
-
seconds_start,
|
| 376 |
-
seconds_total
|
| 377 |
-
)
|
| 378 |
-
|
| 379 |
-
class RVQDataset(Dataset):
|
| 380 |
-
def __init__(
|
| 381 |
-
self,
|
| 382 |
-
manifest_path: str,
|
| 383 |
-
sample_rate: float,
|
| 384 |
-
normalize: bool = False,
|
| 385 |
-
):
|
| 386 |
-
self.sample_rate = sample_rate
|
| 387 |
-
self.datas,inds,tot,self.sizes = load_audio_by_json(manifest_path, None, None, self.sample_rate)
|
| 388 |
-
self.dataset_len = len(self.datas)
|
| 389 |
-
|
| 390 |
-
self.reader = Read_and_PadCrop_Normalized_T(n_samples=CLIPSECS*sample_rate,sample_rate = self.sample_rate)
|
| 391 |
-
self.normalize = normalize
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
def __getitem__(self, i):
|
| 395 |
-
# WORLD_SIZE = int(torch.distributed.get_world_size())
|
| 396 |
-
# WORLD_RANK = int(torch.distributed.get_rank())
|
| 397 |
-
# np.random.seed(1337 + self.epoch * WORLD_SIZE + WORLD_RANK + i)
|
| 398 |
-
# index = random.randint(0,len(self.sizes) - 1)
|
| 399 |
-
index = i
|
| 400 |
-
item = None
|
| 401 |
-
while item is None:
|
| 402 |
-
try:
|
| 403 |
-
wav = self.get_audio_by_slice(index)
|
| 404 |
-
# labels = self.get_labels(index)
|
| 405 |
-
# labels = None
|
| 406 |
-
# item = {"id": index, "source": wav, "label_list": labels}
|
| 407 |
-
item = {"id": index, "source": wav}
|
| 408 |
-
except Exception as e:
|
| 409 |
-
# print(e)
|
| 410 |
-
traceback.print_exc()
|
| 411 |
-
print(f'skip damaged data {index}')
|
| 412 |
-
index = np.random.randint(0,len(self.sizes)-1)
|
| 413 |
-
return item
|
| 414 |
-
|
| 415 |
-
def __len__(self):
|
| 416 |
-
return self.dataset_len
|
| 417 |
-
|
| 418 |
-
def get_audio_by_slice(self,index):
|
| 419 |
-
|
| 420 |
-
wav_path = self.datas[index]['path']
|
| 421 |
-
# print(wav_path)
|
| 422 |
-
audio_info = torchaudio.info(wav_path)
|
| 423 |
-
origin_sample_rate = audio_info.sample_rate
|
| 424 |
-
origin_duration = audio_info.num_frames / origin_sample_rate
|
| 425 |
-
|
| 426 |
-
wav, *ignored = self.reader(wav_path, origin_duration,origin_sample_rate)
|
| 427 |
-
wav = wav.float()
|
| 428 |
-
|
| 429 |
-
# _path, slice_ptr = parse_path(wav_path)
|
| 430 |
-
# original way
|
| 431 |
-
# if len(slice_ptr) == 0:
|
| 432 |
-
# wav, cur_sample_rate = sf.read(_path)
|
| 433 |
-
# else:
|
| 434 |
-
# assert _path.endswith(".zip")
|
| 435 |
-
# data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
|
| 436 |
-
# f = io.BytesIO(data)
|
| 437 |
-
# wav, cur_sample_rate = sf.read(f)
|
| 438 |
-
# wav = torch.from_numpy(wav).float()
|
| 439 |
-
# print(wav.shape)
|
| 440 |
-
wav = wav.permute(1,0)
|
| 441 |
-
wav = self.postprocess(wav, self.sample_rate)
|
| 442 |
-
# print(wav.shape)
|
| 443 |
-
|
| 444 |
-
# wav = wav.squeeze(0)
|
| 445 |
-
return wav
|
| 446 |
-
|
| 447 |
-
def postprocess(self, wav, cur_sample_rate):
|
| 448 |
-
if wav.dim() == 2:
|
| 449 |
-
wav = wav.mean(-1)
|
| 450 |
-
assert wav.dim() == 1, wav.dim()
|
| 451 |
-
|
| 452 |
-
if cur_sample_rate != self.sample_rate:
|
| 453 |
-
raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
|
| 454 |
-
|
| 455 |
-
if self.normalize:
|
| 456 |
-
with torch.no_grad():
|
| 457 |
-
wav = F.layer_norm(wav, wav.shape)
|
| 458 |
-
return wav
|
| 459 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MuCodec/muq_dev/muq_fairseq/models/muq/model/rvq_muq.py
DELETED
|
@@ -1,394 +0,0 @@
|
|
| 1 |
-
try:
|
| 2 |
-
from .rvq import *
|
| 3 |
-
except:
|
| 4 |
-
import sys, os
|
| 5 |
-
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 6 |
-
from rvq import *
|
| 7 |
-
|
| 8 |
-
try:
|
| 9 |
-
from ..modules.random_quantizer import RandomProjectionQuantizer
|
| 10 |
-
from ..modules.features import MelSTFT
|
| 11 |
-
from ..modules.conv import Conv2dSubsampling
|
| 12 |
-
except:
|
| 13 |
-
import sys, os
|
| 14 |
-
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
| 15 |
-
from modules.random_quantizer import RandomProjectionQuantizer
|
| 16 |
-
from modules.features import MelSTFT
|
| 17 |
-
from modules.conv import Conv2dSubsampling
|
| 18 |
-
|
| 19 |
-
import fairseq
|
| 20 |
-
|
| 21 |
-
CLIPSECS = 5 # 5 for rvq, 30 for model
|
| 22 |
-
|
| 23 |
-
class RVQDataset(Dataset):
|
| 24 |
-
def __init__(
|
| 25 |
-
self,
|
| 26 |
-
manifest_path: str,
|
| 27 |
-
sample_rate: float,
|
| 28 |
-
normalize: bool = False,
|
| 29 |
-
):
|
| 30 |
-
self.sample_rate = sample_rate
|
| 31 |
-
self.datas,inds,tot,self.sizes = load_audio_by_json(manifest_path, None, None, self.sample_rate)
|
| 32 |
-
self.dataset_len = len(self.datas)
|
| 33 |
-
|
| 34 |
-
self.reader = Read_and_PadCrop_Normalized_T(n_samples=CLIPSECS*sample_rate,sample_rate = self.sample_rate)
|
| 35 |
-
self.normalize = normalize
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
def __getitem__(self, i):
|
| 39 |
-
# WORLD_SIZE = int(torch.distributed.get_world_size())
|
| 40 |
-
# WORLD_RANK = int(torch.distributed.get_rank())
|
| 41 |
-
# np.random.seed(1337 + self.epoch * WORLD_SIZE + WORLD_RANK + i)
|
| 42 |
-
# index = random.randint(0,len(self.sizes) - 1)
|
| 43 |
-
index = i
|
| 44 |
-
item = None
|
| 45 |
-
while item is None:
|
| 46 |
-
try:
|
| 47 |
-
wav = self.get_audio_by_slice(index)
|
| 48 |
-
item = {"id": index, "source": wav}
|
| 49 |
-
except Exception as e:
|
| 50 |
-
# print(e)
|
| 51 |
-
traceback.print_exc()
|
| 52 |
-
print(f'skip damaged data {index}')
|
| 53 |
-
index = np.random.randint(0,len(self.sizes)-1)
|
| 54 |
-
return item
|
| 55 |
-
|
| 56 |
-
def __len__(self):
|
| 57 |
-
return self.dataset_len
|
| 58 |
-
|
| 59 |
-
def get_audio_by_slice(self,index):
|
| 60 |
-
|
| 61 |
-
wav_path = self.datas[index]['path']
|
| 62 |
-
audio_info = torchaudio.info(wav_path)
|
| 63 |
-
origin_sample_rate = audio_info.sample_rate
|
| 64 |
-
origin_duration = audio_info.num_frames / origin_sample_rate
|
| 65 |
-
|
| 66 |
-
wav, *ignored = self.reader(wav_path, origin_duration,origin_sample_rate)
|
| 67 |
-
wav = wav.float()
|
| 68 |
-
|
| 69 |
-
# _path, slice_ptr = parse_path(wav_path)
|
| 70 |
-
# original way
|
| 71 |
-
# if len(slice_ptr) == 0:
|
| 72 |
-
# wav, cur_sample_rate = sf.read(_path)
|
| 73 |
-
# else:
|
| 74 |
-
# assert _path.endswith(".zip")
|
| 75 |
-
# data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
|
| 76 |
-
# f = io.BytesIO(data)
|
| 77 |
-
# wav, cur_sample_rate = sf.read(f)
|
| 78 |
-
# wav = torch.from_numpy(wav).float()
|
| 79 |
-
# print(wav.shape)
|
| 80 |
-
wav = wav.permute(1,0)
|
| 81 |
-
wav = self.postprocess(wav, self.sample_rate)
|
| 82 |
-
# print(wav.shape)
|
| 83 |
-
|
| 84 |
-
# wav = wav.squeeze(0)
|
| 85 |
-
return wav
|
| 86 |
-
|
| 87 |
-
def postprocess(self, wav, cur_sample_rate):
|
| 88 |
-
if wav.dim() == 2:
|
| 89 |
-
wav = wav.mean(-1)
|
| 90 |
-
assert wav.dim() == 1, wav.dim()
|
| 91 |
-
|
| 92 |
-
if cur_sample_rate != self.sample_rate:
|
| 93 |
-
raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
|
| 94 |
-
|
| 95 |
-
if self.normalize:
|
| 96 |
-
with torch.no_grad():
|
| 97 |
-
wav = F.layer_norm(wav, wav.shape)
|
| 98 |
-
return wav
|
| 99 |
-
|
| 100 |
-
class Preprocessor(nn.Module):
|
| 101 |
-
def __init__(self,
|
| 102 |
-
codebook_dim=16,
|
| 103 |
-
codebook_size=4096,
|
| 104 |
-
hop_length=240,
|
| 105 |
-
n_mels=128,
|
| 106 |
-
stat_path=None,
|
| 107 |
-
is_spec_wise=False,
|
| 108 |
-
s=4,
|
| 109 |
-
) -> None:
|
| 110 |
-
super().__init__()
|
| 111 |
-
|
| 112 |
-
self.features=["melspec_2048"]
|
| 113 |
-
self.s = s
|
| 114 |
-
|
| 115 |
-
# load feature mean / std stats
|
| 116 |
-
import os
|
| 117 |
-
if stat_path is not None and os.path.exists(stat_path):
|
| 118 |
-
with open(stat_path, "r") as f:
|
| 119 |
-
self.stat = json.load(f)
|
| 120 |
-
else:
|
| 121 |
-
# print("No stats file found at `{}`, use default from msd.".format(stat_path))
|
| 122 |
-
self.stat = {"spec_256_cnt": 14394344256, "spec_256_mean": -23.34296658431829, "spec_256_std": 26.189295587132637, "spec_512_cnt": 28677104448, "spec_512_mean": -21.31267396860235, "spec_512_std": 26.52644536245769, "spec_1024_cnt": 57242624832, "spec_1024_mean": -18.852271129208273, "spec_1024_std": 26.443154583585663, "spec_2048_cnt": 114373665600, "spec_2048_mean": -15.638743433896792, "spec_2048_std": 26.115825961611545, "spec_4096_cnt": 228635747136, "spec_4096_mean": -11.715532502794836, "spec_4096_std": 25.763972210234062, "melspec_256_cnt": 14282760192, "melspec_256_mean": -26.962600400166156, "melspec_256_std": 36.13614100912126, "melspec_512_cnt": 14282760192, "melspec_512_mean": -9.108344167718862, "melspec_512_std": 24.71910937988429, "melspec_1024_cnt": 14282760192, "melspec_1024_mean": 0.37302579246531126, "melspec_1024_std": 18.684082325919388, "melspec_2048_cnt": 14282760192, "melspec_2048_mean": 6.768444971712967, "melspec_2048_std": 18.417922652295623, "melspec_4096_cnt": 14282760192, "melspec_4096_mean": 13.617164614990036, "melspec_4096_std": 18.08552130124525, "cqt_cnt": 9373061376, "cqt_mean": 0.46341379757927165, "cqt_std": 0.9543998080910191, "mfcc_256_cnt": 1339008768, "mfcc_256_mean": -11.681755459447485, "mfcc_256_std": 29.183186444668316, "mfcc_512_cnt": 1339008768, "mfcc_512_mean": -2.540581461792183, "mfcc_512_std": 31.93752185832081, "mfcc_1024_cnt": 1339008768, "mfcc_1024_mean": 6.606636263169779, "mfcc_1024_std": 34.151644801729624, "mfcc_2048_cnt": 1339008768, "mfcc_2048_mean": 5.281600844245184, "mfcc_2048_std": 33.12784541220003, "mfcc_4096_cnt": 1339008768, "mfcc_4096_mean": 4.7616569480166095, "mfcc_4096_std": 32.61458906894133, "chromagram_256_cnt": 1339008768, "chromagram_256_mean": 55.15596556703181, "chromagram_256_std": 73.91858278719991, "chromagram_512_cnt": 1339008768, "chromagram_512_mean": 175.73092252759895, "chromagram_512_std": 248.48485148525953, "chromagram_1024_cnt": 1339008768, "chromagram_1024_mean": 589.2947481634608, "chromagram_1024_std": 913.857929063196, "chromagram_2048_cnt": 1339008768, "chromagram_2048_mean": 2062.286388327397, "chromagram_2048_std": 3458.92657915397, "chromagram_4096_cnt": 1339008768, "chromagram_4096_mean": 7673.039107997085, "chromagram_4096_std": 13009.883158267234}
|
| 123 |
-
|
| 124 |
-
# feature extractor
|
| 125 |
-
self.preprocessor_melspec_2048 = MelSTFT(
|
| 126 |
-
n_fft=2048, hop_length=hop_length, is_db=True
|
| 127 |
-
)
|
| 128 |
-
|
| 129 |
-
self.is_spec_wise = is_spec_wise
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
@torch.no_grad()
|
| 133 |
-
def normalize(self, x):
|
| 134 |
-
"""normalize the input audio to have zero mean unit variance"""
|
| 135 |
-
for key in x.keys():
|
| 136 |
-
x[key] = (x[key] - self.stat["%s_mean" % key]) / self.stat["%s_std" % key] # {'melspec_2048_cnt': 14282760192, 'melspec_2048_mean': 6.768444971712967}
|
| 137 |
-
return x
|
| 138 |
-
|
| 139 |
-
@torch.no_grad()
|
| 140 |
-
def rearrange(self, x):
|
| 141 |
-
"""rearrange the batch to flatten every 4 steps"""
|
| 142 |
-
for key in x.keys():
|
| 143 |
-
if key == "chromagram":
|
| 144 |
-
x[key] = rearrange(x[key], "b f t -> b t f")
|
| 145 |
-
else:
|
| 146 |
-
x[key] = rearrange(x[key], "b f (t s) -> b t (s f)", s=self.s)
|
| 147 |
-
return x
|
| 148 |
-
|
| 149 |
-
@torch.no_grad()
|
| 150 |
-
def preprocessing(self, x, features):
|
| 151 |
-
"""extract classic audio features"""
|
| 152 |
-
# check precision
|
| 153 |
-
if x.dtype == torch.float16:
|
| 154 |
-
precision = 16
|
| 155 |
-
else:
|
| 156 |
-
precision = 32
|
| 157 |
-
|
| 158 |
-
out = {}
|
| 159 |
-
for key in features:
|
| 160 |
-
layer = getattr(self, "preprocessor_%s" % key)
|
| 161 |
-
out[key] = layer.float()(x.float())[..., :-1]
|
| 162 |
-
if precision == 16:
|
| 163 |
-
out[key] = out[key].half()
|
| 164 |
-
return out
|
| 165 |
-
|
| 166 |
-
@torch.no_grad()
|
| 167 |
-
def tokenize(self, x):
|
| 168 |
-
out = {}
|
| 169 |
-
for key in x.keys():
|
| 170 |
-
layer = getattr(self, "quantizer_%s" % key)
|
| 171 |
-
out[key] = layer(x[key])
|
| 172 |
-
return out
|
| 173 |
-
|
| 174 |
-
def to_spec_wise(self, x):
|
| 175 |
-
Batch, Spec, Time = x.shape
|
| 176 |
-
SubSpec, N_SubSpec = 16, 8
|
| 177 |
-
assert SubSpec * N_SubSpec == Spec == 128
|
| 178 |
-
x = rearrange(x, "b (n s) t -> b s (n t)", n=N_SubSpec, s=SubSpec)
|
| 179 |
-
return x # [Batch, SubSpec=16, N_SubSpec*Time=8*100Hz]
|
| 180 |
-
|
| 181 |
-
@torch.no_grad()
|
| 182 |
-
def __call__(self, x):
|
| 183 |
-
x = self.preprocessing(x, features=self.features) # -> {'melspec_2048': Tensor{Size([3, 128, 3000]) cuda:0 f32}}
|
| 184 |
-
x = self.normalize(x)
|
| 185 |
-
if self.is_spec_wise:
|
| 186 |
-
x = {k:self.to_spec_wise(v) for k,v in x.items()}
|
| 187 |
-
x = self.rearrange(x) # -> {'melspec_2048': Tensor{Size([3, 750, 512]) cuda:0 f32}}
|
| 188 |
-
return x['melspec_2048'].permute((0, 2, 1))
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
class CQTPreprocessor(nn.Module):
|
| 192 |
-
def __init__(self,
|
| 193 |
-
sr=24000,
|
| 194 |
-
hop=960,
|
| 195 |
-
nb=84,
|
| 196 |
-
to_db = True,
|
| 197 |
-
) -> None:
|
| 198 |
-
super().__init__()
|
| 199 |
-
|
| 200 |
-
from nnAudio.features.cqt import CQT
|
| 201 |
-
import torchaudio
|
| 202 |
-
self.cqt_fn = CQT(
|
| 203 |
-
sr=sr,
|
| 204 |
-
hop_length=hop,
|
| 205 |
-
n_bins=nb,
|
| 206 |
-
fmin=32.7 if nb == 84 else 27.5, # 84 or 88
|
| 207 |
-
bins_per_octave=12,
|
| 208 |
-
filter_scale=1,
|
| 209 |
-
norm=1,
|
| 210 |
-
window='hann',
|
| 211 |
-
center=True,
|
| 212 |
-
pad_mode='constant',
|
| 213 |
-
trainable=False,
|
| 214 |
-
output_format='Magnitude',
|
| 215 |
-
verbose=True,
|
| 216 |
-
)
|
| 217 |
-
if to_db:
|
| 218 |
-
self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
|
| 219 |
-
else:
|
| 220 |
-
self.amplitude_to_db = lambda x:x
|
| 221 |
-
|
| 222 |
-
@torch.no_grad()
|
| 223 |
-
def __call__(self, x):
|
| 224 |
-
return self.amplitude_to_db(self.cqt_fn(x))
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
from dataclasses import dataclass
|
| 228 |
-
|
| 229 |
-
@dataclass
|
| 230 |
-
class UserDirModule:
|
| 231 |
-
user_dir: str
|
| 232 |
-
|
| 233 |
-
def load_model(model_dir, checkpoint_dir):
|
| 234 |
-
'''Load Fairseq SSL model'''
|
| 235 |
-
|
| 236 |
-
if model_dir is not None:
|
| 237 |
-
model_path = UserDirModule(model_dir)
|
| 238 |
-
fairseq.utils.import_user_module(model_path)
|
| 239 |
-
|
| 240 |
-
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_dir], strict=False)
|
| 241 |
-
model = model[0]
|
| 242 |
-
|
| 243 |
-
return model
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
class PreprocessorWithModel(nn.Module):
|
| 248 |
-
def __init__(self, model_dir, checkpoint_dir, use_layer_idx=9) -> None:
|
| 249 |
-
super().__init__()
|
| 250 |
-
self.model = load_model(model_dir=model_dir, checkpoint_dir=checkpoint_dir)
|
| 251 |
-
self.model.eval()
|
| 252 |
-
self.use_layer_idx = use_layer_idx
|
| 253 |
-
|
| 254 |
-
def forward(self, x):
|
| 255 |
-
with torch.no_grad():
|
| 256 |
-
self.model.eval()
|
| 257 |
-
res = self.model(x, features_only = True)
|
| 258 |
-
layer_results = res['layer_results']
|
| 259 |
-
return layer_results[self.use_layer_idx].permute(0,2,1)
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
def Music_Mel_Target_Config():
|
| 264 |
-
config = dict(
|
| 265 |
-
train_dataset = dict(
|
| 266 |
-
manifest_path = 'path/to/data/music4all/train.json',
|
| 267 |
-
sample_rate = 24000,
|
| 268 |
-
normalize = False,
|
| 269 |
-
),
|
| 270 |
-
valid_dataset = dict(
|
| 271 |
-
manifest_path = 'path/to/data/music4all/valid.json',
|
| 272 |
-
sample_rate = 24000,
|
| 273 |
-
normalize = False,
|
| 274 |
-
),
|
| 275 |
-
model = dict(
|
| 276 |
-
input_dim = 128*4,
|
| 277 |
-
n_codebooks = 8,
|
| 278 |
-
codebook_size = 1024,
|
| 279 |
-
codebook_dim = 16,
|
| 280 |
-
quantizer_dropout = 0.0,
|
| 281 |
-
),
|
| 282 |
-
train = dict(
|
| 283 |
-
batch_size = 32,
|
| 284 |
-
num_workers = 6,
|
| 285 |
-
valid_interval = 10,
|
| 286 |
-
save_interval = 100,
|
| 287 |
-
max_updates = 500000,
|
| 288 |
-
lr = 1e-4,
|
| 289 |
-
device = 'cuda:0',
|
| 290 |
-
loss = 'commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean()',
|
| 291 |
-
preprocess = Preprocessor()
|
| 292 |
-
)
|
| 293 |
-
)
|
| 294 |
-
return config
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
def main(config):
|
| 298 |
-
train_dataset = RVQDataset(**config['train_dataset'])
|
| 299 |
-
if config['valid_dataset']['manifest_path'] is None:
|
| 300 |
-
# split train and valid dataset
|
| 301 |
-
from torch.utils.data import random_split
|
| 302 |
-
train_dataset, valid_dataset = random_split(
|
| 303 |
-
train_dataset, lengths=[len(train_dataset) - 500, 500]
|
| 304 |
-
)
|
| 305 |
-
else:
|
| 306 |
-
valid_dataset = RVQDataset(**config['valid_dataset'])
|
| 307 |
-
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=config['train']['batch_size'], drop_last=True, num_workers=config['train']['num_workers'])
|
| 308 |
-
valid_dataloader = DataLoader(valid_dataset, shuffle=False, batch_size=config['train']['batch_size'], drop_last=True, num_workers=config['train']['num_workers'])
|
| 309 |
-
model = ResidualVectorQuantize(**config['model'])
|
| 310 |
-
|
| 311 |
-
device = config['train']['device']
|
| 312 |
-
preprocess = config['train']['preprocess'].to(device)
|
| 313 |
-
model = model.to(device)
|
| 314 |
-
|
| 315 |
-
optimizer = torch.optim.Adam(model.parameters(), lr=config['train']['lr'])
|
| 316 |
-
cur_updates = 0
|
| 317 |
-
is_running = True
|
| 318 |
-
result = {}
|
| 319 |
-
from tqdm import tqdm
|
| 320 |
-
from tensorboardX import SummaryWriter
|
| 321 |
-
writer = SummaryWriter()
|
| 322 |
-
from collections import defaultdict
|
| 323 |
-
import os
|
| 324 |
-
from logging import getLogger
|
| 325 |
-
logger = getLogger()
|
| 326 |
-
|
| 327 |
-
while is_running:
|
| 328 |
-
results = defaultdict(lambda:0)
|
| 329 |
-
for item in tqdm(train_dataloader, desc='train'):
|
| 330 |
-
wavs = item['source']
|
| 331 |
-
optimizer.zero_grad()
|
| 332 |
-
wavs = wavs.to(device)
|
| 333 |
-
x = preprocess(wavs)
|
| 334 |
-
model.train()
|
| 335 |
-
quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = model(x)
|
| 336 |
-
loss = eval(config['train']['loss'])
|
| 337 |
-
loss.backward()
|
| 338 |
-
optimizer.step()
|
| 339 |
-
|
| 340 |
-
results['loss/train'] += loss.item()
|
| 341 |
-
results['commitment_loss/train'] += commitment_loss.item()
|
| 342 |
-
results['codebook_loss/train'] += codebook_loss.item()
|
| 343 |
-
results['rvq_usage/train'] += rvq_usage.float().mean().item()
|
| 344 |
-
|
| 345 |
-
if cur_updates % config['train']['valid_interval'] == 0:
|
| 346 |
-
model.eval()
|
| 347 |
-
with torch.no_grad():
|
| 348 |
-
for item in tqdm(valid_dataloader, desc='valid'):
|
| 349 |
-
wavs = item['source']
|
| 350 |
-
wavs = wavs.to(device)
|
| 351 |
-
x = preprocess(wavs)
|
| 352 |
-
quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = model(x)
|
| 353 |
-
valid_loss = eval(config['train']['loss'])
|
| 354 |
-
|
| 355 |
-
results['loss/valid'] += valid_loss.item()
|
| 356 |
-
results['commitment_loss/valid'] += commitment_loss.item()
|
| 357 |
-
results['codebook_loss/valid'] += codebook_loss.item()
|
| 358 |
-
results['rvq_usage/valid'] += rvq_usage.float().mean().item()
|
| 359 |
-
|
| 360 |
-
results['cur_updates'] = cur_updates
|
| 361 |
-
results['loss/train'] /= config['train']['valid_interval']
|
| 362 |
-
results['commitment_loss/train'] /= config['train']['valid_interval']
|
| 363 |
-
results['codebook_loss/train'] /= config['train']['valid_interval']
|
| 364 |
-
results['rvq_usage/train'] /= config['train']['valid_interval']
|
| 365 |
-
|
| 366 |
-
results['loss/valid'] /= len(valid_dataloader)
|
| 367 |
-
results['commitment_loss/valid'] /= len(valid_dataloader)
|
| 368 |
-
results['codebook_loss/valid'] /= len(valid_dataloader)
|
| 369 |
-
results['rvq_usage/valid'] /= len(valid_dataloader)
|
| 370 |
-
|
| 371 |
-
print('')
|
| 372 |
-
logger.info(str(results))
|
| 373 |
-
for k,v in results.items():
|
| 374 |
-
writer.add_scalar(k, v, cur_updates)
|
| 375 |
-
|
| 376 |
-
results.clear()
|
| 377 |
-
|
| 378 |
-
if cur_updates % config['train']['save_interval'] == 0:
|
| 379 |
-
os.makedirs(f'{writer.logdir}/ckpt/', exist_ok=True)
|
| 380 |
-
logger.info(f'saving checkpoint to {writer.logdir}/ckpt/RVQ_{cur_updates}.pth')
|
| 381 |
-
torch.save(model.state_dict(), f'{writer.logdir}/ckpt/RVQ_{cur_updates}.pth')
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
if cur_updates < config['train']['max_updates']:
|
| 385 |
-
cur_updates += 1
|
| 386 |
-
else:
|
| 387 |
-
is_running = False
|
| 388 |
-
break
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
if __name__ == '__main__':
|
| 393 |
-
config = Music_Mel_Target_Config()
|
| 394 |
-
main(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MuCodec/muq_dev/muq_fairseq/models/muq/model/w2v2_config.json
DELETED
|
@@ -1,113 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"activation_dropout": 0.1,
|
| 3 |
-
"adapter_kernel_size": 3,
|
| 4 |
-
"adapter_stride": 2,
|
| 5 |
-
"add_adapter": false,
|
| 6 |
-
"apply_spec_augment": true,
|
| 7 |
-
"architectures": [
|
| 8 |
-
"Wav2Vec2ConformerForCTC"
|
| 9 |
-
],
|
| 10 |
-
"attention_dropout": 0.1,
|
| 11 |
-
"bos_token_id": 1,
|
| 12 |
-
"classifier_proj_size": 256,
|
| 13 |
-
"codevector_dim": 768,
|
| 14 |
-
"conformer_conv_dropout": 0.1,
|
| 15 |
-
"contrastive_logits_temperature": 0.1,
|
| 16 |
-
"conv_bias": true,
|
| 17 |
-
"conv_depthwise_kernel_size": 31,
|
| 18 |
-
"conv_dim": [
|
| 19 |
-
512,
|
| 20 |
-
512,
|
| 21 |
-
512,
|
| 22 |
-
512,
|
| 23 |
-
512,
|
| 24 |
-
512,
|
| 25 |
-
512
|
| 26 |
-
],
|
| 27 |
-
"conv_kernel": [
|
| 28 |
-
10,
|
| 29 |
-
3,
|
| 30 |
-
3,
|
| 31 |
-
3,
|
| 32 |
-
3,
|
| 33 |
-
2,
|
| 34 |
-
2
|
| 35 |
-
],
|
| 36 |
-
"conv_stride": [
|
| 37 |
-
5,
|
| 38 |
-
2,
|
| 39 |
-
2,
|
| 40 |
-
2,
|
| 41 |
-
2,
|
| 42 |
-
2,
|
| 43 |
-
2
|
| 44 |
-
],
|
| 45 |
-
"ctc_loss_reduction": "sum",
|
| 46 |
-
"ctc_zero_infinity": false,
|
| 47 |
-
"diversity_loss_weight": 0.1,
|
| 48 |
-
"do_stable_layer_norm": true,
|
| 49 |
-
"eos_token_id": 2,
|
| 50 |
-
"feat_extract_activation": "gelu",
|
| 51 |
-
"feat_extract_dropout": 0.0,
|
| 52 |
-
"feat_extract_norm": "layer",
|
| 53 |
-
"feat_proj_dropout": 0.1,
|
| 54 |
-
"feat_quantizer_dropout": 0.0,
|
| 55 |
-
"final_dropout": 0.1,
|
| 56 |
-
"gradient_checkpointing": false,
|
| 57 |
-
"hidden_act": "swish",
|
| 58 |
-
"hidden_dropout": 0.1,
|
| 59 |
-
"hidden_dropout_prob": 0.1,
|
| 60 |
-
"hidden_size": 1024,
|
| 61 |
-
"initializer_range": 0.02,
|
| 62 |
-
"intermediate_size": 4096,
|
| 63 |
-
"layer_norm_eps": 1e-05,
|
| 64 |
-
"layerdrop": 0.0,
|
| 65 |
-
"mask_feature_length": 10,
|
| 66 |
-
"mask_feature_min_masks": 0,
|
| 67 |
-
"mask_feature_prob": 0.0,
|
| 68 |
-
"mask_time_length": 10,
|
| 69 |
-
"mask_time_min_masks": 2,
|
| 70 |
-
"mask_time_prob": 0.05,
|
| 71 |
-
"max_source_positions": 5000,
|
| 72 |
-
"model_type": "wav2vec2-conformer",
|
| 73 |
-
"num_adapter_layers": 3,
|
| 74 |
-
"num_attention_heads": 16,
|
| 75 |
-
"num_codevector_groups": 2,
|
| 76 |
-
"num_codevectors_per_group": 320,
|
| 77 |
-
"num_conv_pos_embedding_groups": 16,
|
| 78 |
-
"num_conv_pos_embeddings": 128,
|
| 79 |
-
"num_feat_extract_layers": 7,
|
| 80 |
-
"num_hidden_layers": 24,
|
| 81 |
-
"num_negatives": 100,
|
| 82 |
-
"output_hidden_size": 1024,
|
| 83 |
-
"pad_token_id": 0,
|
| 84 |
-
"position_embeddings_type": "rotary",
|
| 85 |
-
"proj_codevector_dim": 768,
|
| 86 |
-
"rotary_embedding_base": 10000,
|
| 87 |
-
"tdnn_dilation": [
|
| 88 |
-
1,
|
| 89 |
-
2,
|
| 90 |
-
3,
|
| 91 |
-
1,
|
| 92 |
-
1
|
| 93 |
-
],
|
| 94 |
-
"tdnn_dim": [
|
| 95 |
-
512,
|
| 96 |
-
512,
|
| 97 |
-
512,
|
| 98 |
-
512,
|
| 99 |
-
1500
|
| 100 |
-
],
|
| 101 |
-
"tdnn_kernel": [
|
| 102 |
-
5,
|
| 103 |
-
3,
|
| 104 |
-
3,
|
| 105 |
-
1,
|
| 106 |
-
1
|
| 107 |
-
],
|
| 108 |
-
"torch_dtype": "float32",
|
| 109 |
-
"transformers_version": "4.19.0.dev0",
|
| 110 |
-
"use_weighted_layer_sum": false,
|
| 111 |
-
"vocab_size": 32,
|
| 112 |
-
"xvector_output_dim": 512
|
| 113 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MuCodec/muq_dev/muq_fairseq/models/muq/modules/__init__.py
DELETED
|
@@ -1,2 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
|
|
|
|
|
|
|
|
|
MuCodec/muq_dev/muq_fairseq/models/muq/modules/conv.py
DELETED
|
@@ -1,77 +0,0 @@
|
|
| 1 |
-
from torch import nn
|
| 2 |
-
from einops import rearrange
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
class Res2dModule(nn.Module):
|
| 6 |
-
def __init__(self, idim, odim, stride=(2, 2)):
|
| 7 |
-
super(Res2dModule, self).__init__()
|
| 8 |
-
self.conv1 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride)
|
| 9 |
-
self.bn1 = nn.BatchNorm2d(odim)
|
| 10 |
-
self.conv2 = nn.Conv2d(odim, odim, 3, padding=1)
|
| 11 |
-
self.bn2 = nn.BatchNorm2d(odim)
|
| 12 |
-
self.relu = nn.ReLU()
|
| 13 |
-
|
| 14 |
-
# residual
|
| 15 |
-
self.diff = False
|
| 16 |
-
if (idim != odim) or (stride[0] > 1):
|
| 17 |
-
self.conv3 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride)
|
| 18 |
-
self.bn3 = nn.BatchNorm2d(odim)
|
| 19 |
-
self.diff = True
|
| 20 |
-
|
| 21 |
-
def forward(self, x):
|
| 22 |
-
out = self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x)))))
|
| 23 |
-
if self.diff:
|
| 24 |
-
x = self.bn3(self.conv3(x))
|
| 25 |
-
out = x + out
|
| 26 |
-
out = self.relu(out)
|
| 27 |
-
return out
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
class Conv2dSubsampling(nn.Module):
|
| 31 |
-
"""Convolutional 2D subsampling (to 1/4 length).
|
| 32 |
-
|
| 33 |
-
Args:
|
| 34 |
-
idim (int): Input dimension.
|
| 35 |
-
hdim (int): Hidden dimension.
|
| 36 |
-
odim (int): Output dimension.
|
| 37 |
-
strides (list): Sizes of strides.
|
| 38 |
-
n_bands (int): Number of frequency bands.
|
| 39 |
-
"""
|
| 40 |
-
|
| 41 |
-
def __init__(self, idim, hdim, odim, strides=[2, 2], n_bands=64):
|
| 42 |
-
"""Construct an Conv2dSubsampling object."""
|
| 43 |
-
super(Conv2dSubsampling, self).__init__()
|
| 44 |
-
|
| 45 |
-
self.conv = nn.Sequential(
|
| 46 |
-
Res2dModule(idim, hdim, (2, strides[0])),
|
| 47 |
-
Res2dModule(hdim, hdim, (2, strides[1])),
|
| 48 |
-
)
|
| 49 |
-
self.linear = nn.Linear(hdim * n_bands // 2 // 2, odim)
|
| 50 |
-
|
| 51 |
-
def forward(self, x):
|
| 52 |
-
"""Subsample x.
|
| 53 |
-
|
| 54 |
-
Args:
|
| 55 |
-
x (torch.Tensor): Input tensor (#batch, idim, time).
|
| 56 |
-
|
| 57 |
-
Returns:
|
| 58 |
-
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
| 59 |
-
where time' = time // 4.
|
| 60 |
-
"""
|
| 61 |
-
|
| 62 |
-
if x.dim() == 3:
|
| 63 |
-
x = x.unsqueeze(1) # (b, c, f, t)
|
| 64 |
-
x = self.conv(x)
|
| 65 |
-
x = rearrange(x, "b c f t -> b t (c f)")
|
| 66 |
-
x = self.linear(x)
|
| 67 |
-
return x
|
| 68 |
-
|
| 69 |
-
if __name__ == '__main__':
|
| 70 |
-
import torch
|
| 71 |
-
conv_dim, encoder_dim = 512, 1024
|
| 72 |
-
conv = Conv2dSubsampling(
|
| 73 |
-
1, conv_dim, encoder_dim, strides=[2, 1], n_bands=128
|
| 74 |
-
)
|
| 75 |
-
inp = torch.randn((1, 128, 3000))
|
| 76 |
-
out = conv(inp)
|
| 77 |
-
print(out.shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MuCodec/muq_dev/muq_fairseq/models/muq/modules/features.py
DELETED
|
@@ -1,67 +0,0 @@
|
|
| 1 |
-
import torchaudio
|
| 2 |
-
from torch import nn
|
| 3 |
-
import torch
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
class MelSTFT(nn.Module):
|
| 7 |
-
def __init__(
|
| 8 |
-
self,
|
| 9 |
-
sample_rate=24000,
|
| 10 |
-
n_fft=2048,
|
| 11 |
-
hop_length=240,
|
| 12 |
-
n_mels=128,
|
| 13 |
-
is_db=False,
|
| 14 |
-
):
|
| 15 |
-
super(MelSTFT, self).__init__()
|
| 16 |
-
|
| 17 |
-
# spectrogram
|
| 18 |
-
self.mel_stft = torchaudio.transforms.MelSpectrogram(
|
| 19 |
-
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels
|
| 20 |
-
)
|
| 21 |
-
|
| 22 |
-
# amplitude to decibel
|
| 23 |
-
self.is_db = is_db
|
| 24 |
-
if is_db:
|
| 25 |
-
self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
|
| 26 |
-
|
| 27 |
-
def forward(self, waveform):
|
| 28 |
-
if self.is_db:
|
| 29 |
-
return self.amplitude_to_db(self.mel_stft(waveform))
|
| 30 |
-
else:
|
| 31 |
-
return self.mel_stft(waveform)
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
class CQTPreprocessor(nn.Module):
|
| 35 |
-
def __init__(self,
|
| 36 |
-
sr=24000,
|
| 37 |
-
hop=960,
|
| 38 |
-
nb=84,
|
| 39 |
-
to_db = True,
|
| 40 |
-
) -> None:
|
| 41 |
-
super().__init__()
|
| 42 |
-
|
| 43 |
-
from nnAudio.features.cqt import CQT
|
| 44 |
-
import torchaudio
|
| 45 |
-
self.cqt_fn = CQT(
|
| 46 |
-
sr=sr,
|
| 47 |
-
hop_length=hop,
|
| 48 |
-
n_bins=nb,
|
| 49 |
-
fmin=32.7 if nb == 84 else 27.5, # 84 or 88
|
| 50 |
-
bins_per_octave=12,
|
| 51 |
-
filter_scale=1,
|
| 52 |
-
norm=1,
|
| 53 |
-
window='hann',
|
| 54 |
-
center=True,
|
| 55 |
-
pad_mode='constant',
|
| 56 |
-
trainable=False,
|
| 57 |
-
output_format='Magnitude',
|
| 58 |
-
verbose=True,
|
| 59 |
-
)
|
| 60 |
-
if to_db:
|
| 61 |
-
self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
|
| 62 |
-
else:
|
| 63 |
-
self.amplitude_to_db = lambda x:x
|
| 64 |
-
|
| 65 |
-
@torch.no_grad()
|
| 66 |
-
def __call__(self, x):
|
| 67 |
-
return self.amplitude_to_db(self.cqt_fn(x))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MuCodec/muq_dev/muq_fairseq/models/muq/modules/flash_conformer.py
DELETED
|
@@ -1,2114 +0,0 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
-
# Copyright 2022 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
""" PyTorch Wav2Vec2-Conformer model."""
|
| 16 |
-
|
| 17 |
-
import math
|
| 18 |
-
from dataclasses import dataclass
|
| 19 |
-
from typing import Optional, Tuple, Union
|
| 20 |
-
|
| 21 |
-
import numpy as np
|
| 22 |
-
import torch
|
| 23 |
-
import torch.utils.checkpoint
|
| 24 |
-
from torch import nn
|
| 25 |
-
from torch.nn import CrossEntropyLoss
|
| 26 |
-
from torch.nn import functional as F
|
| 27 |
-
|
| 28 |
-
from transformers.activations import ACT2FN
|
| 29 |
-
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
| 30 |
-
from transformers.modeling_outputs import (
|
| 31 |
-
BaseModelOutput,
|
| 32 |
-
CausalLMOutput,
|
| 33 |
-
SequenceClassifierOutput,
|
| 34 |
-
TokenClassifierOutput,
|
| 35 |
-
Wav2Vec2BaseModelOutput,
|
| 36 |
-
XVectorOutput,
|
| 37 |
-
)
|
| 38 |
-
from transformers.modeling_utils import PreTrainedModel
|
| 39 |
-
from transformers.utils import (
|
| 40 |
-
ModelOutput,
|
| 41 |
-
add_code_sample_docstrings,
|
| 42 |
-
add_start_docstrings,
|
| 43 |
-
add_start_docstrings_to_model_forward,
|
| 44 |
-
logging,
|
| 45 |
-
replace_return_docstrings,
|
| 46 |
-
)
|
| 47 |
-
from transformers.models.wav2vec2_conformer.configuration_wav2vec2_conformer import Wav2Vec2ConformerConfig
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
logger = logging.get_logger(__name__)
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
_HIDDEN_STATES_START_POSITION = 2
|
| 54 |
-
|
| 55 |
-
# General docstring
|
| 56 |
-
_CONFIG_FOR_DOC = "Wav2Vec2ConformerConfig"
|
| 57 |
-
|
| 58 |
-
# Base docstring
|
| 59 |
-
_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-conformer-rope-large-960h-ft"
|
| 60 |
-
_EXPECTED_OUTPUT_SHAPE = [1, 292, 1024]
|
| 61 |
-
|
| 62 |
-
# CTC docstring
|
| 63 |
-
_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
|
| 64 |
-
_CTC_EXPECTED_LOSS = 64.21
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
| 68 |
-
"facebook/wav2vec2-conformer-rel-pos-large",
|
| 69 |
-
# See all Wav2Vec2Conformer models at https://huggingface.co/models?filter=wav2vec2-conformer
|
| 70 |
-
]
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
@dataclass
|
| 74 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTrainingOutput with Wav2Vec2->Wav2Vec2Conformer
|
| 75 |
-
class Wav2Vec2ConformerForPreTrainingOutput(ModelOutput):
|
| 76 |
-
"""
|
| 77 |
-
Output type of [`Wav2Vec2ConformerForPreTraining`], with potential hidden states and attentions.
|
| 78 |
-
|
| 79 |
-
Args:
|
| 80 |
-
loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
|
| 81 |
-
Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
|
| 82 |
-
paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss.
|
| 83 |
-
projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
|
| 84 |
-
Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
|
| 85 |
-
projected quantized states.
|
| 86 |
-
projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
|
| 87 |
-
Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
|
| 88 |
-
target vectors for contrastive loss.
|
| 89 |
-
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 90 |
-
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
| 91 |
-
shape `(batch_size, sequence_length, hidden_size)`.
|
| 92 |
-
|
| 93 |
-
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 94 |
-
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 95 |
-
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 96 |
-
sequence_length)`.
|
| 97 |
-
|
| 98 |
-
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 99 |
-
heads.
|
| 100 |
-
contrastive_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
|
| 101 |
-
The contrastive loss (L_m) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
|
| 102 |
-
diversity_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
|
| 103 |
-
The diversity loss (L_d) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
|
| 104 |
-
"""
|
| 105 |
-
|
| 106 |
-
loss: Optional[torch.FloatTensor] = None
|
| 107 |
-
projected_states: torch.FloatTensor = None
|
| 108 |
-
projected_quantized_states: torch.FloatTensor = None
|
| 109 |
-
codevector_perplexity: torch.FloatTensor = None
|
| 110 |
-
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 111 |
-
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 112 |
-
contrastive_loss: Optional[torch.FloatTensor] = None
|
| 113 |
-
diversity_loss: Optional[torch.FloatTensor] = None
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
|
| 117 |
-
def _compute_mask_indices(
|
| 118 |
-
shape: Tuple[int, int],
|
| 119 |
-
mask_prob: float,
|
| 120 |
-
mask_length: int,
|
| 121 |
-
attention_mask: Optional[torch.LongTensor] = None,
|
| 122 |
-
min_masks: int = 0,
|
| 123 |
-
) -> np.ndarray:
|
| 124 |
-
"""
|
| 125 |
-
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
|
| 126 |
-
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
|
| 127 |
-
CPU as part of the preprocessing during training.
|
| 128 |
-
|
| 129 |
-
Args:
|
| 130 |
-
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
|
| 131 |
-
the first element is the batch size and the second element is the length of the axis to span.
|
| 132 |
-
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
|
| 133 |
-
independently generated mask spans of length `mask_length` is computed by
|
| 134 |
-
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
|
| 135 |
-
actual percentage will be smaller.
|
| 136 |
-
mask_length: size of the mask
|
| 137 |
-
min_masks: minimum number of masked spans
|
| 138 |
-
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
|
| 139 |
-
each batch dimension.
|
| 140 |
-
"""
|
| 141 |
-
batch_size, sequence_length = shape
|
| 142 |
-
|
| 143 |
-
if mask_length < 1:
|
| 144 |
-
raise ValueError("`mask_length` has to be bigger than 0.")
|
| 145 |
-
|
| 146 |
-
if mask_length > sequence_length:
|
| 147 |
-
raise ValueError(
|
| 148 |
-
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
|
| 149 |
-
f" and `sequence_length`: {sequence_length}`"
|
| 150 |
-
)
|
| 151 |
-
|
| 152 |
-
# epsilon is used for probabilistic rounding
|
| 153 |
-
epsilon = np.random.rand(1).item()
|
| 154 |
-
|
| 155 |
-
def compute_num_masked_span(input_length):
|
| 156 |
-
"""Given input length, compute how many spans should be masked"""
|
| 157 |
-
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
|
| 158 |
-
num_masked_span = max(num_masked_span, min_masks)
|
| 159 |
-
|
| 160 |
-
# make sure num masked span <= sequence_length
|
| 161 |
-
if num_masked_span * mask_length > sequence_length:
|
| 162 |
-
num_masked_span = sequence_length // mask_length
|
| 163 |
-
|
| 164 |
-
# make sure num_masked span is also <= input_length - (mask_length - 1)
|
| 165 |
-
if input_length - (mask_length - 1) < num_masked_span:
|
| 166 |
-
num_masked_span = max(input_length - (mask_length - 1), 0)
|
| 167 |
-
|
| 168 |
-
return num_masked_span
|
| 169 |
-
|
| 170 |
-
# compute number of masked spans in batch
|
| 171 |
-
input_lengths = (
|
| 172 |
-
attention_mask.sum(-1).detach().tolist()
|
| 173 |
-
if attention_mask is not None
|
| 174 |
-
else [sequence_length for _ in range(batch_size)]
|
| 175 |
-
)
|
| 176 |
-
|
| 177 |
-
# SpecAugment mask to fill
|
| 178 |
-
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
|
| 179 |
-
spec_aug_mask_idxs = []
|
| 180 |
-
|
| 181 |
-
max_num_masked_span = compute_num_masked_span(sequence_length)
|
| 182 |
-
|
| 183 |
-
if max_num_masked_span == 0:
|
| 184 |
-
return spec_aug_mask
|
| 185 |
-
|
| 186 |
-
for input_length in input_lengths:
|
| 187 |
-
# compute num of masked spans for this input
|
| 188 |
-
num_masked_span = compute_num_masked_span(input_length)
|
| 189 |
-
|
| 190 |
-
# get random indices to mask
|
| 191 |
-
spec_aug_mask_idx = np.random.choice(
|
| 192 |
-
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
|
| 193 |
-
)
|
| 194 |
-
|
| 195 |
-
# pick first sampled index that will serve as a dummy index to pad vector
|
| 196 |
-
# to ensure same dimension for all batches due to probabilistic rounding
|
| 197 |
-
# Picking first sample just pads those vectors twice.
|
| 198 |
-
if len(spec_aug_mask_idx) == 0:
|
| 199 |
-
# this case can only happen if `input_length` is strictly smaller then
|
| 200 |
-
# `sequence_length` in which case the last token has to be a padding
|
| 201 |
-
# token which we can use as a dummy mask id
|
| 202 |
-
dummy_mask_idx = sequence_length - 1
|
| 203 |
-
else:
|
| 204 |
-
dummy_mask_idx = spec_aug_mask_idx[0]
|
| 205 |
-
|
| 206 |
-
spec_aug_mask_idx = np.concatenate(
|
| 207 |
-
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
|
| 208 |
-
)
|
| 209 |
-
spec_aug_mask_idxs.append(spec_aug_mask_idx)
|
| 210 |
-
|
| 211 |
-
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
|
| 212 |
-
|
| 213 |
-
# expand masked indices to masked spans
|
| 214 |
-
spec_aug_mask_idxs = np.broadcast_to(
|
| 215 |
-
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
|
| 216 |
-
)
|
| 217 |
-
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
|
| 218 |
-
|
| 219 |
-
# add offset to the starting indexes so that indexes now create a span
|
| 220 |
-
offsets = np.arange(mask_length)[None, None, :]
|
| 221 |
-
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
|
| 222 |
-
batch_size, max_num_masked_span * mask_length
|
| 223 |
-
)
|
| 224 |
-
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
|
| 225 |
-
|
| 226 |
-
# ensure that we cannot have indices larger than sequence_length
|
| 227 |
-
if spec_aug_mask_idxs.max() > sequence_length - 1:
|
| 228 |
-
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
|
| 229 |
-
|
| 230 |
-
# scatter indices to mask
|
| 231 |
-
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
|
| 232 |
-
|
| 233 |
-
return spec_aug_mask
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._sample_negative_indices
|
| 237 |
-
def _sample_negative_indices(
|
| 238 |
-
features_shape: Tuple, num_negatives: int, mask_time_indices: Optional[np.ndarray] = None
|
| 239 |
-
):
|
| 240 |
-
"""
|
| 241 |
-
Sample `num_negatives` vectors from feature vectors.
|
| 242 |
-
"""
|
| 243 |
-
batch_size, sequence_length = features_shape
|
| 244 |
-
|
| 245 |
-
# generate indices of the positive vectors themselves, repeat them `num_negatives` times
|
| 246 |
-
sequence_length_range = np.arange(sequence_length)
|
| 247 |
-
|
| 248 |
-
# get `num_negatives` random vector indices from the same utterance
|
| 249 |
-
sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32)
|
| 250 |
-
|
| 251 |
-
mask_time_indices = (
|
| 252 |
-
mask_time_indices.astype(bool) if mask_time_indices is not None else np.ones(features_shape, dtype=bool)
|
| 253 |
-
)
|
| 254 |
-
|
| 255 |
-
for batch_idx in range(batch_size):
|
| 256 |
-
high = mask_time_indices[batch_idx].sum() - 1
|
| 257 |
-
mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]]
|
| 258 |
-
|
| 259 |
-
feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives))
|
| 260 |
-
sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives))
|
| 261 |
-
# avoid sampling the same positive vector, but keep the distribution uniform
|
| 262 |
-
sampled_indices[sampled_indices >= feature_indices] += 1
|
| 263 |
-
|
| 264 |
-
# remap to actual indices
|
| 265 |
-
sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices]
|
| 266 |
-
|
| 267 |
-
# correct for batch size
|
| 268 |
-
sampled_negative_indices[batch_idx] += batch_idx * sequence_length
|
| 269 |
-
|
| 270 |
-
return sampled_negative_indices
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
|
| 274 |
-
class Wav2Vec2ConformerNoLayerNormConvLayer(nn.Module):
|
| 275 |
-
def __init__(self, config, layer_id=0):
|
| 276 |
-
super().__init__()
|
| 277 |
-
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
|
| 278 |
-
self.out_conv_dim = config.conv_dim[layer_id]
|
| 279 |
-
|
| 280 |
-
self.conv = nn.Conv1d(
|
| 281 |
-
self.in_conv_dim,
|
| 282 |
-
self.out_conv_dim,
|
| 283 |
-
kernel_size=config.conv_kernel[layer_id],
|
| 284 |
-
stride=config.conv_stride[layer_id],
|
| 285 |
-
bias=config.conv_bias,
|
| 286 |
-
)
|
| 287 |
-
self.activation = ACT2FN[config.feat_extract_activation]
|
| 288 |
-
|
| 289 |
-
def forward(self, hidden_states):
|
| 290 |
-
hidden_states = self.conv(hidden_states)
|
| 291 |
-
hidden_states = self.activation(hidden_states)
|
| 292 |
-
return hidden_states
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
|
| 296 |
-
class Wav2Vec2ConformerLayerNormConvLayer(nn.Module):
|
| 297 |
-
def __init__(self, config, layer_id=0):
|
| 298 |
-
super().__init__()
|
| 299 |
-
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
|
| 300 |
-
self.out_conv_dim = config.conv_dim[layer_id]
|
| 301 |
-
|
| 302 |
-
self.conv = nn.Conv1d(
|
| 303 |
-
self.in_conv_dim,
|
| 304 |
-
self.out_conv_dim,
|
| 305 |
-
kernel_size=config.conv_kernel[layer_id],
|
| 306 |
-
stride=config.conv_stride[layer_id],
|
| 307 |
-
bias=config.conv_bias,
|
| 308 |
-
)
|
| 309 |
-
self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
|
| 310 |
-
self.activation = ACT2FN[config.feat_extract_activation]
|
| 311 |
-
|
| 312 |
-
def forward(self, hidden_states):
|
| 313 |
-
hidden_states = self.conv(hidden_states)
|
| 314 |
-
|
| 315 |
-
hidden_states = hidden_states.transpose(-2, -1)
|
| 316 |
-
hidden_states = self.layer_norm(hidden_states)
|
| 317 |
-
hidden_states = hidden_states.transpose(-2, -1)
|
| 318 |
-
|
| 319 |
-
hidden_states = self.activation(hidden_states)
|
| 320 |
-
return hidden_states
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
|
| 324 |
-
class Wav2Vec2ConformerGroupNormConvLayer(nn.Module):
|
| 325 |
-
def __init__(self, config, layer_id=0):
|
| 326 |
-
super().__init__()
|
| 327 |
-
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
|
| 328 |
-
self.out_conv_dim = config.conv_dim[layer_id]
|
| 329 |
-
|
| 330 |
-
self.conv = nn.Conv1d(
|
| 331 |
-
self.in_conv_dim,
|
| 332 |
-
self.out_conv_dim,
|
| 333 |
-
kernel_size=config.conv_kernel[layer_id],
|
| 334 |
-
stride=config.conv_stride[layer_id],
|
| 335 |
-
bias=config.conv_bias,
|
| 336 |
-
)
|
| 337 |
-
self.activation = ACT2FN[config.feat_extract_activation]
|
| 338 |
-
|
| 339 |
-
self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
|
| 340 |
-
|
| 341 |
-
def forward(self, hidden_states):
|
| 342 |
-
hidden_states = self.conv(hidden_states)
|
| 343 |
-
hidden_states = self.layer_norm(hidden_states)
|
| 344 |
-
hidden_states = self.activation(hidden_states)
|
| 345 |
-
return hidden_states
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->Wav2Vec2Conformer
|
| 349 |
-
class Wav2Vec2ConformerPositionalConvEmbedding(nn.Module):
|
| 350 |
-
def __init__(self, config):
|
| 351 |
-
super().__init__()
|
| 352 |
-
self.conv = nn.Conv1d(
|
| 353 |
-
config.hidden_size,
|
| 354 |
-
config.hidden_size,
|
| 355 |
-
kernel_size=config.num_conv_pos_embeddings,
|
| 356 |
-
padding=config.num_conv_pos_embeddings // 2,
|
| 357 |
-
groups=config.num_conv_pos_embedding_groups,
|
| 358 |
-
)
|
| 359 |
-
|
| 360 |
-
if is_deepspeed_zero3_enabled():
|
| 361 |
-
import deepspeed
|
| 362 |
-
|
| 363 |
-
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
|
| 364 |
-
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
| 365 |
-
deepspeed.zero.register_external_parameter(self, self.conv.weight_v)
|
| 366 |
-
deepspeed.zero.register_external_parameter(self, self.conv.weight_g)
|
| 367 |
-
else:
|
| 368 |
-
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
| 369 |
-
|
| 370 |
-
self.padding = Wav2Vec2ConformerSamePadLayer(config.num_conv_pos_embeddings)
|
| 371 |
-
self.activation = ACT2FN[config.feat_extract_activation]
|
| 372 |
-
|
| 373 |
-
def forward(self, hidden_states):
|
| 374 |
-
hidden_states = hidden_states.transpose(1, 2)
|
| 375 |
-
|
| 376 |
-
hidden_states = self.conv(hidden_states)
|
| 377 |
-
hidden_states = self.padding(hidden_states)
|
| 378 |
-
hidden_states = self.activation(hidden_states)
|
| 379 |
-
|
| 380 |
-
hidden_states = hidden_states.transpose(1, 2)
|
| 381 |
-
return hidden_states
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
class Wav2Vec2ConformerRotaryPositionalEmbedding(nn.Module):
|
| 385 |
-
"""Rotary positional embedding
|
| 386 |
-
Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf
|
| 387 |
-
"""
|
| 388 |
-
|
| 389 |
-
def __init__(self, config):
|
| 390 |
-
super().__init__()
|
| 391 |
-
dim = config.hidden_size // config.num_attention_heads
|
| 392 |
-
base = config.rotary_embedding_base
|
| 393 |
-
|
| 394 |
-
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
| 395 |
-
self.register_buffer("inv_freq", inv_freq)
|
| 396 |
-
self.cached_sequence_length = None
|
| 397 |
-
self.cached_rotary_positional_embedding = None
|
| 398 |
-
|
| 399 |
-
def forward(self, hidden_states):
|
| 400 |
-
sequence_length = hidden_states.shape[1]
|
| 401 |
-
|
| 402 |
-
if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None:
|
| 403 |
-
return self.cached_rotary_positional_embedding
|
| 404 |
-
|
| 405 |
-
self.cached_sequence_length = sequence_length
|
| 406 |
-
time_stamps = torch.arange(sequence_length).type_as(self.inv_freq)
|
| 407 |
-
freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq)
|
| 408 |
-
embeddings = torch.cat((freqs, freqs), dim=-1)
|
| 409 |
-
|
| 410 |
-
cos_embeddings = embeddings.cos()[:, None, None, :]
|
| 411 |
-
sin_embeddings = embeddings.sin()[:, None, None, :]
|
| 412 |
-
self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings])
|
| 413 |
-
return self.cached_rotary_positional_embedding
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
class Wav2Vec2ConformerRelPositionalEmbedding(nn.Module):
|
| 417 |
-
"""Relative positional encoding module."""
|
| 418 |
-
|
| 419 |
-
def __init__(self, config):
|
| 420 |
-
super().__init__()
|
| 421 |
-
self.max_len = config.max_source_positions
|
| 422 |
-
self.d_model = config.hidden_size
|
| 423 |
-
self.pe = None
|
| 424 |
-
self.extend_pe(torch.tensor(0.0).expand(1, self.max_len))
|
| 425 |
-
|
| 426 |
-
def extend_pe(self, x):
|
| 427 |
-
# Reset the positional encodings
|
| 428 |
-
if self.pe is not None:
|
| 429 |
-
# self.pe contains both positive and negative parts
|
| 430 |
-
# the length of self.pe is 2 * input_len - 1
|
| 431 |
-
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
| 432 |
-
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
| 433 |
-
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
| 434 |
-
return
|
| 435 |
-
# Suppose `i` is the position of query vector and `j` is the
|
| 436 |
-
# position of key vector. We use positive relative positions when keys
|
| 437 |
-
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
| 438 |
-
pe_positive = torch.zeros(x.size(1), self.d_model)
|
| 439 |
-
pe_negative = torch.zeros(x.size(1), self.d_model)
|
| 440 |
-
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
| 441 |
-
div_term = torch.exp(
|
| 442 |
-
torch.arange(0, self.d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.d_model)
|
| 443 |
-
)
|
| 444 |
-
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
| 445 |
-
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
| 446 |
-
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
| 447 |
-
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
| 448 |
-
|
| 449 |
-
# Reverse the order of positive indices and concat both positive and
|
| 450 |
-
# negative indices. This is used to support the shifting trick
|
| 451 |
-
# as in https://arxiv.org/abs/1901.02860
|
| 452 |
-
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
| 453 |
-
pe_negative = pe_negative[1:].unsqueeze(0)
|
| 454 |
-
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
| 455 |
-
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
| 456 |
-
|
| 457 |
-
def forward(self, hidden_states: torch.Tensor):
|
| 458 |
-
self.extend_pe(hidden_states)
|
| 459 |
-
start_idx = self.pe.size(1) // 2 - hidden_states.size(1) + 1
|
| 460 |
-
end_idx = self.pe.size(1) // 2 + hidden_states.size(1)
|
| 461 |
-
relative_position_embeddings = self.pe[:, start_idx:end_idx]
|
| 462 |
-
|
| 463 |
-
return relative_position_embeddings
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->Wav2Vec2Conformer
|
| 467 |
-
class Wav2Vec2ConformerSamePadLayer(nn.Module):
|
| 468 |
-
def __init__(self, num_conv_pos_embeddings):
|
| 469 |
-
super().__init__()
|
| 470 |
-
self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
|
| 471 |
-
|
| 472 |
-
def forward(self, hidden_states):
|
| 473 |
-
if self.num_pad_remove > 0:
|
| 474 |
-
hidden_states = hidden_states[:, :, : -self.num_pad_remove]
|
| 475 |
-
return hidden_states
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->Wav2Vec2Conformer
|
| 479 |
-
class Wav2Vec2ConformerFeatureEncoder(nn.Module):
|
| 480 |
-
"""Construct the features from raw audio waveform"""
|
| 481 |
-
|
| 482 |
-
def __init__(self, config):
|
| 483 |
-
super().__init__()
|
| 484 |
-
|
| 485 |
-
if config.feat_extract_norm == "group":
|
| 486 |
-
conv_layers = [Wav2Vec2ConformerGroupNormConvLayer(config, layer_id=0)] + [
|
| 487 |
-
Wav2Vec2ConformerNoLayerNormConvLayer(config, layer_id=i + 1)
|
| 488 |
-
for i in range(config.num_feat_extract_layers - 1)
|
| 489 |
-
]
|
| 490 |
-
elif config.feat_extract_norm == "layer":
|
| 491 |
-
conv_layers = [
|
| 492 |
-
Wav2Vec2ConformerLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)
|
| 493 |
-
]
|
| 494 |
-
else:
|
| 495 |
-
raise ValueError(
|
| 496 |
-
f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
|
| 497 |
-
)
|
| 498 |
-
self.conv_layers = nn.ModuleList(conv_layers)
|
| 499 |
-
self.gradient_checkpointing = False
|
| 500 |
-
self._requires_grad = True
|
| 501 |
-
|
| 502 |
-
def _freeze_parameters(self):
|
| 503 |
-
for param in self.parameters():
|
| 504 |
-
param.requires_grad = False
|
| 505 |
-
self._requires_grad = False
|
| 506 |
-
|
| 507 |
-
def forward(self, input_values):
|
| 508 |
-
hidden_states = input_values[:, None]
|
| 509 |
-
|
| 510 |
-
# make sure hidden_states require grad for gradient_checkpointing
|
| 511 |
-
if self._requires_grad and self.training:
|
| 512 |
-
hidden_states.requires_grad = True
|
| 513 |
-
|
| 514 |
-
for conv_layer in self.conv_layers:
|
| 515 |
-
if self._requires_grad and self.gradient_checkpointing and self.training:
|
| 516 |
-
|
| 517 |
-
def create_custom_forward(module):
|
| 518 |
-
def custom_forward(*inputs):
|
| 519 |
-
return module(*inputs)
|
| 520 |
-
|
| 521 |
-
return custom_forward
|
| 522 |
-
|
| 523 |
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 524 |
-
create_custom_forward(conv_layer),
|
| 525 |
-
hidden_states,
|
| 526 |
-
)
|
| 527 |
-
else:
|
| 528 |
-
hidden_states = conv_layer(hidden_states)
|
| 529 |
-
|
| 530 |
-
return hidden_states
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->Wav2Vec2Conformer
|
| 534 |
-
class Wav2Vec2ConformerFeatureProjection(nn.Module):
|
| 535 |
-
def __init__(self, config):
|
| 536 |
-
super().__init__()
|
| 537 |
-
self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
|
| 538 |
-
self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
|
| 539 |
-
self.dropout = nn.Dropout(config.feat_proj_dropout)
|
| 540 |
-
|
| 541 |
-
def forward(self, hidden_states):
|
| 542 |
-
# non-projected hidden states are needed for quantization
|
| 543 |
-
norm_hidden_states = self.layer_norm(hidden_states)
|
| 544 |
-
hidden_states = self.projection(norm_hidden_states)
|
| 545 |
-
hidden_states = self.dropout(hidden_states)
|
| 546 |
-
return hidden_states, norm_hidden_states
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->Wav2Vec2Conformer
|
| 550 |
-
class Wav2Vec2ConformerFeedForward(nn.Module):
|
| 551 |
-
def __init__(self, config):
|
| 552 |
-
super().__init__()
|
| 553 |
-
self.intermediate_dropout = nn.Dropout(config.activation_dropout)
|
| 554 |
-
|
| 555 |
-
self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 556 |
-
if isinstance(config.hidden_act, str):
|
| 557 |
-
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 558 |
-
else:
|
| 559 |
-
self.intermediate_act_fn = config.hidden_act
|
| 560 |
-
|
| 561 |
-
self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 562 |
-
self.output_dropout = nn.Dropout(config.hidden_dropout)
|
| 563 |
-
|
| 564 |
-
def forward(self, hidden_states):
|
| 565 |
-
hidden_states = self.intermediate_dense(hidden_states)
|
| 566 |
-
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 567 |
-
hidden_states = self.intermediate_dropout(hidden_states)
|
| 568 |
-
|
| 569 |
-
hidden_states = self.output_dense(hidden_states)
|
| 570 |
-
hidden_states = self.output_dropout(hidden_states)
|
| 571 |
-
return hidden_states
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
class Wav2Vec2ConformerConvolutionModule(nn.Module):
|
| 575 |
-
"""Convolution block used in the conformer block"""
|
| 576 |
-
|
| 577 |
-
def __init__(self, config):
|
| 578 |
-
super().__init__()
|
| 579 |
-
if (config.conv_depthwise_kernel_size - 1) % 2 == 1:
|
| 580 |
-
raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding")
|
| 581 |
-
self.layer_norm = nn.LayerNorm(config.hidden_size)
|
| 582 |
-
self.pointwise_conv1 = torch.nn.Conv1d(
|
| 583 |
-
config.hidden_size,
|
| 584 |
-
2 * config.hidden_size,
|
| 585 |
-
kernel_size=1,
|
| 586 |
-
stride=1,
|
| 587 |
-
padding=0,
|
| 588 |
-
bias=False,
|
| 589 |
-
)
|
| 590 |
-
self.glu = torch.nn.GLU(dim=1)
|
| 591 |
-
self.depthwise_conv = torch.nn.Conv1d(
|
| 592 |
-
config.hidden_size,
|
| 593 |
-
config.hidden_size,
|
| 594 |
-
config.conv_depthwise_kernel_size,
|
| 595 |
-
stride=1,
|
| 596 |
-
padding=(config.conv_depthwise_kernel_size - 1) // 2,
|
| 597 |
-
groups=config.hidden_size,
|
| 598 |
-
bias=False,
|
| 599 |
-
)
|
| 600 |
-
self.batch_norm = torch.nn.BatchNorm1d(config.hidden_size)
|
| 601 |
-
self.activation = ACT2FN[config.hidden_act]
|
| 602 |
-
self.pointwise_conv2 = torch.nn.Conv1d(
|
| 603 |
-
config.hidden_size,
|
| 604 |
-
config.hidden_size,
|
| 605 |
-
kernel_size=1,
|
| 606 |
-
stride=1,
|
| 607 |
-
padding=0,
|
| 608 |
-
bias=False,
|
| 609 |
-
)
|
| 610 |
-
self.dropout = torch.nn.Dropout(config.conformer_conv_dropout)
|
| 611 |
-
|
| 612 |
-
def forward(self, hidden_states):
|
| 613 |
-
hidden_states = self.layer_norm(hidden_states)
|
| 614 |
-
# exchange the temporal dimension and the feature dimension
|
| 615 |
-
hidden_states = hidden_states.transpose(1, 2)
|
| 616 |
-
|
| 617 |
-
# GLU mechanism
|
| 618 |
-
# => (batch, 2*channel, dim)
|
| 619 |
-
hidden_states = self.pointwise_conv1(hidden_states)
|
| 620 |
-
# => (batch, channel, dim)
|
| 621 |
-
hidden_states = self.glu(hidden_states)
|
| 622 |
-
|
| 623 |
-
# 1D Depthwise Conv
|
| 624 |
-
hidden_states = self.depthwise_conv(hidden_states)
|
| 625 |
-
hidden_states = self.batch_norm(hidden_states)
|
| 626 |
-
hidden_states = self.activation(hidden_states)
|
| 627 |
-
|
| 628 |
-
hidden_states = self.pointwise_conv2(hidden_states)
|
| 629 |
-
hidden_states = self.dropout(hidden_states)
|
| 630 |
-
hidden_states = hidden_states.transpose(1, 2)
|
| 631 |
-
return hidden_states
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
class Wav2Vec2ConformerSelfAttention(nn.Module):
|
| 635 |
-
"""Construct an Wav2Vec2ConformerSelfAttention object.
|
| 636 |
-
Can be enhanced with rotary or relative position embeddings.
|
| 637 |
-
"""
|
| 638 |
-
|
| 639 |
-
def __init__(self, config):
|
| 640 |
-
super().__init__()
|
| 641 |
-
|
| 642 |
-
self.head_size = config.hidden_size // config.num_attention_heads
|
| 643 |
-
self.num_heads = config.num_attention_heads
|
| 644 |
-
self.position_embeddings_type = config.position_embeddings_type
|
| 645 |
-
|
| 646 |
-
self.linear_q = nn.Linear(config.hidden_size, config.hidden_size)
|
| 647 |
-
self.linear_k = nn.Linear(config.hidden_size, config.hidden_size)
|
| 648 |
-
self.linear_v = nn.Linear(config.hidden_size, config.hidden_size)
|
| 649 |
-
self.linear_out = nn.Linear(config.hidden_size, config.hidden_size)
|
| 650 |
-
|
| 651 |
-
self.dropout = nn.Dropout(p=config.attention_dropout)
|
| 652 |
-
self.dropout_p = config.attention_dropout
|
| 653 |
-
|
| 654 |
-
self.is_causal = config.is_causal
|
| 655 |
-
|
| 656 |
-
if self.position_embeddings_type == "relative":
|
| 657 |
-
# linear transformation for positional encoding
|
| 658 |
-
self.linear_pos = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
| 659 |
-
# these two learnable bias are used in matrix c and matrix d
|
| 660 |
-
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
| 661 |
-
self.pos_bias_u = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
|
| 662 |
-
self.pos_bias_v = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
|
| 663 |
-
|
| 664 |
-
def forward(
|
| 665 |
-
self,
|
| 666 |
-
hidden_states: torch.Tensor,
|
| 667 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 668 |
-
relative_position_embeddings: Optional[torch.Tensor] = None,
|
| 669 |
-
output_attentions: bool = False,
|
| 670 |
-
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 671 |
-
# self-attention mechanism
|
| 672 |
-
batch_size, sequence_length, hidden_size = hidden_states.size()
|
| 673 |
-
|
| 674 |
-
# make sure query/key states can be != value states
|
| 675 |
-
query_key_states = hidden_states
|
| 676 |
-
value_states = hidden_states
|
| 677 |
-
|
| 678 |
-
if self.position_embeddings_type == "rotary":
|
| 679 |
-
if relative_position_embeddings is None:
|
| 680 |
-
raise ValueError(
|
| 681 |
-
"`relative_position_embeddings` has to be defined when `self.position_embeddings_type == 'rotary'"
|
| 682 |
-
)
|
| 683 |
-
query_key_states = self._apply_rotary_embedding(query_key_states, relative_position_embeddings)
|
| 684 |
-
|
| 685 |
-
# project query_key_states and value_states
|
| 686 |
-
query = self.linear_q(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
|
| 687 |
-
key = self.linear_k(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
|
| 688 |
-
value = self.linear_v(value_states).view(batch_size, -1, self.num_heads, self.head_size)
|
| 689 |
-
|
| 690 |
-
# => (batch, head, time1, d_k)
|
| 691 |
-
query = query.transpose(1, 2)
|
| 692 |
-
key = key.transpose(1, 2)
|
| 693 |
-
value = value.transpose(1, 2)
|
| 694 |
-
|
| 695 |
-
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False):
|
| 696 |
-
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=self.dropout_p, is_causal=self.is_causal)
|
| 697 |
-
probs = None
|
| 698 |
-
|
| 699 |
-
# # apply attention_mask if necessary
|
| 700 |
-
# if attention_mask is not None:
|
| 701 |
-
# scores = scores + attention_mask
|
| 702 |
-
|
| 703 |
-
# # => (batch, head, time1, time2)
|
| 704 |
-
# probs = torch.softmax(scores, dim=-1)
|
| 705 |
-
# probs = self.dropout(probs)
|
| 706 |
-
|
| 707 |
-
# # => (batch, head, time1, d_k)
|
| 708 |
-
# hidden_states = torch.matmul(probs, value)
|
| 709 |
-
|
| 710 |
-
# => (batch, time1, hidden_size)
|
| 711 |
-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_size)
|
| 712 |
-
hidden_states = self.linear_out(hidden_states)
|
| 713 |
-
|
| 714 |
-
return hidden_states, probs
|
| 715 |
-
|
| 716 |
-
def _apply_rotary_embedding(self, hidden_states, relative_position_embeddings):
|
| 717 |
-
batch_size, sequence_length, hidden_size = hidden_states.size()
|
| 718 |
-
hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads, self.head_size)
|
| 719 |
-
|
| 720 |
-
cos = relative_position_embeddings[0, :sequence_length, ...]
|
| 721 |
-
sin = relative_position_embeddings[1, :sequence_length, ...]
|
| 722 |
-
|
| 723 |
-
# rotate hidden_states with rotary embeddings
|
| 724 |
-
hidden_states = hidden_states.transpose(0, 1)
|
| 725 |
-
rotated_states_begin = hidden_states[..., : self.head_size // 2]
|
| 726 |
-
rotated_states_end = hidden_states[..., self.head_size // 2 :]
|
| 727 |
-
rotated_states = torch.cat((-rotated_states_end, rotated_states_begin), dim=rotated_states_begin.ndim - 1)
|
| 728 |
-
hidden_states = (hidden_states * cos) + (rotated_states * sin)
|
| 729 |
-
hidden_states = hidden_states.transpose(0, 1)
|
| 730 |
-
|
| 731 |
-
hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads * self.head_size)
|
| 732 |
-
|
| 733 |
-
return hidden_states
|
| 734 |
-
|
| 735 |
-
def _apply_relative_embeddings(self, query, key, relative_position_embeddings):
|
| 736 |
-
# 1. project positional embeddings
|
| 737 |
-
# => (batch, head, 2*time1-1, d_k)
|
| 738 |
-
proj_relative_position_embeddings = self.linear_pos(relative_position_embeddings)
|
| 739 |
-
proj_relative_position_embeddings = proj_relative_position_embeddings.view(
|
| 740 |
-
relative_position_embeddings.size(0), -1, self.num_heads, self.head_size
|
| 741 |
-
)
|
| 742 |
-
proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(1, 2)
|
| 743 |
-
proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(2, 3)
|
| 744 |
-
|
| 745 |
-
# 2. Add bias to query
|
| 746 |
-
# => (batch, head, time1, d_k)
|
| 747 |
-
query = query.transpose(1, 2)
|
| 748 |
-
q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
|
| 749 |
-
q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
|
| 750 |
-
|
| 751 |
-
# 3. attention score: first compute matrix a and matrix c
|
| 752 |
-
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
| 753 |
-
# => (batch, head, time1, time2)
|
| 754 |
-
scores_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
|
| 755 |
-
|
| 756 |
-
# 4. then compute matrix b and matrix d
|
| 757 |
-
# => (batch, head, time1, 2*time1-1)
|
| 758 |
-
scores_bd = torch.matmul(q_with_bias_v, proj_relative_position_embeddings)
|
| 759 |
-
|
| 760 |
-
# 5. shift matrix b and matrix d
|
| 761 |
-
zero_pad = torch.zeros((*scores_bd.size()[:3], 1), device=scores_bd.device, dtype=scores_bd.dtype)
|
| 762 |
-
scores_bd_padded = torch.cat([zero_pad, scores_bd], dim=-1)
|
| 763 |
-
scores_bd_padded_shape = scores_bd.size()[:2] + (scores_bd.shape[3] + 1, scores_bd.shape[2])
|
| 764 |
-
scores_bd_padded = scores_bd_padded.view(*scores_bd_padded_shape)
|
| 765 |
-
scores_bd = scores_bd_padded[:, :, 1:].view_as(scores_bd)
|
| 766 |
-
scores_bd = scores_bd[:, :, :, : scores_bd.size(-1) // 2 + 1]
|
| 767 |
-
|
| 768 |
-
# 6. sum matrices
|
| 769 |
-
# => (batch, head, time1, time2)
|
| 770 |
-
scores = (scores_ac + scores_bd) / math.sqrt(self.head_size)
|
| 771 |
-
|
| 772 |
-
return scores
|
| 773 |
-
|
| 774 |
-
|
| 775 |
-
class Wav2Vec2ConformerEncoderLayer(nn.Module):
|
| 776 |
-
"""Conformer block based on https://arxiv.org/abs/2005.08100."""
|
| 777 |
-
|
| 778 |
-
def __init__(self, config):
|
| 779 |
-
super().__init__()
|
| 780 |
-
embed_dim = config.hidden_size
|
| 781 |
-
dropout = config.attention_dropout
|
| 782 |
-
|
| 783 |
-
# Feed-forward 1
|
| 784 |
-
self.ffn1_layer_norm = nn.LayerNorm(embed_dim)
|
| 785 |
-
self.ffn1 = Wav2Vec2ConformerFeedForward(config)
|
| 786 |
-
|
| 787 |
-
# Self-Attention
|
| 788 |
-
self.self_attn_layer_norm = nn.LayerNorm(embed_dim)
|
| 789 |
-
self.self_attn_dropout = torch.nn.Dropout(dropout)
|
| 790 |
-
self.self_attn = Wav2Vec2ConformerSelfAttention(config)
|
| 791 |
-
|
| 792 |
-
# Conformer Convolution
|
| 793 |
-
self.conv_module = Wav2Vec2ConformerConvolutionModule(config)
|
| 794 |
-
|
| 795 |
-
# Feed-forward 2
|
| 796 |
-
self.ffn2_layer_norm = nn.LayerNorm(embed_dim)
|
| 797 |
-
self.ffn2 = Wav2Vec2ConformerFeedForward(config)
|
| 798 |
-
self.final_layer_norm = nn.LayerNorm(embed_dim)
|
| 799 |
-
|
| 800 |
-
def forward(
|
| 801 |
-
self,
|
| 802 |
-
hidden_states,
|
| 803 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 804 |
-
relative_position_embeddings: Optional[torch.Tensor] = None,
|
| 805 |
-
output_attentions: bool = False,
|
| 806 |
-
):
|
| 807 |
-
hidden_states = hidden_states
|
| 808 |
-
|
| 809 |
-
# 1. Feed-Forward 1 layer
|
| 810 |
-
residual = hidden_states
|
| 811 |
-
hidden_states = self.ffn1_layer_norm(hidden_states)
|
| 812 |
-
hidden_states = self.ffn1(hidden_states)
|
| 813 |
-
hidden_states = hidden_states * 0.5 + residual
|
| 814 |
-
residual = hidden_states
|
| 815 |
-
|
| 816 |
-
# 2. Self-Attention layer
|
| 817 |
-
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 818 |
-
hidden_states, attn_weigts = self.self_attn(
|
| 819 |
-
hidden_states=hidden_states,
|
| 820 |
-
attention_mask=attention_mask,
|
| 821 |
-
relative_position_embeddings=relative_position_embeddings,
|
| 822 |
-
output_attentions=output_attentions,
|
| 823 |
-
)
|
| 824 |
-
hidden_states = self.self_attn_dropout(hidden_states)
|
| 825 |
-
hidden_states = hidden_states + residual
|
| 826 |
-
|
| 827 |
-
# 3. Convolutional Layer
|
| 828 |
-
residual = hidden_states
|
| 829 |
-
hidden_states = self.conv_module(hidden_states)
|
| 830 |
-
hidden_states = residual + hidden_states
|
| 831 |
-
|
| 832 |
-
# 4. Feed-Forward 2 Layer
|
| 833 |
-
residual = hidden_states
|
| 834 |
-
hidden_states = self.ffn2_layer_norm(hidden_states)
|
| 835 |
-
hidden_states = self.ffn2(hidden_states)
|
| 836 |
-
hidden_states = hidden_states * 0.5 + residual
|
| 837 |
-
hidden_states = self.final_layer_norm(hidden_states)
|
| 838 |
-
|
| 839 |
-
return hidden_states, attn_weigts
|
| 840 |
-
|
| 841 |
-
|
| 842 |
-
class Wav2Vec2ConformerEncoder(nn.Module):
|
| 843 |
-
def __init__(self, config, is_causal=False):
|
| 844 |
-
super().__init__()
|
| 845 |
-
config.is_causal = is_causal
|
| 846 |
-
self.config = config
|
| 847 |
-
|
| 848 |
-
if config.position_embeddings_type == "relative":
|
| 849 |
-
self.embed_positions = Wav2Vec2ConformerRelPositionalEmbedding(config)
|
| 850 |
-
elif config.position_embeddings_type == "rotary":
|
| 851 |
-
self.embed_positions = Wav2Vec2ConformerRotaryPositionalEmbedding(config)
|
| 852 |
-
else:
|
| 853 |
-
self.embed_positions = None
|
| 854 |
-
|
| 855 |
-
self.pos_conv_embed = Wav2Vec2ConformerPositionalConvEmbedding(config)
|
| 856 |
-
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 857 |
-
self.dropout = nn.Dropout(config.hidden_dropout)
|
| 858 |
-
self.layers = nn.ModuleList([Wav2Vec2ConformerEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 859 |
-
self.gradient_checkpointing = False
|
| 860 |
-
|
| 861 |
-
def forward(
|
| 862 |
-
self,
|
| 863 |
-
hidden_states,
|
| 864 |
-
attention_mask=None,
|
| 865 |
-
output_attentions=False,
|
| 866 |
-
output_hidden_states=False,
|
| 867 |
-
return_dict=True,
|
| 868 |
-
):
|
| 869 |
-
all_hidden_states = () if output_hidden_states else None
|
| 870 |
-
all_self_attentions = () if output_attentions else None
|
| 871 |
-
|
| 872 |
-
if attention_mask is not None:
|
| 873 |
-
# make sure padded tokens output 0
|
| 874 |
-
hidden_states[~attention_mask] = 0.0
|
| 875 |
-
|
| 876 |
-
# extend attention_mask
|
| 877 |
-
attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
|
| 878 |
-
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
|
| 879 |
-
attention_mask = attention_mask.expand(
|
| 880 |
-
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
|
| 881 |
-
)
|
| 882 |
-
|
| 883 |
-
hidden_states = self.dropout(hidden_states)
|
| 884 |
-
|
| 885 |
-
if self.embed_positions is not None:
|
| 886 |
-
relative_position_embeddings = self.embed_positions(hidden_states)
|
| 887 |
-
else:
|
| 888 |
-
relative_position_embeddings = None
|
| 889 |
-
|
| 890 |
-
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
| 891 |
-
|
| 892 |
-
for i, layer in enumerate(self.layers):
|
| 893 |
-
if output_hidden_states:
|
| 894 |
-
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 895 |
-
|
| 896 |
-
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
| 897 |
-
dropout_probability = np.random.uniform(0, 1)
|
| 898 |
-
|
| 899 |
-
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
|
| 900 |
-
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
| 901 |
-
# under deepspeed zero3 all gpus must run in sync
|
| 902 |
-
if self.gradient_checkpointing and self.training:
|
| 903 |
-
# create gradient checkpointing function
|
| 904 |
-
def create_custom_forward(module):
|
| 905 |
-
def custom_forward(*inputs):
|
| 906 |
-
return module(*inputs, output_attentions)
|
| 907 |
-
|
| 908 |
-
return custom_forward
|
| 909 |
-
|
| 910 |
-
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 911 |
-
create_custom_forward(layer),
|
| 912 |
-
hidden_states,
|
| 913 |
-
attention_mask,
|
| 914 |
-
relative_position_embeddings,
|
| 915 |
-
)
|
| 916 |
-
else:
|
| 917 |
-
layer_outputs = layer(
|
| 918 |
-
hidden_states,
|
| 919 |
-
attention_mask=attention_mask,
|
| 920 |
-
relative_position_embeddings=relative_position_embeddings,
|
| 921 |
-
output_attentions=output_attentions,
|
| 922 |
-
)
|
| 923 |
-
hidden_states = layer_outputs[0]
|
| 924 |
-
|
| 925 |
-
if skip_the_layer:
|
| 926 |
-
layer_outputs = (None, None)
|
| 927 |
-
|
| 928 |
-
if output_attentions:
|
| 929 |
-
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 930 |
-
|
| 931 |
-
hidden_states = self.layer_norm(hidden_states)
|
| 932 |
-
if output_hidden_states:
|
| 933 |
-
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 934 |
-
|
| 935 |
-
if not return_dict:
|
| 936 |
-
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
| 937 |
-
return BaseModelOutput(
|
| 938 |
-
last_hidden_state=hidden_states,
|
| 939 |
-
hidden_states=all_hidden_states,
|
| 940 |
-
attentions=all_self_attentions,
|
| 941 |
-
)
|
| 942 |
-
|
| 943 |
-
|
| 944 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GumbelVectorQuantizer with Wav2Vec2->Wav2Vec2Conformer
|
| 945 |
-
class Wav2Vec2ConformerGumbelVectorQuantizer(nn.Module):
|
| 946 |
-
"""
|
| 947 |
-
Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH
|
| 948 |
-
GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information.
|
| 949 |
-
"""
|
| 950 |
-
|
| 951 |
-
def __init__(self, config):
|
| 952 |
-
super().__init__()
|
| 953 |
-
self.num_groups = config.num_codevector_groups
|
| 954 |
-
self.num_vars = config.num_codevectors_per_group
|
| 955 |
-
|
| 956 |
-
if config.codevector_dim % self.num_groups != 0:
|
| 957 |
-
raise ValueError(
|
| 958 |
-
f"`config.codevector_dim {config.codevector_dim} must be divisible "
|
| 959 |
-
f"by `config.num_codevector_groups` {self.num_groups} for concatenation"
|
| 960 |
-
)
|
| 961 |
-
|
| 962 |
-
# storage for codebook variables (codewords)
|
| 963 |
-
self.codevectors = nn.Parameter(
|
| 964 |
-
torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups)
|
| 965 |
-
)
|
| 966 |
-
self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
|
| 967 |
-
|
| 968 |
-
# can be decayed for training
|
| 969 |
-
self.temperature = 2
|
| 970 |
-
|
| 971 |
-
@staticmethod
|
| 972 |
-
def _compute_perplexity(probs, mask=None):
|
| 973 |
-
if mask is not None:
|
| 974 |
-
mask_extended = mask.flatten()[:, None, None].expand(probs.shape)
|
| 975 |
-
probs = torch.where(mask_extended, probs, torch.zeros_like(probs))
|
| 976 |
-
marginal_probs = probs.sum(dim=0) / mask.sum()
|
| 977 |
-
else:
|
| 978 |
-
marginal_probs = probs.mean(dim=0)
|
| 979 |
-
|
| 980 |
-
perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
|
| 981 |
-
return perplexity
|
| 982 |
-
|
| 983 |
-
def forward(self, hidden_states, mask_time_indices=None):
|
| 984 |
-
batch_size, sequence_length, hidden_size = hidden_states.shape
|
| 985 |
-
|
| 986 |
-
# project to codevector dim
|
| 987 |
-
hidden_states = self.weight_proj(hidden_states)
|
| 988 |
-
hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
|
| 989 |
-
|
| 990 |
-
if self.training:
|
| 991 |
-
# sample code vector probs via gumbel in differentiateable way
|
| 992 |
-
codevector_probs = nn.functional.gumbel_softmax(
|
| 993 |
-
hidden_states.float(), tau=self.temperature, hard=True
|
| 994 |
-
).type_as(hidden_states)
|
| 995 |
-
|
| 996 |
-
# compute perplexity
|
| 997 |
-
codevector_soft_dist = torch.softmax(
|
| 998 |
-
hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
|
| 999 |
-
)
|
| 1000 |
-
perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
|
| 1001 |
-
else:
|
| 1002 |
-
# take argmax in non-differentiable way
|
| 1003 |
-
# comptute hard codevector distribution (one hot)
|
| 1004 |
-
codevector_idx = hidden_states.argmax(dim=-1)
|
| 1005 |
-
codevector_probs = hidden_states.new_zeros(hidden_states.shape).scatter_(
|
| 1006 |
-
-1, codevector_idx.view(-1, 1), 1.0
|
| 1007 |
-
)
|
| 1008 |
-
codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
|
| 1009 |
-
|
| 1010 |
-
perplexity = self._compute_perplexity(codevector_probs, mask_time_indices)
|
| 1011 |
-
|
| 1012 |
-
codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
|
| 1013 |
-
# use probs to retrieve codevectors
|
| 1014 |
-
codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
|
| 1015 |
-
codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
|
| 1016 |
-
codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
|
| 1017 |
-
|
| 1018 |
-
return codevectors, perplexity
|
| 1019 |
-
|
| 1020 |
-
|
| 1021 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Adapter with Wav2Vec2->Wav2Vec2Conformer
|
| 1022 |
-
class Wav2Vec2ConformerAdapter(nn.Module):
|
| 1023 |
-
def __init__(self, config):
|
| 1024 |
-
super().__init__()
|
| 1025 |
-
|
| 1026 |
-
# feature dim might need to be down-projected
|
| 1027 |
-
if config.output_hidden_size != config.hidden_size:
|
| 1028 |
-
self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
|
| 1029 |
-
self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size)
|
| 1030 |
-
else:
|
| 1031 |
-
self.proj = self.proj_layer_norm = None
|
| 1032 |
-
|
| 1033 |
-
self.layers = nn.ModuleList(Wav2Vec2ConformerAdapterLayer(config) for _ in range(config.num_adapter_layers))
|
| 1034 |
-
self.layerdrop = config.layerdrop
|
| 1035 |
-
|
| 1036 |
-
def forward(self, hidden_states):
|
| 1037 |
-
# down project hidden_states if necessary
|
| 1038 |
-
if self.proj is not None and self.proj_layer_norm is not None:
|
| 1039 |
-
hidden_states = self.proj(hidden_states)
|
| 1040 |
-
hidden_states = self.proj_layer_norm(hidden_states)
|
| 1041 |
-
|
| 1042 |
-
hidden_states = hidden_states.transpose(1, 2)
|
| 1043 |
-
|
| 1044 |
-
for layer in self.layers:
|
| 1045 |
-
layerdrop_prob = np.random.random()
|
| 1046 |
-
if not self.training or (layerdrop_prob > self.layerdrop):
|
| 1047 |
-
hidden_states = layer(hidden_states)
|
| 1048 |
-
|
| 1049 |
-
hidden_states = hidden_states.transpose(1, 2)
|
| 1050 |
-
return hidden_states
|
| 1051 |
-
|
| 1052 |
-
|
| 1053 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AdapterLayer with Wav2Vec2->Wav2Vec2Conformer
|
| 1054 |
-
class Wav2Vec2ConformerAdapterLayer(nn.Module):
|
| 1055 |
-
def __init__(self, config):
|
| 1056 |
-
super().__init__()
|
| 1057 |
-
self.conv = nn.Conv1d(
|
| 1058 |
-
config.output_hidden_size,
|
| 1059 |
-
2 * config.output_hidden_size,
|
| 1060 |
-
config.adapter_kernel_size,
|
| 1061 |
-
stride=config.adapter_stride,
|
| 1062 |
-
padding=1,
|
| 1063 |
-
)
|
| 1064 |
-
|
| 1065 |
-
def forward(self, hidden_states):
|
| 1066 |
-
hidden_states = self.conv(hidden_states)
|
| 1067 |
-
hidden_states = nn.functional.glu(hidden_states, dim=1)
|
| 1068 |
-
|
| 1069 |
-
return hidden_states
|
| 1070 |
-
|
| 1071 |
-
|
| 1072 |
-
class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel):
|
| 1073 |
-
"""
|
| 1074 |
-
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 1075 |
-
models.
|
| 1076 |
-
"""
|
| 1077 |
-
|
| 1078 |
-
config_class = Wav2Vec2ConformerConfig
|
| 1079 |
-
base_model_prefix = "wav2vec2_conformer"
|
| 1080 |
-
main_input_name = "input_values"
|
| 1081 |
-
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
| 1082 |
-
supports_gradient_checkpointing = True
|
| 1083 |
-
|
| 1084 |
-
def _init_weights(self, module):
|
| 1085 |
-
"""Initialize the weights"""
|
| 1086 |
-
# Wav2Vec2ForPreTraining last 2 linear layers need standard Linear init.
|
| 1087 |
-
if isinstance(module, Wav2Vec2ConformerForPreTraining):
|
| 1088 |
-
module.project_hid.reset_parameters()
|
| 1089 |
-
module.project_q.reset_parameters()
|
| 1090 |
-
module.project_hid._is_hf_initialized = True
|
| 1091 |
-
module.project_q._is_hf_initialized = True
|
| 1092 |
-
# gumbel softmax requires special init
|
| 1093 |
-
elif isinstance(module, Wav2Vec2ConformerGumbelVectorQuantizer):
|
| 1094 |
-
module.weight_proj.weight.data.normal_(mean=0.0, std=1)
|
| 1095 |
-
module.weight_proj.bias.data.zero_()
|
| 1096 |
-
nn.init.uniform_(module.codevectors)
|
| 1097 |
-
elif isinstance(module, Wav2Vec2ConformerSelfAttention):
|
| 1098 |
-
if hasattr(module, "pos_bias_u"):
|
| 1099 |
-
nn.init.xavier_uniform_(module.pos_bias_u)
|
| 1100 |
-
if hasattr(module, "pos_bias_v"):
|
| 1101 |
-
nn.init.xavier_uniform_(module.pos_bias_v)
|
| 1102 |
-
elif isinstance(module, Wav2Vec2ConformerPositionalConvEmbedding):
|
| 1103 |
-
nn.init.normal_(
|
| 1104 |
-
module.conv.weight,
|
| 1105 |
-
mean=0,
|
| 1106 |
-
std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
|
| 1107 |
-
)
|
| 1108 |
-
nn.init.constant_(module.conv.bias, 0)
|
| 1109 |
-
elif isinstance(module, Wav2Vec2ConformerFeatureProjection):
|
| 1110 |
-
k = math.sqrt(1 / module.projection.in_features)
|
| 1111 |
-
nn.init.uniform_(module.projection.weight, a=-k, b=k)
|
| 1112 |
-
nn.init.uniform_(module.projection.bias, a=-k, b=k)
|
| 1113 |
-
elif isinstance(module, nn.Linear):
|
| 1114 |
-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 1115 |
-
|
| 1116 |
-
if module.bias is not None:
|
| 1117 |
-
module.bias.data.zero_()
|
| 1118 |
-
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
|
| 1119 |
-
module.bias.data.zero_()
|
| 1120 |
-
module.weight.data.fill_(1.0)
|
| 1121 |
-
elif isinstance(module, nn.Conv1d):
|
| 1122 |
-
nn.init.kaiming_normal_(module.weight)
|
| 1123 |
-
|
| 1124 |
-
if module.bias is not None:
|
| 1125 |
-
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
|
| 1126 |
-
nn.init.uniform_(module.bias, a=-k, b=k)
|
| 1127 |
-
|
| 1128 |
-
def _get_feat_extract_output_lengths(
|
| 1129 |
-
self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
|
| 1130 |
-
):
|
| 1131 |
-
"""
|
| 1132 |
-
Computes the output length of the convolutional layers
|
| 1133 |
-
"""
|
| 1134 |
-
|
| 1135 |
-
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
|
| 1136 |
-
|
| 1137 |
-
def _conv_out_length(input_length, kernel_size, stride):
|
| 1138 |
-
# 1D convolutional layer output length formula taken
|
| 1139 |
-
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
| 1140 |
-
return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
|
| 1141 |
-
|
| 1142 |
-
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
|
| 1143 |
-
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
|
| 1144 |
-
|
| 1145 |
-
if add_adapter:
|
| 1146 |
-
for _ in range(self.config.num_adapter_layers):
|
| 1147 |
-
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
|
| 1148 |
-
|
| 1149 |
-
return input_lengths
|
| 1150 |
-
|
| 1151 |
-
def _get_feature_vector_attention_mask(
|
| 1152 |
-
self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
|
| 1153 |
-
):
|
| 1154 |
-
# Effectively attention_mask.sum(-1), but not inplace to be able to run
|
| 1155 |
-
# on inference mode.
|
| 1156 |
-
non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
|
| 1157 |
-
|
| 1158 |
-
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
|
| 1159 |
-
output_lengths = output_lengths.to(torch.long)
|
| 1160 |
-
|
| 1161 |
-
batch_size = attention_mask.shape[0]
|
| 1162 |
-
|
| 1163 |
-
attention_mask = torch.zeros(
|
| 1164 |
-
(batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
|
| 1165 |
-
)
|
| 1166 |
-
# these two operations makes sure that all values before the output lengths idxs are attended to
|
| 1167 |
-
attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
|
| 1168 |
-
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
| 1169 |
-
return attention_mask
|
| 1170 |
-
|
| 1171 |
-
def _set_gradient_checkpointing(self, module, value=False):
|
| 1172 |
-
if isinstance(module, (Wav2Vec2ConformerEncoder, Wav2Vec2ConformerFeatureEncoder)):
|
| 1173 |
-
module.gradient_checkpointing = value
|
| 1174 |
-
|
| 1175 |
-
|
| 1176 |
-
WAV2VEC2_CONFORMER_START_DOCSTRING = r"""
|
| 1177 |
-
Wav2Vec2Conformer was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
|
| 1178 |
-
Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
|
| 1179 |
-
Auli.
|
| 1180 |
-
|
| 1181 |
-
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 1182 |
-
library implements for all its model (such as downloading or saving etc.).
|
| 1183 |
-
|
| 1184 |
-
This model is a PyTorch [nn.Module](https://pytorch.org/docs/stable/nn.html#nn.Module) sub-class. Use it as a
|
| 1185 |
-
regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
|
| 1186 |
-
|
| 1187 |
-
Parameters:
|
| 1188 |
-
config ([`Wav2Vec2ConformerConfig`]): Model configuration class with all the parameters of the model.
|
| 1189 |
-
Initializing with a config file does not load the weights associated with the model, only the
|
| 1190 |
-
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 1191 |
-
"""
|
| 1192 |
-
|
| 1193 |
-
|
| 1194 |
-
WAV2VEC2_CONFORMER_INPUTS_DOCSTRING = r"""
|
| 1195 |
-
Args:
|
| 1196 |
-
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
| 1197 |
-
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
|
| 1198 |
-
into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
|
| 1199 |
-
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
|
| 1200 |
-
conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
|
| 1201 |
-
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1202 |
-
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
|
| 1203 |
-
1]`:
|
| 1204 |
-
|
| 1205 |
-
- 1 for tokens that are **not masked**,
|
| 1206 |
-
- 0 for tokens that are **masked**.
|
| 1207 |
-
|
| 1208 |
-
[What are attention masks?](../glossary#attention-mask)
|
| 1209 |
-
|
| 1210 |
-
<Tip warning={true}>
|
| 1211 |
-
|
| 1212 |
-
`attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask ==
|
| 1213 |
-
True`. For all models whose processor has `config.return_attention_mask == False`, such as
|
| 1214 |
-
[wav2vec2-conformer-rel-pos-large](https://huggingface.co/facebook/wav2vec2-conformer-rel-pos-large),
|
| 1215 |
-
`attention_mask` should **not** be passed to avoid degraded performance when doing batched inference. For
|
| 1216 |
-
such models `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware
|
| 1217 |
-
that these models also yield slightly different results depending on whether `input_values` is padded or
|
| 1218 |
-
not.
|
| 1219 |
-
|
| 1220 |
-
</Tip>
|
| 1221 |
-
|
| 1222 |
-
output_attentions (`bool`, *optional*):
|
| 1223 |
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 1224 |
-
tensors for more detail.
|
| 1225 |
-
output_hidden_states (`bool`, *optional*):
|
| 1226 |
-
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 1227 |
-
more detail.
|
| 1228 |
-
return_dict (`bool`, *optional*):
|
| 1229 |
-
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 1230 |
-
"""
|
| 1231 |
-
|
| 1232 |
-
|
| 1233 |
-
@add_start_docstrings(
|
| 1234 |
-
"The bare Wav2Vec2Conformer Model transformer outputting raw hidden-states without any specific head on top.",
|
| 1235 |
-
WAV2VEC2_CONFORMER_START_DOCSTRING,
|
| 1236 |
-
)
|
| 1237 |
-
class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
|
| 1238 |
-
def __init__(self, config: Wav2Vec2ConformerConfig):
|
| 1239 |
-
super().__init__(config)
|
| 1240 |
-
self.config = config
|
| 1241 |
-
self.feature_extractor = Wav2Vec2ConformerFeatureEncoder(config)
|
| 1242 |
-
self.feature_projection = Wav2Vec2ConformerFeatureProjection(config)
|
| 1243 |
-
|
| 1244 |
-
# model only needs masking vector if mask prob is > 0.0
|
| 1245 |
-
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
|
| 1246 |
-
self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
|
| 1247 |
-
|
| 1248 |
-
self.encoder = Wav2Vec2ConformerEncoder(config)
|
| 1249 |
-
|
| 1250 |
-
self.adapter = Wav2Vec2ConformerAdapter(config) if config.add_adapter else None
|
| 1251 |
-
|
| 1252 |
-
# Initialize weights and apply final processing
|
| 1253 |
-
self.post_init()
|
| 1254 |
-
|
| 1255 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.freeze_feature_encoder
|
| 1256 |
-
def freeze_feature_encoder(self):
|
| 1257 |
-
"""
|
| 1258 |
-
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 1259 |
-
not be updated during training.
|
| 1260 |
-
"""
|
| 1261 |
-
self.feature_extractor._freeze_parameters()
|
| 1262 |
-
|
| 1263 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
|
| 1264 |
-
def _mask_hidden_states(
|
| 1265 |
-
self,
|
| 1266 |
-
hidden_states: torch.FloatTensor,
|
| 1267 |
-
mask_time_indices: Optional[torch.FloatTensor] = None,
|
| 1268 |
-
attention_mask: Optional[torch.LongTensor] = None,
|
| 1269 |
-
):
|
| 1270 |
-
"""
|
| 1271 |
-
Masks extracted features along time axis and/or along feature axis according to
|
| 1272 |
-
[SpecAugment](https://arxiv.org/abs/1904.08779).
|
| 1273 |
-
"""
|
| 1274 |
-
|
| 1275 |
-
# `config.apply_spec_augment` can set masking to False
|
| 1276 |
-
if not getattr(self.config, "apply_spec_augment", True):
|
| 1277 |
-
return hidden_states
|
| 1278 |
-
|
| 1279 |
-
# generate indices & apply SpecAugment along time axis
|
| 1280 |
-
batch_size, sequence_length, hidden_size = hidden_states.size()
|
| 1281 |
-
|
| 1282 |
-
if mask_time_indices is not None:
|
| 1283 |
-
# apply SpecAugment along time axis with given mask_time_indices
|
| 1284 |
-
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
|
| 1285 |
-
elif self.config.mask_time_prob > 0 and self.training:
|
| 1286 |
-
mask_time_indices = _compute_mask_indices(
|
| 1287 |
-
(batch_size, sequence_length),
|
| 1288 |
-
mask_prob=self.config.mask_time_prob,
|
| 1289 |
-
mask_length=self.config.mask_time_length,
|
| 1290 |
-
attention_mask=attention_mask,
|
| 1291 |
-
min_masks=self.config.mask_time_min_masks,
|
| 1292 |
-
)
|
| 1293 |
-
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
|
| 1294 |
-
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
|
| 1295 |
-
|
| 1296 |
-
if self.config.mask_feature_prob > 0 and self.training:
|
| 1297 |
-
# generate indices & apply SpecAugment along feature axis
|
| 1298 |
-
mask_feature_indices = _compute_mask_indices(
|
| 1299 |
-
(batch_size, hidden_size),
|
| 1300 |
-
mask_prob=self.config.mask_feature_prob,
|
| 1301 |
-
mask_length=self.config.mask_feature_length,
|
| 1302 |
-
min_masks=self.config.mask_feature_min_masks,
|
| 1303 |
-
)
|
| 1304 |
-
mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
|
| 1305 |
-
mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
|
| 1306 |
-
hidden_states[mask_feature_indices] = 0
|
| 1307 |
-
|
| 1308 |
-
return hidden_states
|
| 1309 |
-
|
| 1310 |
-
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
|
| 1311 |
-
@add_code_sample_docstrings(
|
| 1312 |
-
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1313 |
-
output_type=Wav2Vec2BaseModelOutput,
|
| 1314 |
-
config_class=_CONFIG_FOR_DOC,
|
| 1315 |
-
modality="audio",
|
| 1316 |
-
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
| 1317 |
-
)
|
| 1318 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.forward with wav2vec2->wav2vec2_conformer
|
| 1319 |
-
def forward(
|
| 1320 |
-
self,
|
| 1321 |
-
input_values: Optional[torch.Tensor],
|
| 1322 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 1323 |
-
mask_time_indices: Optional[torch.FloatTensor] = None,
|
| 1324 |
-
output_attentions: Optional[bool] = None,
|
| 1325 |
-
output_hidden_states: Optional[bool] = None,
|
| 1326 |
-
return_dict: Optional[bool] = None,
|
| 1327 |
-
) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
|
| 1328 |
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1329 |
-
output_hidden_states = (
|
| 1330 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1331 |
-
)
|
| 1332 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1333 |
-
|
| 1334 |
-
extract_features = self.feature_extractor(input_values)
|
| 1335 |
-
extract_features = extract_features.transpose(1, 2)
|
| 1336 |
-
|
| 1337 |
-
if attention_mask is not None:
|
| 1338 |
-
# compute reduced attention_mask corresponding to feature vectors
|
| 1339 |
-
attention_mask = self._get_feature_vector_attention_mask(
|
| 1340 |
-
extract_features.shape[1], attention_mask, add_adapter=False
|
| 1341 |
-
)
|
| 1342 |
-
|
| 1343 |
-
hidden_states, extract_features = self.feature_projection(extract_features)
|
| 1344 |
-
hidden_states = self._mask_hidden_states(
|
| 1345 |
-
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
|
| 1346 |
-
)
|
| 1347 |
-
|
| 1348 |
-
encoder_outputs = self.encoder(
|
| 1349 |
-
hidden_states,
|
| 1350 |
-
attention_mask=attention_mask,
|
| 1351 |
-
output_attentions=output_attentions,
|
| 1352 |
-
output_hidden_states=output_hidden_states,
|
| 1353 |
-
return_dict=return_dict,
|
| 1354 |
-
)
|
| 1355 |
-
|
| 1356 |
-
hidden_states = encoder_outputs[0]
|
| 1357 |
-
|
| 1358 |
-
if self.adapter is not None:
|
| 1359 |
-
hidden_states = self.adapter(hidden_states)
|
| 1360 |
-
|
| 1361 |
-
if not return_dict:
|
| 1362 |
-
return (hidden_states, extract_features) + encoder_outputs[1:]
|
| 1363 |
-
|
| 1364 |
-
return Wav2Vec2BaseModelOutput(
|
| 1365 |
-
last_hidden_state=hidden_states,
|
| 1366 |
-
extract_features=extract_features,
|
| 1367 |
-
hidden_states=encoder_outputs.hidden_states,
|
| 1368 |
-
attentions=encoder_outputs.attentions,
|
| 1369 |
-
)
|
| 1370 |
-
|
| 1371 |
-
|
| 1372 |
-
@add_start_docstrings(
|
| 1373 |
-
"""Wav2Vec2Conformer Model with a quantizer and `VQ` head on top.""", WAV2VEC2_CONFORMER_START_DOCSTRING
|
| 1374 |
-
)
|
| 1375 |
-
class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
|
| 1376 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
|
| 1377 |
-
def __init__(self, config: Wav2Vec2ConformerConfig):
|
| 1378 |
-
super().__init__(config)
|
| 1379 |
-
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
|
| 1380 |
-
self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)
|
| 1381 |
-
|
| 1382 |
-
self.quantizer = Wav2Vec2ConformerGumbelVectorQuantizer(config)
|
| 1383 |
-
|
| 1384 |
-
self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
|
| 1385 |
-
self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
|
| 1386 |
-
|
| 1387 |
-
# Initialize weights and apply final processing
|
| 1388 |
-
self.post_init()
|
| 1389 |
-
|
| 1390 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.set_gumbel_temperature
|
| 1391 |
-
def set_gumbel_temperature(self, temperature: int):
|
| 1392 |
-
"""
|
| 1393 |
-
Set the Gumbel softmax temperature to a given value. Only necessary for training
|
| 1394 |
-
"""
|
| 1395 |
-
self.quantizer.temperature = temperature
|
| 1396 |
-
|
| 1397 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
|
| 1398 |
-
def freeze_feature_encoder(self):
|
| 1399 |
-
"""
|
| 1400 |
-
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 1401 |
-
not be updated during training.
|
| 1402 |
-
"""
|
| 1403 |
-
self.wav2vec2_conformer.feature_extractor._freeze_parameters()
|
| 1404 |
-
|
| 1405 |
-
@staticmethod
|
| 1406 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.compute_contrastive_logits
|
| 1407 |
-
def compute_contrastive_logits(
|
| 1408 |
-
target_features: torch.FloatTensor,
|
| 1409 |
-
negative_features: torch.FloatTensor,
|
| 1410 |
-
predicted_features: torch.FloatTensor,
|
| 1411 |
-
temperature: int = 0.1,
|
| 1412 |
-
):
|
| 1413 |
-
"""
|
| 1414 |
-
Compute logits for contrastive loss based using cosine similarity as the distance measure between
|
| 1415 |
-
`[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied.
|
| 1416 |
-
"""
|
| 1417 |
-
target_features = torch.cat([target_features, negative_features], dim=0)
|
| 1418 |
-
|
| 1419 |
-
logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1).type_as(
|
| 1420 |
-
target_features
|
| 1421 |
-
)
|
| 1422 |
-
|
| 1423 |
-
# apply temperature
|
| 1424 |
-
logits = logits / temperature
|
| 1425 |
-
return logits
|
| 1426 |
-
|
| 1427 |
-
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
|
| 1428 |
-
@replace_return_docstrings(output_type=Wav2Vec2ConformerForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
| 1429 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,wav2vec2_conformer-base->wav2vec2-conformer-rel-pos-large
|
| 1430 |
-
def forward(
|
| 1431 |
-
self,
|
| 1432 |
-
input_values: Optional[torch.Tensor],
|
| 1433 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 1434 |
-
mask_time_indices: Optional[torch.BoolTensor] = None,
|
| 1435 |
-
sampled_negative_indices: Optional[torch.BoolTensor] = None,
|
| 1436 |
-
output_attentions: Optional[bool] = None,
|
| 1437 |
-
output_hidden_states: Optional[bool] = None,
|
| 1438 |
-
return_dict: Optional[bool] = None,
|
| 1439 |
-
) -> Union[Tuple, Wav2Vec2ConformerForPreTrainingOutput]:
|
| 1440 |
-
r"""
|
| 1441 |
-
mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1442 |
-
Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
|
| 1443 |
-
masked extracted features in *config.proj_codevector_dim* space.
|
| 1444 |
-
sampled_negative_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_negatives)`, *optional*):
|
| 1445 |
-
Indices indicating which quantized target vectors are used as negative sampled vectors in contrastive loss.
|
| 1446 |
-
Required input for pre-training.
|
| 1447 |
-
|
| 1448 |
-
Returns:
|
| 1449 |
-
|
| 1450 |
-
Example:
|
| 1451 |
-
|
| 1452 |
-
```python
|
| 1453 |
-
>>> import torch
|
| 1454 |
-
>>> from transformers import AutoFeatureExtractor, Wav2Vec2ConformerForPreTraining
|
| 1455 |
-
>>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
|
| 1456 |
-
... _compute_mask_indices,
|
| 1457 |
-
... _sample_negative_indices,
|
| 1458 |
-
... )
|
| 1459 |
-
>>> from datasets import load_dataset
|
| 1460 |
-
|
| 1461 |
-
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
|
| 1462 |
-
>>> model = Wav2Vec2ConformerForPreTraining.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
|
| 1463 |
-
|
| 1464 |
-
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
| 1465 |
-
>>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values # Batch size 1
|
| 1466 |
-
|
| 1467 |
-
>>> # compute masked indices
|
| 1468 |
-
>>> batch_size, raw_sequence_length = input_values.shape
|
| 1469 |
-
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length).item()
|
| 1470 |
-
>>> mask_time_indices = _compute_mask_indices(
|
| 1471 |
-
... shape=(batch_size, sequence_length), mask_prob=0.2, mask_length=2
|
| 1472 |
-
... )
|
| 1473 |
-
>>> sampled_negative_indices = _sample_negative_indices(
|
| 1474 |
-
... features_shape=(batch_size, sequence_length),
|
| 1475 |
-
... num_negatives=model.config.num_negatives,
|
| 1476 |
-
... mask_time_indices=mask_time_indices,
|
| 1477 |
-
... )
|
| 1478 |
-
>>> mask_time_indices = torch.tensor(data=mask_time_indices, device=input_values.device, dtype=torch.long)
|
| 1479 |
-
>>> sampled_negative_indices = torch.tensor(
|
| 1480 |
-
... data=sampled_negative_indices, device=input_values.device, dtype=torch.long
|
| 1481 |
-
... )
|
| 1482 |
-
|
| 1483 |
-
>>> with torch.no_grad():
|
| 1484 |
-
... outputs = model(input_values, mask_time_indices=mask_time_indices)
|
| 1485 |
-
|
| 1486 |
-
>>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states)
|
| 1487 |
-
>>> cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
|
| 1488 |
-
|
| 1489 |
-
>>> # show that cosine similarity is much higher than random
|
| 1490 |
-
>>> cosine_sim[mask_time_indices.to(torch.bool)].mean() > 0.5
|
| 1491 |
-
tensor(True)
|
| 1492 |
-
|
| 1493 |
-
>>> # for contrastive loss training model should be put into train mode
|
| 1494 |
-
>>> model = model.train()
|
| 1495 |
-
>>> loss = model(
|
| 1496 |
-
... input_values, mask_time_indices=mask_time_indices, sampled_negative_indices=sampled_negative_indices
|
| 1497 |
-
... ).loss
|
| 1498 |
-
```"""
|
| 1499 |
-
|
| 1500 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1501 |
-
|
| 1502 |
-
if mask_time_indices is not None:
|
| 1503 |
-
mask_time_indices = mask_time_indices.to(torch.bool)
|
| 1504 |
-
|
| 1505 |
-
outputs = self.wav2vec2_conformer(
|
| 1506 |
-
input_values,
|
| 1507 |
-
attention_mask=attention_mask,
|
| 1508 |
-
output_attentions=output_attentions,
|
| 1509 |
-
output_hidden_states=output_hidden_states,
|
| 1510 |
-
mask_time_indices=mask_time_indices,
|
| 1511 |
-
return_dict=return_dict,
|
| 1512 |
-
)
|
| 1513 |
-
|
| 1514 |
-
# 1. project all transformed features (including masked) to final vq dim
|
| 1515 |
-
transformer_features = self.project_hid(outputs[0])
|
| 1516 |
-
|
| 1517 |
-
# 2. quantize all (unmasked) extracted features and project to final vq dim
|
| 1518 |
-
extract_features = self.dropout_features(outputs[1])
|
| 1519 |
-
|
| 1520 |
-
if attention_mask is not None:
|
| 1521 |
-
# compute reduced attention_mask correponding to feature vectors
|
| 1522 |
-
attention_mask = self._get_feature_vector_attention_mask(
|
| 1523 |
-
extract_features.shape[1], attention_mask, add_adapter=False
|
| 1524 |
-
)
|
| 1525 |
-
|
| 1526 |
-
quantized_features, codevector_perplexity = self.quantizer(
|
| 1527 |
-
extract_features, mask_time_indices=mask_time_indices
|
| 1528 |
-
)
|
| 1529 |
-
quantized_features = self.project_q(quantized_features)
|
| 1530 |
-
|
| 1531 |
-
loss = contrastive_loss = diversity_loss = None
|
| 1532 |
-
if sampled_negative_indices is not None:
|
| 1533 |
-
batch_size, sequence_length, hidden_size = quantized_features.shape
|
| 1534 |
-
|
| 1535 |
-
# for training, we sample negatives
|
| 1536 |
-
# 3. sample K negatives (distractors) quantized states for contrastive loss
|
| 1537 |
-
# if attention_mask is passed, make sure that padded feature vectors cannot be sampled
|
| 1538 |
-
# sample negative quantized vectors BTC => (BxT)C
|
| 1539 |
-
negative_quantized_features = quantized_features.view(-1, hidden_size)[
|
| 1540 |
-
sampled_negative_indices.long().view(-1)
|
| 1541 |
-
]
|
| 1542 |
-
negative_quantized_features = negative_quantized_features.view(
|
| 1543 |
-
batch_size, sequence_length, -1, hidden_size
|
| 1544 |
-
).permute(2, 0, 1, 3)
|
| 1545 |
-
|
| 1546 |
-
# 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa`
|
| 1547 |
-
# of equation (3) in https://arxiv.org/pdf/2006.11477.pdf
|
| 1548 |
-
logits = self.compute_contrastive_logits(
|
| 1549 |
-
quantized_features[None, :],
|
| 1550 |
-
negative_quantized_features,
|
| 1551 |
-
transformer_features,
|
| 1552 |
-
self.config.contrastive_logits_temperature,
|
| 1553 |
-
)
|
| 1554 |
-
|
| 1555 |
-
# 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low),
|
| 1556 |
-
# its cosine similarity will be masked
|
| 1557 |
-
neg_is_pos = (quantized_features == negative_quantized_features).all(-1)
|
| 1558 |
-
|
| 1559 |
-
if neg_is_pos.any():
|
| 1560 |
-
logits[1:][neg_is_pos] = float("-inf")
|
| 1561 |
-
|
| 1562 |
-
# 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) =
|
| 1563 |
-
# -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa))
|
| 1564 |
-
logits = logits.transpose(0, 2).reshape(-1, logits.size(0))
|
| 1565 |
-
target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten()
|
| 1566 |
-
|
| 1567 |
-
contrastive_loss = nn.functional.cross_entropy(logits.float(), target, reduction="sum")
|
| 1568 |
-
# 7. compute diversity loss: \mathbf{L}_d
|
| 1569 |
-
num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups
|
| 1570 |
-
diversity_loss = ((num_codevectors - codevector_perplexity) / num_codevectors) * mask_time_indices.sum()
|
| 1571 |
-
|
| 1572 |
-
# 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d
|
| 1573 |
-
loss = contrastive_loss + self.config.diversity_loss_weight * diversity_loss
|
| 1574 |
-
|
| 1575 |
-
if not return_dict:
|
| 1576 |
-
if loss is not None:
|
| 1577 |
-
return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
|
| 1578 |
-
return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
|
| 1579 |
-
|
| 1580 |
-
return Wav2Vec2ConformerForPreTrainingOutput(
|
| 1581 |
-
loss=loss,
|
| 1582 |
-
projected_states=transformer_features,
|
| 1583 |
-
projected_quantized_states=quantized_features,
|
| 1584 |
-
codevector_perplexity=codevector_perplexity,
|
| 1585 |
-
hidden_states=outputs.hidden_states,
|
| 1586 |
-
attentions=outputs.attentions,
|
| 1587 |
-
contrastive_loss=contrastive_loss,
|
| 1588 |
-
diversity_loss=diversity_loss,
|
| 1589 |
-
)
|
| 1590 |
-
|
| 1591 |
-
|
| 1592 |
-
@add_start_docstrings(
|
| 1593 |
-
"""Wav2Vec2Conformer Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
|
| 1594 |
-
WAV2VEC2_CONFORMER_START_DOCSTRING,
|
| 1595 |
-
)
|
| 1596 |
-
class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
|
| 1597 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
|
| 1598 |
-
def __init__(self, config):
|
| 1599 |
-
super().__init__(config)
|
| 1600 |
-
|
| 1601 |
-
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
|
| 1602 |
-
self.dropout = nn.Dropout(config.final_dropout)
|
| 1603 |
-
|
| 1604 |
-
if config.vocab_size is None:
|
| 1605 |
-
raise ValueError(
|
| 1606 |
-
f"You are trying to instantiate {self.__class__} with a configuration that "
|
| 1607 |
-
"does not define the vocabulary size of the language model head. Please "
|
| 1608 |
-
"instantiate the model as follows: `Wav2Vec2ConformerForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
|
| 1609 |
-
"or define `vocab_size` of your model's configuration."
|
| 1610 |
-
)
|
| 1611 |
-
output_hidden_size = (
|
| 1612 |
-
config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
|
| 1613 |
-
)
|
| 1614 |
-
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
| 1615 |
-
|
| 1616 |
-
# Initialize weights and apply final processing
|
| 1617 |
-
self.post_init()
|
| 1618 |
-
|
| 1619 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
|
| 1620 |
-
def freeze_feature_encoder(self):
|
| 1621 |
-
"""
|
| 1622 |
-
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 1623 |
-
not be updated during training.
|
| 1624 |
-
"""
|
| 1625 |
-
self.wav2vec2_conformer.feature_extractor._freeze_parameters()
|
| 1626 |
-
|
| 1627 |
-
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
|
| 1628 |
-
@add_code_sample_docstrings(
|
| 1629 |
-
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1630 |
-
output_type=CausalLMOutput,
|
| 1631 |
-
config_class=_CONFIG_FOR_DOC,
|
| 1632 |
-
expected_output=_CTC_EXPECTED_OUTPUT,
|
| 1633 |
-
expected_loss=_CTC_EXPECTED_LOSS,
|
| 1634 |
-
)
|
| 1635 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
|
| 1636 |
-
def forward(
|
| 1637 |
-
self,
|
| 1638 |
-
input_values: Optional[torch.Tensor],
|
| 1639 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 1640 |
-
output_attentions: Optional[bool] = None,
|
| 1641 |
-
output_hidden_states: Optional[bool] = None,
|
| 1642 |
-
return_dict: Optional[bool] = None,
|
| 1643 |
-
labels: Optional[torch.Tensor] = None,
|
| 1644 |
-
) -> Union[Tuple, CausalLMOutput]:
|
| 1645 |
-
r"""
|
| 1646 |
-
labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
|
| 1647 |
-
Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
|
| 1648 |
-
the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
|
| 1649 |
-
All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
|
| 1650 |
-
config.vocab_size - 1]`.
|
| 1651 |
-
"""
|
| 1652 |
-
|
| 1653 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1654 |
-
|
| 1655 |
-
outputs = self.wav2vec2_conformer(
|
| 1656 |
-
input_values,
|
| 1657 |
-
attention_mask=attention_mask,
|
| 1658 |
-
output_attentions=output_attentions,
|
| 1659 |
-
output_hidden_states=output_hidden_states,
|
| 1660 |
-
return_dict=return_dict,
|
| 1661 |
-
)
|
| 1662 |
-
|
| 1663 |
-
hidden_states = outputs[0]
|
| 1664 |
-
hidden_states = self.dropout(hidden_states)
|
| 1665 |
-
|
| 1666 |
-
logits = self.lm_head(hidden_states)
|
| 1667 |
-
|
| 1668 |
-
loss = None
|
| 1669 |
-
if labels is not None:
|
| 1670 |
-
if labels.max() >= self.config.vocab_size:
|
| 1671 |
-
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
|
| 1672 |
-
|
| 1673 |
-
# retrieve loss input_lengths from attention_mask
|
| 1674 |
-
attention_mask = (
|
| 1675 |
-
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
|
| 1676 |
-
)
|
| 1677 |
-
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
|
| 1678 |
-
|
| 1679 |
-
# assuming that padded tokens are filled with -100
|
| 1680 |
-
# when not being attended to
|
| 1681 |
-
labels_mask = labels >= 0
|
| 1682 |
-
target_lengths = labels_mask.sum(-1)
|
| 1683 |
-
flattened_targets = labels.masked_select(labels_mask)
|
| 1684 |
-
|
| 1685 |
-
# ctc_loss doesn't support fp16
|
| 1686 |
-
log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
|
| 1687 |
-
|
| 1688 |
-
with torch.backends.cudnn.flags(enabled=False):
|
| 1689 |
-
loss = nn.functional.ctc_loss(
|
| 1690 |
-
log_probs,
|
| 1691 |
-
flattened_targets,
|
| 1692 |
-
input_lengths,
|
| 1693 |
-
target_lengths,
|
| 1694 |
-
blank=self.config.pad_token_id,
|
| 1695 |
-
reduction=self.config.ctc_loss_reduction,
|
| 1696 |
-
zero_infinity=self.config.ctc_zero_infinity,
|
| 1697 |
-
)
|
| 1698 |
-
|
| 1699 |
-
if not return_dict:
|
| 1700 |
-
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
| 1701 |
-
return ((loss,) + output) if loss is not None else output
|
| 1702 |
-
|
| 1703 |
-
return CausalLMOutput(
|
| 1704 |
-
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
| 1705 |
-
)
|
| 1706 |
-
|
| 1707 |
-
|
| 1708 |
-
@add_start_docstrings(
|
| 1709 |
-
"""
|
| 1710 |
-
Wav2Vec2Conformer Model with a sequence classification head on top (a linear layer over the pooled output) for
|
| 1711 |
-
tasks like SUPERB Keyword Spotting.
|
| 1712 |
-
""",
|
| 1713 |
-
WAV2VEC2_CONFORMER_START_DOCSTRING,
|
| 1714 |
-
)
|
| 1715 |
-
class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ConformerPreTrainedModel):
|
| 1716 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
|
| 1717 |
-
def __init__(self, config):
|
| 1718 |
-
super().__init__(config)
|
| 1719 |
-
|
| 1720 |
-
if hasattr(config, "add_adapter") and config.add_adapter:
|
| 1721 |
-
raise ValueError(
|
| 1722 |
-
"Sequence classification does not support the use of Wav2Vec2Conformer adapters (config.add_adapter=True)"
|
| 1723 |
-
)
|
| 1724 |
-
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
|
| 1725 |
-
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
| 1726 |
-
if config.use_weighted_layer_sum:
|
| 1727 |
-
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
| 1728 |
-
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
|
| 1729 |
-
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
|
| 1730 |
-
|
| 1731 |
-
# Initialize weights and apply final processing
|
| 1732 |
-
self.post_init()
|
| 1733 |
-
|
| 1734 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
|
| 1735 |
-
def freeze_feature_encoder(self):
|
| 1736 |
-
"""
|
| 1737 |
-
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 1738 |
-
not be updated during training.
|
| 1739 |
-
"""
|
| 1740 |
-
self.wav2vec2_conformer.feature_extractor._freeze_parameters()
|
| 1741 |
-
|
| 1742 |
-
def freeze_base_model(self):
|
| 1743 |
-
"""
|
| 1744 |
-
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
| 1745 |
-
be updated during training. Only the classification head will be updated.
|
| 1746 |
-
"""
|
| 1747 |
-
for param in self.wav2vec2_conformer.parameters():
|
| 1748 |
-
param.requires_grad = False
|
| 1749 |
-
|
| 1750 |
-
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
|
| 1751 |
-
@add_code_sample_docstrings(
|
| 1752 |
-
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1753 |
-
output_type=SequenceClassifierOutput,
|
| 1754 |
-
config_class=_CONFIG_FOR_DOC,
|
| 1755 |
-
modality="audio",
|
| 1756 |
-
)
|
| 1757 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
|
| 1758 |
-
def forward(
|
| 1759 |
-
self,
|
| 1760 |
-
input_values: Optional[torch.Tensor],
|
| 1761 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 1762 |
-
output_attentions: Optional[bool] = None,
|
| 1763 |
-
output_hidden_states: Optional[bool] = None,
|
| 1764 |
-
return_dict: Optional[bool] = None,
|
| 1765 |
-
labels: Optional[torch.Tensor] = None,
|
| 1766 |
-
) -> Union[Tuple, SequenceClassifierOutput]:
|
| 1767 |
-
r"""
|
| 1768 |
-
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1769 |
-
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1770 |
-
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1771 |
-
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1772 |
-
"""
|
| 1773 |
-
|
| 1774 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1775 |
-
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
| 1776 |
-
|
| 1777 |
-
outputs = self.wav2vec2_conformer(
|
| 1778 |
-
input_values,
|
| 1779 |
-
attention_mask=attention_mask,
|
| 1780 |
-
output_attentions=output_attentions,
|
| 1781 |
-
output_hidden_states=output_hidden_states,
|
| 1782 |
-
return_dict=return_dict,
|
| 1783 |
-
)
|
| 1784 |
-
|
| 1785 |
-
if self.config.use_weighted_layer_sum:
|
| 1786 |
-
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
| 1787 |
-
hidden_states = torch.stack(hidden_states, dim=1)
|
| 1788 |
-
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
| 1789 |
-
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
| 1790 |
-
else:
|
| 1791 |
-
hidden_states = outputs[0]
|
| 1792 |
-
|
| 1793 |
-
hidden_states = self.projector(hidden_states)
|
| 1794 |
-
if attention_mask is None:
|
| 1795 |
-
pooled_output = hidden_states.mean(dim=1)
|
| 1796 |
-
else:
|
| 1797 |
-
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
|
| 1798 |
-
hidden_states[~padding_mask] = 0.0
|
| 1799 |
-
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
|
| 1800 |
-
|
| 1801 |
-
logits = self.classifier(pooled_output)
|
| 1802 |
-
|
| 1803 |
-
loss = None
|
| 1804 |
-
if labels is not None:
|
| 1805 |
-
loss_fct = CrossEntropyLoss()
|
| 1806 |
-
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
| 1807 |
-
|
| 1808 |
-
if not return_dict:
|
| 1809 |
-
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
| 1810 |
-
return ((loss,) + output) if loss is not None else output
|
| 1811 |
-
|
| 1812 |
-
return SequenceClassifierOutput(
|
| 1813 |
-
loss=loss,
|
| 1814 |
-
logits=logits,
|
| 1815 |
-
hidden_states=outputs.hidden_states,
|
| 1816 |
-
attentions=outputs.attentions,
|
| 1817 |
-
)
|
| 1818 |
-
|
| 1819 |
-
|
| 1820 |
-
@add_start_docstrings(
|
| 1821 |
-
"""
|
| 1822 |
-
Wav2Vec2Conformer Model with a frame classification head on top for tasks like Speaker Diarization.
|
| 1823 |
-
""",
|
| 1824 |
-
WAV2VEC2_CONFORMER_START_DOCSTRING,
|
| 1825 |
-
)
|
| 1826 |
-
class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedModel):
|
| 1827 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
|
| 1828 |
-
def __init__(self, config):
|
| 1829 |
-
super().__init__(config)
|
| 1830 |
-
|
| 1831 |
-
if hasattr(config, "add_adapter") and config.add_adapter:
|
| 1832 |
-
raise ValueError(
|
| 1833 |
-
"Audio frame classification does not support the use of Wav2Vec2Conformer adapters (config.add_adapter=True)"
|
| 1834 |
-
)
|
| 1835 |
-
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
|
| 1836 |
-
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
| 1837 |
-
if config.use_weighted_layer_sum:
|
| 1838 |
-
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
| 1839 |
-
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 1840 |
-
self.num_labels = config.num_labels
|
| 1841 |
-
|
| 1842 |
-
self.init_weights()
|
| 1843 |
-
|
| 1844 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
|
| 1845 |
-
def freeze_feature_encoder(self):
|
| 1846 |
-
"""
|
| 1847 |
-
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 1848 |
-
not be updated during training.
|
| 1849 |
-
"""
|
| 1850 |
-
self.wav2vec2_conformer.feature_extractor._freeze_parameters()
|
| 1851 |
-
|
| 1852 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_base_model with wav2vec2->wav2vec2_conformer
|
| 1853 |
-
def freeze_base_model(self):
|
| 1854 |
-
"""
|
| 1855 |
-
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
| 1856 |
-
be updated during training. Only the classification head will be updated.
|
| 1857 |
-
"""
|
| 1858 |
-
for param in self.wav2vec2_conformer.parameters():
|
| 1859 |
-
param.requires_grad = False
|
| 1860 |
-
|
| 1861 |
-
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
|
| 1862 |
-
@add_code_sample_docstrings(
|
| 1863 |
-
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 1864 |
-
output_type=TokenClassifierOutput,
|
| 1865 |
-
config_class=_CONFIG_FOR_DOC,
|
| 1866 |
-
modality="audio",
|
| 1867 |
-
)
|
| 1868 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.forward with wav2vec2->wav2vec2_conformer
|
| 1869 |
-
def forward(
|
| 1870 |
-
self,
|
| 1871 |
-
input_values: Optional[torch.Tensor],
|
| 1872 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 1873 |
-
labels: Optional[torch.Tensor] = None,
|
| 1874 |
-
output_attentions: Optional[bool] = None,
|
| 1875 |
-
output_hidden_states: Optional[bool] = None,
|
| 1876 |
-
return_dict: Optional[bool] = None,
|
| 1877 |
-
) -> Union[Tuple, TokenClassifierOutput]:
|
| 1878 |
-
r"""
|
| 1879 |
-
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1880 |
-
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1881 |
-
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1882 |
-
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1883 |
-
"""
|
| 1884 |
-
|
| 1885 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1886 |
-
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
| 1887 |
-
|
| 1888 |
-
outputs = self.wav2vec2_conformer(
|
| 1889 |
-
input_values,
|
| 1890 |
-
attention_mask=attention_mask,
|
| 1891 |
-
output_attentions=output_attentions,
|
| 1892 |
-
output_hidden_states=output_hidden_states,
|
| 1893 |
-
return_dict=return_dict,
|
| 1894 |
-
)
|
| 1895 |
-
|
| 1896 |
-
if self.config.use_weighted_layer_sum:
|
| 1897 |
-
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
| 1898 |
-
hidden_states = torch.stack(hidden_states, dim=1)
|
| 1899 |
-
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
| 1900 |
-
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
| 1901 |
-
else:
|
| 1902 |
-
hidden_states = outputs[0]
|
| 1903 |
-
|
| 1904 |
-
logits = self.classifier(hidden_states)
|
| 1905 |
-
|
| 1906 |
-
loss = None
|
| 1907 |
-
if labels is not None:
|
| 1908 |
-
loss_fct = CrossEntropyLoss()
|
| 1909 |
-
loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
|
| 1910 |
-
|
| 1911 |
-
if not return_dict:
|
| 1912 |
-
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
| 1913 |
-
return output
|
| 1914 |
-
|
| 1915 |
-
return TokenClassifierOutput(
|
| 1916 |
-
loss=loss,
|
| 1917 |
-
logits=logits,
|
| 1918 |
-
hidden_states=outputs.hidden_states,
|
| 1919 |
-
attentions=outputs.attentions,
|
| 1920 |
-
)
|
| 1921 |
-
|
| 1922 |
-
|
| 1923 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss
|
| 1924 |
-
class AMSoftmaxLoss(nn.Module):
|
| 1925 |
-
def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
|
| 1926 |
-
super(AMSoftmaxLoss, self).__init__()
|
| 1927 |
-
self.scale = scale
|
| 1928 |
-
self.margin = margin
|
| 1929 |
-
self.num_labels = num_labels
|
| 1930 |
-
self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
|
| 1931 |
-
self.loss = nn.CrossEntropyLoss()
|
| 1932 |
-
|
| 1933 |
-
def forward(self, hidden_states, labels):
|
| 1934 |
-
labels = labels.flatten()
|
| 1935 |
-
weight = nn.functional.normalize(self.weight, dim=0)
|
| 1936 |
-
hidden_states = nn.functional.normalize(hidden_states, dim=1)
|
| 1937 |
-
cos_theta = torch.mm(hidden_states, weight)
|
| 1938 |
-
psi = cos_theta - self.margin
|
| 1939 |
-
|
| 1940 |
-
onehot = nn.functional.one_hot(labels, self.num_labels)
|
| 1941 |
-
logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
|
| 1942 |
-
loss = self.loss(logits, labels)
|
| 1943 |
-
|
| 1944 |
-
return loss
|
| 1945 |
-
|
| 1946 |
-
|
| 1947 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer
|
| 1948 |
-
class TDNNLayer(nn.Module):
|
| 1949 |
-
def __init__(self, config, layer_id=0):
|
| 1950 |
-
super().__init__()
|
| 1951 |
-
self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
|
| 1952 |
-
self.out_conv_dim = config.tdnn_dim[layer_id]
|
| 1953 |
-
self.kernel_size = config.tdnn_kernel[layer_id]
|
| 1954 |
-
self.dilation = config.tdnn_dilation[layer_id]
|
| 1955 |
-
|
| 1956 |
-
self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
|
| 1957 |
-
self.activation = nn.ReLU()
|
| 1958 |
-
|
| 1959 |
-
def forward(self, hidden_states):
|
| 1960 |
-
hidden_states = hidden_states.unsqueeze(1)
|
| 1961 |
-
hidden_states = nn.functional.unfold(
|
| 1962 |
-
hidden_states,
|
| 1963 |
-
(self.kernel_size, self.in_conv_dim),
|
| 1964 |
-
stride=(1, self.in_conv_dim),
|
| 1965 |
-
dilation=(self.dilation, 1),
|
| 1966 |
-
)
|
| 1967 |
-
hidden_states = hidden_states.transpose(1, 2)
|
| 1968 |
-
hidden_states = self.kernel(hidden_states)
|
| 1969 |
-
|
| 1970 |
-
hidden_states = self.activation(hidden_states)
|
| 1971 |
-
return hidden_states
|
| 1972 |
-
|
| 1973 |
-
|
| 1974 |
-
@add_start_docstrings(
|
| 1975 |
-
"""
|
| 1976 |
-
Wav2Vec2Conformer Model with an XVector feature extraction head on top for tasks like Speaker Verification.
|
| 1977 |
-
""",
|
| 1978 |
-
WAV2VEC2_CONFORMER_START_DOCSTRING,
|
| 1979 |
-
)
|
| 1980 |
-
class Wav2Vec2ConformerForXVector(Wav2Vec2ConformerPreTrainedModel):
|
| 1981 |
-
def __init__(self, config):
|
| 1982 |
-
super().__init__(config)
|
| 1983 |
-
|
| 1984 |
-
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
|
| 1985 |
-
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
| 1986 |
-
if config.use_weighted_layer_sum:
|
| 1987 |
-
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
| 1988 |
-
self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
|
| 1989 |
-
|
| 1990 |
-
tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
|
| 1991 |
-
self.tdnn = nn.ModuleList(tdnn_layers)
|
| 1992 |
-
|
| 1993 |
-
self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
|
| 1994 |
-
self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
|
| 1995 |
-
|
| 1996 |
-
self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
|
| 1997 |
-
|
| 1998 |
-
self.init_weights()
|
| 1999 |
-
|
| 2000 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
|
| 2001 |
-
def freeze_feature_encoder(self):
|
| 2002 |
-
"""
|
| 2003 |
-
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
| 2004 |
-
not be updated during training.
|
| 2005 |
-
"""
|
| 2006 |
-
self.wav2vec2_conformer.feature_extractor._freeze_parameters()
|
| 2007 |
-
|
| 2008 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_base_model with wav2vec2->wav2vec2_conformer
|
| 2009 |
-
def freeze_base_model(self):
|
| 2010 |
-
"""
|
| 2011 |
-
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
| 2012 |
-
be updated during training. Only the classification head will be updated.
|
| 2013 |
-
"""
|
| 2014 |
-
for param in self.wav2vec2_conformer.parameters():
|
| 2015 |
-
param.requires_grad = False
|
| 2016 |
-
|
| 2017 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector._get_tdnn_output_lengths with wav2vec2->wav2vec2_conformer
|
| 2018 |
-
def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
|
| 2019 |
-
"""
|
| 2020 |
-
Computes the output length of the TDNN layers
|
| 2021 |
-
"""
|
| 2022 |
-
|
| 2023 |
-
def _conv_out_length(input_length, kernel_size, stride):
|
| 2024 |
-
# 1D convolutional layer output length formula taken
|
| 2025 |
-
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
| 2026 |
-
return (input_length - kernel_size) // stride + 1
|
| 2027 |
-
|
| 2028 |
-
for kernel_size in self.config.tdnn_kernel:
|
| 2029 |
-
input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
|
| 2030 |
-
|
| 2031 |
-
return input_lengths
|
| 2032 |
-
|
| 2033 |
-
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
|
| 2034 |
-
@add_code_sample_docstrings(
|
| 2035 |
-
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 2036 |
-
output_type=XVectorOutput,
|
| 2037 |
-
config_class=_CONFIG_FOR_DOC,
|
| 2038 |
-
modality="audio",
|
| 2039 |
-
)
|
| 2040 |
-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
|
| 2041 |
-
def forward(
|
| 2042 |
-
self,
|
| 2043 |
-
input_values: Optional[torch.Tensor],
|
| 2044 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 2045 |
-
output_attentions: Optional[bool] = None,
|
| 2046 |
-
output_hidden_states: Optional[bool] = None,
|
| 2047 |
-
return_dict: Optional[bool] = None,
|
| 2048 |
-
labels: Optional[torch.Tensor] = None,
|
| 2049 |
-
) -> Union[Tuple, XVectorOutput]:
|
| 2050 |
-
r"""
|
| 2051 |
-
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 2052 |
-
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 2053 |
-
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 2054 |
-
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 2055 |
-
"""
|
| 2056 |
-
|
| 2057 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 2058 |
-
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
| 2059 |
-
|
| 2060 |
-
outputs = self.wav2vec2_conformer(
|
| 2061 |
-
input_values,
|
| 2062 |
-
attention_mask=attention_mask,
|
| 2063 |
-
output_attentions=output_attentions,
|
| 2064 |
-
output_hidden_states=output_hidden_states,
|
| 2065 |
-
return_dict=return_dict,
|
| 2066 |
-
)
|
| 2067 |
-
|
| 2068 |
-
if self.config.use_weighted_layer_sum:
|
| 2069 |
-
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
| 2070 |
-
hidden_states = torch.stack(hidden_states, dim=1)
|
| 2071 |
-
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
| 2072 |
-
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
| 2073 |
-
else:
|
| 2074 |
-
hidden_states = outputs[0]
|
| 2075 |
-
|
| 2076 |
-
hidden_states = self.projector(hidden_states)
|
| 2077 |
-
|
| 2078 |
-
for tdnn_layer in self.tdnn:
|
| 2079 |
-
hidden_states = tdnn_layer(hidden_states)
|
| 2080 |
-
|
| 2081 |
-
# Statistic Pooling
|
| 2082 |
-
if attention_mask is None:
|
| 2083 |
-
mean_features = hidden_states.mean(dim=1)
|
| 2084 |
-
std_features = hidden_states.std(dim=1)
|
| 2085 |
-
else:
|
| 2086 |
-
feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
|
| 2087 |
-
tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
|
| 2088 |
-
mean_features = []
|
| 2089 |
-
std_features = []
|
| 2090 |
-
for i, length in enumerate(tdnn_output_lengths):
|
| 2091 |
-
mean_features.append(hidden_states[i, :length].mean(dim=0))
|
| 2092 |
-
std_features.append(hidden_states[i, :length].std(dim=0))
|
| 2093 |
-
mean_features = torch.stack(mean_features)
|
| 2094 |
-
std_features = torch.stack(std_features)
|
| 2095 |
-
statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
|
| 2096 |
-
|
| 2097 |
-
output_embeddings = self.feature_extractor(statistic_pooling)
|
| 2098 |
-
logits = self.classifier(output_embeddings)
|
| 2099 |
-
|
| 2100 |
-
loss = None
|
| 2101 |
-
if labels is not None:
|
| 2102 |
-
loss = self.objective(logits, labels)
|
| 2103 |
-
|
| 2104 |
-
if not return_dict:
|
| 2105 |
-
output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
|
| 2106 |
-
return ((loss,) + output) if loss is not None else output
|
| 2107 |
-
|
| 2108 |
-
return XVectorOutput(
|
| 2109 |
-
loss=loss,
|
| 2110 |
-
logits=logits,
|
| 2111 |
-
embeddings=output_embeddings,
|
| 2112 |
-
hidden_states=outputs.hidden_states,
|
| 2113 |
-
attentions=outputs.attentions,
|
| 2114 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MuCodec/muq_dev/muq_fairseq/models/muq/modules/random_quantizer.py
DELETED
|
@@ -1,68 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from torch import nn, einsum
|
| 3 |
-
from einops import rearrange
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
class RandomProjectionQuantizer(nn.Module):
|
| 7 |
-
"""
|
| 8 |
-
Random projection and codebook lookup module
|
| 9 |
-
|
| 10 |
-
Some code is borrowed from:
|
| 11 |
-
https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/random_projection_quantizer.py
|
| 12 |
-
But I did normalization using pre-computed global mean & variance instead of using layer norm.
|
| 13 |
-
"""
|
| 14 |
-
|
| 15 |
-
def __init__(
|
| 16 |
-
self,
|
| 17 |
-
input_dim,
|
| 18 |
-
codebook_dim,
|
| 19 |
-
codebook_size,
|
| 20 |
-
seed=142,
|
| 21 |
-
):
|
| 22 |
-
super().__init__()
|
| 23 |
-
|
| 24 |
-
# random seed
|
| 25 |
-
torch.manual_seed(seed)
|
| 26 |
-
|
| 27 |
-
# randomly initialized projection
|
| 28 |
-
random_projection = torch.empty(input_dim, codebook_dim)
|
| 29 |
-
nn.init.xavier_normal_(random_projection)
|
| 30 |
-
self.register_buffer("random_projection", random_projection)
|
| 31 |
-
|
| 32 |
-
# randomly initialized codebook
|
| 33 |
-
codebook = torch.empty(codebook_size, codebook_dim)
|
| 34 |
-
nn.init.normal_(codebook)
|
| 35 |
-
self.register_buffer("codebook", codebook)
|
| 36 |
-
|
| 37 |
-
def codebook_lookup(self, x):
|
| 38 |
-
# reshape
|
| 39 |
-
b = x.shape[0]
|
| 40 |
-
x = rearrange(x, "b n e -> (b n) e")
|
| 41 |
-
|
| 42 |
-
# L2 normalization
|
| 43 |
-
normalized_x = nn.functional.normalize(x, dim=1, p=2)
|
| 44 |
-
normalized_codebook = nn.functional.normalize(self.codebook, dim=1, p=2)
|
| 45 |
-
|
| 46 |
-
# compute distances
|
| 47 |
-
distances = torch.cdist(normalized_codebook, normalized_x)
|
| 48 |
-
|
| 49 |
-
# get nearest
|
| 50 |
-
nearest_indices = torch.argmin(distances, dim=0)
|
| 51 |
-
|
| 52 |
-
# reshape
|
| 53 |
-
xq = rearrange(nearest_indices, "(b n) -> b n", b=b)
|
| 54 |
-
|
| 55 |
-
return xq
|
| 56 |
-
|
| 57 |
-
@torch.no_grad()
|
| 58 |
-
def forward(self, x):
|
| 59 |
-
# always eval
|
| 60 |
-
self.eval()
|
| 61 |
-
|
| 62 |
-
# random projection [batch, length, input_dim] -> [batch, length, codebook_dim]
|
| 63 |
-
x = einsum("b n d, d e -> b n e", x, self.random_projection)
|
| 64 |
-
|
| 65 |
-
# codebook lookup
|
| 66 |
-
xq = self.codebook_lookup(x)
|
| 67 |
-
|
| 68 |
-
return xq
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MuCodec/muq_dev/muq_fairseq/models/muq/muq_model.py
DELETED
|
@@ -1,139 +0,0 @@
|
|
| 1 |
-
try:
|
| 2 |
-
from .model.muq import MuQ
|
| 3 |
-
except:
|
| 4 |
-
import sys, os
|
| 5 |
-
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 6 |
-
from model.muq import MuQ
|
| 7 |
-
try:
|
| 8 |
-
from fairseq.fairseq.dataclass import FairseqDataclass
|
| 9 |
-
from fairseq.fairseq.models import BaseFairseqModel, register_model
|
| 10 |
-
from fairseq.fairseq.tasks.fairseq_task import FairseqTask
|
| 11 |
-
except:
|
| 12 |
-
from fairseq.dataclass import FairseqDataclass
|
| 13 |
-
from fairseq.models import BaseFairseqModel, register_model
|
| 14 |
-
from fairseq.tasks.fairseq_task import FairseqTask
|
| 15 |
-
|
| 16 |
-
from dataclasses import dataclass, field
|
| 17 |
-
from typing import List, Tuple, Optional
|
| 18 |
-
import torch
|
| 19 |
-
|
| 20 |
-
from logging import getLogger
|
| 21 |
-
|
| 22 |
-
logger = getLogger(__name__)
|
| 23 |
-
|
| 24 |
-
@dataclass
|
| 25 |
-
class MuQConfig(FairseqDataclass):
|
| 26 |
-
label_rate:int = field(default=25)
|
| 27 |
-
num_codebooks:int = field(default=1)
|
| 28 |
-
codebook_dim:int = field(default=16)
|
| 29 |
-
codebook_size:int = field(default=4096)
|
| 30 |
-
features:List[str] = field(default_factory=lambda:["melspec_2048"])
|
| 31 |
-
hop_length:int = field(default=240)
|
| 32 |
-
n_mels:int = field(default=128)
|
| 33 |
-
conv_dim:int = field(default=512)
|
| 34 |
-
encoder_dim:int = field(default=1024)
|
| 35 |
-
encoder_depth:int = field(default=12)
|
| 36 |
-
mask_hop:float = field(default=0.4)
|
| 37 |
-
mask_prob:float = field(default=0.6)
|
| 38 |
-
is_flash:bool = field(default=False)
|
| 39 |
-
stat_path:Optional[str] = field(default=None)
|
| 40 |
-
model_path:Optional[str] = field(default=None)
|
| 41 |
-
w2v2_config_path:Optional[str] = field(default=None)
|
| 42 |
-
use_rvq_target:bool = field(default=False)
|
| 43 |
-
use_vq_target:bool = field(default=False)
|
| 44 |
-
rvq_ckpt_path: Optional[str] = field(default=None)
|
| 45 |
-
recon_loss_ratio: Optional[float] = field(default=None)
|
| 46 |
-
resume_checkpoint: Optional[str] = None
|
| 47 |
-
use_hubert_masking_strategy:bool = field(default=False)
|
| 48 |
-
use_hubert_featurizer:bool = field(default=False)
|
| 49 |
-
hubert_conv_feature_layers:str = field(default_factory=lambda:"[(512,10,5)] + [(512,3,2)] * 3 + [(512,3,3)] + [(512,2,2)] * 2")
|
| 50 |
-
rvq_n_codebooks:int = field(default=8)
|
| 51 |
-
rvq_multi_layer_num:int = field(default=1)
|
| 52 |
-
use_encodec_target:bool = field(default=False)
|
| 53 |
-
|
| 54 |
-
SAMPLE_RATE = 24_000
|
| 55 |
-
|
| 56 |
-
@register_model("muq", dataclass=MuQConfig)
|
| 57 |
-
class MuQModel(BaseFairseqModel):
|
| 58 |
-
def __init__(self, cfg: MuQConfig, task_cfg: FairseqTask):
|
| 59 |
-
super().__init__()
|
| 60 |
-
self.cfg = cfg
|
| 61 |
-
self.model = MuQ(
|
| 62 |
-
num_codebooks=cfg.num_codebooks,
|
| 63 |
-
codebook_dim=cfg.codebook_dim,
|
| 64 |
-
codebook_size=cfg.codebook_size,
|
| 65 |
-
features=cfg.features,
|
| 66 |
-
n_mels=cfg.n_mels,
|
| 67 |
-
conv_dim=cfg.conv_dim,
|
| 68 |
-
encoder_dim=cfg.encoder_dim,
|
| 69 |
-
encoder_depth=cfg.encoder_depth,
|
| 70 |
-
mask_hop=cfg.mask_hop,
|
| 71 |
-
mask_prob=cfg.mask_prob,
|
| 72 |
-
is_flash=cfg.is_flash,
|
| 73 |
-
stat_path=cfg.stat_path,
|
| 74 |
-
model_path=cfg.model_path,
|
| 75 |
-
w2v2_config_path=cfg.w2v2_config_path,
|
| 76 |
-
use_rvq_target=cfg.use_rvq_target,
|
| 77 |
-
use_vq_target=cfg.use_vq_target,
|
| 78 |
-
rvq_ckpt_path=cfg.rvq_ckpt_path,
|
| 79 |
-
recon_loss_ratio=cfg.recon_loss_ratio,
|
| 80 |
-
label_rate=cfg.label_rate,
|
| 81 |
-
use_hubert_masking_strategy=cfg.use_hubert_masking_strategy,
|
| 82 |
-
use_hubert_featurizer=cfg.use_hubert_featurizer,
|
| 83 |
-
hubert_conv_feature_layers=cfg.hubert_conv_feature_layers,
|
| 84 |
-
rvq_n_codebooks=cfg.rvq_n_codebooks,
|
| 85 |
-
rvq_multi_layer_num=cfg.rvq_multi_layer_num,
|
| 86 |
-
use_encodec_target=cfg.use_encodec_target,
|
| 87 |
-
)
|
| 88 |
-
|
| 89 |
-
def forward(
|
| 90 |
-
self,
|
| 91 |
-
source: torch.Tensor, # B,L
|
| 92 |
-
features_only: bool = False,
|
| 93 |
-
label = None, # pre-extracted labeks, dim is [Batch, N_Codebook, SeqLen]
|
| 94 |
-
**kwargs,
|
| 95 |
-
):
|
| 96 |
-
source = source[..., :int((source.shape[-1]//(SAMPLE_RATE//self.cfg.label_rate))*(SAMPLE_RATE//self.cfg.label_rate)) ]
|
| 97 |
-
if features_only:
|
| 98 |
-
if 'attention_mask' in kwargs:
|
| 99 |
-
attention_mask = kwargs['attention_mask']
|
| 100 |
-
elif 'padding_mask' in kwargs:
|
| 101 |
-
attention_mask = ~kwargs['padding_mask'].bool()
|
| 102 |
-
else:
|
| 103 |
-
attention_mask = None
|
| 104 |
-
_, hidden_states = self.model.get_predictions(source, attention_mask=attention_mask, is_features_only=True)
|
| 105 |
-
result = {
|
| 106 |
-
"layer_results": hidden_states
|
| 107 |
-
}
|
| 108 |
-
return result
|
| 109 |
-
else:
|
| 110 |
-
result = {}
|
| 111 |
-
logits, hidden_emb, losses, accuracies = self.model(source, label=label)
|
| 112 |
-
result["losses"] = losses
|
| 113 |
-
result["accuracies"] = accuracies
|
| 114 |
-
result["logits"] = logits
|
| 115 |
-
result["hidden_emb"] = hidden_emb
|
| 116 |
-
for k, v in losses.items():
|
| 117 |
-
result[k] = v
|
| 118 |
-
return result
|
| 119 |
-
|
| 120 |
-
@classmethod
|
| 121 |
-
def build_model(cls, cfg: MuQConfig, task: FairseqTask):
|
| 122 |
-
"""Build a new model instance."""
|
| 123 |
-
|
| 124 |
-
model = MuQModel(cfg, task.cfg)
|
| 125 |
-
import numpy as np
|
| 126 |
-
s = 0
|
| 127 |
-
for param in model.parameters():
|
| 128 |
-
s += np.product(param.size())
|
| 129 |
-
# print('# of parameters: '+str(s/1024.0/1024.0))
|
| 130 |
-
|
| 131 |
-
if cfg.get("resume_checkpoint", None):
|
| 132 |
-
print("Loading checkpoint from {}".format(cfg.resume_checkpoint))
|
| 133 |
-
model.load_state_dict(torch.load(cfg.resume_checkpoint)['model'], strict=False)
|
| 134 |
-
|
| 135 |
-
return model
|
| 136 |
-
|
| 137 |
-
def get_losses(self, result, batch):
|
| 138 |
-
return result['losses']
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MuCodec/muq_dev/muq_fairseq/tasks/muq_pretraining.py
DELETED
|
@@ -1,354 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2017-present, Facebook, Inc.
|
| 2 |
-
# All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# This source code is licensed under the license found in the LICENSE file in
|
| 5 |
-
# the root directory of this source tree. An additional grant of patent rights
|
| 6 |
-
# can be found in the PATENTS file in the same directory.
|
| 7 |
-
|
| 8 |
-
import logging
|
| 9 |
-
import os
|
| 10 |
-
import sys
|
| 11 |
-
from typing import Dict, List, Optional, Tuple
|
| 12 |
-
|
| 13 |
-
import numpy as np
|
| 14 |
-
import torch
|
| 15 |
-
|
| 16 |
-
from dataclasses import dataclass, field
|
| 17 |
-
from fairseq.data import Dictionary, HubertDataset
|
| 18 |
-
from fairseq.dataclass.configs import FairseqDataclass
|
| 19 |
-
from fairseq.tasks import register_task
|
| 20 |
-
from fairseq.tasks.fairseq_task import FairseqTask
|
| 21 |
-
from omegaconf import MISSING
|
| 22 |
-
|
| 23 |
-
from ..data.mert_dataset import MERTDataset
|
| 24 |
-
from ..data.ark_dataset import ArkDataset
|
| 25 |
-
|
| 26 |
-
logger = logging.getLogger(__name__)
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
class LabelEncoder(object):
|
| 30 |
-
def __init__(self, dictionary: Dictionary) -> None:
|
| 31 |
-
self.dictionary = dictionary
|
| 32 |
-
|
| 33 |
-
def __call__(self, label: str) -> List[str]:
|
| 34 |
-
# encode_line return a torch.IntTensor, should be all 1 for vanila HuBERT
|
| 35 |
-
return self.dictionary.encode_line(
|
| 36 |
-
label,
|
| 37 |
-
append_eos=False,
|
| 38 |
-
add_if_not_exist=False,
|
| 39 |
-
)
|
| 40 |
-
class PaddedNumpyLabelEncoder(object):
|
| 41 |
-
def __init__(self):
|
| 42 |
-
# self.dictionary = dictionary
|
| 43 |
-
pass
|
| 44 |
-
|
| 45 |
-
def __call__(self, label):
|
| 46 |
-
t = torch.IntTensor(np.asarray(label))
|
| 47 |
-
t = t[t>=0] # remove padded -1 values at the end
|
| 48 |
-
return t
|
| 49 |
-
|
| 50 |
-
@dataclass
|
| 51 |
-
class MuQPretrainingConfig(FairseqDataclass):
|
| 52 |
-
data: str = field(default=MISSING, metadata={"help": "path to data directory"})
|
| 53 |
-
sharding_data: int = field(
|
| 54 |
-
default=-1,
|
| 55 |
-
metadata={
|
| 56 |
-
"help": "set this para >1 to use sharding dataset to prevent OOM"
|
| 57 |
-
"prepare data tsv and label files by adding postfix for sharding 64 like:"
|
| 58 |
-
"train_28_64.tsv and train_28_64.encodec_6"
|
| 59 |
-
},
|
| 60 |
-
)
|
| 61 |
-
load_random_data_shard: bool = field(
|
| 62 |
-
default=True,
|
| 63 |
-
metadata={
|
| 64 |
-
"help": "whether to laod shards randomly or in order when use sharding_data"
|
| 65 |
-
},
|
| 66 |
-
)
|
| 67 |
-
fine_tuning: bool = field(
|
| 68 |
-
default=False, metadata={"help": "set to true if fine-tuning Hubert"}
|
| 69 |
-
)
|
| 70 |
-
labels: List[str] = field(
|
| 71 |
-
default_factory=lambda: ["ltr"],
|
| 72 |
-
metadata={
|
| 73 |
-
"help": (
|
| 74 |
-
"extension of the label files to load, frame-level labels for"
|
| 75 |
-
" pre-training, and sequence-level label for fine-tuning"
|
| 76 |
-
)
|
| 77 |
-
},
|
| 78 |
-
)
|
| 79 |
-
label_dir: Optional[str] = field(
|
| 80 |
-
default=None,
|
| 81 |
-
metadata={
|
| 82 |
-
"help": "if set, looks for labels in this directory instead",
|
| 83 |
-
},
|
| 84 |
-
)
|
| 85 |
-
label_scp_path: Optional[str] = field(
|
| 86 |
-
default=None,
|
| 87 |
-
metadata={
|
| 88 |
-
'help': 'if set, load label from scp file'
|
| 89 |
-
}
|
| 90 |
-
)
|
| 91 |
-
label_scp_clip_duration: float = field(
|
| 92 |
-
default=-1,
|
| 93 |
-
metadata={
|
| 94 |
-
'help': 'clip duration for loading scp label. if set to -1, this will not make effect.'
|
| 95 |
-
}
|
| 96 |
-
)
|
| 97 |
-
label_rate: float = field(
|
| 98 |
-
default=-1.0,
|
| 99 |
-
metadata={"help": "label frame rate. -1.0 for sequence label"},
|
| 100 |
-
)
|
| 101 |
-
sample_rate: int = field(
|
| 102 |
-
default=16_000,
|
| 103 |
-
metadata={
|
| 104 |
-
"help": "target sample rate. audio files will be up/down "
|
| 105 |
-
"sampled to this rate"
|
| 106 |
-
},
|
| 107 |
-
)
|
| 108 |
-
normalize: bool = field(
|
| 109 |
-
default=False,
|
| 110 |
-
metadata={"help": "if set, normalizes input to have 0 mean and unit variance"},
|
| 111 |
-
)
|
| 112 |
-
enable_padding: bool = field(
|
| 113 |
-
default=False,
|
| 114 |
-
metadata={"help": "pad shorter samples instead of cropping"},
|
| 115 |
-
)
|
| 116 |
-
max_keep_size: Optional[int] = field(
|
| 117 |
-
default=None,
|
| 118 |
-
metadata={"help": "exclude sample longer than this"},
|
| 119 |
-
)
|
| 120 |
-
max_sample_size: Optional[int] = field(
|
| 121 |
-
default=None,
|
| 122 |
-
metadata={"help": "max sample size to crop to for batching"},
|
| 123 |
-
)
|
| 124 |
-
min_sample_size: Optional[int] = field(
|
| 125 |
-
default=None,
|
| 126 |
-
metadata={"help": "min sample size to crop to for batching"},
|
| 127 |
-
)
|
| 128 |
-
single_target: Optional[bool] = field(
|
| 129 |
-
default=False,
|
| 130 |
-
metadata={
|
| 131 |
-
"help": "if set, AddTargetDatasets outputs same keys " "as AddTargetDataset"
|
| 132 |
-
},
|
| 133 |
-
)
|
| 134 |
-
random_crop: Optional[bool] = field(
|
| 135 |
-
default=True,
|
| 136 |
-
metadata={"help": "always crop from the beginning if false"},
|
| 137 |
-
)
|
| 138 |
-
pad_audio: Optional[bool] = field(
|
| 139 |
-
default=False,
|
| 140 |
-
metadata={"help": "pad audio to the longest one in the batch if true"},
|
| 141 |
-
)
|
| 142 |
-
|
| 143 |
-
store_labels: Optional[bool] = field(
|
| 144 |
-
default=False,
|
| 145 |
-
metadata={"help": "whether to load all of the label into memory"},
|
| 146 |
-
)
|
| 147 |
-
|
| 148 |
-
numpy_memmap_label: Optional[bool] = field(
|
| 149 |
-
default=False,
|
| 150 |
-
metadata={"help": "whether the label file is saved as a numpy file, each line is ended with padding -1"},
|
| 151 |
-
)
|
| 152 |
-
|
| 153 |
-
augmentation_effects: Optional[str] = field(
|
| 154 |
-
default="[]",
|
| 155 |
-
metadata={
|
| 156 |
-
"help": (
|
| 157 |
-
"a list of effects that might apply to the audios"
|
| 158 |
-
"example: \"['random_mute', 'random_Gaussian', 'reverse_polarity']\" "
|
| 159 |
-
"supported: random_mute,"
|
| 160 |
-
"todo: "
|
| 161 |
-
)
|
| 162 |
-
},
|
| 163 |
-
)
|
| 164 |
-
augmentation_probs: Optional[str] = field(
|
| 165 |
-
default="[]",
|
| 166 |
-
metadata={
|
| 167 |
-
"help": (
|
| 168 |
-
"the corresponding probabilities for the data augmentation effects"
|
| 169 |
-
"example: \"[0.1, 0.5, 0.8]\" "
|
| 170 |
-
"the sum is not necessarily need to be 1.0, and multiple effects can be applied to the same audio"
|
| 171 |
-
)
|
| 172 |
-
},
|
| 173 |
-
)
|
| 174 |
-
|
| 175 |
-
# inbatch_noise_augment_len_range: Optional[List[int]] = field(
|
| 176 |
-
# default_factory=lambda: [8000, 24000],
|
| 177 |
-
# default = [8000, 24000],
|
| 178 |
-
inbatch_noise_augment_len_range: Optional[str] = field(
|
| 179 |
-
default = "[8000, 24000]",
|
| 180 |
-
metadata={
|
| 181 |
-
"help": (
|
| 182 |
-
"the range of length of the mix-up noise augmentation, unit in smaples"
|
| 183 |
-
)
|
| 184 |
-
},
|
| 185 |
-
)
|
| 186 |
-
# inbatch_noise_augment_number_range: Optional[List[int]] = field(
|
| 187 |
-
# default_factory=lambda: [1, 3],
|
| 188 |
-
# default = [1, 3],
|
| 189 |
-
inbatch_noise_augment_number_range: Optional[str] = field(
|
| 190 |
-
default = "[1, 3]",
|
| 191 |
-
metadata={
|
| 192 |
-
"help": (
|
| 193 |
-
"the range of numbers of the mix-up noise augmentation"
|
| 194 |
-
)
|
| 195 |
-
},
|
| 196 |
-
)
|
| 197 |
-
inbatch_noise_augment_volume: float = field(
|
| 198 |
-
default = 1.0,
|
| 199 |
-
metadata={
|
| 200 |
-
"help": (
|
| 201 |
-
"the coefficient used to modify the volume of the noise audios wavs"
|
| 202 |
-
)
|
| 203 |
-
},
|
| 204 |
-
)
|
| 205 |
-
dynamic_crops: Optional[str] = field(
|
| 206 |
-
default="[]",
|
| 207 |
-
metadata={
|
| 208 |
-
"help": (
|
| 209 |
-
"used to set the maximum audio length setting, for training"
|
| 210 |
-
"example: \"[1, 2, 3, 4, 5, 10]\" "
|
| 211 |
-
)
|
| 212 |
-
},
|
| 213 |
-
)
|
| 214 |
-
dynamic_crops_epoches: Optional[str] = field(
|
| 215 |
-
default="[]",
|
| 216 |
-
metadata={
|
| 217 |
-
"help": (
|
| 218 |
-
"used to set training epoches of changing the maximum audio length"
|
| 219 |
-
"example: \"[1, 10, 20, 40, 80, 160,]\" "
|
| 220 |
-
"then len need to be equal to len(dynamic_crops)"
|
| 221 |
-
)
|
| 222 |
-
},
|
| 223 |
-
)
|
| 224 |
-
|
| 225 |
-
cqt_loss_bin_dataloader: Optional[int] = field(
|
| 226 |
-
default=-1,
|
| 227 |
-
metadata={
|
| 228 |
-
"help": (
|
| 229 |
-
"use this parameter to prepare cqt prediction objective in dataloader"
|
| 230 |
-
)
|
| 231 |
-
},
|
| 232 |
-
)
|
| 233 |
-
|
| 234 |
-
clip_secs: int = field(
|
| 235 |
-
default=5,
|
| 236 |
-
metadata={
|
| 237 |
-
"help": "clip secs for each audio"
|
| 238 |
-
}
|
| 239 |
-
)
|
| 240 |
-
|
| 241 |
-
dataset_shuffle: bool = field(
|
| 242 |
-
default=True,
|
| 243 |
-
metadata={
|
| 244 |
-
"help": (
|
| 245 |
-
"dataset shuffle when sample a batch"
|
| 246 |
-
)
|
| 247 |
-
},
|
| 248 |
-
)
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
@register_task("muq_pretraining", dataclass=MuQPretrainingConfig)
|
| 252 |
-
class MuQPretrainingTask(FairseqTask):
|
| 253 |
-
|
| 254 |
-
cfg: MuQPretrainingConfig
|
| 255 |
-
|
| 256 |
-
def __init__(
|
| 257 |
-
self,
|
| 258 |
-
cfg: MuQPretrainingConfig,
|
| 259 |
-
) -> None:
|
| 260 |
-
super().__init__(cfg)
|
| 261 |
-
|
| 262 |
-
logger.info(f"current directory is {os.getcwd()}")
|
| 263 |
-
logger.info(f"MuQPretrainingTask Config {cfg}")
|
| 264 |
-
|
| 265 |
-
self.cfg = cfg
|
| 266 |
-
self.fine_tuning = cfg.fine_tuning
|
| 267 |
-
|
| 268 |
-
if cfg.fine_tuning:
|
| 269 |
-
self.state.add_factory("target_dictionary", self.load_dictionaries)
|
| 270 |
-
else:
|
| 271 |
-
self.state.add_factory("dictionaries", self.load_dictionaries)
|
| 272 |
-
|
| 273 |
-
self.blank_symbol = "<s>"
|
| 274 |
-
|
| 275 |
-
# use eval() to pass list parameters, skirt the fairseq/torch error: Can't pickle <enum 'Choices'>: attribute lookup Choices on fairseq.dataclass.constants failed
|
| 276 |
-
self.augmentation_effects = eval(self.cfg.augmentation_effects)
|
| 277 |
-
self.augmentation_probs = eval(self.cfg.augmentation_probs)
|
| 278 |
-
if len(self.augmentation_effects) > 0:
|
| 279 |
-
assert len(self.augmentation_effects) == len(self.augmentation_probs)
|
| 280 |
-
logger.info(f"Applying audio augmentation {self.augmentation_effects}, probabilities: {self.augmentation_probs}")
|
| 281 |
-
|
| 282 |
-
self.inbatch_noise_augment_number_range = eval(self.cfg.inbatch_noise_augment_number_range)
|
| 283 |
-
self.inbatch_noise_augment_len_range = eval(self.cfg.inbatch_noise_augment_len_range)
|
| 284 |
-
|
| 285 |
-
self.max_sample_size = self.cfg.max_sample_size
|
| 286 |
-
|
| 287 |
-
self.dynamic_crops = eval(self.cfg.dynamic_crops)
|
| 288 |
-
self.dynamic_crops_epoches = eval(self.cfg.dynamic_crops_epoches)
|
| 289 |
-
assert len(self.dynamic_crops) == len(self.dynamic_crops_epoches)
|
| 290 |
-
if len(self.dynamic_crops) > 0:
|
| 291 |
-
assert self.dynamic_crops_epoches[0] == 1
|
| 292 |
-
|
| 293 |
-
self.cqt_loss_bin_dataloader = self.cfg.cqt_loss_bin_dataloader
|
| 294 |
-
|
| 295 |
-
self.numpy_memmap_label = self.cfg.numpy_memmap_label
|
| 296 |
-
self.store_labels = self.cfg.store_labels
|
| 297 |
-
if self.numpy_memmap_label:
|
| 298 |
-
assert self.store_labels
|
| 299 |
-
|
| 300 |
-
@property
|
| 301 |
-
def source_dictionary(self) -> Optional[Dictionary]:
|
| 302 |
-
return None
|
| 303 |
-
|
| 304 |
-
@property
|
| 305 |
-
def target_dictionary(self) -> Optional[Dictionary]:
|
| 306 |
-
return self.state.target_dictionary
|
| 307 |
-
|
| 308 |
-
@property
|
| 309 |
-
def dictionaries(self) -> List[Dictionary]:
|
| 310 |
-
return self.state.dictionaries
|
| 311 |
-
|
| 312 |
-
@classmethod
|
| 313 |
-
def setup_task(
|
| 314 |
-
cls, cfg: MuQPretrainingConfig, **kwargs
|
| 315 |
-
) -> "MuQPretrainingTask":
|
| 316 |
-
return cls(cfg)
|
| 317 |
-
|
| 318 |
-
def load_dictionaries(self):
|
| 319 |
-
label_dir = self.cfg.data if (self.cfg.label_dir is None or self.cfg.label_dir == '') else self.cfg.label_dir
|
| 320 |
-
print(label_dir)
|
| 321 |
-
dictionaries = [
|
| 322 |
-
Dictionary.load(f"{label_dir}/dict.{label}.txt")
|
| 323 |
-
for label in self.cfg.labels
|
| 324 |
-
]
|
| 325 |
-
return dictionaries[0] if self.cfg.fine_tuning else dictionaries
|
| 326 |
-
|
| 327 |
-
def get_label_dir(self) -> str:
|
| 328 |
-
if self.cfg.label_dir is None or self.cfg.label_dir=='':
|
| 329 |
-
return self.cfg.data
|
| 330 |
-
return self.cfg.label_dir
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
def is_force_load_dataset(self, epoch, training_restore=False):
|
| 334 |
-
# find the threshold that holds epoch \in [threshold, next_threshold)
|
| 335 |
-
return (epoch in self.dynamic_crops_epoches) or training_restore or (self.cfg.sharding_data > 1)
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
def set_dynamic_crop_max_sample(self, epoch):
|
| 339 |
-
pass
|
| 340 |
-
|
| 341 |
-
def load_dataset(self, split: str, **kwargs) -> None:
|
| 342 |
-
pass
|
| 343 |
-
|
| 344 |
-
def load_dataset_ark(self, split, **kwargs):
|
| 345 |
-
pass
|
| 346 |
-
|
| 347 |
-
def load_dataset_mert(self, split: str, **kwargs) -> None:
|
| 348 |
-
pass
|
| 349 |
-
|
| 350 |
-
def max_positions(self) -> Tuple[int, int]:
|
| 351 |
-
return (sys.maxsize, sys.maxsize)
|
| 352 |
-
|
| 353 |
-
def filter_indices_by_size(self, indices: np.array, *args, **kwargs) -> np.array:
|
| 354 |
-
return indices
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MuCodec/muq_dev/test.py
DELETED
|
@@ -1,22 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from dataclasses import dataclass
|
| 3 |
-
import fairseq
|
| 4 |
-
import os.path as op
|
| 5 |
-
|
| 6 |
-
root = op.dirname(op.abspath(__file__))
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
@dataclass
|
| 10 |
-
class UserDirModule:
|
| 11 |
-
user_dir: str
|
| 12 |
-
|
| 13 |
-
def load_model(model_dir, checkpoint_dir):
|
| 14 |
-
'''Load Fairseq SSL model'''
|
| 15 |
-
|
| 16 |
-
model_path = UserDirModule(model_dir)
|
| 17 |
-
fairseq.utils.import_user_module(model_path)
|
| 18 |
-
|
| 19 |
-
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_dir], strict=False)
|
| 20 |
-
model = model[0]
|
| 21 |
-
|
| 22 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MuCodec/readme.md
DELETED
|
@@ -1,67 +0,0 @@
|
|
| 1 |
-
# MuCodec: Ultra Low-Bitrate Music Codec
|
| 2 |
-
|
| 3 |
-
This repository is the official code repository for MuCodec: Ultra Low-Bitrate Music Codec. You can find our paper on [arXiv] (https://arxiv.org/pdf/2409.13216). The demo page is available [here](https://xuyaoxun.github.io/MuCodec_demo/).
|
| 4 |
-
|
| 5 |
-
In this repository, we provide the Mucodec model, inference scripts, and the checkpoint that has been trained on the Million Song Dataset. Specifically, we have released the model and inference code corresponding to the lowest bitrate of 0.35 kbps as mentioned in the paper, to demonstrate the effectiveness of our work.
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
MuCodec supports 48kHz, dual-channel (stereo) audio reconstruction. If the original audio is in a different format, it will first be converted to 48kHz, dual-channel audio.
|
| 9 |
-
|
| 10 |
-
## Installation
|
| 11 |
-
|
| 12 |
-
You can install the necessary dependencies using the `requirements.txt` file with Python 3.8.12:
|
| 13 |
-
|
| 14 |
-
```bash
|
| 15 |
-
pip install -r requirements.txt
|
| 16 |
-
```
|
| 17 |
-
|
| 18 |
-
Due to storage limitations, we have saved the model checkpoints on Hugging Face at https://huggingface.co/yaoxunxu/mucodec. You can easily download the models from Hugging Face and save them in the following directories:
|
| 19 |
-
|
| 20 |
-
- Save `audioldm_48k.pth` in the `tools` folder.
|
| 21 |
-
- Save `muq.pt` in the `muq_dev` folder.
|
| 22 |
-
- Save `mucodec.pt` in the `ckpt` folder.
|
| 23 |
-
|
| 24 |
-
Please note that all three checkpoints must be downloaded completely for the model to load correctly. The final file paths should be:
|
| 25 |
-
|
| 26 |
-
```
|
| 27 |
-
tools/audioldm_48k.pth
|
| 28 |
-
muq_dev/muq.pt
|
| 29 |
-
ckpt/mucodec.pt
|
| 30 |
-
```
|
| 31 |
-
|
| 32 |
-
The file `audioldm_48k.pth` is sourced from https://huggingface.co/haoheliu/audioldm_48k/blob/main/audioldm_48k.pth.
|
| 33 |
-
|
| 34 |
-
## Inference
|
| 35 |
-
|
| 36 |
-
To run inference, use the following command:
|
| 37 |
-
|
| 38 |
-
```bash
|
| 39 |
-
python3 generate.py
|
| 40 |
-
```
|
| 41 |
-
|
| 42 |
-
We have provided a sample song `test.wav`, randomly sampled from the Million Song Dataset, in the `test_wav` folder. The default input path is `test_wav/test.wav`, and the output path for the reconstructed audio is `reconstruct/test.wav`.
|
| 43 |
-
|
| 44 |
-
In the `generate.py` file, we have implemented several functions to facilitate the music compression and reconstruction process. You can easily obtain compressed tokens from audio using the `sound2code` function, and reconstruct the audio from tokens using the `code2sound` function.
|
| 45 |
-
|
| 46 |
-
## Note
|
| 47 |
-
|
| 48 |
-
Please note that the open-sourced model was trained solely on the Million Song Dataset. Considering the quality issues of this dataset, the open-sourced model may not achieve the same performance as demonstrated in the demo. Unfortunately, due to copyright restrictions, we are unable to release the checkpoints trained on additional datasets. However, you can use your own dataset to further train the model and achieve better results.
|
| 49 |
-
|
| 50 |
-
## License
|
| 51 |
-
|
| 52 |
-
The code in this repository is released under the MIT license as found in the [LICENSE](LICENSE) file.
|
| 53 |
-
|
| 54 |
-
The model weights (muq.pt, mucodec.pt) in this repository are released under the CC-BY-NC 4.0 license, as detailed in the [LICENSE_weights](LICENSE_weights) file.
|
| 55 |
-
|
| 56 |
-
## Citation
|
| 57 |
-
|
| 58 |
-
If you find our work useful, please cite our paper:
|
| 59 |
-
|
| 60 |
-
```bibtex
|
| 61 |
-
@article{xu2024mucodec,
|
| 62 |
-
title={MuCodec: Ultra Low-Bitrate Music Codec},
|
| 63 |
-
author={Xu, Yaoxun and Chen, Hangting and Yu, Jianwei and Tan, Wei and Gu, Rongzhi and Lei, Shun and Lin, Zhiwei and Wu, Zhiyong},
|
| 64 |
-
journal={arXiv preprint arXiv:2409.13216},
|
| 65 |
-
year={2024}
|
| 66 |
-
}
|
| 67 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MuCodec/reconstructed/test.wav
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:946e5815c7c3b8cab9f8eb6ca6707e821498fd59233d3ee356f6bb6f2fd2296b
|
| 3 |
-
size 99367376
|
|
|
|
|
|
|
|
|
|
|
|
MuCodec/requirements.txt
DELETED
|
@@ -1,335 +0,0 @@
|
|
| 1 |
-
absl-py==2.0.0
|
| 2 |
-
accelerate==0.30.1
|
| 3 |
-
aeiou==0.0.20
|
| 4 |
-
aiobotocore==2.13.1
|
| 5 |
-
aiofiles==23.2.1
|
| 6 |
-
aiohttp==3.9.3
|
| 7 |
-
aioitertools==0.11.0
|
| 8 |
-
aiosignal==1.3.1
|
| 9 |
-
alias-free-torch==0.0.6
|
| 10 |
-
altair==5.3.0
|
| 11 |
-
annotated-types==0.6.0
|
| 12 |
-
antlr4-python3-runtime==4.8
|
| 13 |
-
anyio==4.3.0
|
| 14 |
-
appdirs==1.4.4
|
| 15 |
-
argbind==0.3.9
|
| 16 |
-
asttokens==2.4.1
|
| 17 |
-
astunparse==1.6.3
|
| 18 |
-
async-timeout==4.0.3
|
| 19 |
-
attrs==23.1.0
|
| 20 |
-
audioread==3.0.1
|
| 21 |
-
auraloss==0.4.0
|
| 22 |
-
av==11.0.0
|
| 23 |
-
backcall==0.2.0
|
| 24 |
-
beartype==0.18.5
|
| 25 |
-
bitarray==2.9.2
|
| 26 |
-
bleach==6.1.0
|
| 27 |
-
blis==0.7.11
|
| 28 |
-
bokeh==3.1.1
|
| 29 |
-
botocore==1.34.131
|
| 30 |
-
braceexpand==0.1.7
|
| 31 |
-
cachetools==5.3.2
|
| 32 |
-
catalogue==2.0.10
|
| 33 |
-
certifi==2023.11.17
|
| 34 |
-
cffi==1.16.0
|
| 35 |
-
charset-normalizer==3.3.2
|
| 36 |
-
clean-fid==0.1.35
|
| 37 |
-
click==8.1.7
|
| 38 |
-
clip-anytorch==2.6.0
|
| 39 |
-
cloudpathlib==0.16.0
|
| 40 |
-
cloudpickle==3.0.0
|
| 41 |
-
cn2an==0.5.22
|
| 42 |
-
colorama==0.4.6
|
| 43 |
-
colorcet==3.1.0
|
| 44 |
-
colorlog==6.8.2
|
| 45 |
-
confection==0.1.4
|
| 46 |
-
configparser==7.0.0
|
| 47 |
-
contourpy==1.1.1
|
| 48 |
-
cycler==0.12.1
|
| 49 |
-
cymem==2.0.8
|
| 50 |
-
Cython==3.0.10
|
| 51 |
-
dataclasses==0.6
|
| 52 |
-
datasets
|
| 53 |
-
dctorch==0.1.2
|
| 54 |
-
decorator==5.1.1
|
| 55 |
-
decord==0.6.0
|
| 56 |
-
deepspeed==0.14.0
|
| 57 |
-
demucs==4.0.1
|
| 58 |
-
descript-audio-codec==1.0.0
|
| 59 |
-
descript-audiotools==0.7.2
|
| 60 |
-
diffusers==0.27.2
|
| 61 |
-
dill==0.3.8
|
| 62 |
-
Distance==0.1.3
|
| 63 |
-
docker-pycreds==0.4.0
|
| 64 |
-
docopt==0.6.2
|
| 65 |
-
docstring_parser==0.16
|
| 66 |
-
dora_search==0.1.12
|
| 67 |
-
einops==0.7.0
|
| 68 |
-
einops-exts==0.0.4
|
| 69 |
-
einx==0.3.0
|
| 70 |
-
ema-pytorch==0.2.3
|
| 71 |
-
encodec==0.1.1
|
| 72 |
-
exceptiongroup==1.2.0
|
| 73 |
-
executing==2.0.1
|
| 74 |
-
expecttest==0.1.6
|
| 75 |
-
fairseq==0.12.2
|
| 76 |
-
fastapi==0.110.3
|
| 77 |
-
fastcore==1.6.3
|
| 78 |
-
ffmpy==0.3.2
|
| 79 |
-
filelock==3.13.1
|
| 80 |
-
fire==0.6.0
|
| 81 |
-
flashy==0.0.2
|
| 82 |
-
flatten-dict==0.4.2
|
| 83 |
-
fonttools==4.49.0
|
| 84 |
-
frozendict==2.4.4
|
| 85 |
-
frozenlist==1.4.1
|
| 86 |
-
fsspec==2024.6.1
|
| 87 |
-
ftfy==6.1.3
|
| 88 |
-
future==1.0.0
|
| 89 |
-
g2p-en==2.1.0
|
| 90 |
-
gin-config==0.5.0
|
| 91 |
-
gitdb==4.0.11
|
| 92 |
-
GitPython==3.1.43
|
| 93 |
-
google-auth==2.23.4
|
| 94 |
-
google-auth-oauthlib==1.0.0
|
| 95 |
-
gradio==4.26.0
|
| 96 |
-
gradio_client==0.15.1
|
| 97 |
-
grpcio==1.59.3
|
| 98 |
-
h11==0.14.0
|
| 99 |
-
h5py==3.11.0
|
| 100 |
-
hjson==3.1.0
|
| 101 |
-
holoviews==1.17.1
|
| 102 |
-
httpcore==1.0.5
|
| 103 |
-
httpx==0.27.0
|
| 104 |
-
huggingface-hub==0.23.5
|
| 105 |
-
hydra-colorlog==1.2.0
|
| 106 |
-
hydra-core==1.0.7
|
| 107 |
-
hypothesis==6.90.0
|
| 108 |
-
idna==3.4
|
| 109 |
-
imageio==2.34.2
|
| 110 |
-
importlib-metadata==6.8.0
|
| 111 |
-
importlib-resources==5.12.0
|
| 112 |
-
inflect==7.0.0
|
| 113 |
-
ipython==8.12.3
|
| 114 |
-
jedi==0.19.1
|
| 115 |
-
jieba-fast==0.53
|
| 116 |
-
Jinja2==3.1.2
|
| 117 |
-
jmespath==1.0.1
|
| 118 |
-
joblib==1.3.2
|
| 119 |
-
json5==0.9.25
|
| 120 |
-
jsonlines==4.0.0
|
| 121 |
-
jsonmerge==1.9.2
|
| 122 |
-
jsonschema==4.22.0
|
| 123 |
-
jsonschema-specifications==2023.12.1
|
| 124 |
-
julius==0.2.7
|
| 125 |
-
k-diffusion==0.1.1
|
| 126 |
-
kaldiio==2.18.0
|
| 127 |
-
kiwisolver==1.4.5
|
| 128 |
-
kornia==0.7.3
|
| 129 |
-
kornia_rs==0.1.5
|
| 130 |
-
laion-clap==1.1.4
|
| 131 |
-
lameenc==1.7.0
|
| 132 |
-
langcodes==3.4.0
|
| 133 |
-
language_data==1.2.0
|
| 134 |
-
lazy_loader==0.3
|
| 135 |
-
librosa==0.9.2
|
| 136 |
-
lightning==2.2.1
|
| 137 |
-
lightning-utilities==0.10.1
|
| 138 |
-
linkify-it-py==2.0.3
|
| 139 |
-
lion-pytorch==0.2.2
|
| 140 |
-
llvmlite==0.41.1
|
| 141 |
-
local-attention==1.8.6
|
| 142 |
-
loguru==0.7.2
|
| 143 |
-
lxml==5.2.2
|
| 144 |
-
marisa-trie==1.1.1
|
| 145 |
-
Markdown==3.5.1
|
| 146 |
-
markdown-it-py==3.0.0
|
| 147 |
-
markdown2==2.5.0
|
| 148 |
-
MarkupSafe==2.1.3
|
| 149 |
-
matplotlib==3.7.5
|
| 150 |
-
matplotlib-inline==0.1.7
|
| 151 |
-
mdit-py-plugins==0.4.1
|
| 152 |
-
mdurl==0.1.2
|
| 153 |
-
mpmath==1.3.0
|
| 154 |
-
msgpack==1.0.8
|
| 155 |
-
multidict==6.0.5
|
| 156 |
-
multiprocess==0.70.16
|
| 157 |
-
murmurhash==1.0.10
|
| 158 |
-
mypy-extensions==1.0.0
|
| 159 |
-
networkx==3.1
|
| 160 |
-
ninja==1.11.1.1
|
| 161 |
-
nltk==3.8.1
|
| 162 |
-
nnAudio==0.3.3
|
| 163 |
-
num2words==0.5.13
|
| 164 |
-
numba==0.58.1
|
| 165 |
-
numpy==1.23.5
|
| 166 |
-
nvidia-cublas-cu11==11.11.3.6
|
| 167 |
-
nvidia-cuda-cupti-cu11==11.8.87
|
| 168 |
-
nvidia-cuda-nvrtc-cu11==11.8.89
|
| 169 |
-
nvidia-cuda-runtime-cu11==11.8.89
|
| 170 |
-
nvidia-cudnn-cu11==8.7.0.84
|
| 171 |
-
nvidia-cufft-cu11==10.9.0.58
|
| 172 |
-
nvidia-curand-cu11==10.3.0.86
|
| 173 |
-
nvidia-cusolver-cu11==11.4.1.48
|
| 174 |
-
nvidia-cusparse-cu11==11.7.5.86
|
| 175 |
-
nvidia-nccl-cu11==2.19.3
|
| 176 |
-
nvidia-nvtx-cu11==11.8.86
|
| 177 |
-
oauthlib==3.2.2
|
| 178 |
-
omegaconf
|
| 179 |
-
opencv-contrib-python==4.8.1.78
|
| 180 |
-
opencv-python==4.8.1.78
|
| 181 |
-
openunmix==1.2.1
|
| 182 |
-
orjson==3.10.3
|
| 183 |
-
packaging==23.2
|
| 184 |
-
pandas==2.0.2
|
| 185 |
-
panel==1.2.3
|
| 186 |
-
param==2.1.1
|
| 187 |
-
parso==0.8.4
|
| 188 |
-
pathtools==0.1.2
|
| 189 |
-
pedalboard==0.7.4
|
| 190 |
-
peft==0.10.0
|
| 191 |
-
pexpect==4.9.0
|
| 192 |
-
pickleshare==0.7.5
|
| 193 |
-
Pillow==10.1.0
|
| 194 |
-
pkgutil_resolve_name==1.3.10
|
| 195 |
-
platformdirs==4.2.0
|
| 196 |
-
plotly==5.23.0
|
| 197 |
-
pooch==1.8.1
|
| 198 |
-
portalocker==2.10.1
|
| 199 |
-
prefigure==0.0.9
|
| 200 |
-
preshed==3.0.9
|
| 201 |
-
proces==0.1.7
|
| 202 |
-
prodict==0.8.18
|
| 203 |
-
progressbar==2.5
|
| 204 |
-
prompt_toolkit==3.0.47
|
| 205 |
-
protobuf==3.19.6
|
| 206 |
-
psutil==5.9.6
|
| 207 |
-
ptyprocess==0.7.0
|
| 208 |
-
pure_eval==0.2.3
|
| 209 |
-
py-cpuinfo==9.0.0
|
| 210 |
-
pyarrow==17.0.0
|
| 211 |
-
pyarrow-hotfix==0.6
|
| 212 |
-
pyasn1==0.5.1
|
| 213 |
-
pyasn1-modules==0.3.0
|
| 214 |
-
pybind11==2.11.1
|
| 215 |
-
pycparser==2.21
|
| 216 |
-
pydantic==2.6.3
|
| 217 |
-
pydantic_core==2.16.3
|
| 218 |
-
pydub==0.25.1
|
| 219 |
-
Pygments==2.18.0
|
| 220 |
-
pyloudnorm==0.1.1
|
| 221 |
-
pynndescent==0.5.13
|
| 222 |
-
pynvml==11.5.0
|
| 223 |
-
pyparsing==3.1.2
|
| 224 |
-
pypinyin==0.51.0
|
| 225 |
-
pyre-extensions==0.0.29
|
| 226 |
-
pyreaper==0.0.10
|
| 227 |
-
pystoi==0.4.1
|
| 228 |
-
python-dateutil==2.8.2
|
| 229 |
-
python-multipart==0.0.9
|
| 230 |
-
pytorch-lightning==2.1.0
|
| 231 |
-
pytz==2023.3.post1
|
| 232 |
-
pyviz_comms==3.0.3
|
| 233 |
-
PyWavelets==1.4.1
|
| 234 |
-
PyYAML==6.0.1
|
| 235 |
-
randomname==0.2.1
|
| 236 |
-
referencing==0.35.1
|
| 237 |
-
regex==2023.10.3
|
| 238 |
-
requests==2.32.3
|
| 239 |
-
requests-oauthlib==1.3.1
|
| 240 |
-
resampy==0.4.3
|
| 241 |
-
retrying==1.3.4
|
| 242 |
-
rich==13.7.1
|
| 243 |
-
rpds-py==0.18.1
|
| 244 |
-
rsa==4.9
|
| 245 |
-
ruamel.yaml==0.18.5
|
| 246 |
-
ruamel.yaml.clib==0.2.8
|
| 247 |
-
ruff==0.4.4
|
| 248 |
-
s3fs==2024.6.1
|
| 249 |
-
s3transfer==0.7.0
|
| 250 |
-
sacrebleu==2.4.2
|
| 251 |
-
safetensors==0.4.3
|
| 252 |
-
scikit-image==0.21.0
|
| 253 |
-
scikit-learn==1.3.2
|
| 254 |
-
scipy==1.10.1
|
| 255 |
-
semantic-version==2.10.0
|
| 256 |
-
sentencepiece==0.1.99
|
| 257 |
-
sentry-sdk==2.10.0
|
| 258 |
-
setproctitle==1.3.3
|
| 259 |
-
shellingham==1.5.4
|
| 260 |
-
six==1.16.0
|
| 261 |
-
smart-open==6.4.0
|
| 262 |
-
smmap==5.0.1
|
| 263 |
-
sniffio==1.3.1
|
| 264 |
-
sortedcontainers==2.4.0
|
| 265 |
-
SoundFile==0.10.2
|
| 266 |
-
sox==1.4.1
|
| 267 |
-
soxr==0.3.7
|
| 268 |
-
spacy==3.7.4
|
| 269 |
-
spacy-legacy==3.0.12
|
| 270 |
-
spacy-loggers==1.0.5
|
| 271 |
-
srsly==2.4.8
|
| 272 |
-
stack-data==0.6.3
|
| 273 |
-
starlette==0.37.2
|
| 274 |
-
submitit==1.5.1
|
| 275 |
-
sympy==1.12
|
| 276 |
-
tabulate==0.9.0
|
| 277 |
-
tenacity==9.0.0
|
| 278 |
-
tensorboard==2.14.0
|
| 279 |
-
tensorboard-data-server==0.7.2
|
| 280 |
-
termcolor==2.3.0
|
| 281 |
-
thinc==8.2.3
|
| 282 |
-
threadpoolctl==3.3.0
|
| 283 |
-
tifffile==2023.7.10
|
| 284 |
-
timm==0.9.11
|
| 285 |
-
tokenizers==0.19.1
|
| 286 |
-
tomlkit==0.12.0
|
| 287 |
-
toolz==0.12.1
|
| 288 |
-
torch==2.2.0+cu118
|
| 289 |
-
torch-stoi==0.2.1
|
| 290 |
-
torchaudio==2.2.0+cu118
|
| 291 |
-
torchdata==0.7.1
|
| 292 |
-
torchdiffeq==0.2.4
|
| 293 |
-
torchlibrosa==0.1.0
|
| 294 |
-
torchmetrics==0.11.4
|
| 295 |
-
torchsde==0.2.6
|
| 296 |
-
torchtext==0.17.0
|
| 297 |
-
torchvision==0.17.0+cu118
|
| 298 |
-
tornado==6.4.1
|
| 299 |
-
tqdm==4.66.4
|
| 300 |
-
traitlets==5.14.3
|
| 301 |
-
trampoline==0.1.2
|
| 302 |
-
transformers==4.42.4
|
| 303 |
-
treetable==0.2.5
|
| 304 |
-
triton==2.2.0
|
| 305 |
-
typeguard==2.13.0
|
| 306 |
-
typer==0.9.4
|
| 307 |
-
types-dataclasses==0.6.6
|
| 308 |
-
typing-inspect==0.9.0
|
| 309 |
-
typing_extensions==4.8.0
|
| 310 |
-
tzdata==2023.3
|
| 311 |
-
uc-micro-py==1.0.3
|
| 312 |
-
umap-learn==0.5.6
|
| 313 |
-
Unidecode==1.3.8
|
| 314 |
-
urllib3==1.26.18
|
| 315 |
-
uvicorn==0.29.0
|
| 316 |
-
v-diffusion-pytorch==0.0.2
|
| 317 |
-
vector-quantize-pytorch==1.9.14
|
| 318 |
-
wandb==0.15.4
|
| 319 |
-
wasabi==1.1.2
|
| 320 |
-
wcwidth==0.2.12
|
| 321 |
-
weasel==0.3.4
|
| 322 |
-
webdataset==0.2.48
|
| 323 |
-
webencodings==0.5.1
|
| 324 |
-
websockets==11.0.3
|
| 325 |
-
Werkzeug==3.0.1
|
| 326 |
-
wget==3.2
|
| 327 |
-
wordsegment==1.3.1
|
| 328 |
-
wrapt==1.16.0
|
| 329 |
-
x-clip==0.14.4
|
| 330 |
-
x-transformers==1.26.6
|
| 331 |
-
xformers==0.0.24+cu118
|
| 332 |
-
xxhash==3.4.1
|
| 333 |
-
xyzservices==2024.6.0
|
| 334 |
-
yarl==1.9.4
|
| 335 |
-
zipp==3.17.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MuCodec/test_wav/test.wav
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:8cd28fa4fc1e8695be47602407088fcc9c486ac27b0ac6712ad30b7c7bcef4f8
|
| 3 |
-
size 22823468
|
|
|
|
|
|
|
|
|
|
|
|
MuCodec/tools/get_melvaehifigan48k.py
DELETED
|
@@ -1,1551 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
import soundfile as sf
|
| 3 |
-
import os
|
| 4 |
-
from librosa.filters import mel as librosa_mel_fn
|
| 5 |
-
import sys
|
| 6 |
-
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
| 7 |
-
import tools.torch_tools as torch_tools
|
| 8 |
-
import torch.nn as nn
|
| 9 |
-
import torch
|
| 10 |
-
import numpy as np
|
| 11 |
-
from einops import rearrange
|
| 12 |
-
from scipy.signal import get_window
|
| 13 |
-
from librosa.util import pad_center, tiny
|
| 14 |
-
import librosa.util as librosa_util
|
| 15 |
-
|
| 16 |
-
class AttrDict(dict):
|
| 17 |
-
def __init__(self, *args, **kwargs):
|
| 18 |
-
super(AttrDict, self).__init__(*args, **kwargs)
|
| 19 |
-
self.__dict__ = self
|
| 20 |
-
|
| 21 |
-
def init_weights(m, mean=0.0, std=0.01):
|
| 22 |
-
classname = m.__class__.__name__
|
| 23 |
-
if classname.find("Conv") != -1:
|
| 24 |
-
m.weight.data.normal_(mean, std)
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
def get_padding(kernel_size, dilation=1):
|
| 28 |
-
return int((kernel_size * dilation - dilation) / 2)
|
| 29 |
-
|
| 30 |
-
LRELU_SLOPE = 0.1
|
| 31 |
-
|
| 32 |
-
class ResBlock(torch.nn.Module):
|
| 33 |
-
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
|
| 34 |
-
super(ResBlock, self).__init__()
|
| 35 |
-
self.h = h
|
| 36 |
-
self.convs1 = nn.ModuleList(
|
| 37 |
-
[
|
| 38 |
-
torch.nn.utils.weight_norm(
|
| 39 |
-
nn.Conv1d(
|
| 40 |
-
channels,
|
| 41 |
-
channels,
|
| 42 |
-
kernel_size,
|
| 43 |
-
1,
|
| 44 |
-
dilation=dilation[0],
|
| 45 |
-
padding=get_padding(kernel_size, dilation[0]),
|
| 46 |
-
)
|
| 47 |
-
),
|
| 48 |
-
torch.nn.utils.weight_norm(
|
| 49 |
-
nn.Conv1d(
|
| 50 |
-
channels,
|
| 51 |
-
channels,
|
| 52 |
-
kernel_size,
|
| 53 |
-
1,
|
| 54 |
-
dilation=dilation[1],
|
| 55 |
-
padding=get_padding(kernel_size, dilation[1]),
|
| 56 |
-
)
|
| 57 |
-
),
|
| 58 |
-
torch.nn.utils.weight_norm(
|
| 59 |
-
nn.Conv1d(
|
| 60 |
-
channels,
|
| 61 |
-
channels,
|
| 62 |
-
kernel_size,
|
| 63 |
-
1,
|
| 64 |
-
dilation=dilation[2],
|
| 65 |
-
padding=get_padding(kernel_size, dilation[2]),
|
| 66 |
-
)
|
| 67 |
-
),
|
| 68 |
-
]
|
| 69 |
-
)
|
| 70 |
-
self.convs1.apply(init_weights)
|
| 71 |
-
|
| 72 |
-
self.convs2 = nn.ModuleList(
|
| 73 |
-
[
|
| 74 |
-
torch.nn.utils.weight_norm(
|
| 75 |
-
nn.Conv1d(
|
| 76 |
-
channels,
|
| 77 |
-
channels,
|
| 78 |
-
kernel_size,
|
| 79 |
-
1,
|
| 80 |
-
dilation=1,
|
| 81 |
-
padding=get_padding(kernel_size, 1),
|
| 82 |
-
)
|
| 83 |
-
),
|
| 84 |
-
torch.nn.utils.weight_norm(
|
| 85 |
-
nn.Conv1d(
|
| 86 |
-
channels,
|
| 87 |
-
channels,
|
| 88 |
-
kernel_size,
|
| 89 |
-
1,
|
| 90 |
-
dilation=1,
|
| 91 |
-
padding=get_padding(kernel_size, 1),
|
| 92 |
-
)
|
| 93 |
-
),
|
| 94 |
-
torch.nn.utils.weight_norm(
|
| 95 |
-
nn.Conv1d(
|
| 96 |
-
channels,
|
| 97 |
-
channels,
|
| 98 |
-
kernel_size,
|
| 99 |
-
1,
|
| 100 |
-
dilation=1,
|
| 101 |
-
padding=get_padding(kernel_size, 1),
|
| 102 |
-
)
|
| 103 |
-
),
|
| 104 |
-
]
|
| 105 |
-
)
|
| 106 |
-
self.convs2.apply(init_weights)
|
| 107 |
-
|
| 108 |
-
def forward(self, x):
|
| 109 |
-
for c1, c2 in zip(self.convs1, self.convs2):
|
| 110 |
-
xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
|
| 111 |
-
xt = c1(xt)
|
| 112 |
-
xt = torch.nn.functional.leaky_relu(xt, LRELU_SLOPE)
|
| 113 |
-
xt = c2(xt)
|
| 114 |
-
x = xt + x
|
| 115 |
-
return x
|
| 116 |
-
|
| 117 |
-
def remove_weight_norm(self):
|
| 118 |
-
for l in self.convs1:
|
| 119 |
-
torch.nn.utils.remove_weight_norm(l)
|
| 120 |
-
for l in self.convs2:
|
| 121 |
-
torch.nn.utils.remove_weight_norm(l)
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
class Generator_old(torch.nn.Module):
|
| 125 |
-
def __init__(self, h):
|
| 126 |
-
super(Generator_old, self).__init__()
|
| 127 |
-
self.h = h
|
| 128 |
-
self.num_kernels = len(h.resblock_kernel_sizes)
|
| 129 |
-
self.num_upsamples = len(h.upsample_rates)
|
| 130 |
-
self.conv_pre = torch.nn.utils.weight_norm(
|
| 131 |
-
nn.Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
|
| 132 |
-
)
|
| 133 |
-
resblock = ResBlock
|
| 134 |
-
|
| 135 |
-
self.ups = nn.ModuleList()
|
| 136 |
-
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
| 137 |
-
self.ups.append(
|
| 138 |
-
torch.nn.utils.weight_norm(
|
| 139 |
-
nn.ConvTranspose1d(
|
| 140 |
-
h.upsample_initial_channel // (2**i),
|
| 141 |
-
h.upsample_initial_channel // (2 ** (i + 1)),
|
| 142 |
-
k,
|
| 143 |
-
u,
|
| 144 |
-
padding=(k - u) // 2,
|
| 145 |
-
)
|
| 146 |
-
)
|
| 147 |
-
)
|
| 148 |
-
|
| 149 |
-
self.resblocks = nn.ModuleList()
|
| 150 |
-
for i in range(len(self.ups)):
|
| 151 |
-
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
| 152 |
-
for j, (k, d) in enumerate(
|
| 153 |
-
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
|
| 154 |
-
):
|
| 155 |
-
self.resblocks.append(resblock(h, ch, k, d))
|
| 156 |
-
|
| 157 |
-
self.conv_post = torch.nn.utils.weight_norm(nn.Conv1d(ch, 1, 7, 1, padding=3))
|
| 158 |
-
self.ups.apply(init_weights)
|
| 159 |
-
self.conv_post.apply(init_weights)
|
| 160 |
-
|
| 161 |
-
def forward(self, x):
|
| 162 |
-
x = self.conv_pre(x)
|
| 163 |
-
for i in range(self.num_upsamples):
|
| 164 |
-
x = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
|
| 165 |
-
x = self.ups[i](x)
|
| 166 |
-
xs = None
|
| 167 |
-
for j in range(self.num_kernels):
|
| 168 |
-
if xs is None:
|
| 169 |
-
xs = self.resblocks[i * self.num_kernels + j](x)
|
| 170 |
-
else:
|
| 171 |
-
xs += self.resblocks[i * self.num_kernels + j](x)
|
| 172 |
-
x = xs / self.num_kernels
|
| 173 |
-
x = torch.nn.functional.leaky_relu(x)
|
| 174 |
-
x = self.conv_post(x)
|
| 175 |
-
x = torch.tanh(x)
|
| 176 |
-
|
| 177 |
-
return x
|
| 178 |
-
|
| 179 |
-
def remove_weight_norm(self):
|
| 180 |
-
# print("Removing weight norm...")
|
| 181 |
-
for l in self.ups:
|
| 182 |
-
torch.nn.utils.remove_weight_norm(l)
|
| 183 |
-
for l in self.resblocks:
|
| 184 |
-
l.remove_weight_norm()
|
| 185 |
-
torch.nn.utils.remove_weight_norm(self.conv_pre)
|
| 186 |
-
torch.nn.utils.remove_weight_norm(self.conv_post)
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
def nonlinearity(x):
|
| 191 |
-
# swish
|
| 192 |
-
return x * torch.sigmoid(x)
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
def Normalize(in_channels, num_groups=32):
|
| 196 |
-
return torch.nn.GroupNorm(
|
| 197 |
-
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
|
| 198 |
-
)
|
| 199 |
-
|
| 200 |
-
class Downsample(nn.Module):
|
| 201 |
-
def __init__(self, in_channels, with_conv):
|
| 202 |
-
super().__init__()
|
| 203 |
-
self.with_conv = with_conv
|
| 204 |
-
if self.with_conv:
|
| 205 |
-
# Do time downsampling here
|
| 206 |
-
# no asymmetric padding in torch conv, must do it ourselves
|
| 207 |
-
self.conv = torch.nn.Conv2d(
|
| 208 |
-
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
| 209 |
-
)
|
| 210 |
-
|
| 211 |
-
def forward(self, x):
|
| 212 |
-
if self.with_conv:
|
| 213 |
-
pad = (0, 1, 0, 1)
|
| 214 |
-
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
| 215 |
-
x = self.conv(x)
|
| 216 |
-
else:
|
| 217 |
-
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
| 218 |
-
return x
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
class DownsampleTimeStride4(nn.Module):
|
| 222 |
-
def __init__(self, in_channels, with_conv):
|
| 223 |
-
super().__init__()
|
| 224 |
-
self.with_conv = with_conv
|
| 225 |
-
if self.with_conv:
|
| 226 |
-
# Do time downsampling here
|
| 227 |
-
# no asymmetric padding in torch conv, must do it ourselves
|
| 228 |
-
self.conv = torch.nn.Conv2d(
|
| 229 |
-
in_channels, in_channels, kernel_size=5, stride=(4, 2), padding=1
|
| 230 |
-
)
|
| 231 |
-
|
| 232 |
-
def forward(self, x):
|
| 233 |
-
if self.with_conv:
|
| 234 |
-
pad = (0, 1, 0, 1)
|
| 235 |
-
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
| 236 |
-
x = self.conv(x)
|
| 237 |
-
else:
|
| 238 |
-
x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2))
|
| 239 |
-
return x
|
| 240 |
-
|
| 241 |
-
class Upsample(nn.Module):
|
| 242 |
-
def __init__(self, in_channels, with_conv):
|
| 243 |
-
super().__init__()
|
| 244 |
-
self.with_conv = with_conv
|
| 245 |
-
if self.with_conv:
|
| 246 |
-
self.conv = torch.nn.Conv2d(
|
| 247 |
-
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
| 248 |
-
)
|
| 249 |
-
|
| 250 |
-
def forward(self, x):
|
| 251 |
-
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
| 252 |
-
if self.with_conv:
|
| 253 |
-
x = self.conv(x)
|
| 254 |
-
return x
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
class UpsampleTimeStride4(nn.Module):
|
| 258 |
-
def __init__(self, in_channels, with_conv):
|
| 259 |
-
super().__init__()
|
| 260 |
-
self.with_conv = with_conv
|
| 261 |
-
if self.with_conv:
|
| 262 |
-
self.conv = torch.nn.Conv2d(
|
| 263 |
-
in_channels, in_channels, kernel_size=5, stride=1, padding=2
|
| 264 |
-
)
|
| 265 |
-
|
| 266 |
-
def forward(self, x):
|
| 267 |
-
x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode="nearest")
|
| 268 |
-
if self.with_conv:
|
| 269 |
-
x = self.conv(x)
|
| 270 |
-
return x
|
| 271 |
-
|
| 272 |
-
class AttnBlock(nn.Module):
|
| 273 |
-
def __init__(self, in_channels):
|
| 274 |
-
super().__init__()
|
| 275 |
-
self.in_channels = in_channels
|
| 276 |
-
|
| 277 |
-
self.norm = Normalize(in_channels)
|
| 278 |
-
self.q = torch.nn.Conv2d(
|
| 279 |
-
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
| 280 |
-
)
|
| 281 |
-
self.k = torch.nn.Conv2d(
|
| 282 |
-
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
| 283 |
-
)
|
| 284 |
-
self.v = torch.nn.Conv2d(
|
| 285 |
-
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
| 286 |
-
)
|
| 287 |
-
self.proj_out = torch.nn.Conv2d(
|
| 288 |
-
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
| 289 |
-
)
|
| 290 |
-
|
| 291 |
-
def forward(self, x):
|
| 292 |
-
h_ = x
|
| 293 |
-
h_ = self.norm(h_)
|
| 294 |
-
q = self.q(h_)
|
| 295 |
-
k = self.k(h_)
|
| 296 |
-
v = self.v(h_)
|
| 297 |
-
|
| 298 |
-
# compute attention
|
| 299 |
-
b, c, h, w = q.shape
|
| 300 |
-
q = q.reshape(b, c, h * w).contiguous()
|
| 301 |
-
q = q.permute(0, 2, 1).contiguous() # b,hw,c
|
| 302 |
-
k = k.reshape(b, c, h * w).contiguous() # b,c,hw
|
| 303 |
-
w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
| 304 |
-
w_ = w_ * (int(c) ** (-0.5))
|
| 305 |
-
w_ = torch.nn.functional.softmax(w_, dim=2)
|
| 306 |
-
|
| 307 |
-
# attend to values
|
| 308 |
-
v = v.reshape(b, c, h * w).contiguous()
|
| 309 |
-
w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)
|
| 310 |
-
h_ = torch.bmm(
|
| 311 |
-
v, w_
|
| 312 |
-
).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
| 313 |
-
h_ = h_.reshape(b, c, h, w).contiguous()
|
| 314 |
-
|
| 315 |
-
h_ = self.proj_out(h_)
|
| 316 |
-
|
| 317 |
-
return x + h_
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
def make_attn(in_channels, attn_type="vanilla"):
|
| 321 |
-
assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown"
|
| 322 |
-
# print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
| 323 |
-
if attn_type == "vanilla":
|
| 324 |
-
return AttnBlock(in_channels)
|
| 325 |
-
elif attn_type == "none":
|
| 326 |
-
return nn.Identity(in_channels)
|
| 327 |
-
else:
|
| 328 |
-
raise ValueError(attn_type)
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
class ResnetBlock(nn.Module):
|
| 332 |
-
def __init__(
|
| 333 |
-
self,
|
| 334 |
-
*,
|
| 335 |
-
in_channels,
|
| 336 |
-
out_channels=None,
|
| 337 |
-
conv_shortcut=False,
|
| 338 |
-
dropout,
|
| 339 |
-
temb_channels=512,
|
| 340 |
-
):
|
| 341 |
-
super().__init__()
|
| 342 |
-
self.in_channels = in_channels
|
| 343 |
-
out_channels = in_channels if out_channels is None else out_channels
|
| 344 |
-
self.out_channels = out_channels
|
| 345 |
-
self.use_conv_shortcut = conv_shortcut
|
| 346 |
-
|
| 347 |
-
self.norm1 = Normalize(in_channels)
|
| 348 |
-
self.conv1 = torch.nn.Conv2d(
|
| 349 |
-
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
| 350 |
-
)
|
| 351 |
-
if temb_channels > 0:
|
| 352 |
-
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
| 353 |
-
self.norm2 = Normalize(out_channels)
|
| 354 |
-
self.dropout = torch.nn.Dropout(dropout)
|
| 355 |
-
self.conv2 = torch.nn.Conv2d(
|
| 356 |
-
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
| 357 |
-
)
|
| 358 |
-
if self.in_channels != self.out_channels:
|
| 359 |
-
if self.use_conv_shortcut:
|
| 360 |
-
self.conv_shortcut = torch.nn.Conv2d(
|
| 361 |
-
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
| 362 |
-
)
|
| 363 |
-
else:
|
| 364 |
-
self.nin_shortcut = torch.nn.Conv2d(
|
| 365 |
-
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
| 366 |
-
)
|
| 367 |
-
|
| 368 |
-
def forward(self, x, temb):
|
| 369 |
-
h = x
|
| 370 |
-
h = self.norm1(h)
|
| 371 |
-
h = nonlinearity(h)
|
| 372 |
-
h = self.conv1(h)
|
| 373 |
-
|
| 374 |
-
if temb is not None:
|
| 375 |
-
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
| 376 |
-
|
| 377 |
-
h = self.norm2(h)
|
| 378 |
-
h = nonlinearity(h)
|
| 379 |
-
h = self.dropout(h)
|
| 380 |
-
h = self.conv2(h)
|
| 381 |
-
|
| 382 |
-
if self.in_channels != self.out_channels:
|
| 383 |
-
if self.use_conv_shortcut:
|
| 384 |
-
x = self.conv_shortcut(x)
|
| 385 |
-
else:
|
| 386 |
-
x = self.nin_shortcut(x)
|
| 387 |
-
|
| 388 |
-
return x + h
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
class Encoder(nn.Module):
|
| 392 |
-
def __init__(
|
| 393 |
-
self,
|
| 394 |
-
*,
|
| 395 |
-
ch,
|
| 396 |
-
out_ch,
|
| 397 |
-
ch_mult=(1, 2, 4, 8),
|
| 398 |
-
num_res_blocks,
|
| 399 |
-
attn_resolutions,
|
| 400 |
-
dropout=0.0,
|
| 401 |
-
resamp_with_conv=True,
|
| 402 |
-
in_channels,
|
| 403 |
-
resolution,
|
| 404 |
-
z_channels,
|
| 405 |
-
double_z=True,
|
| 406 |
-
use_linear_attn=False,
|
| 407 |
-
attn_type="vanilla",
|
| 408 |
-
downsample_time_stride4_levels=[],
|
| 409 |
-
**ignore_kwargs,
|
| 410 |
-
):
|
| 411 |
-
super().__init__()
|
| 412 |
-
if use_linear_attn:
|
| 413 |
-
attn_type = "linear"
|
| 414 |
-
self.ch = ch
|
| 415 |
-
self.temb_ch = 0
|
| 416 |
-
self.num_resolutions = len(ch_mult)
|
| 417 |
-
self.num_res_blocks = num_res_blocks
|
| 418 |
-
self.resolution = resolution
|
| 419 |
-
self.in_channels = in_channels
|
| 420 |
-
self.downsample_time_stride4_levels = downsample_time_stride4_levels
|
| 421 |
-
|
| 422 |
-
if len(self.downsample_time_stride4_levels) > 0:
|
| 423 |
-
assert max(self.downsample_time_stride4_levels) < self.num_resolutions, (
|
| 424 |
-
"The level to perform downsample 4 operation need to be smaller than the total resolution number %s"
|
| 425 |
-
% str(self.num_resolutions)
|
| 426 |
-
)
|
| 427 |
-
|
| 428 |
-
# downsampling
|
| 429 |
-
self.conv_in = torch.nn.Conv2d(
|
| 430 |
-
in_channels, self.ch, kernel_size=3, stride=1, padding=1
|
| 431 |
-
)
|
| 432 |
-
|
| 433 |
-
curr_res = resolution
|
| 434 |
-
in_ch_mult = (1,) + tuple(ch_mult)
|
| 435 |
-
self.in_ch_mult = in_ch_mult
|
| 436 |
-
self.down = nn.ModuleList()
|
| 437 |
-
for i_level in range(self.num_resolutions):
|
| 438 |
-
block = nn.ModuleList()
|
| 439 |
-
attn = nn.ModuleList()
|
| 440 |
-
block_in = ch * in_ch_mult[i_level]
|
| 441 |
-
block_out = ch * ch_mult[i_level]
|
| 442 |
-
for i_block in range(self.num_res_blocks):
|
| 443 |
-
block.append(
|
| 444 |
-
ResnetBlock(
|
| 445 |
-
in_channels=block_in,
|
| 446 |
-
out_channels=block_out,
|
| 447 |
-
temb_channels=self.temb_ch,
|
| 448 |
-
dropout=dropout,
|
| 449 |
-
)
|
| 450 |
-
)
|
| 451 |
-
block_in = block_out
|
| 452 |
-
if curr_res in attn_resolutions:
|
| 453 |
-
attn.append(make_attn(block_in, attn_type=attn_type))
|
| 454 |
-
down = nn.Module()
|
| 455 |
-
down.block = block
|
| 456 |
-
down.attn = attn
|
| 457 |
-
if i_level != self.num_resolutions - 1:
|
| 458 |
-
if i_level in self.downsample_time_stride4_levels:
|
| 459 |
-
down.downsample = DownsampleTimeStride4(block_in, resamp_with_conv)
|
| 460 |
-
else:
|
| 461 |
-
down.downsample = Downsample(block_in, resamp_with_conv)
|
| 462 |
-
curr_res = curr_res // 2
|
| 463 |
-
self.down.append(down)
|
| 464 |
-
|
| 465 |
-
# middle
|
| 466 |
-
self.mid = nn.Module()
|
| 467 |
-
self.mid.block_1 = ResnetBlock(
|
| 468 |
-
in_channels=block_in,
|
| 469 |
-
out_channels=block_in,
|
| 470 |
-
temb_channels=self.temb_ch,
|
| 471 |
-
dropout=dropout,
|
| 472 |
-
)
|
| 473 |
-
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
| 474 |
-
self.mid.block_2 = ResnetBlock(
|
| 475 |
-
in_channels=block_in,
|
| 476 |
-
out_channels=block_in,
|
| 477 |
-
temb_channels=self.temb_ch,
|
| 478 |
-
dropout=dropout,
|
| 479 |
-
)
|
| 480 |
-
|
| 481 |
-
# end
|
| 482 |
-
self.norm_out = Normalize(block_in)
|
| 483 |
-
self.conv_out = torch.nn.Conv2d(
|
| 484 |
-
block_in,
|
| 485 |
-
2 * z_channels if double_z else z_channels,
|
| 486 |
-
kernel_size=3,
|
| 487 |
-
stride=1,
|
| 488 |
-
padding=1,
|
| 489 |
-
)
|
| 490 |
-
|
| 491 |
-
def forward(self, x):
|
| 492 |
-
# timestep embedding
|
| 493 |
-
temb = None
|
| 494 |
-
# downsampling
|
| 495 |
-
hs = [self.conv_in(x)]
|
| 496 |
-
for i_level in range(self.num_resolutions):
|
| 497 |
-
for i_block in range(self.num_res_blocks):
|
| 498 |
-
h = self.down[i_level].block[i_block](hs[-1], temb)
|
| 499 |
-
if len(self.down[i_level].attn) > 0:
|
| 500 |
-
h = self.down[i_level].attn[i_block](h)
|
| 501 |
-
hs.append(h)
|
| 502 |
-
if i_level != self.num_resolutions - 1:
|
| 503 |
-
hs.append(self.down[i_level].downsample(hs[-1]))
|
| 504 |
-
|
| 505 |
-
# middle
|
| 506 |
-
h = hs[-1]
|
| 507 |
-
h = self.mid.block_1(h, temb)
|
| 508 |
-
h = self.mid.attn_1(h)
|
| 509 |
-
h = self.mid.block_2(h, temb)
|
| 510 |
-
|
| 511 |
-
# end
|
| 512 |
-
h = self.norm_out(h)
|
| 513 |
-
h = nonlinearity(h)
|
| 514 |
-
h = self.conv_out(h)
|
| 515 |
-
return h
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
class Decoder(nn.Module):
|
| 519 |
-
def __init__(
|
| 520 |
-
self,
|
| 521 |
-
*,
|
| 522 |
-
ch,
|
| 523 |
-
out_ch,
|
| 524 |
-
ch_mult=(1, 2, 4, 8),
|
| 525 |
-
num_res_blocks,
|
| 526 |
-
attn_resolutions,
|
| 527 |
-
dropout=0.0,
|
| 528 |
-
resamp_with_conv=True,
|
| 529 |
-
in_channels,
|
| 530 |
-
resolution,
|
| 531 |
-
z_channels,
|
| 532 |
-
give_pre_end=False,
|
| 533 |
-
tanh_out=False,
|
| 534 |
-
use_linear_attn=False,
|
| 535 |
-
downsample_time_stride4_levels=[],
|
| 536 |
-
attn_type="vanilla",
|
| 537 |
-
**ignorekwargs,
|
| 538 |
-
):
|
| 539 |
-
super().__init__()
|
| 540 |
-
if use_linear_attn:
|
| 541 |
-
attn_type = "linear"
|
| 542 |
-
self.ch = ch
|
| 543 |
-
self.temb_ch = 0
|
| 544 |
-
self.num_resolutions = len(ch_mult)
|
| 545 |
-
self.num_res_blocks = num_res_blocks
|
| 546 |
-
self.resolution = resolution
|
| 547 |
-
self.in_channels = in_channels
|
| 548 |
-
self.give_pre_end = give_pre_end
|
| 549 |
-
self.tanh_out = tanh_out
|
| 550 |
-
self.downsample_time_stride4_levels = downsample_time_stride4_levels
|
| 551 |
-
|
| 552 |
-
if len(self.downsample_time_stride4_levels) > 0:
|
| 553 |
-
assert max(self.downsample_time_stride4_levels) < self.num_resolutions, (
|
| 554 |
-
"The level to perform downsample 4 operation need to be smaller than the total resolution number %s"
|
| 555 |
-
% str(self.num_resolutions)
|
| 556 |
-
)
|
| 557 |
-
|
| 558 |
-
# compute in_ch_mult, block_in and curr_res at lowest res
|
| 559 |
-
(1,) + tuple(ch_mult)
|
| 560 |
-
block_in = ch * ch_mult[self.num_resolutions - 1]
|
| 561 |
-
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
| 562 |
-
self.z_shape = (1, z_channels, curr_res, curr_res)
|
| 563 |
-
# print(
|
| 564 |
-
# "Working with z of shape {} = {} dimensions.".format(
|
| 565 |
-
# self.z_shape, np.prod(self.z_shape)
|
| 566 |
-
# )
|
| 567 |
-
# )
|
| 568 |
-
|
| 569 |
-
# z to block_in
|
| 570 |
-
self.conv_in = torch.nn.Conv2d(
|
| 571 |
-
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
| 572 |
-
)
|
| 573 |
-
|
| 574 |
-
# middle
|
| 575 |
-
self.mid = nn.Module()
|
| 576 |
-
self.mid.block_1 = ResnetBlock(
|
| 577 |
-
in_channels=block_in,
|
| 578 |
-
out_channels=block_in,
|
| 579 |
-
temb_channels=self.temb_ch,
|
| 580 |
-
dropout=dropout,
|
| 581 |
-
)
|
| 582 |
-
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
| 583 |
-
self.mid.block_2 = ResnetBlock(
|
| 584 |
-
in_channels=block_in,
|
| 585 |
-
out_channels=block_in,
|
| 586 |
-
temb_channels=self.temb_ch,
|
| 587 |
-
dropout=dropout,
|
| 588 |
-
)
|
| 589 |
-
|
| 590 |
-
# upsampling
|
| 591 |
-
self.up = nn.ModuleList()
|
| 592 |
-
for i_level in reversed(range(self.num_resolutions)):
|
| 593 |
-
block = nn.ModuleList()
|
| 594 |
-
attn = nn.ModuleList()
|
| 595 |
-
block_out = ch * ch_mult[i_level]
|
| 596 |
-
for i_block in range(self.num_res_blocks + 1):
|
| 597 |
-
block.append(
|
| 598 |
-
ResnetBlock(
|
| 599 |
-
in_channels=block_in,
|
| 600 |
-
out_channels=block_out,
|
| 601 |
-
temb_channels=self.temb_ch,
|
| 602 |
-
dropout=dropout,
|
| 603 |
-
)
|
| 604 |
-
)
|
| 605 |
-
block_in = block_out
|
| 606 |
-
if curr_res in attn_resolutions:
|
| 607 |
-
attn.append(make_attn(block_in, attn_type=attn_type))
|
| 608 |
-
up = nn.Module()
|
| 609 |
-
up.block = block
|
| 610 |
-
up.attn = attn
|
| 611 |
-
if i_level != 0:
|
| 612 |
-
if i_level - 1 in self.downsample_time_stride4_levels:
|
| 613 |
-
up.upsample = UpsampleTimeStride4(block_in, resamp_with_conv)
|
| 614 |
-
else:
|
| 615 |
-
up.upsample = Upsample(block_in, resamp_with_conv)
|
| 616 |
-
curr_res = curr_res * 2
|
| 617 |
-
self.up.insert(0, up) # prepend to get consistent order
|
| 618 |
-
|
| 619 |
-
# end
|
| 620 |
-
self.norm_out = Normalize(block_in)
|
| 621 |
-
self.conv_out = torch.nn.Conv2d(
|
| 622 |
-
block_in, out_ch, kernel_size=3, stride=1, padding=1
|
| 623 |
-
)
|
| 624 |
-
|
| 625 |
-
def forward(self, z):
|
| 626 |
-
# assert z.shape[1:] == self.z_shape[1:]
|
| 627 |
-
self.last_z_shape = z.shape
|
| 628 |
-
|
| 629 |
-
# timestep embedding
|
| 630 |
-
temb = None
|
| 631 |
-
|
| 632 |
-
# z to block_in
|
| 633 |
-
h = self.conv_in(z)
|
| 634 |
-
|
| 635 |
-
# middle
|
| 636 |
-
h = self.mid.block_1(h, temb)
|
| 637 |
-
h = self.mid.attn_1(h)
|
| 638 |
-
h = self.mid.block_2(h, temb)
|
| 639 |
-
|
| 640 |
-
# upsampling
|
| 641 |
-
for i_level in reversed(range(self.num_resolutions)):
|
| 642 |
-
for i_block in range(self.num_res_blocks + 1):
|
| 643 |
-
h = self.up[i_level].block[i_block](h, temb)
|
| 644 |
-
if len(self.up[i_level].attn) > 0:
|
| 645 |
-
h = self.up[i_level].attn[i_block](h)
|
| 646 |
-
if i_level != 0:
|
| 647 |
-
h = self.up[i_level].upsample(h)
|
| 648 |
-
|
| 649 |
-
# end
|
| 650 |
-
if self.give_pre_end:
|
| 651 |
-
return h
|
| 652 |
-
|
| 653 |
-
h = self.norm_out(h)
|
| 654 |
-
h = nonlinearity(h)
|
| 655 |
-
h = self.conv_out(h)
|
| 656 |
-
if self.tanh_out:
|
| 657 |
-
h = torch.tanh(h)
|
| 658 |
-
return h
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
class DiagonalGaussianDistribution(object):
|
| 662 |
-
def __init__(self, parameters, deterministic=False):
|
| 663 |
-
self.parameters = parameters
|
| 664 |
-
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
| 665 |
-
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
| 666 |
-
self.deterministic = deterministic
|
| 667 |
-
self.std = torch.exp(0.5 * self.logvar)
|
| 668 |
-
self.var = torch.exp(self.logvar)
|
| 669 |
-
if self.deterministic:
|
| 670 |
-
self.var = self.std = torch.zeros_like(self.mean).to(
|
| 671 |
-
device=self.parameters.device
|
| 672 |
-
)
|
| 673 |
-
|
| 674 |
-
def sample(self):
|
| 675 |
-
x = self.mean + self.std * torch.randn(self.mean.shape).to(
|
| 676 |
-
device=self.parameters.device
|
| 677 |
-
)
|
| 678 |
-
return x
|
| 679 |
-
|
| 680 |
-
def kl(self, other=None):
|
| 681 |
-
if self.deterministic:
|
| 682 |
-
return torch.Tensor([0.0])
|
| 683 |
-
else:
|
| 684 |
-
if other is None:
|
| 685 |
-
return 0.5 * torch.mean(
|
| 686 |
-
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
| 687 |
-
dim=[1, 2, 3],
|
| 688 |
-
)
|
| 689 |
-
else:
|
| 690 |
-
return 0.5 * torch.mean(
|
| 691 |
-
torch.pow(self.mean - other.mean, 2) / other.var
|
| 692 |
-
+ self.var / other.var
|
| 693 |
-
- 1.0
|
| 694 |
-
- self.logvar
|
| 695 |
-
+ other.logvar,
|
| 696 |
-
dim=[1, 2, 3],
|
| 697 |
-
)
|
| 698 |
-
|
| 699 |
-
def nll(self, sample, dims=[1, 2, 3]):
|
| 700 |
-
if self.deterministic:
|
| 701 |
-
return torch.Tensor([0.0])
|
| 702 |
-
logtwopi = np.log(2.0 * np.pi)
|
| 703 |
-
return 0.5 * torch.sum(
|
| 704 |
-
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
| 705 |
-
dim=dims,
|
| 706 |
-
)
|
| 707 |
-
|
| 708 |
-
def mode(self):
|
| 709 |
-
return self.mean
|
| 710 |
-
|
| 711 |
-
def get_vocoder_config_48k():
|
| 712 |
-
return {
|
| 713 |
-
"resblock": "1",
|
| 714 |
-
"num_gpus": 8,
|
| 715 |
-
"batch_size": 128,
|
| 716 |
-
"learning_rate": 0.0001,
|
| 717 |
-
"adam_b1": 0.8,
|
| 718 |
-
"adam_b2": 0.99,
|
| 719 |
-
"lr_decay": 0.999,
|
| 720 |
-
"seed": 1234,
|
| 721 |
-
|
| 722 |
-
"upsample_rates": [6,5,4,2,2],
|
| 723 |
-
"upsample_kernel_sizes": [12,10,8,4,4],
|
| 724 |
-
"upsample_initial_channel": 1536,
|
| 725 |
-
"resblock_kernel_sizes": [3,7,11,15],
|
| 726 |
-
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5], [1,3,5]],
|
| 727 |
-
|
| 728 |
-
"segment_size": 15360,
|
| 729 |
-
"num_mels": 256,
|
| 730 |
-
"n_fft": 2048,
|
| 731 |
-
"hop_size": 480,
|
| 732 |
-
"win_size": 2048,
|
| 733 |
-
|
| 734 |
-
"sampling_rate": 48000,
|
| 735 |
-
|
| 736 |
-
"fmin": 20,
|
| 737 |
-
"fmax": 24000,
|
| 738 |
-
"fmax_for_loss": None,
|
| 739 |
-
|
| 740 |
-
"num_workers": 8,
|
| 741 |
-
|
| 742 |
-
"dist_config": {
|
| 743 |
-
"dist_backend": "nccl",
|
| 744 |
-
"dist_url": "tcp://localhost:18273",
|
| 745 |
-
"world_size": 1
|
| 746 |
-
}
|
| 747 |
-
}
|
| 748 |
-
|
| 749 |
-
def get_vocoder(config, device, mel_bins):
|
| 750 |
-
name = "HiFi-GAN"
|
| 751 |
-
speaker = ""
|
| 752 |
-
if name == "MelGAN":
|
| 753 |
-
if speaker == "LJSpeech":
|
| 754 |
-
vocoder = torch.hub.load(
|
| 755 |
-
"descriptinc/melgan-neurips", "load_melgan", "linda_johnson"
|
| 756 |
-
)
|
| 757 |
-
elif speaker == "universal":
|
| 758 |
-
vocoder = torch.hub.load(
|
| 759 |
-
"descriptinc/melgan-neurips", "load_melgan", "multi_speaker"
|
| 760 |
-
)
|
| 761 |
-
vocoder.mel2wav.eval()
|
| 762 |
-
vocoder.mel2wav.to(device)
|
| 763 |
-
elif name == "HiFi-GAN":
|
| 764 |
-
if(mel_bins == 256):
|
| 765 |
-
config = get_vocoder_config_48k()
|
| 766 |
-
config = AttrDict(config)
|
| 767 |
-
vocoder = Generator_old(config)
|
| 768 |
-
# print("Load hifigan/g_01080000")
|
| 769 |
-
# ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000"))
|
| 770 |
-
# ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000"))
|
| 771 |
-
# ckpt = torch_version_orig_mod_remove(ckpt)
|
| 772 |
-
# vocoder.load_state_dict(ckpt["generator"])
|
| 773 |
-
vocoder.eval()
|
| 774 |
-
vocoder.remove_weight_norm()
|
| 775 |
-
vocoder.to(device)
|
| 776 |
-
else:
|
| 777 |
-
raise ValueError(mel_bins)
|
| 778 |
-
return vocoder
|
| 779 |
-
|
| 780 |
-
def vocoder_infer(mels, vocoder, lengths=None):
|
| 781 |
-
with torch.no_grad():
|
| 782 |
-
wavs = vocoder(mels).squeeze(1)
|
| 783 |
-
|
| 784 |
-
#wavs = (wavs.cpu().numpy() * 32768).astype("int16")
|
| 785 |
-
wavs = (wavs.cpu().numpy())
|
| 786 |
-
|
| 787 |
-
if lengths is not None:
|
| 788 |
-
wavs = wavs[:, :lengths]
|
| 789 |
-
|
| 790 |
-
# wavs = [wav for wav in wavs]
|
| 791 |
-
|
| 792 |
-
# for i in range(len(mels)):
|
| 793 |
-
# if lengths is not None:
|
| 794 |
-
# wavs[i] = wavs[i][: lengths[i]]
|
| 795 |
-
|
| 796 |
-
return wavs
|
| 797 |
-
|
| 798 |
-
@torch.no_grad()
|
| 799 |
-
def vocoder_chunk_infer(mels, vocoder, lengths=None):
|
| 800 |
-
chunk_size = 256*4
|
| 801 |
-
shift_size = 256*1
|
| 802 |
-
ov_size = chunk_size-shift_size
|
| 803 |
-
# import pdb;pdb.set_trace()
|
| 804 |
-
|
| 805 |
-
for cinx in range(0, mels.shape[2], shift_size):
|
| 806 |
-
if(cinx==0):
|
| 807 |
-
wavs = vocoder(mels[:,:,cinx:cinx+chunk_size]).squeeze(1).cpu()
|
| 808 |
-
num_samples = int(wavs.shape[-1]/chunk_size)*chunk_size
|
| 809 |
-
wavs = wavs[:,0:num_samples]
|
| 810 |
-
ov_sample = int(float(wavs.shape[-1]) * ov_size / chunk_size)
|
| 811 |
-
ov_win = torch.from_numpy(np.linspace(0,1,ov_sample)[None,:])
|
| 812 |
-
ov_win = torch.cat([ov_win,1-ov_win],-1)
|
| 813 |
-
if(cinx+chunk_size>=mels.shape[2]):
|
| 814 |
-
break
|
| 815 |
-
else:
|
| 816 |
-
cur_wav = vocoder(mels[:,:,cinx:cinx+chunk_size]).squeeze(1).cpu()[:,0:num_samples]
|
| 817 |
-
wavs[:,-ov_sample:] = wavs[:,-ov_sample:] * ov_win[:,-ov_sample:] + cur_wav[:,0:ov_sample] * ov_win[:,0:ov_sample]
|
| 818 |
-
# wavs[:,-ov_sample:] = wavs[:,-ov_sample:] * 1.0 + cur_wav[:,0:ov_sample] * 0.0
|
| 819 |
-
wavs = torch.cat([wavs, cur_wav[:,ov_sample:]],-1)
|
| 820 |
-
if(cinx+chunk_size>=mels.shape[2]):
|
| 821 |
-
break
|
| 822 |
-
# print(wavs.shape)
|
| 823 |
-
|
| 824 |
-
wavs = (wavs.cpu().numpy())
|
| 825 |
-
|
| 826 |
-
if lengths is not None:
|
| 827 |
-
wavs = wavs[:, :lengths]
|
| 828 |
-
# print(wavs.shape)
|
| 829 |
-
return wavs
|
| 830 |
-
|
| 831 |
-
def synth_one_sample(mel_input, mel_prediction, labels, vocoder):
|
| 832 |
-
if vocoder is not None:
|
| 833 |
-
|
| 834 |
-
wav_reconstruction = vocoder_infer(
|
| 835 |
-
mel_input.permute(0, 2, 1),
|
| 836 |
-
vocoder,
|
| 837 |
-
)
|
| 838 |
-
wav_prediction = vocoder_infer(
|
| 839 |
-
mel_prediction.permute(0, 2, 1),
|
| 840 |
-
vocoder,
|
| 841 |
-
)
|
| 842 |
-
else:
|
| 843 |
-
wav_reconstruction = wav_prediction = None
|
| 844 |
-
|
| 845 |
-
return wav_reconstruction, wav_prediction
|
| 846 |
-
|
| 847 |
-
|
| 848 |
-
class AutoencoderKL(nn.Module):
|
| 849 |
-
def __init__(
|
| 850 |
-
self,
|
| 851 |
-
ddconfig=None,
|
| 852 |
-
lossconfig=None,
|
| 853 |
-
batchsize=None,
|
| 854 |
-
embed_dim=None,
|
| 855 |
-
time_shuffle=1,
|
| 856 |
-
subband=1,
|
| 857 |
-
sampling_rate=16000,
|
| 858 |
-
ckpt_path=None,
|
| 859 |
-
reload_from_ckpt=None,
|
| 860 |
-
ignore_keys=[],
|
| 861 |
-
image_key="fbank",
|
| 862 |
-
colorize_nlabels=None,
|
| 863 |
-
monitor=None,
|
| 864 |
-
base_learning_rate=1e-5,
|
| 865 |
-
scale_factor=1
|
| 866 |
-
):
|
| 867 |
-
super().__init__()
|
| 868 |
-
self.automatic_optimization = False
|
| 869 |
-
assert (
|
| 870 |
-
"mel_bins" in ddconfig.keys()
|
| 871 |
-
), "mel_bins is not specified in the Autoencoder config"
|
| 872 |
-
num_mel = ddconfig["mel_bins"]
|
| 873 |
-
self.image_key = image_key
|
| 874 |
-
self.sampling_rate = sampling_rate
|
| 875 |
-
self.encoder = Encoder(**ddconfig)
|
| 876 |
-
self.decoder = Decoder(**ddconfig)
|
| 877 |
-
|
| 878 |
-
self.loss = None
|
| 879 |
-
self.subband = int(subband)
|
| 880 |
-
|
| 881 |
-
if self.subband > 1:
|
| 882 |
-
print("Use subband decomposition %s" % self.subband)
|
| 883 |
-
|
| 884 |
-
assert ddconfig["double_z"]
|
| 885 |
-
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
|
| 886 |
-
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
| 887 |
-
|
| 888 |
-
if self.image_key == "fbank":
|
| 889 |
-
self.vocoder = get_vocoder(None, "cpu", num_mel)
|
| 890 |
-
self.embed_dim = embed_dim
|
| 891 |
-
if colorize_nlabels is not None:
|
| 892 |
-
assert type(colorize_nlabels) == int
|
| 893 |
-
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
| 894 |
-
if monitor is not None:
|
| 895 |
-
self.monitor = monitor
|
| 896 |
-
if ckpt_path is not None:
|
| 897 |
-
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
| 898 |
-
self.learning_rate = float(base_learning_rate)
|
| 899 |
-
# print("Initial learning rate %s" % self.learning_rate)
|
| 900 |
-
|
| 901 |
-
self.time_shuffle = time_shuffle
|
| 902 |
-
self.reload_from_ckpt = reload_from_ckpt
|
| 903 |
-
self.reloaded = False
|
| 904 |
-
self.mean, self.std = None, None
|
| 905 |
-
|
| 906 |
-
self.feature_cache = None
|
| 907 |
-
self.flag_first_run = True
|
| 908 |
-
self.train_step = 0
|
| 909 |
-
|
| 910 |
-
self.logger_save_dir = None
|
| 911 |
-
self.logger_exp_name = None
|
| 912 |
-
self.scale_factor = scale_factor
|
| 913 |
-
|
| 914 |
-
print("Num parameters:")
|
| 915 |
-
print("Encoder : ", sum(p.numel() for p in self.encoder.parameters()))
|
| 916 |
-
print("Decoder : ", sum(p.numel() for p in self.decoder.parameters()))
|
| 917 |
-
print("Vocoder : ", sum(p.numel() for p in self.vocoder.parameters()))
|
| 918 |
-
|
| 919 |
-
def get_log_dir(self):
|
| 920 |
-
if self.logger_save_dir is None and self.logger_exp_name is None:
|
| 921 |
-
return os.path.join(self.logger.save_dir, self.logger._project)
|
| 922 |
-
else:
|
| 923 |
-
return os.path.join(self.logger_save_dir, self.logger_exp_name)
|
| 924 |
-
|
| 925 |
-
def set_log_dir(self, save_dir, exp_name):
|
| 926 |
-
self.logger_save_dir = save_dir
|
| 927 |
-
self.logger_exp_name = exp_name
|
| 928 |
-
|
| 929 |
-
def init_from_ckpt(self, path, ignore_keys=list()):
|
| 930 |
-
sd = torch.load(path, map_location="cpu")["state_dict"]
|
| 931 |
-
keys = list(sd.keys())
|
| 932 |
-
for k in keys:
|
| 933 |
-
for ik in ignore_keys:
|
| 934 |
-
if k.startswith(ik):
|
| 935 |
-
print("Deleting key {} from state_dict.".format(k))
|
| 936 |
-
del sd[k]
|
| 937 |
-
self.load_state_dict(sd, strict=False)
|
| 938 |
-
print(f"Restored from {path}")
|
| 939 |
-
|
| 940 |
-
def encode(self, x):
|
| 941 |
-
# x = self.time_shuffle_operation(x)
|
| 942 |
-
# x = self.freq_split_subband(x)
|
| 943 |
-
h = self.encoder(x)
|
| 944 |
-
moments = self.quant_conv(h)
|
| 945 |
-
posterior = DiagonalGaussianDistribution(moments)
|
| 946 |
-
return posterior
|
| 947 |
-
|
| 948 |
-
def decode(self, z):
|
| 949 |
-
z = self.post_quant_conv(z)
|
| 950 |
-
dec = self.decoder(z)
|
| 951 |
-
# bs, ch, shuffled_timesteps, fbins = dec.size()
|
| 952 |
-
# dec = self.time_unshuffle_operation(dec, bs, int(ch*shuffled_timesteps), fbins)
|
| 953 |
-
# dec = self.freq_merge_subband(dec)
|
| 954 |
-
return dec
|
| 955 |
-
|
| 956 |
-
def decode_to_waveform(self, dec):
|
| 957 |
-
|
| 958 |
-
if self.image_key == "fbank":
|
| 959 |
-
dec = dec.squeeze(1).permute(0, 2, 1)
|
| 960 |
-
wav_reconstruction = vocoder_chunk_infer(dec, self.vocoder)
|
| 961 |
-
elif self.image_key == "stft":
|
| 962 |
-
dec = dec.squeeze(1).permute(0, 2, 1)
|
| 963 |
-
wav_reconstruction = self.wave_decoder(dec)
|
| 964 |
-
return wav_reconstruction
|
| 965 |
-
|
| 966 |
-
def mel_spectrogram_to_waveform(
|
| 967 |
-
self, mel, savepath=".", bs=None, name="outwav", save=True
|
| 968 |
-
):
|
| 969 |
-
# Mel: [bs, 1, t-steps, fbins]
|
| 970 |
-
if len(mel.size()) == 4:
|
| 971 |
-
mel = mel.squeeze(1)
|
| 972 |
-
mel = mel.permute(0, 2, 1)
|
| 973 |
-
waveform = self.vocoder(mel)
|
| 974 |
-
waveform = waveform.cpu().detach().numpy()
|
| 975 |
-
#if save:
|
| 976 |
-
# self.save_waveform(waveform, savepath, name)
|
| 977 |
-
return waveform
|
| 978 |
-
|
| 979 |
-
@torch.no_grad()
|
| 980 |
-
def encode_first_stage(self, x):
|
| 981 |
-
return self.encode(x)
|
| 982 |
-
|
| 983 |
-
@torch.no_grad()
|
| 984 |
-
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
|
| 985 |
-
if predict_cids:
|
| 986 |
-
if z.dim() == 4:
|
| 987 |
-
z = torch.argmax(z.exp(), dim=1).long()
|
| 988 |
-
z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
|
| 989 |
-
z = rearrange(z, "b h w c -> b c h w").contiguous()
|
| 990 |
-
|
| 991 |
-
z = 1.0 / self.scale_factor * z
|
| 992 |
-
return self.decode(z)
|
| 993 |
-
|
| 994 |
-
def decode_first_stage_withgrad(self, z):
|
| 995 |
-
z = 1.0 / self.scale_factor * z
|
| 996 |
-
return self.decode(z)
|
| 997 |
-
|
| 998 |
-
def get_first_stage_encoding(self, encoder_posterior, use_mode=False):
|
| 999 |
-
if isinstance(encoder_posterior, DiagonalGaussianDistribution) and not use_mode:
|
| 1000 |
-
z = encoder_posterior.sample()
|
| 1001 |
-
elif isinstance(encoder_posterior, DiagonalGaussianDistribution) and use_mode:
|
| 1002 |
-
z = encoder_posterior.mode()
|
| 1003 |
-
elif isinstance(encoder_posterior, torch.Tensor):
|
| 1004 |
-
z = encoder_posterior
|
| 1005 |
-
else:
|
| 1006 |
-
raise NotImplementedError(
|
| 1007 |
-
f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
|
| 1008 |
-
)
|
| 1009 |
-
return self.scale_factor * z
|
| 1010 |
-
|
| 1011 |
-
def visualize_latent(self, input):
|
| 1012 |
-
import matplotlib.pyplot as plt
|
| 1013 |
-
|
| 1014 |
-
# for i in range(10):
|
| 1015 |
-
# zero_input = torch.zeros_like(input) - 11.59
|
| 1016 |
-
# zero_input[:,:,i * 16: i * 16 + 16,:16] += 13.59
|
| 1017 |
-
|
| 1018 |
-
# posterior = self.encode(zero_input)
|
| 1019 |
-
# latent = posterior.sample()
|
| 1020 |
-
# avg_latent = torch.mean(latent, dim=1)[0]
|
| 1021 |
-
# plt.imshow(avg_latent.cpu().detach().numpy().T)
|
| 1022 |
-
# plt.savefig("%s.png" % i)
|
| 1023 |
-
# plt.close()
|
| 1024 |
-
|
| 1025 |
-
np.save("input.npy", input.cpu().detach().numpy())
|
| 1026 |
-
# zero_input = torch.zeros_like(input) - 11.59
|
| 1027 |
-
time_input = input.clone()
|
| 1028 |
-
time_input[:, :, :, :32] *= 0
|
| 1029 |
-
time_input[:, :, :, :32] -= 11.59
|
| 1030 |
-
|
| 1031 |
-
np.save("time_input.npy", time_input.cpu().detach().numpy())
|
| 1032 |
-
|
| 1033 |
-
posterior = self.encode(time_input)
|
| 1034 |
-
latent = posterior.sample()
|
| 1035 |
-
np.save("time_latent.npy", latent.cpu().detach().numpy())
|
| 1036 |
-
avg_latent = torch.mean(latent, dim=1)
|
| 1037 |
-
for i in range(avg_latent.size(0)):
|
| 1038 |
-
plt.imshow(avg_latent[i].cpu().detach().numpy().T)
|
| 1039 |
-
plt.savefig("freq_%s.png" % i)
|
| 1040 |
-
plt.close()
|
| 1041 |
-
|
| 1042 |
-
freq_input = input.clone()
|
| 1043 |
-
freq_input[:, :, :512, :] *= 0
|
| 1044 |
-
freq_input[:, :, :512, :] -= 11.59
|
| 1045 |
-
|
| 1046 |
-
np.save("freq_input.npy", freq_input.cpu().detach().numpy())
|
| 1047 |
-
|
| 1048 |
-
posterior = self.encode(freq_input)
|
| 1049 |
-
latent = posterior.sample()
|
| 1050 |
-
np.save("freq_latent.npy", latent.cpu().detach().numpy())
|
| 1051 |
-
avg_latent = torch.mean(latent, dim=1)
|
| 1052 |
-
for i in range(avg_latent.size(0)):
|
| 1053 |
-
plt.imshow(avg_latent[i].cpu().detach().numpy().T)
|
| 1054 |
-
plt.savefig("time_%s.png" % i)
|
| 1055 |
-
plt.close()
|
| 1056 |
-
|
| 1057 |
-
def get_input(self, batch):
|
| 1058 |
-
fname, text, label_indices, waveform, stft, fbank = (
|
| 1059 |
-
batch["fname"],
|
| 1060 |
-
batch["text"],
|
| 1061 |
-
batch["label_vector"],
|
| 1062 |
-
batch["waveform"],
|
| 1063 |
-
batch["stft"],
|
| 1064 |
-
batch["log_mel_spec"],
|
| 1065 |
-
)
|
| 1066 |
-
# if(self.time_shuffle != 1):
|
| 1067 |
-
# if(fbank.size(1) % self.time_shuffle != 0):
|
| 1068 |
-
# pad_len = self.time_shuffle - (fbank.size(1) % self.time_shuffle)
|
| 1069 |
-
# fbank = torch.nn.functional.pad(fbank, (0,0,0,pad_len))
|
| 1070 |
-
|
| 1071 |
-
ret = {}
|
| 1072 |
-
|
| 1073 |
-
ret["fbank"], ret["stft"], ret["fname"], ret["waveform"] = (
|
| 1074 |
-
fbank.unsqueeze(1),
|
| 1075 |
-
stft.unsqueeze(1),
|
| 1076 |
-
fname,
|
| 1077 |
-
waveform.unsqueeze(1),
|
| 1078 |
-
)
|
| 1079 |
-
|
| 1080 |
-
return ret
|
| 1081 |
-
|
| 1082 |
-
def save_wave(self, batch_wav, fname, save_dir):
|
| 1083 |
-
os.makedirs(save_dir, exist_ok=True)
|
| 1084 |
-
|
| 1085 |
-
for wav, name in zip(batch_wav, fname):
|
| 1086 |
-
name = os.path.basename(name)
|
| 1087 |
-
|
| 1088 |
-
sf.write(os.path.join(save_dir, name), wav, samplerate=self.sampling_rate)
|
| 1089 |
-
|
| 1090 |
-
def get_last_layer(self):
|
| 1091 |
-
return self.decoder.conv_out.weight
|
| 1092 |
-
|
| 1093 |
-
@torch.no_grad()
|
| 1094 |
-
def log_images(self, batch, train=True, only_inputs=False, waveform=None, **kwargs):
|
| 1095 |
-
log = dict()
|
| 1096 |
-
x = batch.to(self.device)
|
| 1097 |
-
if not only_inputs:
|
| 1098 |
-
xrec, posterior = self(x)
|
| 1099 |
-
log["samples"] = self.decode(posterior.sample())
|
| 1100 |
-
log["reconstructions"] = xrec
|
| 1101 |
-
|
| 1102 |
-
log["inputs"] = x
|
| 1103 |
-
wavs = self._log_img(log, train=train, index=0, waveform=waveform)
|
| 1104 |
-
return wavs
|
| 1105 |
-
|
| 1106 |
-
def _log_img(self, log, train=True, index=0, waveform=None):
|
| 1107 |
-
images_input = self.tensor2numpy(log["inputs"][index, 0]).T
|
| 1108 |
-
images_reconstruct = self.tensor2numpy(log["reconstructions"][index, 0]).T
|
| 1109 |
-
images_samples = self.tensor2numpy(log["samples"][index, 0]).T
|
| 1110 |
-
|
| 1111 |
-
if train:
|
| 1112 |
-
name = "train"
|
| 1113 |
-
else:
|
| 1114 |
-
name = "val"
|
| 1115 |
-
|
| 1116 |
-
if self.logger is not None:
|
| 1117 |
-
self.logger.log_image(
|
| 1118 |
-
"img_%s" % name,
|
| 1119 |
-
[images_input, images_reconstruct, images_samples],
|
| 1120 |
-
caption=["input", "reconstruct", "samples"],
|
| 1121 |
-
)
|
| 1122 |
-
|
| 1123 |
-
inputs, reconstructions, samples = (
|
| 1124 |
-
log["inputs"],
|
| 1125 |
-
log["reconstructions"],
|
| 1126 |
-
log["samples"],
|
| 1127 |
-
)
|
| 1128 |
-
|
| 1129 |
-
if self.image_key == "fbank":
|
| 1130 |
-
wav_original, wav_prediction = synth_one_sample(
|
| 1131 |
-
inputs[index],
|
| 1132 |
-
reconstructions[index],
|
| 1133 |
-
labels="validation",
|
| 1134 |
-
vocoder=self.vocoder,
|
| 1135 |
-
)
|
| 1136 |
-
wav_original, wav_samples = synth_one_sample(
|
| 1137 |
-
inputs[index], samples[index], labels="validation", vocoder=self.vocoder
|
| 1138 |
-
)
|
| 1139 |
-
wav_original, wav_samples, wav_prediction = (
|
| 1140 |
-
wav_original[0],
|
| 1141 |
-
wav_samples[0],
|
| 1142 |
-
wav_prediction[0],
|
| 1143 |
-
)
|
| 1144 |
-
elif self.image_key == "stft":
|
| 1145 |
-
wav_prediction = (
|
| 1146 |
-
self.decode_to_waveform(reconstructions)[index, 0]
|
| 1147 |
-
.cpu()
|
| 1148 |
-
.detach()
|
| 1149 |
-
.numpy()
|
| 1150 |
-
)
|
| 1151 |
-
wav_samples = (
|
| 1152 |
-
self.decode_to_waveform(samples)[index, 0].cpu().detach().numpy()
|
| 1153 |
-
)
|
| 1154 |
-
wav_original = waveform[index, 0].cpu().detach().numpy()
|
| 1155 |
-
|
| 1156 |
-
if self.logger is not None:
|
| 1157 |
-
self.logger.experiment.log(
|
| 1158 |
-
{
|
| 1159 |
-
"original_%s"
|
| 1160 |
-
% name: wandb.Audio(
|
| 1161 |
-
wav_original, caption="original", sample_rate=self.sampling_rate
|
| 1162 |
-
),
|
| 1163 |
-
"reconstruct_%s"
|
| 1164 |
-
% name: wandb.Audio(
|
| 1165 |
-
wav_prediction,
|
| 1166 |
-
caption="reconstruct",
|
| 1167 |
-
sample_rate=self.sampling_rate,
|
| 1168 |
-
),
|
| 1169 |
-
"samples_%s"
|
| 1170 |
-
% name: wandb.Audio(
|
| 1171 |
-
wav_samples, caption="samples", sample_rate=self.sampling_rate
|
| 1172 |
-
),
|
| 1173 |
-
}
|
| 1174 |
-
)
|
| 1175 |
-
|
| 1176 |
-
return wav_original, wav_prediction, wav_samples
|
| 1177 |
-
|
| 1178 |
-
def tensor2numpy(self, tensor):
|
| 1179 |
-
return tensor.cpu().detach().numpy()
|
| 1180 |
-
|
| 1181 |
-
def to_rgb(self, x):
|
| 1182 |
-
assert self.image_key == "segmentation"
|
| 1183 |
-
if not hasattr(self, "colorize"):
|
| 1184 |
-
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
| 1185 |
-
x = torch.nn.functional.conv2d(x, weight=self.colorize)
|
| 1186 |
-
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
|
| 1187 |
-
return x
|
| 1188 |
-
|
| 1189 |
-
|
| 1190 |
-
class IdentityFirstStage(torch.nn.Module):
|
| 1191 |
-
def __init__(self, *args, vq_interface=False, **kwargs):
|
| 1192 |
-
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
|
| 1193 |
-
super().__init__()
|
| 1194 |
-
|
| 1195 |
-
def encode(self, x, *args, **kwargs):
|
| 1196 |
-
return x
|
| 1197 |
-
|
| 1198 |
-
def decode(self, x, *args, **kwargs):
|
| 1199 |
-
return x
|
| 1200 |
-
|
| 1201 |
-
def quantize(self, x, *args, **kwargs):
|
| 1202 |
-
if self.vq_interface:
|
| 1203 |
-
return x, None, [None, None, None]
|
| 1204 |
-
return x
|
| 1205 |
-
|
| 1206 |
-
def forward(self, x, *args, **kwargs):
|
| 1207 |
-
return x
|
| 1208 |
-
|
| 1209 |
-
|
| 1210 |
-
def window_sumsquare(
|
| 1211 |
-
window,
|
| 1212 |
-
n_frames,
|
| 1213 |
-
hop_length,
|
| 1214 |
-
win_length,
|
| 1215 |
-
n_fft,
|
| 1216 |
-
dtype=np.float32,
|
| 1217 |
-
norm=None,
|
| 1218 |
-
):
|
| 1219 |
-
"""
|
| 1220 |
-
# from librosa 0.6
|
| 1221 |
-
Compute the sum-square envelope of a window function at a given hop length.
|
| 1222 |
-
|
| 1223 |
-
This is used to estimate modulation effects induced by windowing
|
| 1224 |
-
observations in short-time fourier transforms.
|
| 1225 |
-
|
| 1226 |
-
Parameters
|
| 1227 |
-
----------
|
| 1228 |
-
window : string, tuple, number, callable, or list-like
|
| 1229 |
-
Window specification, as in `get_window`
|
| 1230 |
-
|
| 1231 |
-
n_frames : int > 0
|
| 1232 |
-
The number of analysis frames
|
| 1233 |
-
|
| 1234 |
-
hop_length : int > 0
|
| 1235 |
-
The number of samples to advance between frames
|
| 1236 |
-
|
| 1237 |
-
win_length : [optional]
|
| 1238 |
-
The length of the window function. By default, this matches `n_fft`.
|
| 1239 |
-
|
| 1240 |
-
n_fft : int > 0
|
| 1241 |
-
The length of each analysis frame.
|
| 1242 |
-
|
| 1243 |
-
dtype : np.dtype
|
| 1244 |
-
The data type of the output
|
| 1245 |
-
|
| 1246 |
-
Returns
|
| 1247 |
-
-------
|
| 1248 |
-
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
|
| 1249 |
-
The sum-squared envelope of the window function
|
| 1250 |
-
"""
|
| 1251 |
-
if win_length is None:
|
| 1252 |
-
win_length = n_fft
|
| 1253 |
-
|
| 1254 |
-
n = n_fft + hop_length * (n_frames - 1)
|
| 1255 |
-
x = np.zeros(n, dtype=dtype)
|
| 1256 |
-
|
| 1257 |
-
# Compute the squared window at the desired length
|
| 1258 |
-
win_sq = get_window(window, win_length, fftbins=True)
|
| 1259 |
-
win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
|
| 1260 |
-
win_sq = librosa_util.pad_center(win_sq, n_fft)
|
| 1261 |
-
|
| 1262 |
-
# Fill the envelope
|
| 1263 |
-
for i in range(n_frames):
|
| 1264 |
-
sample = i * hop_length
|
| 1265 |
-
x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
|
| 1266 |
-
return x
|
| 1267 |
-
|
| 1268 |
-
def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5):
|
| 1269 |
-
"""
|
| 1270 |
-
PARAMS
|
| 1271 |
-
------
|
| 1272 |
-
C: compression factor
|
| 1273 |
-
"""
|
| 1274 |
-
return normalize_fun(torch.clamp(x, min=clip_val) * C)
|
| 1275 |
-
|
| 1276 |
-
|
| 1277 |
-
def dynamic_range_decompression(x, C=1):
|
| 1278 |
-
"""
|
| 1279 |
-
PARAMS
|
| 1280 |
-
------
|
| 1281 |
-
C: compression factor used to compress
|
| 1282 |
-
"""
|
| 1283 |
-
return torch.exp(x) / C
|
| 1284 |
-
|
| 1285 |
-
|
| 1286 |
-
class STFT(torch.nn.Module):
|
| 1287 |
-
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
|
| 1288 |
-
|
| 1289 |
-
def __init__(self, filter_length, hop_length, win_length, window="hann"):
|
| 1290 |
-
super(STFT, self).__init__()
|
| 1291 |
-
self.filter_length = filter_length
|
| 1292 |
-
self.hop_length = hop_length
|
| 1293 |
-
self.win_length = win_length
|
| 1294 |
-
self.window = window
|
| 1295 |
-
self.forward_transform = None
|
| 1296 |
-
scale = self.filter_length / self.hop_length
|
| 1297 |
-
fourier_basis = np.fft.fft(np.eye(self.filter_length))
|
| 1298 |
-
|
| 1299 |
-
cutoff = int((self.filter_length / 2 + 1))
|
| 1300 |
-
fourier_basis = np.vstack(
|
| 1301 |
-
[np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
|
| 1302 |
-
)
|
| 1303 |
-
|
| 1304 |
-
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
|
| 1305 |
-
inverse_basis = torch.FloatTensor(
|
| 1306 |
-
np.linalg.pinv(scale * fourier_basis).T[:, None, :]
|
| 1307 |
-
)
|
| 1308 |
-
|
| 1309 |
-
if window is not None:
|
| 1310 |
-
assert filter_length >= win_length
|
| 1311 |
-
# get window and zero center pad it to filter_length
|
| 1312 |
-
fft_window = get_window(window, win_length, fftbins=True)
|
| 1313 |
-
fft_window = pad_center(fft_window, size=filter_length)
|
| 1314 |
-
fft_window = torch.from_numpy(fft_window).float()
|
| 1315 |
-
|
| 1316 |
-
# window the bases
|
| 1317 |
-
forward_basis *= fft_window
|
| 1318 |
-
inverse_basis *= fft_window
|
| 1319 |
-
|
| 1320 |
-
self.register_buffer("forward_basis", forward_basis.float())
|
| 1321 |
-
self.register_buffer("inverse_basis", inverse_basis.float())
|
| 1322 |
-
|
| 1323 |
-
def transform(self, input_data):
|
| 1324 |
-
|
| 1325 |
-
device = self.forward_basis.device
|
| 1326 |
-
input_data = input_data.to(device)
|
| 1327 |
-
|
| 1328 |
-
num_batches = input_data.size(0)
|
| 1329 |
-
num_samples = input_data.size(1)
|
| 1330 |
-
|
| 1331 |
-
self.num_samples = num_samples
|
| 1332 |
-
|
| 1333 |
-
# similar to librosa, reflect-pad the input
|
| 1334 |
-
input_data = input_data.view(num_batches, 1, num_samples)
|
| 1335 |
-
input_data = torch.nn.functional.pad(
|
| 1336 |
-
input_data.unsqueeze(1),
|
| 1337 |
-
(int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
|
| 1338 |
-
mode="reflect",
|
| 1339 |
-
)
|
| 1340 |
-
input_data = input_data.squeeze(1)
|
| 1341 |
-
|
| 1342 |
-
forward_transform = torch.nn.functional.conv1d(
|
| 1343 |
-
input_data,
|
| 1344 |
-
torch.autograd.Variable(self.forward_basis, requires_grad=False),
|
| 1345 |
-
stride=self.hop_length,
|
| 1346 |
-
padding=0,
|
| 1347 |
-
)#.cpu()
|
| 1348 |
-
|
| 1349 |
-
cutoff = int((self.filter_length / 2) + 1)
|
| 1350 |
-
real_part = forward_transform[:, :cutoff, :]
|
| 1351 |
-
imag_part = forward_transform[:, cutoff:, :]
|
| 1352 |
-
|
| 1353 |
-
magnitude = torch.sqrt(real_part**2 + imag_part**2)
|
| 1354 |
-
phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
|
| 1355 |
-
|
| 1356 |
-
return magnitude, phase
|
| 1357 |
-
|
| 1358 |
-
def inverse(self, magnitude, phase):
|
| 1359 |
-
|
| 1360 |
-
device = self.forward_basis.device
|
| 1361 |
-
magnitude, phase = magnitude.to(device), phase.to(device)
|
| 1362 |
-
|
| 1363 |
-
recombine_magnitude_phase = torch.cat(
|
| 1364 |
-
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
|
| 1365 |
-
)
|
| 1366 |
-
|
| 1367 |
-
inverse_transform = torch.nn.functional.conv_transpose1d(
|
| 1368 |
-
recombine_magnitude_phase,
|
| 1369 |
-
torch.autograd.Variable(self.inverse_basis, requires_grad=False),
|
| 1370 |
-
stride=self.hop_length,
|
| 1371 |
-
padding=0,
|
| 1372 |
-
)
|
| 1373 |
-
|
| 1374 |
-
if self.window is not None:
|
| 1375 |
-
window_sum = window_sumsquare(
|
| 1376 |
-
self.window,
|
| 1377 |
-
magnitude.size(-1),
|
| 1378 |
-
hop_length=self.hop_length,
|
| 1379 |
-
win_length=self.win_length,
|
| 1380 |
-
n_fft=self.filter_length,
|
| 1381 |
-
dtype=np.float32,
|
| 1382 |
-
)
|
| 1383 |
-
# remove modulation effects
|
| 1384 |
-
approx_nonzero_indices = torch.from_numpy(
|
| 1385 |
-
np.where(window_sum > tiny(window_sum))[0]
|
| 1386 |
-
)
|
| 1387 |
-
window_sum = torch.autograd.Variable(
|
| 1388 |
-
torch.from_numpy(window_sum), requires_grad=False
|
| 1389 |
-
)
|
| 1390 |
-
window_sum = window_sum
|
| 1391 |
-
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
|
| 1392 |
-
approx_nonzero_indices
|
| 1393 |
-
]
|
| 1394 |
-
|
| 1395 |
-
# scale by hop ratio
|
| 1396 |
-
inverse_transform *= float(self.filter_length) / self.hop_length
|
| 1397 |
-
|
| 1398 |
-
inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
|
| 1399 |
-
inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
|
| 1400 |
-
|
| 1401 |
-
return inverse_transform
|
| 1402 |
-
|
| 1403 |
-
def forward(self, input_data):
|
| 1404 |
-
self.magnitude, self.phase = self.transform(input_data)
|
| 1405 |
-
reconstruction = self.inverse(self.magnitude, self.phase)
|
| 1406 |
-
return reconstruction
|
| 1407 |
-
|
| 1408 |
-
|
| 1409 |
-
class TacotronSTFT(torch.nn.Module):
|
| 1410 |
-
def __init__(
|
| 1411 |
-
self,
|
| 1412 |
-
filter_length,
|
| 1413 |
-
hop_length,
|
| 1414 |
-
win_length,
|
| 1415 |
-
n_mel_channels,
|
| 1416 |
-
sampling_rate,
|
| 1417 |
-
mel_fmin,
|
| 1418 |
-
mel_fmax,
|
| 1419 |
-
):
|
| 1420 |
-
super(TacotronSTFT, self).__init__()
|
| 1421 |
-
self.n_mel_channels = n_mel_channels
|
| 1422 |
-
self.sampling_rate = sampling_rate
|
| 1423 |
-
self.stft_fn = STFT(filter_length, hop_length, win_length)
|
| 1424 |
-
mel_basis = librosa_mel_fn(
|
| 1425 |
-
sr = sampling_rate, n_fft = filter_length, n_mels = n_mel_channels, fmin = mel_fmin, fmax = mel_fmax
|
| 1426 |
-
)
|
| 1427 |
-
mel_basis = torch.from_numpy(mel_basis).float()
|
| 1428 |
-
self.register_buffer("mel_basis", mel_basis)
|
| 1429 |
-
|
| 1430 |
-
def spectral_normalize(self, magnitudes, normalize_fun):
|
| 1431 |
-
output = dynamic_range_compression(magnitudes, normalize_fun)
|
| 1432 |
-
return output
|
| 1433 |
-
|
| 1434 |
-
def spectral_de_normalize(self, magnitudes):
|
| 1435 |
-
output = dynamic_range_decompression(magnitudes)
|
| 1436 |
-
return output
|
| 1437 |
-
|
| 1438 |
-
def mel_spectrogram(self, y, normalize_fun=torch.log):
|
| 1439 |
-
"""Computes mel-spectrograms from a batch of waves
|
| 1440 |
-
PARAMS
|
| 1441 |
-
------
|
| 1442 |
-
y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
|
| 1443 |
-
|
| 1444 |
-
RETURNS
|
| 1445 |
-
-------
|
| 1446 |
-
mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
|
| 1447 |
-
"""
|
| 1448 |
-
assert torch.min(y.data) >= -1, torch.min(y.data)
|
| 1449 |
-
assert torch.max(y.data) <= 1, torch.max(y.data)
|
| 1450 |
-
|
| 1451 |
-
magnitudes, phases = self.stft_fn.transform(y)
|
| 1452 |
-
magnitudes = magnitudes.data
|
| 1453 |
-
mel_output = torch.matmul(self.mel_basis, magnitudes)
|
| 1454 |
-
mel_output = self.spectral_normalize(mel_output, normalize_fun)
|
| 1455 |
-
energy = torch.norm(magnitudes, dim=1)
|
| 1456 |
-
|
| 1457 |
-
log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun)
|
| 1458 |
-
|
| 1459 |
-
return mel_output, log_magnitudes, energy
|
| 1460 |
-
|
| 1461 |
-
|
| 1462 |
-
def build_pretrained_models(ckpt):
|
| 1463 |
-
checkpoint = torch.load(ckpt, map_location="cpu")
|
| 1464 |
-
scale_factor = checkpoint["state_dict"]["scale_factor"].item()
|
| 1465 |
-
print("scale_factor: ", scale_factor)
|
| 1466 |
-
|
| 1467 |
-
vae_state_dict = {k[18:]: v for k, v in checkpoint["state_dict"].items() if "first_stage_model." in k}
|
| 1468 |
-
|
| 1469 |
-
config = {
|
| 1470 |
-
"preprocessing": {
|
| 1471 |
-
"audio": {
|
| 1472 |
-
"sampling_rate": 48000,
|
| 1473 |
-
"max_wav_value": 32768,
|
| 1474 |
-
"duration": 10.24
|
| 1475 |
-
},
|
| 1476 |
-
"stft": {
|
| 1477 |
-
"filter_length": 2048,
|
| 1478 |
-
"hop_length": 480,
|
| 1479 |
-
"win_length": 2048
|
| 1480 |
-
},
|
| 1481 |
-
"mel": {
|
| 1482 |
-
"n_mel_channels": 256,
|
| 1483 |
-
"mel_fmin": 20,
|
| 1484 |
-
"mel_fmax": 24000
|
| 1485 |
-
}
|
| 1486 |
-
},
|
| 1487 |
-
"model": {
|
| 1488 |
-
"params": {
|
| 1489 |
-
"first_stage_config": {
|
| 1490 |
-
"params": {
|
| 1491 |
-
"sampling_rate": 48000,
|
| 1492 |
-
"batchsize": 4,
|
| 1493 |
-
"monitor": "val/rec_loss",
|
| 1494 |
-
"image_key": "fbank",
|
| 1495 |
-
"subband": 1,
|
| 1496 |
-
"embed_dim": 16,
|
| 1497 |
-
"time_shuffle": 1,
|
| 1498 |
-
"lossconfig": {
|
| 1499 |
-
"target": "audioldm2.latent_diffusion.modules.losses.LPIPSWithDiscriminator",
|
| 1500 |
-
"params": {
|
| 1501 |
-
"disc_start": 50001,
|
| 1502 |
-
"kl_weight": 1000,
|
| 1503 |
-
"disc_weight": 0.5,
|
| 1504 |
-
"disc_in_channels": 1
|
| 1505 |
-
}
|
| 1506 |
-
},
|
| 1507 |
-
"ddconfig": {
|
| 1508 |
-
"double_z": True,
|
| 1509 |
-
"mel_bins": 256,
|
| 1510 |
-
"z_channels": 16,
|
| 1511 |
-
"resolution": 256,
|
| 1512 |
-
"downsample_time": False,
|
| 1513 |
-
"in_channels": 1,
|
| 1514 |
-
"out_ch": 1,
|
| 1515 |
-
"ch": 128,
|
| 1516 |
-
"ch_mult": [
|
| 1517 |
-
1,
|
| 1518 |
-
2,
|
| 1519 |
-
4,
|
| 1520 |
-
8
|
| 1521 |
-
],
|
| 1522 |
-
"num_res_blocks": 2,
|
| 1523 |
-
"attn_resolutions": [],
|
| 1524 |
-
"dropout": 0
|
| 1525 |
-
}
|
| 1526 |
-
}
|
| 1527 |
-
},
|
| 1528 |
-
}
|
| 1529 |
-
}
|
| 1530 |
-
}
|
| 1531 |
-
vae_config = config["model"]["params"]["first_stage_config"]["params"]
|
| 1532 |
-
vae_config["scale_factor"] = scale_factor
|
| 1533 |
-
|
| 1534 |
-
vae = AutoencoderKL(**vae_config)
|
| 1535 |
-
vae.load_state_dict(vae_state_dict)
|
| 1536 |
-
|
| 1537 |
-
fn_STFT = TacotronSTFT(
|
| 1538 |
-
config["preprocessing"]["stft"]["filter_length"],
|
| 1539 |
-
config["preprocessing"]["stft"]["hop_length"],
|
| 1540 |
-
config["preprocessing"]["stft"]["win_length"],
|
| 1541 |
-
config["preprocessing"]["mel"]["n_mel_channels"],
|
| 1542 |
-
config["preprocessing"]["audio"]["sampling_rate"],
|
| 1543 |
-
config["preprocessing"]["mel"]["mel_fmin"],
|
| 1544 |
-
config["preprocessing"]["mel"]["mel_fmax"],
|
| 1545 |
-
)
|
| 1546 |
-
|
| 1547 |
-
vae.eval()
|
| 1548 |
-
fn_STFT.eval()
|
| 1549 |
-
return vae, fn_STFT
|
| 1550 |
-
|
| 1551 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MuCodec/tools/torch_tools.py
DELETED
|
@@ -1,100 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torchaudio
|
| 3 |
-
import random
|
| 4 |
-
import itertools
|
| 5 |
-
import numpy as np
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
def normalize_wav(waveform):
|
| 10 |
-
waveform = waveform - torch.mean(waveform)
|
| 11 |
-
waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8)
|
| 12 |
-
return waveform * 0.5
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def pad_wav(waveform, segment_length):
|
| 16 |
-
waveform_length = len(waveform)
|
| 17 |
-
|
| 18 |
-
if segment_length is None or waveform_length == segment_length:
|
| 19 |
-
return waveform
|
| 20 |
-
elif waveform_length > segment_length:
|
| 21 |
-
return waveform[:segment_length]
|
| 22 |
-
else:
|
| 23 |
-
pad_wav = torch.zeros(segment_length - waveform_length).to(waveform.device)
|
| 24 |
-
waveform = torch.cat([waveform, pad_wav])
|
| 25 |
-
return waveform
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def _pad_spec(fbank, target_length=1024):
|
| 29 |
-
batch, n_frames, channels = fbank.shape
|
| 30 |
-
p = target_length - n_frames
|
| 31 |
-
if p > 0:
|
| 32 |
-
pad = torch.zeros(batch, p, channels).to(fbank.device)
|
| 33 |
-
fbank = torch.cat([fbank, pad], 1)
|
| 34 |
-
elif p < 0:
|
| 35 |
-
fbank = fbank[:, :target_length, :]
|
| 36 |
-
|
| 37 |
-
if channels % 2 != 0:
|
| 38 |
-
fbank = fbank[:, :, :-1]
|
| 39 |
-
|
| 40 |
-
return fbank
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
def read_wav_file(filename, segment_length):
|
| 44 |
-
waveform, sr = torchaudio.load(filename) # Faster!!!
|
| 45 |
-
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)[0]
|
| 46 |
-
try:
|
| 47 |
-
waveform = normalize_wav(waveform)
|
| 48 |
-
except:
|
| 49 |
-
print ("Exception normalizing:", filename)
|
| 50 |
-
waveform = torch.ones(160000)
|
| 51 |
-
waveform = pad_wav(waveform, segment_length).unsqueeze(0)
|
| 52 |
-
waveform = waveform / torch.max(torch.abs(waveform))
|
| 53 |
-
waveform = 0.5 * waveform
|
| 54 |
-
return waveform
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
def get_mel_from_wav(audio, _stft):
|
| 58 |
-
audio = torch.nan_to_num(torch.clip(audio, -1, 1))
|
| 59 |
-
audio = torch.autograd.Variable(audio, requires_grad=False)
|
| 60 |
-
melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio)
|
| 61 |
-
return melspec, log_magnitudes_stft, energy
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
def wav_to_fbank(paths, target_length=1024, fn_STFT=None):
|
| 65 |
-
assert fn_STFT is not None
|
| 66 |
-
|
| 67 |
-
waveform = torch.cat([read_wav_file(path, target_length * 160) for path in paths], 0) # hop size is 160
|
| 68 |
-
|
| 69 |
-
fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT)
|
| 70 |
-
fbank = fbank.transpose(1, 2)
|
| 71 |
-
log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2)
|
| 72 |
-
|
| 73 |
-
fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(
|
| 74 |
-
log_magnitudes_stft, target_length
|
| 75 |
-
)
|
| 76 |
-
|
| 77 |
-
return fbank, log_magnitudes_stft, waveform
|
| 78 |
-
|
| 79 |
-
def wav_to_fbank2(waveform, target_length=-1, fn_STFT=None):
|
| 80 |
-
assert fn_STFT is not None
|
| 81 |
-
|
| 82 |
-
fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT)
|
| 83 |
-
fbank = fbank.transpose(1, 2)
|
| 84 |
-
log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2)
|
| 85 |
-
# print(fbank.shape, log_magnitudes_stft.shape)
|
| 86 |
-
|
| 87 |
-
if(target_length>0):
|
| 88 |
-
fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(
|
| 89 |
-
log_magnitudes_stft, target_length
|
| 90 |
-
)
|
| 91 |
-
|
| 92 |
-
return fbank, log_magnitudes_stft, waveform
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
def uncapitalize(s):
|
| 96 |
-
if s:
|
| 97 |
-
return s[:1].lower() + s[1:]
|
| 98 |
-
else:
|
| 99 |
-
return ""
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|