mrfakename commited on
Commit
0929406
·
verified ·
1 Parent(s): fa77ee6

Delete MuCodec

Browse files
Files changed (37) hide show
  1. MuCodec/.DS_Store +0 -0
  2. MuCodec/.gitattributes +0 -2
  3. MuCodec/.gitignore +0 -3
  4. MuCodec/LICENSE +0 -21
  5. MuCodec/LICENSE_weights +0 -399
  6. MuCodec/configs/models/transformer2D.json +0 -25
  7. MuCodec/configs/scheduler/stable_diffusion_2.1_largenoise_sample.json +0 -14
  8. MuCodec/generate.py +0 -248
  9. MuCodec/libs/rvq/descript_quantize3.py +0 -298
  10. MuCodec/model.py +0 -367
  11. MuCodec/models/attention.py +0 -682
  12. MuCodec/models/transformer_2d_flow.py +0 -545
  13. MuCodec/muq_dev/muq_fairseq/data/__init__.py +0 -1
  14. MuCodec/muq_dev/muq_fairseq/data/ark_dataset.py +0 -71
  15. MuCodec/muq_dev/muq_fairseq/data/mert_dataset.py +0 -295
  16. MuCodec/muq_dev/muq_fairseq/data/utils/data_utils.py +0 -535
  17. MuCodec/muq_dev/muq_fairseq/models/muq/__init__.py +0 -1
  18. MuCodec/muq_dev/muq_fairseq/models/muq/model/__init__.py +0 -2
  19. MuCodec/muq_dev/muq_fairseq/models/muq/model/muq.py +0 -520
  20. MuCodec/muq_dev/muq_fairseq/models/muq/model/pred_ark_target_with_model.py +0 -151
  21. MuCodec/muq_dev/muq_fairseq/models/muq/model/rvq.py +0 -459
  22. MuCodec/muq_dev/muq_fairseq/models/muq/model/rvq_muq.py +0 -394
  23. MuCodec/muq_dev/muq_fairseq/models/muq/model/w2v2_config.json +0 -113
  24. MuCodec/muq_dev/muq_fairseq/models/muq/modules/__init__.py +0 -2
  25. MuCodec/muq_dev/muq_fairseq/models/muq/modules/conv.py +0 -77
  26. MuCodec/muq_dev/muq_fairseq/models/muq/modules/features.py +0 -67
  27. MuCodec/muq_dev/muq_fairseq/models/muq/modules/flash_conformer.py +0 -2114
  28. MuCodec/muq_dev/muq_fairseq/models/muq/modules/random_quantizer.py +0 -68
  29. MuCodec/muq_dev/muq_fairseq/models/muq/muq_model.py +0 -139
  30. MuCodec/muq_dev/muq_fairseq/tasks/muq_pretraining.py +0 -354
  31. MuCodec/muq_dev/test.py +0 -22
  32. MuCodec/readme.md +0 -67
  33. MuCodec/reconstructed/test.wav +0 -3
  34. MuCodec/requirements.txt +0 -335
  35. MuCodec/test_wav/test.wav +0 -3
  36. MuCodec/tools/get_melvaehifigan48k.py +0 -1551
  37. MuCodec/tools/torch_tools.py +0 -100
MuCodec/.DS_Store DELETED
Binary file (8.2 kB)
 
MuCodec/.gitattributes DELETED
@@ -1,2 +0,0 @@
1
- *.pt filter=lfs diff=lfs merge=lfs -text
2
- *.pth filter=lfs diff=lfs merge=lfs -text
 
 
 
MuCodec/.gitignore DELETED
@@ -1,3 +0,0 @@
1
- __pycache__
2
- *.pt
3
- *.pth
 
 
 
 
MuCodec/LICENSE DELETED
@@ -1,21 +0,0 @@
1
- MIT License
2
-
3
- Copyright (c) Meta Platforms, Inc. and affiliates.
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MuCodec/LICENSE_weights DELETED
@@ -1,399 +0,0 @@
1
- Attribution-NonCommercial 4.0 International
2
-
3
- =======================================================================
4
-
5
- Creative Commons Corporation ("Creative Commons") is not a law firm and
6
- does not provide legal services or legal advice. Distribution of
7
- Creative Commons public licenses does not create a lawyer-client or
8
- other relationship. Creative Commons makes its licenses and related
9
- information available on an "as-is" basis. Creative Commons gives no
10
- warranties regarding its licenses, any material licensed under their
11
- terms and conditions, or any related information. Creative Commons
12
- disclaims all liability for damages resulting from their use to the
13
- fullest extent possible.
14
-
15
- Using Creative Commons Public Licenses
16
-
17
- Creative Commons public licenses provide a standard set of terms and
18
- conditions that creators and other rights holders may use to share
19
- original works of authorship and other material subject to copyright
20
- and certain other rights specified in the public license below. The
21
- following considerations are for informational purposes only, are not
22
- exhaustive, and do not form part of our licenses.
23
-
24
- Considerations for licensors: Our public licenses are
25
- intended for use by those authorized to give the public
26
- permission to use material in ways otherwise restricted by
27
- copyright and certain other rights. Our licenses are
28
- irrevocable. Licensors should read and understand the terms
29
- and conditions of the license they choose before applying it.
30
- Licensors should also secure all rights necessary before
31
- applying our licenses so that the public can reuse the
32
- material as expected. Licensors should clearly mark any
33
- material not subject to the license. This includes other CC-
34
- licensed material, or material used under an exception or
35
- limitation to copyright. More considerations for licensors:
36
- wiki.creativecommons.org/Considerations_for_licensors
37
-
38
- Considerations for the public: By using one of our public
39
- licenses, a licensor grants the public permission to use the
40
- licensed material under specified terms and conditions. If
41
- the licensor's permission is not necessary for any reason--for
42
- example, because of any applicable exception or limitation to
43
- copyright--then that use is not regulated by the license. Our
44
- licenses grant only permissions under copyright and certain
45
- other rights that a licensor has authority to grant. Use of
46
- the licensed material may still be restricted for other
47
- reasons, including because others have copyright or other
48
- rights in the material. A licensor may make special requests,
49
- such as asking that all changes be marked or described.
50
- Although not required by our licenses, you are encouraged to
51
- respect those requests where reasonable. More_considerations
52
- for the public:
53
- wiki.creativecommons.org/Considerations_for_licensees
54
-
55
- =======================================================================
56
-
57
- Creative Commons Attribution-NonCommercial 4.0 International Public
58
- License
59
-
60
- By exercising the Licensed Rights (defined below), You accept and agree
61
- to be bound by the terms and conditions of this Creative Commons
62
- Attribution-NonCommercial 4.0 International Public License ("Public
63
- License"). To the extent this Public License may be interpreted as a
64
- contract, You are granted the Licensed Rights in consideration of Your
65
- acceptance of these terms and conditions, and the Licensor grants You
66
- such rights in consideration of benefits the Licensor receives from
67
- making the Licensed Material available under these terms and
68
- conditions.
69
-
70
- Section 1 -- Definitions.
71
-
72
- a. Adapted Material means material subject to Copyright and Similar
73
- Rights that is derived from or based upon the Licensed Material
74
- and in which the Licensed Material is translated, altered,
75
- arranged, transformed, or otherwise modified in a manner requiring
76
- permission under the Copyright and Similar Rights held by the
77
- Licensor. For purposes of this Public License, where the Licensed
78
- Material is a musical work, performance, or sound recording,
79
- Adapted Material is always produced where the Licensed Material is
80
- synched in timed relation with a moving image.
81
-
82
- b. Adapter's License means the license You apply to Your Copyright
83
- and Similar Rights in Your contributions to Adapted Material in
84
- accordance with the terms and conditions of this Public License.
85
-
86
- c. Copyright and Similar Rights means copyright and/or similar rights
87
- closely related to copyright including, without limitation,
88
- performance, broadcast, sound recording, and Sui Generis Database
89
- Rights, without regard to how the rights are labeled or
90
- categorized. For purposes of this Public License, the rights
91
- specified in Section 2(b)(1)-(2) are not Copyright and Similar
92
- Rights.
93
- d. Effective Technological Measures means those measures that, in the
94
- absence of proper authority, may not be circumvented under laws
95
- fulfilling obligations under Article 11 of the WIPO Copyright
96
- Treaty adopted on December 20, 1996, and/or similar international
97
- agreements.
98
-
99
- e. Exceptions and Limitations means fair use, fair dealing, and/or
100
- any other exception or limitation to Copyright and Similar Rights
101
- that applies to Your use of the Licensed Material.
102
-
103
- f. Licensed Material means the artistic or literary work, database,
104
- or other material to which the Licensor applied this Public
105
- License.
106
-
107
- g. Licensed Rights means the rights granted to You subject to the
108
- terms and conditions of this Public License, which are limited to
109
- all Copyright and Similar Rights that apply to Your use of the
110
- Licensed Material and that the Licensor has authority to license.
111
-
112
- h. Licensor means the individual(s) or entity(ies) granting rights
113
- under this Public License.
114
-
115
- i. NonCommercial means not primarily intended for or directed towards
116
- commercial advantage or monetary compensation. For purposes of
117
- this Public License, the exchange of the Licensed Material for
118
- other material subject to Copyright and Similar Rights by digital
119
- file-sharing or similar means is NonCommercial provided there is
120
- no payment of monetary compensation in connection with the
121
- exchange.
122
-
123
- j. Share means to provide material to the public by any means or
124
- process that requires permission under the Licensed Rights, such
125
- as reproduction, public display, public performance, distribution,
126
- dissemination, communication, or importation, and to make material
127
- available to the public including in ways that members of the
128
- public may access the material from a place and at a time
129
- individually chosen by them.
130
-
131
- k. Sui Generis Database Rights means rights other than copyright
132
- resulting from Directive 96/9/EC of the European Parliament and of
133
- the Council of 11 March 1996 on the legal protection of databases,
134
- as amended and/or succeeded, as well as other essentially
135
- equivalent rights anywhere in the world.
136
-
137
- l. You means the individual or entity exercising the Licensed Rights
138
- under this Public License. Your has a corresponding meaning.
139
-
140
- Section 2 -- Scope.
141
-
142
- a. License grant.
143
-
144
- 1. Subject to the terms and conditions of this Public License,
145
- the Licensor hereby grants You a worldwide, royalty-free,
146
- non-sublicensable, non-exclusive, irrevocable license to
147
- exercise the Licensed Rights in the Licensed Material to:
148
-
149
- a. reproduce and Share the Licensed Material, in whole or
150
- in part, for NonCommercial purposes only; and
151
-
152
- b. produce, reproduce, and Share Adapted Material for
153
- NonCommercial purposes only.
154
-
155
- 2. Exceptions and Limitations. For the avoidance of doubt, where
156
- Exceptions and Limitations apply to Your use, this Public
157
- License does not apply, and You do not need to comply with
158
- its terms and conditions.
159
-
160
- 3. Term. The term of this Public License is specified in Section
161
- 6(a).
162
-
163
- 4. Media and formats; technical modifications allowed. The
164
- Licensor authorizes You to exercise the Licensed Rights in
165
- all media and formats whether now known or hereafter created,
166
- and to make technical modifications necessary to do so. The
167
- Licensor waives and/or agrees not to assert any right or
168
- authority to forbid You from making technical modifications
169
- necessary to exercise the Licensed Rights, including
170
- technical modifications necessary to circumvent Effective
171
- Technological Measures. For purposes of this Public License,
172
- simply making modifications authorized by this Section 2(a)
173
- (4) never produces Adapted Material.
174
-
175
- 5. Downstream recipients.
176
-
177
- a. Offer from the Licensor -- Licensed Material. Every
178
- recipient of the Licensed Material automatically
179
- receives an offer from the Licensor to exercise the
180
- Licensed Rights under the terms and conditions of this
181
- Public License.
182
-
183
- b. No downstream restrictions. You may not offer or impose
184
- any additional or different terms or conditions on, or
185
- apply any Effective Technological Measures to, the
186
- Licensed Material if doing so restricts exercise of the
187
- Licensed Rights by any recipient of the Licensed
188
- Material.
189
-
190
- 6. No endorsement. Nothing in this Public License constitutes or
191
- may be construed as permission to assert or imply that You
192
- are, or that Your use of the Licensed Material is, connected
193
- with, or sponsored, endorsed, or granted official status by,
194
- the Licensor or others designated to receive attribution as
195
- provided in Section 3(a)(1)(A)(i).
196
-
197
- b. Other rights.
198
-
199
- 1. Moral rights, such as the right of integrity, are not
200
- licensed under this Public License, nor are publicity,
201
- privacy, and/or other similar personality rights; however, to
202
- the extent possible, the Licensor waives and/or agrees not to
203
- assert any such rights held by the Licensor to the limited
204
- extent necessary to allow You to exercise the Licensed
205
- Rights, but not otherwise.
206
-
207
- 2. Patent and trademark rights are not licensed under this
208
- Public License.
209
-
210
- 3. To the extent possible, the Licensor waives any right to
211
- collect royalties from You for the exercise of the Licensed
212
- Rights, whether directly or through a collecting society
213
- under any voluntary or waivable statutory or compulsory
214
- licensing scheme. In all other cases the Licensor expressly
215
- reserves any right to collect such royalties, including when
216
- the Licensed Material is used other than for NonCommercial
217
- purposes.
218
-
219
- Section 3 -- License Conditions.
220
-
221
- Your exercise of the Licensed Rights is expressly made subject to the
222
- following conditions.
223
-
224
- a. Attribution.
225
-
226
- 1. If You Share the Licensed Material (including in modified
227
- form), You must:
228
-
229
- a. retain the following if it is supplied by the Licensor
230
- with the Licensed Material:
231
-
232
- i. identification of the creator(s) of the Licensed
233
- Material and any others designated to receive
234
- attribution, in any reasonable manner requested by
235
- the Licensor (including by pseudonym if
236
- designated);
237
-
238
- ii. a copyright notice;
239
-
240
- iii. a notice that refers to this Public License;
241
-
242
- iv. a notice that refers to the disclaimer of
243
- warranties;
244
-
245
- v. a URI or hyperlink to the Licensed Material to the
246
- extent reasonably practicable;
247
-
248
- b. indicate if You modified the Licensed Material and
249
- retain an indication of any previous modifications; and
250
-
251
- c. indicate the Licensed Material is licensed under this
252
- Public License, and include the text of, or the URI or
253
- hyperlink to, this Public License.
254
-
255
- 2. You may satisfy the conditions in Section 3(a)(1) in any
256
- reasonable manner based on the medium, means, and context in
257
- which You Share the Licensed Material. For example, it may be
258
- reasonable to satisfy the conditions by providing a URI or
259
- hyperlink to a resource that includes the required
260
- information.
261
-
262
- 3. If requested by the Licensor, You must remove any of the
263
- information required by Section 3(a)(1)(A) to the extent
264
- reasonably practicable.
265
-
266
- 4. If You Share Adapted Material You produce, the Adapter's
267
- License You apply must not prevent recipients of the Adapted
268
- Material from complying with this Public License.
269
-
270
- Section 4 -- Sui Generis Database Rights.
271
-
272
- Where the Licensed Rights include Sui Generis Database Rights that
273
- apply to Your use of the Licensed Material:
274
-
275
- a. for the avoidance of doubt, Section 2(a)(1) grants You the right
276
- to extract, reuse, reproduce, and Share all or a substantial
277
- portion of the contents of the database for NonCommercial purposes
278
- only;
279
-
280
- b. if You include all or a substantial portion of the database
281
- contents in a database in which You have Sui Generis Database
282
- Rights, then the database in which You have Sui Generis Database
283
- Rights (but not its individual contents) is Adapted Material; and
284
-
285
- c. You must comply with the conditions in Section 3(a) if You Share
286
- all or a substantial portion of the contents of the database.
287
-
288
- For the avoidance of doubt, this Section 4 supplements and does not
289
- replace Your obligations under this Public License where the Licensed
290
- Rights include other Copyright and Similar Rights.
291
-
292
- Section 5 -- Disclaimer of Warranties and Limitation of Liability.
293
-
294
- a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
295
- EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
296
- AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
297
- ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
298
- IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
299
- WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
300
- PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
301
- ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
302
- KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
303
- ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
304
-
305
- b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
306
- TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
307
- NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
308
- INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
309
- COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
310
- USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
311
- ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
312
- DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
313
- IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
314
-
315
- c. The disclaimer of warranties and limitation of liability provided
316
- above shall be interpreted in a manner that, to the extent
317
- possible, most closely approximates an absolute disclaimer and
318
- waiver of all liability.
319
-
320
- Section 6 -- Term and Termination.
321
-
322
- a. This Public License applies for the term of the Copyright and
323
- Similar Rights licensed here. However, if You fail to comply with
324
- this Public License, then Your rights under this Public License
325
- terminate automatically.
326
-
327
- b. Where Your right to use the Licensed Material has terminated under
328
- Section 6(a), it reinstates:
329
-
330
- 1. automatically as of the date the violation is cured, provided
331
- it is cured within 30 days of Your discovery of the
332
- violation; or
333
-
334
- 2. upon express reinstatement by the Licensor.
335
-
336
- For the avoidance of doubt, this Section 6(b) does not affect any
337
- right the Licensor may have to seek remedies for Your violations
338
- of this Public License.
339
-
340
- c. For the avoidance of doubt, the Licensor may also offer the
341
- Licensed Material under separate terms or conditions or stop
342
- distributing the Licensed Material at any time; however, doing so
343
- will not terminate this Public License.
344
-
345
- d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
346
- License.
347
-
348
- Section 7 -- Other Terms and Conditions.
349
-
350
- a. The Licensor shall not be bound by any additional or different
351
- terms or conditions communicated by You unless expressly agreed.
352
-
353
- b. Any arrangements, understandings, or agreements regarding the
354
- Licensed Material not stated herein are separate from and
355
- independent of the terms and conditions of this Public License.
356
-
357
- Section 8 -- Interpretation.
358
-
359
- a. For the avoidance of doubt, this Public License does not, and
360
- shall not be interpreted to, reduce, limit, restrict, or impose
361
- conditions on any use of the Licensed Material that could lawfully
362
- be made without permission under this Public License.
363
-
364
- b. To the extent possible, if any provision of this Public License is
365
- deemed unenforceable, it shall be automatically reformed to the
366
- minimum extent necessary to make it enforceable. If the provision
367
- cannot be reformed, it shall be severed from this Public License
368
- without affecting the enforceability of the remaining terms and
369
- conditions.
370
-
371
- c. No term or condition of this Public License will be waived and no
372
- failure to comply consented to unless expressly agreed to by the
373
- Licensor.
374
-
375
- d. Nothing in this Public License constitutes or may be interpreted
376
- as a limitation upon, or waiver of, any privileges and immunities
377
- that apply to the Licensor or You, including from the legal
378
- processes of any jurisdiction or authority.
379
-
380
- =======================================================================
381
-
382
- Creative Commons is not a party to its public
383
- licenses. Notwithstanding, Creative Commons may elect to apply one of
384
- its public licenses to material it publishes and in those instances
385
- will be considered the “Licensor.” The text of the Creative Commons
386
- public licenses is dedicated to the public domain under the CC0 Public
387
- Domain Dedication. Except for the limited purpose of indicating that
388
- material is shared under a Creative Commons public license or as
389
- otherwise permitted by the Creative Commons policies published at
390
- creativecommons.org/policies, Creative Commons does not authorize the
391
- use of the trademark "Creative Commons" or any other trademark or logo
392
- of Creative Commons without its prior written consent including,
393
- without limitation, in connection with any unauthorized modifications
394
- to any of its public licenses or any other arrangements,
395
- understandings, or agreements concerning use of licensed material. For
396
- the avoidance of doubt, this paragraph does not form part of the
397
- public licenses.
398
-
399
- Creative Commons may be contacted at creativecommons.org.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MuCodec/configs/models/transformer2D.json DELETED
@@ -1,25 +0,0 @@
1
- {
2
- "_class_name": "Transformer2DModel",
3
- "activation_fn": "gelu-approximate",
4
- "attention_bias": true,
5
- "attention_head_dim": 72,
6
- "attention_type": "default",
7
- "cross_attention_dim": null,
8
- "double_self_attention": false,
9
- "dropout": 0.0,
10
- "in_channels": 96,
11
- "norm_elementwise_affine": false,
12
- "norm_eps": 1e-06,
13
- "norm_num_groups": 32,
14
- "norm_type": "ada_norm_single",
15
- "num_attention_heads": 22,
16
- "num_embeds_ada_norm": 1000,
17
- "num_layers": 24,
18
- "num_vector_embeds": null,
19
- "only_cross_attention": false,
20
- "out_channels": 32,
21
- "patch_size": 2,
22
- "sample_size": 384,
23
- "upcast_attention": false,
24
- "use_linear_projection": false
25
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MuCodec/configs/scheduler/stable_diffusion_2.1_largenoise_sample.json DELETED
@@ -1,14 +0,0 @@
1
- {
2
- "_class_name": "DDIMScheduler",
3
- "_diffusers_version": "0.8.0",
4
- "beta_end": 0.02,
5
- "beta_schedule": "scaled_linear",
6
- "beta_start": 0.0015,
7
- "clip_sample": false,
8
- "num_train_timesteps": 1000,
9
- "prediction_type": "sample",
10
- "set_alpha_to_one": false,
11
- "skip_prk_steps": true,
12
- "steps_offset": 1,
13
- "trained_betas": null
14
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MuCodec/generate.py DELETED
@@ -1,248 +0,0 @@
1
- import json
2
- import torch
3
- from tqdm import tqdm
4
- import sys
5
- from model import PromptCondAudioDiffusion
6
- from diffusers import DDIMScheduler, DDPMScheduler
7
- import torchaudio
8
- import librosa
9
- import os
10
- import math
11
- import numpy as np
12
- from tools.get_melvaehifigan48k import build_pretrained_models
13
- import tools.torch_tools as torch_tools
14
- from safetensors.torch import load_file
15
- from cached_path import cached_path
16
-
17
- class MuCodec:
18
- def __init__(self, \
19
- model_path, \
20
- layer_num, \
21
- load_main_model=True, \
22
- device="cuda:0"):
23
-
24
- self.layer_num = layer_num - 1
25
- self.sample_rate = 48000
26
- self.device = device
27
-
28
- self.MAX_DURATION = 360
29
- if load_main_model:
30
- audio_ldm_path = str(cached_path("hf://haoheliu/audioldm_48k/audioldm_48k.pth"))
31
- self.vae, self.stft = build_pretrained_models(audio_ldm_path)
32
- self.vae, self.stft = self.vae.eval().to(device), self.stft.eval().to(device)
33
- main_config = {
34
- "num_channels":32,
35
- "unet_model_name":None,
36
- "unet_model_config_path":os.path.dirname(os.path.abspath(__file__)) + "/configs/models/transformer2D.json",
37
- "snr_gamma":None,
38
- }
39
- self.model = PromptCondAudioDiffusion(**main_config)
40
- if model_path.endswith('.safetensors'):
41
- main_weights = load_file(model_path)
42
- else:
43
- main_weights = torch.load(model_path, map_location='cpu')
44
- self.model.load_state_dict(main_weights, strict=False)
45
- self.model = self.model.to(device)
46
- print ("Successfully loaded checkpoint from:", model_path)
47
- else:
48
- main_config = {
49
- "num_channels":32,
50
- "unet_model_name":None,
51
- "unet_model_config_path":None,
52
- "snr_gamma":None,
53
- }
54
- self.model = PromptCondAudioDiffusion(**main_config).to(device)
55
- main_weights = torch.load(model_path, map_location='cpu')
56
- self.model.load_state_dict(main_weights, strict=False)
57
- self.model = self.model.to(device)
58
- print ("Successfully loaded checkpoint from:", model_path)
59
-
60
- self.model.eval()
61
- self.model.init_device_dtype(torch.device(device), torch.float32)
62
- print("scaling factor: ", self.model.normfeat.std)
63
-
64
- def file2code(self, fname):
65
- orig_samples, fs = torchaudio.load(fname)
66
- if(fs!=self.sample_rate):
67
- orig_samples = torchaudio.functional.resample(orig_samples, fs, self.sample_rate)
68
- fs = self.sample_rate
69
- if orig_samples.shape[0] == 1:
70
- orig_samples = torch.cat([orig_samples, orig_samples], 0)
71
- return self.sound2code(orig_samples)
72
-
73
- @torch.no_grad()
74
- @torch.autocast(device_type="cuda", dtype=torch.float32)
75
- def sound2code(self, orig_samples, batch_size=3):
76
- if(orig_samples.ndim == 2):
77
- audios = orig_samples.unsqueeze(0).to(self.device)
78
- elif(orig_samples.ndim == 3):
79
- audios = orig_samples.to(self.device)
80
- else:
81
- assert orig_samples.ndim in (2,3), orig_samples.shape
82
- audios = self.preprocess_audio(audios)
83
- audios = audios.squeeze(0)
84
- orig_length = audios.shape[-1]
85
- min_samples = int(40.96 * self.sample_rate)
86
- output_len = int(orig_length / float(self.sample_rate) * 25) + 1
87
- print("output_len: ", output_len)
88
-
89
- while(audios.shape[-1] < min_samples + 480):
90
- audios = torch.cat([audios, audios], -1)
91
- int_max_len=audios.shape[-1]//min_samples+1
92
- # print("int_max_len: ", int_max_len)
93
- audios = torch.cat([audios, audios], -1)
94
- # print("audios:",audios.shape)
95
- audios=audios[:,:int(int_max_len*(min_samples+480))]
96
- codes_list=[]
97
-
98
- audio_input = audios.reshape(2, -1, min_samples+480).permute(1, 0, 2).reshape(-1, 2, min_samples+480)
99
-
100
- for audio_inx in range(0, audio_input.shape[0], batch_size):
101
- # import pdb; pdb.set_trace()
102
- codes, _, spk_embeds = self.model.fetch_codes_batch((audio_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer=self.layer_num)
103
- codes_list.append(torch.cat(codes, 1))
104
- # print("codes_list",codes_list[0].shape)
105
-
106
- codes = torch.cat(codes_list, 0).permute(1,0,2).reshape(1, -1)[None] # B 3 T -> 3 B T
107
- codes=codes[:,:,:output_len]
108
-
109
- return codes
110
-
111
- @torch.no_grad()
112
- def code2sound(self, codes, prompt=None, duration=40.96, guidance_scale=1.5, num_steps=20, disable_progress=False):
113
- codes = codes.to(self.device)
114
- first_latent = torch.randn(codes.shape[0], 32, 512, 32).to(self.device)
115
- first_latent_length = 0
116
- first_latent_codes_length = 0
117
- if(isinstance(prompt, torch.Tensor)):
118
- prompt = prompt.to(self.device)
119
- if(prompt.ndim == 3):
120
- assert prompt.shape[0] == 1, prompt.shape
121
- prompt = prompt[0]
122
- elif(prompt.ndim == 1):
123
- prompt = prompt.unsqueeze(0).repeat(2,1)
124
- elif(prompt.ndim == 2):
125
- if(prompt.shape[0] == 1):
126
- prompt = prompt.repeat(2,1)
127
-
128
- if(prompt.shape[-1] < int(30.76 * self.sample_rate)):
129
- prompt = prompt[:,:int(10.24*self.sample_rate)] # limit max length to 10.24
130
- else:
131
- prompt = prompt[:,int(20.48*self.sample_rate):int(30.72*self.sample_rate)] # limit max length to 10.24
132
-
133
- true_mel , _, _ = torch_tools.wav_to_fbank2(prompt, -1, fn_STFT=self.stft) # maximum 10.24s
134
- true_mel = true_mel.unsqueeze(1).to(self.device)
135
- true_latent = torch.cat([self.vae.get_first_stage_encoding(self.vae.encode_first_stage(true_mel[[m]])) for m in range(true_mel.shape[0])],0)
136
- true_latent = true_latent.reshape(true_latent.shape[0]//2, -1, true_latent.shape[2], true_latent.shape[3]).detach()
137
-
138
- first_latent[:,:,0:true_latent.shape[2],:] = true_latent
139
- first_latent_length = true_latent.shape[2]
140
- first_latent_codes = self.sound2code(prompt)[:,:,0:first_latent_length*2] # B 4 T
141
- first_latent_codes_length = first_latent_codes.shape[-1]
142
- codes = torch.cat([first_latent_codes, codes], -1)
143
-
144
- min_samples = 1024
145
- hop_samples = min_samples // 4 * 3
146
- ovlp_samples = min_samples - hop_samples
147
- hop_frames = hop_samples // 2
148
- ovlp_frames = ovlp_samples // 2
149
-
150
- codes_len= codes.shape[-1]
151
- target_len = int((codes_len - first_latent_codes_length) / 100 * 4 * self.sample_rate)
152
-
153
- if(codes_len < min_samples):
154
- while(codes.shape[-1] < min_samples):
155
- codes = torch.cat([codes, codes], -1)
156
- codes = codes[:,:,0:min_samples]
157
- codes_len = codes.shape[-1]
158
- if((codes_len - ovlp_frames) % hop_samples > 0):
159
- len_codes=math.ceil((codes_len - ovlp_samples) / float(hop_samples)) * hop_samples + ovlp_samples
160
- while(codes.shape[-1] < len_codes):
161
- codes = torch.cat([codes, codes], -1)
162
- codes = codes[:,:,0:len_codes]
163
- latent_length = 512
164
- latent_list = []
165
- spk_embeds = torch.zeros([1, 32, 1, 32], device=codes.device)
166
- with torch.autocast(device_type="cuda", dtype=torch.float16):
167
- for sinx in range(0, codes.shape[-1]-hop_samples, hop_samples):
168
- codes_input=[]
169
- codes_input.append(codes[:,:,sinx:sinx+min_samples])
170
- if(sinx == 0):
171
- incontext_length = first_latent_length
172
- latents = self.model.inference_codes(codes_input, spk_embeds, first_latent, latent_length, incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
173
- latent_list.append(latents)
174
- else:
175
- true_latent = latent_list[-1][:,:,-ovlp_frames:,:]
176
- len_add_to_512 = 512 - true_latent.shape[-2]
177
- incontext_length = true_latent.shape[-2]
178
- true_latent = torch.cat([true_latent, torch.randn(true_latent.shape[0], true_latent.shape[1], len_add_to_512, true_latent.shape[-1]).to(self.device)], -2)
179
- latents = self.model.inference_codes(codes_input, spk_embeds, true_latent, latent_length, incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
180
- latent_list.append(latents)
181
-
182
- latent_list = [l.float() for l in latent_list]
183
- latent_list[0] = latent_list[0][:,:,first_latent_length:,:]
184
- min_samples = int(duration * self.sample_rate)
185
- hop_samples = min_samples // 4 * 3
186
- ovlp_samples = min_samples - hop_samples
187
- with torch.no_grad():
188
- output = None
189
- for i in range(len(latent_list)):
190
- latent = latent_list[i]
191
- bsz , ch, t, f = latent.shape
192
- latent = latent.reshape(bsz*2, ch//2, t, f)
193
- mel = self.vae.decode_first_stage(latent)
194
- cur_output = self.vae.decode_to_waveform(mel)
195
- cur_output = torch.from_numpy(cur_output)[:, 0:min_samples]
196
-
197
- if output is None:
198
- output = cur_output
199
- else:
200
- ov_win = torch.from_numpy(np.linspace(0, 1, ovlp_samples)[None, :])
201
- ov_win = torch.cat([ov_win, 1 - ov_win], -1)
202
- output[:, -ovlp_samples:] = output[:, -ovlp_samples:] * ov_win[:, -ovlp_samples:] + cur_output[:, 0:ovlp_samples] * ov_win[:, 0:ovlp_samples]
203
- output = torch.cat([output, cur_output[:, ovlp_samples:]], -1)
204
- output = output[:, 0:target_len]
205
- return output
206
-
207
- @torch.no_grad()
208
- def preprocess_audio(self, input_audios, threshold=0.8):
209
- assert len(input_audios.shape) == 3, input_audios.shape
210
- nchan = input_audios.shape[1]
211
- input_audios = input_audios.reshape(input_audios.shape[0], -1)
212
- norm_value = torch.ones_like(input_audios[:,0])
213
- max_volume = input_audios.abs().max(dim=-1)[0]
214
- norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
215
- return input_audios.reshape(input_audios.shape[0], nchan, -1)/norm_value.unsqueeze(-1).unsqueeze(-1)
216
-
217
- @torch.no_grad()
218
- def sound2sound(self, sound, prompt=None, min_duration=40.96, steps=50, disable_progress=False):
219
- codes = self.sound2code(sound)
220
- wave = self.code2sound(codes, prompt, duration=min_duration, guidance_scale=1.5, num_steps=steps, disable_progress=disable_progress)
221
- return wave
222
-
223
- if __name__=="__main__":
224
- ckpt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ckpt/mucodec.pt")
225
- mucodec = MuCodec(model_path=ckpt_path,layer_num=7,load_main_model=True)
226
-
227
- filelist = []
228
-
229
- root_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_wav")
230
- for f in [os.path.join(root_dir, f) for f in os.listdir(root_dir) if '.flac' in f or '.wav' in f or '.mp3' in f]:
231
- a, fs = torchaudio.load(f)
232
- if(fs!=48000):
233
- a = torchaudio.functional.resample(a, fs, 48000)
234
- if(a.shape[0]==1):
235
- a = torch.cat([a,a],0)
236
- ori_len = a.shape[-1]
237
- filelist.append([a, '', [0, a.shape[-1]/48000.], f,ori_len])
238
-
239
- reconstructed_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "reconstructed")
240
-
241
- os.makedirs(reconstructed_dir, exist_ok=True)
242
-
243
- for sample_idx, (orig_samples, lyric, st_et, fname,ori_len) in enumerate(filelist):
244
- print(fname, lyric)
245
- wave = mucodec.sound2sound(orig_samples,None)
246
- wave = wave[:,0:ori_len]
247
- torchaudio.save(os.path.join(reconstructed_dir, os.path.basename(fname)),wave.detach().cpu(), 48000)
248
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MuCodec/libs/rvq/descript_quantize3.py DELETED
@@ -1,298 +0,0 @@
1
- from typing import Union
2
-
3
- import numpy as np
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
- from einops import rearrange
8
- from torch.nn.utils import weight_norm
9
-
10
- def WNConv1d(*args, **kwargs):
11
- return weight_norm(nn.Conv1d(*args, **kwargs))
12
-
13
- class VectorQuantize(nn.Module):
14
- """
15
- Implementation of VQ similar to Karpathy's repo:
16
- https://github.com/karpathy/deep-vector-quantization
17
- Additionally uses following tricks from Improved VQGAN
18
- (https://arxiv.org/pdf/2110.04627.pdf):
19
- 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
20
- for improved codebook usage
21
- 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
22
- improves training stability
23
- """
24
-
25
- def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int, stale_tolerance: int = 100):
26
- super().__init__()
27
- self.codebook_size = codebook_size
28
- self.codebook_dim = codebook_dim
29
-
30
- self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
31
- self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
32
- self.codebook = nn.Embedding(codebook_size, codebook_dim)
33
- self.register_buffer("stale_counter", torch.zeros(self.codebook_size,))
34
- self.stale_tolerance = stale_tolerance
35
-
36
- def forward(self, z):
37
- """Quantized the input tensor using a fixed codebook and returns
38
- the corresponding codebook vectors
39
-
40
- Parameters
41
- ----------
42
- z : Tensor[B x D x T]
43
-
44
- Returns
45
- -------
46
- Tensor[B x D x T]
47
- Quantized continuous representation of input
48
- Tensor[1]
49
- Commitment loss to train encoder to predict vectors closer to codebook
50
- entries
51
- Tensor[1]
52
- Codebook loss to update the codebook
53
- Tensor[B x T]
54
- Codebook indices (quantized discrete representation of input)
55
- Tensor[B x D x T]
56
- Projected latents (continuous representation of input before quantization)
57
- """
58
-
59
- # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
60
- z_e = self.in_proj(z) # z_e : (B x D x T)
61
- z_q, indices = self.decode_latents(z_e)
62
-
63
- commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
64
- codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
65
-
66
- z_q = (
67
- z_e + (z_q - z_e).detach()
68
- ) # noop in forward pass, straight-through gradient estimator in backward pass
69
-
70
- z_q = self.out_proj(z_q)
71
-
72
- return z_q, commitment_loss, codebook_loss, indices, z_e
73
-
74
- def embed_code(self, embed_id):
75
- return F.embedding(embed_id, self.codebook.weight)
76
-
77
- def decode_code(self, embed_id):
78
- return self.embed_code(embed_id).transpose(1, 2)
79
-
80
- def decode_latents(self, latents):
81
- encodings = rearrange(latents, "b d t -> (b t) d")
82
- codebook = self.codebook.weight # codebook: (N x D)
83
-
84
- # L2 normalize encodings and codebook (ViT-VQGAN)
85
- encodings = F.normalize(encodings)
86
- codebook = F.normalize(codebook)
87
-
88
- # Compute euclidean distance with codebook
89
- dist = (
90
- encodings.pow(2).sum(1, keepdim=True)
91
- - 2 * encodings @ codebook.t()
92
- + codebook.pow(2).sum(1, keepdim=True).t()
93
- )
94
- indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
95
- z_q = self.decode_code(indices)
96
-
97
- if(self.training):
98
- onehots = torch.nn.functional.one_hot(indices, self.codebook_size).float() # B, T, codebook_size
99
- stale_codes = (onehots.sum(0).sum(0) == 0).float()
100
- self.stale_counter = self.stale_counter * stale_codes + stale_codes
101
-
102
- # random replace codes that haven't been used for a while
103
- replace_code = (self.stale_counter == self.stale_tolerance).float() # codebook_size
104
- if replace_code.sum(-1) > 0:
105
- print("Replace {} codes".format(replace_code.sum(-1)))
106
- random_input_idx = torch.randperm(encodings.shape[0])
107
- random_input = encodings[random_input_idx].view(encodings.shape)
108
- if random_input.shape[0] < self.codebook_size:
109
- random_input = torch.cat([random_input]*(self.codebook_size // random_input.shape[0] + 1), 0)
110
- random_input = random_input[:self.codebook_size,:].contiguous() # codebook_size, dim
111
-
112
- self.codebook.weight.data = self.codebook.weight.data * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1)
113
- self.stale_counter = self.stale_counter * (1 - replace_code)
114
-
115
- return z_q, indices
116
-
117
-
118
- class ResidualVectorQuantize(nn.Module):
119
- """
120
- Introduced in SoundStream: An end2end neural audio codec
121
- https://arxiv.org/abs/2107.03312
122
- """
123
-
124
- def __init__(
125
- self,
126
- input_dim: int = 512,
127
- n_codebooks: int = 9,
128
- codebook_size: int = 1024,
129
- codebook_dim: Union[int, list] = 8,
130
- quantizer_dropout: float = 0.0,
131
- stale_tolerance: int = 100,
132
- ):
133
- super().__init__()
134
- if isinstance(codebook_dim, int):
135
- codebook_dim = [codebook_dim for _ in range(n_codebooks)]
136
-
137
- self.n_codebooks = n_codebooks
138
- self.codebook_dim = codebook_dim
139
- self.codebook_size = codebook_size
140
-
141
- self.quantizers = nn.ModuleList(
142
- [
143
- VectorQuantize(input_dim, codebook_size, codebook_dim[i], stale_tolerance=stale_tolerance)
144
- for i in range(n_codebooks)
145
- ]
146
- )
147
- self.quantizer_dropout = quantizer_dropout
148
-
149
- def forward(self, z, n_quantizers: int = None):
150
- """Quantized the input tensor using a fixed set of `n` codebooks and returns
151
- the corresponding codebook vectors
152
- Parameters
153
- ----------
154
- z : Tensor[B x D x T]
155
- n_quantizers : int, optional
156
- No. of quantizers to use
157
- (n_quantizers < self.n_codebooks ex: for quantizer dropout)
158
- Note: if `self.quantizer_dropout` is True, this argument is ignored
159
- when in training mode, and a random number of quantizers is used.
160
- Returns
161
- -------
162
- dict
163
- A dictionary with the following keys:
164
-
165
- "z" : Tensor[B x D x T]
166
- Quantized continuous representation of input
167
- "codes" : Tensor[B x N x T]
168
- Codebook indices for each codebook
169
- (quantized discrete representation of input)
170
- "latents" : Tensor[B x N*D x T]
171
- Projected latents (continuous representation of input before quantization)
172
- "vq/commitment_loss" : Tensor[1]
173
- Commitment loss to train encoder to predict vectors closer to codebook
174
- entries
175
- "vq/codebook_loss" : Tensor[1]
176
- Codebook loss to update the codebook
177
- """
178
- z_q = 0
179
- residual = z
180
- commitment_loss = 0
181
- codebook_loss = 0
182
-
183
- codebook_indices = []
184
- latents = []
185
-
186
- if n_quantizers is None:
187
- n_quantizers = self.n_codebooks
188
- if self.training:
189
- n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
190
- dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
191
- n_dropout = int(z.shape[0] * self.quantizer_dropout)
192
- n_quantizers[:n_dropout] = dropout[:n_dropout]
193
- n_quantizers = n_quantizers.to(z.device)
194
- else:
195
- n_quantizers = torch.ones((z.shape[0],)) * n_quantizers + 1
196
- n_quantizers = n_quantizers.to(z.device)
197
-
198
- for i, quantizer in enumerate(self.quantizers):
199
- # if self.training is False and i >= n_quantizers:
200
- # break
201
-
202
- z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
203
- residual
204
- )
205
-
206
- # Create mask to apply quantizer dropout
207
- mask = (
208
- torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
209
- )
210
- z_q = z_q + z_q_i * mask[:, None, None]
211
- residual = residual - z_q_i
212
-
213
- # Sum losses
214
- commitment_loss += (commitment_loss_i * mask).mean()
215
- codebook_loss += (codebook_loss_i * mask).mean()
216
-
217
- codebook_indices.append(indices_i)
218
- latents.append(z_e_i)
219
-
220
- codes = torch.stack(codebook_indices, dim=1)
221
- latents = torch.cat(latents, dim=1)
222
-
223
- encodings = F.one_hot(codes, self.codebook_size).float() # B N T 1024
224
- for n in range(encodings.shape[1]):
225
- print("Lyaer {}, Ratio of unused vector : {:.1f}".format(n,
226
- (encodings[:,n,:,:].sum(0).sum(0) < 1.0).sum()/torch.numel(encodings[:,n,:,:].sum(0).sum(0) < 1.0) * 100.
227
- ))
228
-
229
- return z_q, codes, latents, commitment_loss, codebook_loss, n_quantizers.clamp(max=self.n_codebooks).long() - 1
230
-
231
- def from_codes(self, codes: torch.Tensor):
232
- """Given the quantized codes, reconstruct the continuous representation
233
- Parameters
234
- ----------
235
- codes : Tensor[B x N x T]
236
- Quantized discrete representation of input
237
- Returns
238
- -------
239
- Tensor[B x D x T]
240
- Quantized continuous representation of input
241
- """
242
- z_q = 0.0
243
- z_p = []
244
- n_codebooks = codes.shape[1]
245
- for i in range(n_codebooks):
246
- z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
247
- z_p.append(z_p_i)
248
-
249
- z_q_i = self.quantizers[i].out_proj(z_p_i)
250
- z_q = z_q + z_q_i
251
- return z_q, torch.cat(z_p, dim=1), codes
252
-
253
- def from_latents(self, latents: torch.Tensor):
254
- """Given the unquantized latents, reconstruct the
255
- continuous representation after quantization.
256
-
257
- Parameters
258
- ----------
259
- latents : Tensor[B x N x T]
260
- Continuous representation of input after projection
261
-
262
- Returns
263
- -------
264
- Tensor[B x D x T]
265
- Quantized representation of full-projected space
266
- Tensor[B x D x T]
267
- Quantized representation of latent space
268
- """
269
- z_q = 0
270
- z_p = []
271
- codes = []
272
- dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
273
-
274
- n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
275
- 0
276
- ]
277
- for i in range(n_codebooks):
278
- j, k = dims[i], dims[i + 1]
279
- z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
280
- z_p.append(z_p_i)
281
- codes.append(codes_i)
282
-
283
- z_q_i = self.quantizers[i].out_proj(z_p_i)
284
- z_q = z_q + z_q_i
285
-
286
- return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
287
-
288
-
289
- if __name__ == "__main__":
290
- rvq = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 4, codebook_size = 1024, codebook_dim = 32, quantizer_dropout = 0.0)
291
- x = torch.randn(16, 1024, 80)
292
- quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = rvq(x)
293
- print(quantized_prompt_embeds.shape)
294
- print(codes.shape)
295
- # w/o reconstruction
296
- loss = commitment_loss * 0.25 + codebook_loss * 1.0
297
- # w/ reconstruction
298
- loss = commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MuCodec/model.py DELETED
@@ -1,367 +0,0 @@
1
- import yaml
2
- import random
3
- import inspect
4
- import numpy as np
5
- from tqdm import tqdm
6
- import typing as tp
7
- from abc import ABC
8
-
9
- import torch
10
- import torch.nn as nn
11
- import torch.nn.functional as F
12
- import torchaudio
13
-
14
- from einops import repeat
15
- from tools.torch_tools import wav_to_fbank
16
- import os
17
- import diffusers
18
- from diffusers.utils.torch_utils import randn_tensor
19
- from diffusers import DDPMScheduler
20
- from models.transformer_2d_flow import Transformer2DModel
21
- from libs.rvq.descript_quantize3 import ResidualVectorQuantize
22
- from torch.cuda.amp import autocast
23
- from muq_dev.test import load_model
24
-
25
-
26
-
27
-
28
- class SampleProcessor(torch.nn.Module):
29
- def project_sample(self, x: torch.Tensor):
30
- """Project the original sample to the 'space' where the diffusion will happen."""
31
- return x
32
-
33
- def return_sample(self, z: torch.Tensor):
34
- """Project back from diffusion space to the actual sample space."""
35
- return z
36
-
37
- class Feature2DProcessor(SampleProcessor):
38
- def __init__(self, dim: int = 8, power_std: tp.Union[float, tp.List[float], torch.Tensor] = 1., \
39
- num_samples: int = 100_000):
40
- super().__init__()
41
- self.num_samples = num_samples
42
- self.dim = dim
43
- self.power_std = power_std
44
- self.register_buffer('counts', torch.zeros(1))
45
- self.register_buffer('sum_x', torch.zeros(dim, 32))
46
- self.register_buffer('sum_x2', torch.zeros(dim, 32))
47
- self.register_buffer('sum_target_x2', torch.zeros(dim, 32))
48
- self.counts: torch.Tensor
49
- self.sum_x: torch.Tensor
50
- self.sum_x2: torch.Tensor
51
-
52
- @property
53
- def mean(self):
54
- mean = self.sum_x / self.counts
55
- return mean
56
-
57
- @property
58
- def std(self):
59
- std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
60
- return std
61
-
62
- @property
63
- def target_std(self):
64
- return 1
65
-
66
- def project_sample(self, x: torch.Tensor):
67
- assert x.dim() == 4
68
- if self.counts.item() < self.num_samples:
69
- self.counts += len(x)
70
- self.sum_x += x.mean(dim=(2,)).sum(dim=0)
71
- self.sum_x2 += x.pow(2).mean(dim=(2,)).sum(dim=0)
72
- rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
73
- x = (x - self.mean.view(1, -1, 1, 32).contiguous()) * rescale.view(1, -1, 1, 32).contiguous()
74
- return x
75
-
76
- def return_sample(self, x: torch.Tensor):
77
- assert x.dim() == 4
78
- rescale = (self.std / self.target_std) ** self.power_std
79
- x = x * rescale.view(1, -1, 1, 32).contiguous() + self.mean.view(1, -1, 1, 32).contiguous()
80
- return x
81
-
82
-
83
- class BASECFM(torch.nn.Module, ABC):
84
- def __init__(
85
- self,
86
- estimator,
87
- ):
88
- super().__init__()
89
- self.sigma_min = 1e-4
90
-
91
- self.estimator = estimator
92
-
93
- @torch.inference_mode()
94
- def forward(self, mu, n_timesteps, temperature=1.0):
95
- """Forward diffusion
96
-
97
- Args:
98
- mu (torch.Tensor): output of encoder
99
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
100
- n_timesteps (int): number of diffusion steps
101
- temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
102
-
103
- Returns:
104
- sample: generated mel-spectrogram
105
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
106
- """
107
- z = torch.randn_like(mu) * temperature
108
- t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
109
- return self.solve_euler(z, t_span=t_span)
110
-
111
- def solve_euler(self, x, incontext_x, incontext_length, t_span, mu, added_cond_kwargs, guidance_scale):
112
- """
113
- Fixed euler solver for ODEs.
114
- Args:
115
- x (torch.Tensor): random noise
116
- t_span (torch.Tensor): n_timesteps interpolated
117
- shape: (n_timesteps + 1,)
118
- mu (torch.Tensor): output of encoder
119
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
120
- """
121
- t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
122
- noise = x.clone()
123
-
124
- # I am storing this because I can later plot it by putting a debugger here and saving it to a file
125
- # Or in future might add like a return_all_steps flag
126
- sol = []
127
-
128
- for step in tqdm(range(1, len(t_span))):
129
- x[:,:,0:incontext_length,:] = (1 - (1 - self.sigma_min) * t) * noise[:,:,0:incontext_length,:] + t * incontext_x[:,:,0:incontext_length,:]
130
- if(guidance_scale > 1.0):
131
- dphi_dt = self.estimator( \
132
- torch.cat([ \
133
- torch.cat([x, x], 0), \
134
- torch.cat([incontext_x, incontext_x], 0), \
135
- torch.cat([torch.zeros_like(mu), mu], 0), \
136
- ], 1), \
137
- timestep = t.unsqueeze(-1).repeat(2), \
138
- added_cond_kwargs={k:torch.cat([v,v],0) for k,v in added_cond_kwargs.items()}).sample
139
- dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0)
140
- dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond)
141
- else:
142
- dphi_dt = self.estimator(torch.cat([x, incontext_x, mu], 1), \
143
- timestep = t.unsqueeze(-1),
144
- added_cond_kwargs=added_cond_kwargs).sample
145
-
146
- x = x + dt * dphi_dt
147
- t = t + dt
148
- sol.append(x)
149
- if step < len(t_span) - 1:
150
- dt = t_span[step + 1] - t
151
-
152
- return sol[-1]
153
-
154
-
155
- class PromptCondAudioDiffusion(nn.Module):
156
- def __init__(
157
- self,
158
- num_channels,
159
- unet_model_name=None,
160
- unet_model_config_path=None,
161
- snr_gamma=None,
162
- uncondition=True,
163
- out_paint=False,
164
- ):
165
- super().__init__()
166
-
167
- assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required"
168
-
169
- self.unet_model_name = unet_model_name
170
- self.unet_model_config_path = unet_model_config_path
171
- self.snr_gamma = snr_gamma
172
- self.uncondition = uncondition
173
- self.num_channels = num_channels
174
-
175
- # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview
176
- self.normfeat = Feature2DProcessor(dim=num_channels)
177
-
178
- self.sample_rate = 48000
179
- self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000)
180
- self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000)
181
- muencoder_dir = "muq_dev/muq_fairseq"
182
- muencoder_ckpt = "muq_dev/muq.pt"
183
-
184
- self.muencoder = load_model(
185
- model_dir=os.path.abspath(muencoder_dir),
186
- checkpoint_dir=os.path.abspath(muencoder_ckpt),
187
- )
188
- self.rsq48tomuencoder = torchaudio.transforms.Resample(48000, 24000)
189
- for v in self.muencoder.parameters():v.requires_grad = False
190
- self.rvq_muencoder_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
191
- self.cond_muencoder_emb = nn.Linear(1024, 16*32)
192
- self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,))
193
-
194
- unet = Transformer2DModel.from_config(
195
- unet_model_config_path,
196
- )
197
- self.set_from = "random"
198
- self.cfm_wrapper = BASECFM(unet)
199
- print("Transformer initialized from pretrain.")
200
-
201
-
202
- def compute_snr(self, timesteps):
203
- """
204
- Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
205
- """
206
- alphas_cumprod = self.noise_scheduler.alphas_cumprod
207
- sqrt_alphas_cumprod = alphas_cumprod**0.5
208
- sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
209
-
210
- # Expand the tensors.
211
- # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
212
- sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
213
- while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
214
- sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
215
- alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
216
-
217
- sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
218
- while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
219
- sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
220
- sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
221
-
222
- # Compute SNR.
223
- snr = (alpha / sigma) ** 2
224
- return snr
225
-
226
- def preprocess_audio(self, input_audios, threshold=0.9):
227
- assert len(input_audios.shape) == 2, input_audios.shape
228
- norm_value = torch.ones_like(input_audios[:,0])
229
- max_volume = input_audios.abs().max(dim=-1)[0]
230
- norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
231
- return input_audios/norm_value.unsqueeze(-1)
232
-
233
-
234
-
235
-
236
- def extract_muencoder_embeds(self, input_audio_0,input_audio_1,layer):
237
- input_wav_mean = (input_audio_0 + input_audio_1) / 2.0
238
- input_wav_mean = self.muencoder(self.rsq48tomuencoder(input_wav_mean), features_only = True)
239
- layer_results = input_wav_mean['layer_results']
240
- muencoder_emb = layer_results[layer]
241
- muencoder_emb = muencoder_emb.permute(0,2,1).contiguous()
242
- return muencoder_emb
243
-
244
-
245
-
246
-
247
- def init_device_dtype(self, device, dtype):
248
- self.device = device
249
- self.dtype = dtype
250
-
251
- @torch.no_grad()
252
- def fetch_codes(self, input_audios, additional_feats,layer):
253
- input_audio_0 = input_audios[[0],:]
254
- input_audio_1 = input_audios[[1],:]
255
- input_audio_0 = self.preprocess_audio(input_audio_0)
256
- input_audio_1 = self.preprocess_audio(input_audio_1)
257
-
258
- self.muencoder.eval()
259
-
260
-
261
- muencoder_emb = self.extract_muencoder_embeds(input_audio_0,input_audio_1,layer)
262
- muencoder_emb = muencoder_emb.detach()
263
-
264
- self.rvq_muencoder_emb.eval()
265
- quantized_muencoder_emb, codes_muencoder_emb, *_ = self.rvq_muencoder_emb(muencoder_emb)
266
-
267
-
268
- spk_embeds = None
269
-
270
-
271
- return [codes_muencoder_emb], [muencoder_emb], spk_embeds
272
- @torch.no_grad()
273
- def fetch_codes_batch(self, input_audios, additional_feats,layer):
274
- input_audio_0 = input_audios[:,0,:]
275
- input_audio_1 = input_audios[:,1,:]
276
- input_audio_0 = self.preprocess_audio(input_audio_0)
277
- input_audio_1 = self.preprocess_audio(input_audio_1)
278
-
279
- self.muencoder.eval()
280
-
281
-
282
- muencoder_emb = self.extract_muencoder_embeds(input_audio_0,input_audio_1,layer)
283
- muencoder_emb = muencoder_emb.detach()
284
-
285
- self.rvq_muencoder_emb.eval()
286
- quantized_muencoder_emb, codes_muencoder_emb, *_ = self.rvq_muencoder_emb(muencoder_emb) # b,d,t
287
-
288
- spk_embeds = None
289
-
290
- return [codes_muencoder_emb], [muencoder_emb], spk_embeds
291
- @torch.no_grad()
292
- def inference_codes(self, codes, spk_embeds, true_latents, latent_length,incontext_length, additional_feats,
293
- guidance_scale=2, num_steps=20,
294
- disable_progress=True, scenario='start_seg'):
295
- classifier_free_guidance = guidance_scale > 1.0
296
- device = self.device
297
- dtype = self.dtype
298
- codes_muencoder_emb = codes[0]
299
-
300
-
301
- batch_size = codes_muencoder_emb.shape[0]
302
-
303
-
304
- quantized_muencoder_emb,_,_=self.rvq_muencoder_emb.from_codes(codes_muencoder_emb)
305
-
306
- quantized_muencoder_emb = self.cond_muencoder_emb(quantized_muencoder_emb.permute(0,2,1)) # b t 16*32
307
- quantized_muencoder_emb = quantized_muencoder_emb.reshape(quantized_muencoder_emb.shape[0], quantized_muencoder_emb.shape[1]//2, 2, 16, 32).reshape(quantized_muencoder_emb.shape[0], quantized_muencoder_emb.shape[1]//2, 2*16, 32).permute(0,2,1,3).contiguous() # b 32 t f
308
-
309
-
310
- num_frames = quantized_muencoder_emb.shape[-2]
311
-
312
- num_channels_latents = self.num_channels
313
- latents = self.prepare_latents(batch_size, num_frames, num_channels_latents, dtype, device)
314
-
315
- bsz, _, height, width = latents.shape
316
- resolution = torch.tensor([height, width]).repeat(bsz, 1)
317
- aspect_ratio = torch.tensor([float(height / width)]).repeat(bsz, 1)
318
- resolution = resolution.to(dtype=quantized_muencoder_emb.dtype, device=device)
319
- aspect_ratio = aspect_ratio.to(dtype=quantized_muencoder_emb.dtype, device=device)
320
- if classifier_free_guidance:
321
- resolution = torch.cat([resolution, resolution], 0)
322
- aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], 0)
323
- added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
324
-
325
- latent_masks = torch.zeros(latents.shape[0], latents.shape[2], dtype=torch.int64, device=latents.device)
326
- latent_masks[:,0:latent_length] = 2
327
- if(scenario=='other_seg'):
328
- latent_masks[:,0:incontext_length] = 1
329
-
330
-
331
-
332
- quantized_muencoder_emb = (latent_masks > 0.5).unsqueeze(1).unsqueeze(-1) * quantized_muencoder_emb \
333
- + (latent_masks < 0.5).unsqueeze(1).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,32,1,32)
334
- true_latents = self.normfeat.project_sample(true_latents)
335
- incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(1).unsqueeze(-1).float()
336
- incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0]
337
-
338
- additional_model_input = torch.cat([quantized_muencoder_emb],1)
339
-
340
- temperature = 1.0
341
- t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_muencoder_emb.device)
342
- latents = self.cfm_wrapper.solve_euler(latents * temperature, incontext_latents, incontext_length, t_span, additional_model_input, added_cond_kwargs, guidance_scale)
343
-
344
- latents[:,:,0:incontext_length,:] = incontext_latents[:,:,0:incontext_length,:]
345
- latents = self.normfeat.return_sample(latents)
346
- return latents
347
-
348
- @torch.no_grad()
349
- def inference(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
350
- disable_progress=True,layer=5,scenario='start_seg'):
351
- codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer)
352
-
353
- latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
354
- guidance_scale=guidance_scale, num_steps=num_steps, \
355
- disable_progress=disable_progress,scenario=scenario)
356
- return latents
357
-
358
- def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device):
359
- divisor = 4
360
- shape = (batch_size, num_channels_latents, num_frames, 32)
361
- if(num_frames%divisor>0):
362
- num_frames = round(num_frames/float(divisor))*divisor
363
- shape = (batch_size, num_channels_latents, num_frames, 32)
364
- latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
365
- return latents
366
-
367
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MuCodec/models/attention.py DELETED
@@ -1,682 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from typing import Any, Dict, Optional
15
-
16
- import torch
17
- import torch.nn.functional as F
18
- from torch import nn
19
-
20
- from diffusers.utils import USE_PEFT_BACKEND
21
- from diffusers.utils.torch_utils import maybe_allow_in_graph
22
- from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
23
- from diffusers.models.attention_processor import Attention
24
- from diffusers.models.embeddings import SinusoidalPositionalEmbedding
25
- from diffusers.models.lora import LoRACompatibleLinear
26
- from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
27
-
28
-
29
- def _chunked_feed_forward(
30
- ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None
31
- ):
32
- # "feed_forward_chunk_size" can be used to save memory
33
- if hidden_states.shape[chunk_dim] % chunk_size != 0:
34
- raise ValueError(
35
- f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
36
- )
37
-
38
- num_chunks = hidden_states.shape[chunk_dim] // chunk_size
39
- if lora_scale is None:
40
- ff_output = torch.cat(
41
- [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
42
- dim=chunk_dim,
43
- )
44
- else:
45
- # TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete
46
- ff_output = torch.cat(
47
- [ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
48
- dim=chunk_dim,
49
- )
50
-
51
- return ff_output
52
-
53
-
54
- @maybe_allow_in_graph
55
- class GatedSelfAttentionDense(nn.Module):
56
- r"""
57
- A gated self-attention dense layer that combines visual features and object features.
58
-
59
- Parameters:
60
- query_dim (`int`): The number of channels in the query.
61
- context_dim (`int`): The number of channels in the context.
62
- n_heads (`int`): The number of heads to use for attention.
63
- d_head (`int`): The number of channels in each head.
64
- """
65
-
66
- def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
67
- super().__init__()
68
-
69
- # we need a linear projection since we need cat visual feature and obj feature
70
- self.linear = nn.Linear(context_dim, query_dim)
71
-
72
- self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
73
- self.ff = FeedForward(query_dim, activation_fn="geglu")
74
-
75
- self.norm1 = nn.LayerNorm(query_dim)
76
- self.norm2 = nn.LayerNorm(query_dim)
77
-
78
- self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
79
- self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
80
-
81
- self.enabled = True
82
-
83
- def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
84
- if not self.enabled:
85
- return x
86
-
87
- n_visual = x.shape[1]
88
- objs = self.linear(objs)
89
-
90
- x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
91
- x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
92
-
93
- return x
94
-
95
-
96
- @maybe_allow_in_graph
97
- class BasicTransformerBlock(nn.Module):
98
- r"""
99
- A basic Transformer block.
100
-
101
- Parameters:
102
- dim (`int`): The number of channels in the input and output.
103
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
104
- attention_head_dim (`int`): The number of channels in each head.
105
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
106
- cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
107
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
108
- num_embeds_ada_norm (:
109
- obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
110
- attention_bias (:
111
- obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
112
- only_cross_attention (`bool`, *optional*):
113
- Whether to use only cross-attention layers. In this case two cross attention layers are used.
114
- double_self_attention (`bool`, *optional*):
115
- Whether to use two self-attention layers. In this case no cross attention layers are used.
116
- upcast_attention (`bool`, *optional*):
117
- Whether to upcast the attention computation to float32. This is useful for mixed precision training.
118
- norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
119
- Whether to use learnable elementwise affine parameters for normalization.
120
- norm_type (`str`, *optional*, defaults to `"layer_norm"`):
121
- The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
122
- final_dropout (`bool` *optional*, defaults to False):
123
- Whether to apply a final dropout after the last feed-forward layer.
124
- attention_type (`str`, *optional*, defaults to `"default"`):
125
- The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
126
- positional_embeddings (`str`, *optional*, defaults to `None`):
127
- The type of positional embeddings to apply to.
128
- num_positional_embeddings (`int`, *optional*, defaults to `None`):
129
- The maximum number of positional embeddings to apply.
130
- """
131
-
132
- def __init__(
133
- self,
134
- dim: int,
135
- num_attention_heads: int,
136
- attention_head_dim: int,
137
- dropout=0.0,
138
- cross_attention_dim: Optional[int] = None,
139
- activation_fn: str = "geglu",
140
- num_embeds_ada_norm: Optional[int] = None,
141
- attention_bias: bool = False,
142
- only_cross_attention: bool = False,
143
- double_self_attention: bool = False,
144
- upcast_attention: bool = False,
145
- norm_elementwise_affine: bool = True,
146
- norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
147
- norm_eps: float = 1e-5,
148
- final_dropout: bool = False,
149
- attention_type: str = "default",
150
- positional_embeddings: Optional[str] = None,
151
- num_positional_embeddings: Optional[int] = None,
152
- ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
153
- ada_norm_bias: Optional[int] = None,
154
- ff_inner_dim: Optional[int] = None,
155
- ff_bias: bool = True,
156
- attention_out_bias: bool = True,
157
- ):
158
- super().__init__()
159
- self.only_cross_attention = only_cross_attention
160
-
161
- self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
162
- self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
163
- self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
164
- self.use_layer_norm = norm_type == "layer_norm"
165
- self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
166
-
167
- if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
168
- raise ValueError(
169
- f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
170
- f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
171
- )
172
-
173
- if positional_embeddings and (num_positional_embeddings is None):
174
- raise ValueError(
175
- "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
176
- )
177
-
178
- if positional_embeddings == "sinusoidal":
179
- self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
180
- else:
181
- self.pos_embed = None
182
-
183
- # Define 3 blocks. Each block has its own normalization layer.
184
- # 1. Self-Attn
185
- if self.use_ada_layer_norm:
186
- self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
187
- elif self.use_ada_layer_norm_zero:
188
- self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
189
- elif self.use_ada_layer_norm_continuous:
190
- self.norm1 = AdaLayerNormContinuous(
191
- dim,
192
- ada_norm_continous_conditioning_embedding_dim,
193
- norm_elementwise_affine,
194
- norm_eps,
195
- ada_norm_bias,
196
- "rms_norm",
197
- )
198
- else:
199
- self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
200
-
201
- self.attn1 = Attention(
202
- query_dim=dim,
203
- heads=num_attention_heads,
204
- dim_head=attention_head_dim,
205
- dropout=dropout,
206
- bias=attention_bias,
207
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
208
- upcast_attention=upcast_attention,
209
- out_bias=attention_out_bias,
210
- )
211
-
212
- # 2. Cross-Attn
213
- if cross_attention_dim is not None or double_self_attention:
214
- # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
215
- # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
216
- # the second cross attention block.
217
- if self.use_ada_layer_norm:
218
- self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
219
- elif self.use_ada_layer_norm_continuous:
220
- self.norm2 = AdaLayerNormContinuous(
221
- dim,
222
- ada_norm_continous_conditioning_embedding_dim,
223
- norm_elementwise_affine,
224
- norm_eps,
225
- ada_norm_bias,
226
- "rms_norm",
227
- )
228
- else:
229
- self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
230
-
231
- self.attn2 = Attention(
232
- query_dim=dim,
233
- cross_attention_dim=cross_attention_dim if not double_self_attention else None,
234
- heads=num_attention_heads,
235
- dim_head=attention_head_dim,
236
- dropout=dropout,
237
- bias=attention_bias,
238
- upcast_attention=upcast_attention,
239
- out_bias=attention_out_bias,
240
- ) # is self-attn if encoder_hidden_states is none
241
- else:
242
- self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
243
- self.attn2 = None
244
-
245
- # 3. Feed-forward
246
- if self.use_ada_layer_norm_continuous:
247
- self.norm3 = AdaLayerNormContinuous(
248
- dim,
249
- ada_norm_continous_conditioning_embedding_dim,
250
- norm_elementwise_affine,
251
- norm_eps,
252
- ada_norm_bias,
253
- "layer_norm",
254
- )
255
- elif not self.use_ada_layer_norm_single:
256
- self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
257
-
258
- self.ff = FeedForward(
259
- dim,
260
- dropout=dropout,
261
- activation_fn=activation_fn,
262
- final_dropout=final_dropout,
263
- inner_dim=ff_inner_dim,
264
- bias=ff_bias,
265
- )
266
-
267
- # 4. Fuser
268
- if attention_type == "gated" or attention_type == "gated-text-image":
269
- self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
270
-
271
- # 5. Scale-shift for PixArt-Alpha.
272
- if self.use_ada_layer_norm_single:
273
- self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
274
-
275
- # let chunk size default to None
276
- self._chunk_size = None
277
- self._chunk_dim = 0
278
-
279
- def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
280
- # Sets chunk feed-forward
281
- self._chunk_size = chunk_size
282
- self._chunk_dim = dim
283
-
284
- def forward(
285
- self,
286
- hidden_states: torch.FloatTensor,
287
- attention_mask: Optional[torch.FloatTensor] = None,
288
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
289
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
290
- timestep: Optional[torch.LongTensor] = None,
291
- cross_attention_kwargs: Dict[str, Any] = None,
292
- class_labels: Optional[torch.LongTensor] = None,
293
- added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
294
- ) -> torch.FloatTensor:
295
- # Notice that normalization is always applied before the real computation in the following blocks.
296
- # 0. Self-Attention
297
- batch_size = hidden_states.shape[0]
298
-
299
- if self.use_ada_layer_norm:
300
- norm_hidden_states = self.norm1(hidden_states, timestep)
301
- elif self.use_ada_layer_norm_zero:
302
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
303
- hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
304
- )
305
- elif self.use_layer_norm:
306
- norm_hidden_states = self.norm1(hidden_states)
307
- elif self.use_ada_layer_norm_continuous:
308
- norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
309
- elif self.use_ada_layer_norm_single:
310
- # print("Using PixArt-Alpha norm")
311
- # print("time step: ", timestep.shape)
312
- # print("self.scale_shift_table: ", self.scale_shift_table.shape)
313
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
314
- self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
315
- ).chunk(6, dim=1)
316
- norm_hidden_states = self.norm1(hidden_states)
317
- # print("scale_msa: ", scale_msa.shape)
318
- # print("shift_msa: ", shift_msa.shape)
319
- #scale_msa: torch.Size([5, 1, 1152])
320
- #shift_msa: torch.Size([5, 1, 1152])
321
- # exit()
322
- # print("before: ", norm_hidden_states.shape)
323
- #before: torch.Size([5, 3584, 1152])
324
- norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
325
- # print("after: ", norm_hidden_states.shape)
326
- #before: torch.Size([5, 3584, 1152])
327
- # exit()
328
- norm_hidden_states = norm_hidden_states.squeeze(1)
329
- else:
330
- raise ValueError("Incorrect norm used")
331
-
332
- if self.pos_embed is not None:
333
- norm_hidden_states = self.pos_embed(norm_hidden_states)
334
-
335
-
336
- # 1. Retrieve lora scale.
337
- lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
338
-
339
- # 2. Prepare GLIGEN inputs
340
- cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
341
- gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
342
-
343
- attn_output = self.attn1(
344
- norm_hidden_states,
345
- encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
346
- attention_mask=attention_mask,
347
- **cross_attention_kwargs,
348
- )
349
- if self.use_ada_layer_norm_zero:
350
- attn_output = gate_msa.unsqueeze(1) * attn_output
351
- elif self.use_ada_layer_norm_single:
352
- attn_output = gate_msa * attn_output
353
-
354
- hidden_states = attn_output + hidden_states
355
- if hidden_states.ndim == 4:
356
- hidden_states = hidden_states.squeeze(1)
357
-
358
- # 2.5 GLIGEN Control
359
- if gligen_kwargs is not None:
360
- hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
361
-
362
- # 3. Cross-Attention
363
- if self.attn2 is not None:
364
- if self.use_ada_layer_norm:
365
- norm_hidden_states = self.norm2(hidden_states, timestep)
366
- elif self.use_ada_layer_norm_zero or self.use_layer_norm:
367
- norm_hidden_states = self.norm2(hidden_states)
368
- elif self.use_ada_layer_norm_single:
369
- # For PixArt norm2 isn't applied here:
370
- # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
371
- norm_hidden_states = hidden_states
372
- elif self.use_ada_layer_norm_continuous:
373
- norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
374
- else:
375
- raise ValueError("Incorrect norm")
376
-
377
- if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
378
- norm_hidden_states = self.pos_embed(norm_hidden_states)
379
-
380
- attn_output = self.attn2(
381
- norm_hidden_states,
382
- encoder_hidden_states=encoder_hidden_states,
383
- attention_mask=encoder_attention_mask,
384
- **cross_attention_kwargs,
385
- )
386
- hidden_states = attn_output + hidden_states
387
-
388
- # 4. Feed-forward
389
- if self.use_ada_layer_norm_continuous:
390
- norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
391
- elif not self.use_ada_layer_norm_single:
392
- norm_hidden_states = self.norm3(hidden_states)
393
-
394
- if self.use_ada_layer_norm_zero:
395
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
396
-
397
- if self.use_ada_layer_norm_single:
398
- norm_hidden_states = self.norm2(hidden_states)
399
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
400
-
401
- if self._chunk_size is not None:
402
- # "feed_forward_chunk_size" can be used to save memory
403
- ff_output = _chunked_feed_forward(
404
- self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale
405
- )
406
- else:
407
- ff_output = self.ff(norm_hidden_states, scale=lora_scale)
408
-
409
- if self.use_ada_layer_norm_zero:
410
- ff_output = gate_mlp.unsqueeze(1) * ff_output
411
- elif self.use_ada_layer_norm_single:
412
- ff_output = gate_mlp * ff_output
413
-
414
- hidden_states = ff_output + hidden_states
415
- if hidden_states.ndim == 4:
416
- hidden_states = hidden_states.squeeze(1)
417
-
418
- return hidden_states
419
-
420
-
421
- @maybe_allow_in_graph
422
- class TemporalBasicTransformerBlock(nn.Module):
423
- r"""
424
- A basic Transformer block for video like data.
425
-
426
- Parameters:
427
- dim (`int`): The number of channels in the input and output.
428
- time_mix_inner_dim (`int`): The number of channels for temporal attention.
429
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
430
- attention_head_dim (`int`): The number of channels in each head.
431
- cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
432
- """
433
-
434
- def __init__(
435
- self,
436
- dim: int,
437
- time_mix_inner_dim: int,
438
- num_attention_heads: int,
439
- attention_head_dim: int,
440
- cross_attention_dim: Optional[int] = None,
441
- ):
442
- super().__init__()
443
- self.is_res = dim == time_mix_inner_dim
444
-
445
- self.norm_in = nn.LayerNorm(dim)
446
-
447
- # Define 3 blocks. Each block has its own normalization layer.
448
- # 1. Self-Attn
449
- self.norm_in = nn.LayerNorm(dim)
450
- self.ff_in = FeedForward(
451
- dim,
452
- dim_out=time_mix_inner_dim,
453
- activation_fn="geglu",
454
- )
455
-
456
- self.norm1 = nn.LayerNorm(time_mix_inner_dim)
457
- self.attn1 = Attention(
458
- query_dim=time_mix_inner_dim,
459
- heads=num_attention_heads,
460
- dim_head=attention_head_dim,
461
- cross_attention_dim=None,
462
- )
463
-
464
- # 2. Cross-Attn
465
- if cross_attention_dim is not None:
466
- # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
467
- # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
468
- # the second cross attention block.
469
- self.norm2 = nn.LayerNorm(time_mix_inner_dim)
470
- self.attn2 = Attention(
471
- query_dim=time_mix_inner_dim,
472
- cross_attention_dim=cross_attention_dim,
473
- heads=num_attention_heads,
474
- dim_head=attention_head_dim,
475
- ) # is self-attn if encoder_hidden_states is none
476
- else:
477
- self.norm2 = None
478
- self.attn2 = None
479
-
480
- # 3. Feed-forward
481
- self.norm3 = nn.LayerNorm(time_mix_inner_dim)
482
- self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
483
-
484
- # let chunk size default to None
485
- self._chunk_size = None
486
- self._chunk_dim = None
487
-
488
- def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
489
- # Sets chunk feed-forward
490
- self._chunk_size = chunk_size
491
- # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
492
- self._chunk_dim = 1
493
-
494
- def forward(
495
- self,
496
- hidden_states: torch.FloatTensor,
497
- num_frames: int,
498
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
499
- ) -> torch.FloatTensor:
500
- # Notice that normalization is always applied before the real computation in the following blocks.
501
- # 0. Self-Attention
502
- batch_size = hidden_states.shape[0]
503
-
504
- batch_frames, seq_length, channels = hidden_states.shape
505
- batch_size = batch_frames // num_frames
506
-
507
- hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
508
- hidden_states = hidden_states.permute(0, 2, 1, 3)
509
- hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
510
-
511
- residual = hidden_states
512
- hidden_states = self.norm_in(hidden_states)
513
-
514
- if self._chunk_size is not None:
515
- hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
516
- else:
517
- hidden_states = self.ff_in(hidden_states)
518
-
519
- if self.is_res:
520
- hidden_states = hidden_states + residual
521
-
522
- norm_hidden_states = self.norm1(hidden_states)
523
- attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
524
- hidden_states = attn_output + hidden_states
525
-
526
- # 3. Cross-Attention
527
- if self.attn2 is not None:
528
- norm_hidden_states = self.norm2(hidden_states)
529
- attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
530
- hidden_states = attn_output + hidden_states
531
-
532
- # 4. Feed-forward
533
- norm_hidden_states = self.norm3(hidden_states)
534
-
535
- if self._chunk_size is not None:
536
- ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
537
- else:
538
- ff_output = self.ff(norm_hidden_states)
539
-
540
- if self.is_res:
541
- hidden_states = ff_output + hidden_states
542
- else:
543
- hidden_states = ff_output
544
-
545
- hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
546
- hidden_states = hidden_states.permute(0, 2, 1, 3)
547
- hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
548
-
549
- return hidden_states
550
-
551
-
552
- class SkipFFTransformerBlock(nn.Module):
553
- def __init__(
554
- self,
555
- dim: int,
556
- num_attention_heads: int,
557
- attention_head_dim: int,
558
- kv_input_dim: int,
559
- kv_input_dim_proj_use_bias: bool,
560
- dropout=0.0,
561
- cross_attention_dim: Optional[int] = None,
562
- attention_bias: bool = False,
563
- attention_out_bias: bool = True,
564
- ):
565
- super().__init__()
566
- if kv_input_dim != dim:
567
- self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
568
- else:
569
- self.kv_mapper = None
570
-
571
- self.norm1 = RMSNorm(dim, 1e-06)
572
-
573
- self.attn1 = Attention(
574
- query_dim=dim,
575
- heads=num_attention_heads,
576
- dim_head=attention_head_dim,
577
- dropout=dropout,
578
- bias=attention_bias,
579
- cross_attention_dim=cross_attention_dim,
580
- out_bias=attention_out_bias,
581
- )
582
-
583
- self.norm2 = RMSNorm(dim, 1e-06)
584
-
585
- self.attn2 = Attention(
586
- query_dim=dim,
587
- cross_attention_dim=cross_attention_dim,
588
- heads=num_attention_heads,
589
- dim_head=attention_head_dim,
590
- dropout=dropout,
591
- bias=attention_bias,
592
- out_bias=attention_out_bias,
593
- )
594
-
595
- def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
596
- cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
597
-
598
- if self.kv_mapper is not None:
599
- encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
600
-
601
- norm_hidden_states = self.norm1(hidden_states)
602
-
603
- attn_output = self.attn1(
604
- norm_hidden_states,
605
- encoder_hidden_states=encoder_hidden_states,
606
- **cross_attention_kwargs,
607
- )
608
-
609
- hidden_states = attn_output + hidden_states
610
-
611
- norm_hidden_states = self.norm2(hidden_states)
612
-
613
- attn_output = self.attn2(
614
- norm_hidden_states,
615
- encoder_hidden_states=encoder_hidden_states,
616
- **cross_attention_kwargs,
617
- )
618
-
619
- hidden_states = attn_output + hidden_states
620
-
621
- return hidden_states
622
-
623
-
624
- class FeedForward(nn.Module):
625
- r"""
626
- A feed-forward layer.
627
-
628
- Parameters:
629
- dim (`int`): The number of channels in the input.
630
- dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
631
- mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
632
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
633
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
634
- final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
635
- bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
636
- """
637
-
638
- def __init__(
639
- self,
640
- dim: int,
641
- dim_out: Optional[int] = None,
642
- mult: int = 4,
643
- dropout: float = 0.0,
644
- activation_fn: str = "geglu",
645
- final_dropout: bool = False,
646
- inner_dim=None,
647
- bias: bool = True,
648
- ):
649
- super().__init__()
650
- if inner_dim is None:
651
- inner_dim = int(dim * mult)
652
- dim_out = dim_out if dim_out is not None else dim
653
- linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
654
-
655
- if activation_fn == "gelu":
656
- act_fn = GELU(dim, inner_dim, bias=bias)
657
- if activation_fn == "gelu-approximate":
658
- act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
659
- elif activation_fn == "geglu":
660
- act_fn = GEGLU(dim, inner_dim, bias=bias)
661
- elif activation_fn == "geglu-approximate":
662
- act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
663
-
664
- self.net = nn.ModuleList([])
665
- # project in
666
- self.net.append(act_fn)
667
- # project dropout
668
- self.net.append(nn.Dropout(dropout))
669
- # project out
670
- self.net.append(linear_cls(inner_dim, dim_out, bias=bias))
671
- # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
672
- if final_dropout:
673
- self.net.append(nn.Dropout(dropout))
674
-
675
- def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
676
- compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
677
- for module in self.net:
678
- if isinstance(module, compatible_cls):
679
- hidden_states = module(hidden_states, scale)
680
- else:
681
- hidden_states = module(hidden_states)
682
- return hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MuCodec/models/transformer_2d_flow.py DELETED
@@ -1,545 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from dataclasses import dataclass
15
- import math
16
- from typing import Any, Dict, Optional, Tuple
17
-
18
- import torch
19
- import torch.nn.functional as F
20
- from torch import nn
21
-
22
- from diffusers.configuration_utils import ConfigMixin, register_to_config
23
- from diffusers.models.embeddings import ImagePositionalEmbeddings
24
- from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
25
- from models.attention import BasicTransformerBlock
26
- from diffusers.models.embeddings import PatchEmbed, PixArtAlphaTextProjection
27
- from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
28
- from diffusers.models.modeling_utils import ModelMixin
29
- from diffusers.models.embeddings import TimestepEmbedding
30
-
31
- class PixArtAlphaCombinedFlowEmbeddings(nn.Module):
32
- """
33
- For PixArt-Alpha.
34
-
35
- Reference:
36
- https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
37
- """
38
-
39
- def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
40
- super().__init__()
41
-
42
- self.flow_t_size = 512
43
- self.outdim = size_emb_dim
44
- self.timestep_embedder = TimestepEmbedding(in_channels=self.flow_t_size, time_embed_dim=embedding_dim)
45
-
46
- self.use_additional_conditions = use_additional_conditions
47
- if use_additional_conditions:
48
- self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
49
- self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
50
- self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
51
-
52
- # https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/models/unet/nn.py#L87
53
- def timestep_embedding(self, timesteps, max_period=10000, scale=1000):
54
- """Create sinusoidal timestep embeddings.
55
-
56
- :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional.
57
- :param dim: the dimension of the output.
58
- :param max_period: controls the minimum frequency of the embeddings.
59
- :return: an [N x dim] Tensor of positional embeddings.
60
- """
61
- half = self.flow_t_size // 2
62
- freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, device=timesteps.device) / half).type(timesteps.type())
63
- args = timesteps[:, None] * freqs[None] * scale
64
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
65
- if self.flow_t_size % 2:
66
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
67
- return embedding
68
-
69
- def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
70
- timesteps_proj = self.timestep_embedding(timestep)
71
- timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
72
-
73
- if self.use_additional_conditions:
74
- resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
75
- resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
76
- aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype)
77
- aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1)
78
- conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1)
79
- else:
80
- conditioning = timesteps_emb
81
-
82
- return conditioning
83
-
84
- class AdaLayerNormSingleFlow(nn.Module):
85
- r"""
86
- Norm layer adaptive layer norm single (adaLN-single).
87
-
88
- As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
89
-
90
- Parameters:
91
- embedding_dim (`int`): The size of each embedding vector.
92
- use_additional_conditions (`bool`): To use additional conditions for normalization or not.
93
- """
94
-
95
- def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
96
- super().__init__()
97
-
98
- self.emb = PixArtAlphaCombinedFlowEmbeddings(
99
- embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
100
- )
101
-
102
- self.silu = nn.SiLU()
103
- self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
104
-
105
- def forward(
106
- self,
107
- timestep: torch.Tensor,
108
- added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
109
- batch_size: Optional[int] = None,
110
- hidden_dtype: Optional[torch.dtype] = None,
111
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
112
- # No modulation happening here.
113
- embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
114
- return self.linear(self.silu(embedded_timestep)), embedded_timestep
115
-
116
-
117
- @dataclass
118
- class Transformer2DModelOutput(BaseOutput):
119
- """
120
- The output of [`Transformer2DModel`].
121
-
122
- Args:
123
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
124
- The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
125
- distributions for the unnoised latent pixels.
126
- """
127
-
128
- sample: torch.FloatTensor
129
-
130
-
131
- class Transformer2DModel(ModelMixin, ConfigMixin):
132
- """
133
- A 2D Transformer model for image-like data.
134
-
135
- Parameters:
136
- num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
137
- attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
138
- in_channels (`int`, *optional*):
139
- The number of channels in the input and output (specify if the input is **continuous**).
140
- num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
141
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
142
- cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
143
- sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
144
- This is fixed during training since it is used to learn a number of position embeddings.
145
- num_vector_embeds (`int`, *optional*):
146
- The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
147
- Includes the class for the masked latent pixel.
148
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
149
- num_embeds_ada_norm ( `int`, *optional*):
150
- The number of diffusion steps used during training. Pass if at least one of the norm_layers is
151
- `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
152
- added to the hidden states.
153
-
154
- During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
155
- attention_bias (`bool`, *optional*):
156
- Configure if the `TransformerBlocks` attention should contain a bias parameter.
157
- """
158
-
159
- _supports_gradient_checkpointing = True
160
-
161
- @register_to_config
162
- def __init__(
163
- self,
164
- num_attention_heads: int = 16,
165
- attention_head_dim: int = 88,
166
- in_channels: Optional[int] = None,
167
- out_channels: Optional[int] = None,
168
- num_layers: int = 1,
169
- dropout: float = 0.0,
170
- norm_num_groups: int = 32,
171
- cross_attention_dim: Optional[int] = None,
172
- attention_bias: bool = False,
173
- sample_size: Optional[int] = None,
174
- num_vector_embeds: Optional[int] = None,
175
- patch_size: Optional[int] = None,
176
- activation_fn: str = "geglu",
177
- num_embeds_ada_norm: Optional[int] = None,
178
- use_linear_projection: bool = False,
179
- only_cross_attention: bool = False,
180
- double_self_attention: bool = False,
181
- upcast_attention: bool = False,
182
- norm_type: str = "layer_norm",
183
- norm_elementwise_affine: bool = True,
184
- norm_eps: float = 1e-5,
185
- attention_type: str = "default",
186
- caption_channels: int = None,
187
- ):
188
- super().__init__()
189
- self.use_linear_projection = use_linear_projection
190
- self.num_attention_heads = num_attention_heads
191
- self.attention_head_dim = attention_head_dim
192
- inner_dim = num_attention_heads * attention_head_dim
193
-
194
- conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
195
- linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
196
-
197
- # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
198
- # Define whether input is continuous or discrete depending on configuration
199
- self.is_input_continuous = (in_channels is not None) and (patch_size is None)
200
- self.is_input_vectorized = num_vector_embeds is not None
201
- self.is_input_patches = in_channels is not None and patch_size is not None
202
-
203
- if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
204
- deprecation_message = (
205
- f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
206
- " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
207
- " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
208
- " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
209
- " would be very nice if you could open a Pull request for the `transformer/config.json` file"
210
- )
211
- deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
212
- norm_type = "ada_norm"
213
-
214
- if self.is_input_continuous and self.is_input_vectorized:
215
- raise ValueError(
216
- f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
217
- " sure that either `in_channels` or `num_vector_embeds` is None."
218
- )
219
- elif self.is_input_vectorized and self.is_input_patches:
220
- raise ValueError(
221
- f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
222
- " sure that either `num_vector_embeds` or `num_patches` is None."
223
- )
224
- elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
225
- raise ValueError(
226
- f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
227
- f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
228
- )
229
-
230
- # 2. Define input layers
231
- if self.is_input_continuous:
232
- self.in_channels = in_channels
233
-
234
- self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
235
- if use_linear_projection:
236
- self.proj_in = linear_cls(in_channels, inner_dim)
237
- else:
238
- self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
239
- elif self.is_input_vectorized:
240
- assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
241
- assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
242
-
243
- self.height = sample_size
244
- self.width = sample_size
245
- self.num_vector_embeds = num_vector_embeds
246
- self.num_latent_pixels = self.height * self.width
247
-
248
- self.latent_image_embedding = ImagePositionalEmbeddings(
249
- num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
250
- )
251
- elif self.is_input_patches:
252
- assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
253
-
254
- self.height = sample_size
255
- self.width = sample_size
256
-
257
- self.patch_size = patch_size
258
- interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
259
- interpolation_scale = max(interpolation_scale, 1)
260
- self.pos_embed = PatchEmbed(
261
- height=sample_size,
262
- width=sample_size,
263
- patch_size=patch_size,
264
- in_channels=in_channels,
265
- embed_dim=inner_dim,
266
- interpolation_scale=interpolation_scale,
267
- )
268
-
269
- # 3. Define transformers blocks
270
- self.transformer_blocks = nn.ModuleList(
271
- [
272
- BasicTransformerBlock(
273
- inner_dim,
274
- num_attention_heads,
275
- attention_head_dim,
276
- dropout=dropout,
277
- cross_attention_dim=cross_attention_dim,
278
- activation_fn=activation_fn,
279
- num_embeds_ada_norm=num_embeds_ada_norm,
280
- attention_bias=attention_bias,
281
- only_cross_attention=only_cross_attention,
282
- double_self_attention=double_self_attention,
283
- upcast_attention=upcast_attention,
284
- norm_type=norm_type,
285
- norm_elementwise_affine=norm_elementwise_affine,
286
- norm_eps=norm_eps,
287
- attention_type=attention_type,
288
- )
289
- for d in range(num_layers)
290
- ]
291
- )
292
-
293
- # 4. Define output layers
294
- self.out_channels = in_channels if out_channels is None else out_channels
295
- if self.is_input_continuous:
296
- # TODO: should use out_channels for continuous projections
297
- if use_linear_projection:
298
- self.proj_out = linear_cls(inner_dim, in_channels)
299
- else:
300
- self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
301
- elif self.is_input_vectorized:
302
- self.norm_out = nn.LayerNorm(inner_dim)
303
- self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
304
- elif self.is_input_patches and norm_type != "ada_norm_single":
305
- self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
306
- self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
307
- self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
308
- elif self.is_input_patches and norm_type == "ada_norm_single":
309
- self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
310
- self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
311
- self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
312
-
313
- # 5. PixArt-Alpha blocks.
314
- self.adaln_single = None
315
- self.use_additional_conditions = False
316
- if norm_type == "ada_norm_single":
317
- self.use_additional_conditions = self.config.sample_size == 128
318
- # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
319
- # additional conditions until we find better name
320
- self.adaln_single = AdaLayerNormSingleFlow(inner_dim, use_additional_conditions=self.use_additional_conditions)
321
-
322
- self.caption_projection = None
323
- if caption_channels is not None:
324
- self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
325
-
326
- self.gradient_checkpointing = False
327
-
328
- def _set_gradient_checkpointing(self, module, value=False):
329
- if hasattr(module, "gradient_checkpointing"):
330
- module.gradient_checkpointing = value
331
-
332
- def forward(
333
- self,
334
- hidden_states: torch.Tensor,
335
- encoder_hidden_states: Optional[torch.Tensor] = None,
336
- timestep: Optional[torch.LongTensor] = None,
337
- added_cond_kwargs: Dict[str, torch.Tensor] = None,
338
- class_labels: Optional[torch.LongTensor] = None,
339
- cross_attention_kwargs: Dict[str, Any] = None,
340
- attention_mask: Optional[torch.Tensor] = None,
341
- encoder_attention_mask: Optional[torch.Tensor] = None,
342
- return_dict: bool = True,
343
- ):
344
- """
345
- The [`Transformer2DModel`] forward method.
346
-
347
- Args:
348
- hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
349
- Input `hidden_states`.
350
- encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
351
- Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
352
- self-attention.
353
- timestep ( `torch.LongTensor`, *optional*):
354
- Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
355
- class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
356
- Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
357
- `AdaLayerZeroNorm`.
358
- cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
359
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
360
- `self.processor` in
361
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
362
- attention_mask ( `torch.Tensor`, *optional*):
363
- An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
364
- is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
365
- negative values to the attention scores corresponding to "discard" tokens.
366
- encoder_attention_mask ( `torch.Tensor`, *optional*):
367
- Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
368
-
369
- * Mask `(batch, sequence_length)` True = keep, False = discard.
370
- * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
371
-
372
- If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
373
- above. This bias will be added to the cross-attention scores.
374
- return_dict (`bool`, *optional*, defaults to `True`):
375
- Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
376
- tuple.
377
-
378
- Returns:
379
- If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
380
- `tuple` where the first element is the sample tensor.
381
- """
382
- # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
383
- # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
384
- # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
385
- # expects mask of shape:
386
- # [batch, key_tokens]
387
- # adds singleton query_tokens dimension:
388
- # [batch, 1, key_tokens]
389
- # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
390
- # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
391
- # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
392
- if attention_mask is not None and attention_mask.ndim == 2:
393
- # assume that mask is expressed as:
394
- # (1 = keep, 0 = discard)
395
- # convert mask into a bias that can be added to attention scores:
396
- # (keep = +0, discard = -10000.0)
397
- attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
398
- attention_mask = attention_mask.unsqueeze(1)
399
-
400
- # convert encoder_attention_mask to a bias the same way we do for attention_mask
401
- if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
402
- encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
403
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
404
-
405
- # Retrieve lora scale.
406
- lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
407
-
408
- # 1. Input
409
- if self.is_input_continuous:
410
- batch, _, height, width = hidden_states.shape
411
- residual = hidden_states
412
-
413
- hidden_states = self.norm(hidden_states)
414
- if not self.use_linear_projection:
415
- hidden_states = (
416
- self.proj_in(hidden_states, scale=lora_scale)
417
- if not USE_PEFT_BACKEND
418
- else self.proj_in(hidden_states)
419
- )
420
- inner_dim = hidden_states.shape[1]
421
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
422
- else:
423
- inner_dim = hidden_states.shape[1]
424
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
425
- hidden_states = (
426
- self.proj_in(hidden_states, scale=lora_scale)
427
- if not USE_PEFT_BACKEND
428
- else self.proj_in(hidden_states)
429
- )
430
-
431
- elif self.is_input_vectorized:
432
- hidden_states = self.latent_image_embedding(hidden_states)
433
- elif self.is_input_patches:
434
- height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
435
- hidden_states = self.pos_embed(hidden_states)
436
-
437
- if self.adaln_single is not None:
438
- if self.use_additional_conditions and added_cond_kwargs is None:
439
- raise ValueError(
440
- "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
441
- )
442
- batch_size = hidden_states.shape[0]
443
- timestep, embedded_timestep = self.adaln_single(
444
- timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
445
- )
446
-
447
- # 2. Blocks
448
- if self.caption_projection is not None:
449
- batch_size = hidden_states.shape[0]
450
- encoder_hidden_states = self.caption_projection(encoder_hidden_states)
451
- encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
452
-
453
- for block in self.transformer_blocks:
454
- if self.training and self.gradient_checkpointing:
455
-
456
- def create_custom_forward(module, return_dict=None):
457
- def custom_forward(*inputs):
458
- if return_dict is not None:
459
- return module(*inputs, return_dict=return_dict)
460
- else:
461
- return module(*inputs)
462
-
463
- return custom_forward
464
-
465
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
466
- hidden_states = torch.utils.checkpoint.checkpoint(
467
- create_custom_forward(block),
468
- hidden_states,
469
- attention_mask,
470
- encoder_hidden_states,
471
- encoder_attention_mask,
472
- timestep,
473
- cross_attention_kwargs,
474
- class_labels,
475
- **ckpt_kwargs,
476
- )
477
- else:
478
- hidden_states = block(
479
- hidden_states,
480
- attention_mask=attention_mask,
481
- encoder_hidden_states=encoder_hidden_states,
482
- encoder_attention_mask=encoder_attention_mask,
483
- timestep=timestep,
484
- cross_attention_kwargs=cross_attention_kwargs,
485
- class_labels=class_labels,
486
- )
487
-
488
- # 3. Output
489
- if self.is_input_continuous:
490
- if not self.use_linear_projection:
491
- hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
492
- hidden_states = (
493
- self.proj_out(hidden_states, scale=lora_scale)
494
- if not USE_PEFT_BACKEND
495
- else self.proj_out(hidden_states)
496
- )
497
- else:
498
- hidden_states = (
499
- self.proj_out(hidden_states, scale=lora_scale)
500
- if not USE_PEFT_BACKEND
501
- else self.proj_out(hidden_states)
502
- )
503
- hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
504
-
505
- output = hidden_states + residual
506
- elif self.is_input_vectorized:
507
- hidden_states = self.norm_out(hidden_states)
508
- logits = self.out(hidden_states)
509
- # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
510
- logits = logits.permute(0, 2, 1)
511
-
512
- # log(p(x_0))
513
- output = F.log_softmax(logits.double(), dim=1).float()
514
-
515
- if self.is_input_patches:
516
- if self.config.norm_type != "ada_norm_single":
517
- conditioning = self.transformer_blocks[0].norm1.emb(
518
- timestep, class_labels, hidden_dtype=hidden_states.dtype
519
- )
520
- shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
521
- hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
522
- hidden_states = self.proj_out_2(hidden_states)
523
- elif self.config.norm_type == "ada_norm_single":
524
- shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
525
- hidden_states = self.norm_out(hidden_states)
526
- # Modulation
527
- hidden_states = hidden_states * (1 + scale) + shift
528
- hidden_states = self.proj_out(hidden_states)
529
- hidden_states = hidden_states.squeeze(1)
530
-
531
- # unpatchify
532
- if self.adaln_single is None:
533
- height = width = int(hidden_states.shape[1] ** 0.5)
534
- hidden_states = hidden_states.reshape(
535
- shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
536
- )
537
- hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
538
- output = hidden_states.reshape(
539
- shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
540
- )
541
-
542
- if not return_dict:
543
- return (output,)
544
-
545
- return Transformer2DModelOutput(sample=output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MuCodec/muq_dev/muq_fairseq/data/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .mert_dataset import MERTDataset
 
 
MuCodec/muq_dev/muq_fairseq/data/ark_dataset.py DELETED
@@ -1,71 +0,0 @@
1
- import logging
2
- import torch
3
- import torch.nn.functional as F
4
- from fairseq.data.audio.raw_audio_dataset import RawAudioDataset
5
- from typing import Tuple
6
- try:
7
- import kaldiio
8
- except:
9
- kaldiio = None
10
- import warnings
11
-
12
- logger = logging.getLogger(__name__)
13
-
14
-
15
- class ArkDataset(RawAudioDataset):
16
- def __init__(
17
- self,
18
- wav_scp,
19
- dur_scp,
20
- sr = 24000,
21
- max_dur = 20,
22
- num_buckets=0,
23
- normalize=False,
24
- ):
25
- super().__init__(
26
- sample_rate=sr,
27
- max_sample_size=max_dur*sr,
28
- min_sample_size=1200,
29
- shuffle=True,
30
- pad=True,
31
- normalize=normalize,
32
- compute_mask=False,
33
- )
34
- self.sr = sr
35
- self.max_dur = max_dur
36
- self.normalize = normalize
37
-
38
- logger.info("Loading Kaldi scp files from {}".format(wav_scp))
39
-
40
- self.wav_data = kaldiio.load_scp(wav_scp)
41
- self.keys = list(self.wav_data.keys())
42
- dur_data = {}
43
- keys_set = set(self.keys)
44
-
45
- with open(dur_scp, 'r') as f:
46
- for line in f:
47
- line = line.strip().split()
48
- if line[0] in keys_set:
49
- dur_data[line[0]] = float(line[-1])
50
- self.sizes = [int(dur_data[k]*self.sr/100) for k in self.keys]
51
-
52
- logger.info("Loading Kaldi scp files done")
53
-
54
- self.dataset_len = len(self.keys)
55
- self.set_bucket_info(num_buckets)
56
-
57
- def __len__(self):
58
- return self.dataset_len
59
-
60
- def __getitem__(self, idx):
61
- pass
62
-
63
- def size(self, idx):
64
- pass
65
-
66
- def postprocess(self, wav):
67
- pass
68
-
69
- def collater(self, samples):
70
- pass
71
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MuCodec/muq_dev/muq_fairseq/data/mert_dataset.py DELETED
@@ -1,295 +0,0 @@
1
- # Copyright (c) Facebook, Inc. and its affiliates.
2
- #
3
- # This source code is licensed under the MIT license found in the
4
- # LICENSE file in the root directory of this source tree.
5
-
6
- import itertools
7
- import logging
8
- import os
9
- import sys
10
- from typing import Any, List, Optional, Union
11
-
12
- import numpy as np
13
- from typing import Tuple
14
- import torch
15
- import torch.nn.functional as F
16
- from fairseq.data import data_utils
17
- from fairseq.data.fairseq_dataset import FairseqDataset
18
- from fairseq.data.audio.audio_utils import (
19
- parse_path,
20
- read_from_stored_zip,
21
- )
22
-
23
- import math
24
- import io
25
- import torchaudio
26
- # this is in the user_dir
27
- from nnAudio import features as nnAudioFeatures
28
-
29
- # from tqdm import tqdm
30
- import tqdm
31
- import json
32
- import random
33
- import traceback
34
- from einops import rearrange
35
- # from scripts.prepare_codecs_from_manifest import *
36
-
37
- logger = logging.getLogger(__name__)
38
-
39
- class model_cqt_pred(torch.nn.Module):
40
- def __init__(self, n_bins=84, sr=16000, freq=50):
41
- super().__init__()
42
- self.epsilon=1e-10
43
- # Getting Mel Spectrogram on the fly
44
- self.spec_layer = nnAudioFeatures.cqt.CQT(sr=sr, hop_length=sr//freq, fmin=32.7,
45
- fmax=None, n_bins=n_bins, bins_per_octave=n_bins//7,
46
- filter_scale=1, norm=1, window='hann', center=True,
47
- pad_mode='constant', trainable=False,
48
- output_format='Magnitude', verbose=True)
49
-
50
- # self.fc = nn.Linear(input_dim, n_bins)
51
-
52
- # self.criterion = nn.MSELoss()
53
- self.forward_dict = {
54
- # 'masked_transformer_output': self.plain_forward
55
- 'compute_cqt': self.compute_cqt
56
- }
57
- def compute_cqt(self, x):
58
- '''
59
- convert waveform to CQT -> [batch, bins, len] -> transpose
60
- '''
61
- # align with the padding of HuBERT model,
62
- # the truncation is calculated by bruteforce search since the nnAudio padding strategy and fairseq models are different
63
- # x = x[..., :-560]
64
- return torch.transpose(self.spec_layer(x), -1, -2)
65
-
66
- def forward(self, x, forward_type='masked_transformer_output'):
67
- '''
68
- take input from transformer hidden states: [batch, len_seq, channel]
69
- output: [batch, len_seq, n_bins]
70
- '''
71
-
72
- return self.forward_dict[forward_type](x)
73
-
74
- def load_audio_by_json(json_path, max_keep, min_keep, tgt_sample_rate, clip_secs=5):
75
- # read json file
76
- print(json_path)
77
- datas = []
78
- inds = []
79
- sizes = []
80
- with open(json_path) as fp:
81
- for ind,line in enumerate(fp):
82
- data = json.loads(line)
83
- if 'duration' in data and min_keep is not None and tgt_sample_rate*data['duration'] < min_keep:
84
- continue
85
- datas.append(data)
86
- inds.append(ind)
87
- # sz = int(data['duration'] * data['sample_rate'])
88
- if clip_secs > 0:
89
- sz = int(tgt_sample_rate * clip_secs)
90
- else:
91
- sz = int(tgt_sample_rate * data['duration'])
92
- sizes.append(sz)
93
- tot = ind + 1
94
- return datas,inds,tot,sizes
95
- def load_audio(manifest_path, max_keep, min_keep):
96
- pass
97
-
98
-
99
- def load_label(label_path, inds, tot):
100
- pass
101
-
102
- def load_numpy_label(label_path, inds, tot):
103
- labels = np.load(label_path, mmap_mode='r')
104
- assert (labels.shape[0] == tot), f"number of labels does not match ({labels.shape[0]} != {tot})"
105
- return labels
106
-
107
- def verify_label_lengths(
108
- audio_sizes,
109
- audio_rate,
110
- label_path,
111
- label_rate,
112
- inds,
113
- tot,
114
- tol=0.1, # tolerance in seconds
115
- ):
116
- pass
117
-
118
- class Read_and_PadCrop_Normalized_T(torch.nn.Module):
119
- def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True):
120
-
121
- super().__init__()
122
-
123
- self.n_samples = n_samples
124
- self.sample_rate = sample_rate
125
- self.randomize = randomize
126
-
127
-
128
- def __call__(self, filename: str, duration: float, cur_sample_rate: int, fixed_offset_duration=None) -> Tuple[torch.Tensor, float, float, int, int]:
129
- pass
130
-
131
-
132
- class MERTDataset(FairseqDataset):
133
- def __init__(
134
- self,
135
- manifest_path: str,
136
- sample_rate: float,
137
- label_paths: List[str],
138
- label_rates: Union[List[float], float], # -1 for sequence labels
139
- pad_list: List[str],
140
- eos_list: List[str],
141
- label_scp_path: Optional[str] = None,
142
- label_scp_clip_duration: float = -1,
143
- label_processors: Optional[List[Any]] = None,
144
- max_keep_sample_size: Optional[int] = None,
145
- min_keep_sample_size: Optional[int] = None,
146
- max_sample_size: Optional[int] = None,
147
- shuffle: bool = True,
148
- pad_audio: bool = False,
149
- normalize: bool = False,
150
- store_labels: bool = True,
151
- npmemmap: bool = False,
152
- random_crop: bool = False,
153
- single_target: bool = False,
154
- augmentation_effects: List[str] = [],
155
- augmentation_probs: List[float] = [],
156
- inbatch_noise_augment_len_range: List[int] = [8000, 24000],
157
- inbatch_noise_augment_number_range: List[int] = [1, 3],
158
- inbatch_noise_augment_volume: float = 1.0,
159
- cqt_prediction_bin: int = -1,
160
- dataset_len:int = 128*3000,
161
- clip_secs = 5,
162
- ):
163
- self.sample_rate = sample_rate
164
- self.shuffle = shuffle
165
- self.random_crop = random_crop
166
- self.datas,inds,tot,self.sizes = load_audio_by_json(manifest_path,max_keep_sample_size,min_keep_sample_size, self.sample_rate, clip_secs)
167
- self.inds = inds
168
-
169
- self.num_labels = len(label_paths)
170
- self.pad_list = pad_list
171
- self.eos_list = eos_list
172
- self.label_processors = label_processors
173
- self.single_target = single_target
174
- self.label_rates = (
175
- [label_rates for _ in range(len(label_paths))]
176
- if isinstance(label_rates, float)
177
- else label_rates
178
- )
179
- self.store_labels = store_labels
180
- self.npmemmap = npmemmap
181
- self.label_scp_path = label_scp_path
182
- self.label_scp_clip_duration = label_scp_clip_duration
183
-
184
-
185
- if self.label_scp_path is not None:
186
- from kaldiio import load_scp
187
- self.label_scp = load_scp(self.label_scp_path)
188
-
189
- # self.dataset_len = dataset_len
190
- self.dataset_len = len(self.datas)
191
- logger.info('preparing labels')
192
- logger.info('========dataset len: {}=========='.format(self.dataset_len))
193
- if store_labels:
194
- if self.npmemmap:
195
- self.label_list = [load_numpy_label(p+'.npy', inds, tot) for p in label_paths]
196
- else:
197
- self.label_list = [load_label(p, inds, tot) for p in label_paths]
198
- else:
199
- self.label_paths = label_paths
200
- # self.label_offsets_list = [
201
- # load_label_offset(p, inds, tot) for p in label_paths
202
- # ]
203
- assert label_processors is None or len(label_processors) == self.num_labels
204
-
205
-
206
- self.max_sample_size = (
207
- max_sample_size if max_sample_size is not None else sys.maxsize
208
- )
209
- self.pad_audio = pad_audio
210
- self.normalize = normalize
211
- logger.info(
212
- f"pad_audio={pad_audio}, random_crop={random_crop}, "
213
- f"normalize={normalize}, max_sample_size={self.max_sample_size}"
214
- )
215
-
216
- self.augmentation_effects = augmentation_effects
217
- self.augmentation_probs = augmentation_probs
218
-
219
-
220
- self.inbatch_noise_augment_len_range = inbatch_noise_augment_len_range
221
- self.inbatch_noise_augment_number_range = inbatch_noise_augment_number_range
222
- self.inbatch_noise_augment_volume = inbatch_noise_augment_volume
223
-
224
-
225
- self.cqt_prediction_bin = cqt_prediction_bin
226
- if self.cqt_prediction_bin > 0:
227
- self.encoder_cqt_model = model_cqt_pred(n_bins=self.cqt_prediction_bin)
228
- logger.info('preparing cqt loss objective in dataloader with cpu')
229
-
230
- self.epoch = -1
231
-
232
- self.reader = Read_and_PadCrop_Normalized_T(n_samples=clip_secs*sample_rate if clip_secs>0 else None, sample_rate = self.sample_rate)
233
-
234
-
235
-
236
- @property
237
- def can_reuse_epoch_itr_across_epochs(self):
238
- pass
239
- def set_epoch(self, epoch):
240
- pass
241
-
242
- def inbatch_noise_augment(self,
243
- target_audio: torch.Tensor, target_audio_idx: int ,
244
- batch_audios: torch.Tensor, # [bsz, audio_lengths]
245
- noise_len_min: int, noise_len_max: int,
246
- n_noise_min: int, n_noise_max: int,
247
- noise_vol: float = 1.0):
248
- pass
249
-
250
- def get_audio_by_slice(self,index):
251
- pass
252
- def get_audio(self, index):
253
- pass
254
-
255
- def get_label(self, index, label_idx):
256
- pass
257
-
258
- def get_labels(self, index):
259
- pass
260
-
261
- def __getitem__(self, i):
262
- pass
263
-
264
- def __len__(self):
265
- return self.dataset_len
266
-
267
- def crop_to_max_size(self, wav, target_size):
268
- pass
269
-
270
- def collater(self, samples):
271
- pass
272
-
273
- def collater_audio(self, audios, audio_size):
274
- pass
275
-
276
- def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad):
277
- pass
278
-
279
- def collater_seq_label(self, targets, pad):
280
- pass
281
-
282
- def collater_label(self, targets_by_label, audio_size, audio_starts):
283
- pass
284
-
285
- def num_tokens(self, index):
286
- pass
287
-
288
- def size(self, index):
289
- pass
290
-
291
- def ordered_indices(self):
292
- pass
293
-
294
- def postprocess(self, wav, cur_sample_rate):
295
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MuCodec/muq_dev/muq_fairseq/data/utils/data_utils.py DELETED
@@ -1,535 +0,0 @@
1
- # Copyright (c) Facebook, Inc. and its affiliates.
2
- #
3
- # This source code is licensed under the MIT license found in the
4
- # LICENSE file in the root directory of this source tree.
5
-
6
- import logging
7
- import math
8
- import numpy as np
9
- import torch
10
-
11
- from typing import Optional, Tuple
12
-
13
-
14
-
15
- logger = logging.getLogger(__name__)
16
-
17
-
18
-
19
- def compute_mask_indices(
20
- shape: Tuple[int, int],
21
- padding_mask: Optional[torch.Tensor],
22
- mask_prob: float,
23
- mask_length: int,
24
- mask_type: str = "static",
25
- mask_other: float = 0.0,
26
- min_masks: int = 0,
27
- no_overlap: bool = False,
28
- min_space: int = 0,
29
- require_same_masks: bool = True,
30
- mask_dropout: float = 0.0,
31
- add_masks: bool = False,
32
- seed: Optional[int] = None,
33
- epoch: Optional[int] = None,
34
- indices: Optional[torch.Tensor] = None,
35
- idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset
36
- num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset
37
- ) -> np.ndarray:
38
- """
39
- Computes random mask spans for a given shape
40
-
41
- Args:
42
- shape: the the shape for which to compute masks.
43
- should be of size 2 where first element is batch size and 2nd is timesteps
44
- padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
45
- mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
46
- number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
47
- however due to overlaps, the actual number will be smaller (unless no_overlap is True)
48
- mask_type: how to compute mask lengths
49
- static = fixed size
50
- uniform = sample from uniform distribution [mask_other, mask_length*2]
51
- normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
52
- poisson = sample from possion distribution with lambda = mask length
53
- min_masks: minimum number of masked spans
54
- no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
55
- min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
56
- require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
57
- mask_dropout: randomly dropout this percentage of masks in each example
58
- """
59
-
60
- bsz, all_sz = shape
61
- mask = np.full((bsz, all_sz), False)
62
-
63
- if num_mask_ver == 1:
64
- all_num_mask = int(
65
- # add a random number for probabilistic rounding
66
- mask_prob * all_sz / float(mask_length)
67
- + np.random.rand()
68
- )
69
- all_num_mask = max(min_masks, all_num_mask)
70
-
71
- mask_idcs = []
72
- for i in range(bsz):
73
- if seed is not None and epoch is not None and indices is not None:
74
- seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6)
75
- else:
76
- seed_i = None
77
-
78
- rng = np.random.default_rng(seed_i)
79
-
80
- if padding_mask is not None:
81
- sz = all_sz - padding_mask[i].long().sum().item()
82
- assert sz >= 0, sz
83
- else:
84
- sz = all_sz
85
-
86
- if num_mask_ver == 1:
87
- if padding_mask is not None:
88
- num_mask = int(
89
- # add a random number for probabilistic rounding
90
- mask_prob * sz / float(mask_length)
91
- + np.random.rand()
92
- )
93
- num_mask = max(min_masks, num_mask)
94
- else:
95
- num_mask = all_num_mask
96
- elif num_mask_ver == 2:
97
- num_mask = int(
98
- # add a random number for probabilistic rounding
99
- mask_prob * sz / float(mask_length)
100
- + rng.random()
101
- )
102
- num_mask = max(min_masks, num_mask)
103
- else:
104
- raise ValueError()
105
-
106
- if mask_type == "static":
107
- lengths = np.full(num_mask, mask_length)
108
- elif mask_type == "uniform":
109
- lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
110
- elif mask_type == "normal":
111
- lengths = rng.normal(mask_length, mask_other, size=num_mask)
112
- lengths = [max(1, int(round(x))) for x in lengths]
113
- elif mask_type == "poisson":
114
- lengths = rng.poisson(mask_length, size=num_mask)
115
- lengths = [int(round(x)) for x in lengths]
116
- else:
117
- raise Exception("unknown mask selection " + mask_type)
118
-
119
- if sum(lengths) == 0:
120
- if mask_type == "static":
121
- raise ValueError(f"this should never happens")
122
- else:
123
- lengths = [min(mask_length, sz - 1)]
124
-
125
- if no_overlap:
126
- mask_idc = []
127
-
128
- def arrange(s, e, length, keep_length):
129
- span_start = rng.randint(s, e - length)
130
- mask_idc.extend(span_start + i for i in range(length))
131
-
132
- new_parts = []
133
- if span_start - s - min_space >= keep_length:
134
- new_parts.append((s, span_start - min_space + 1))
135
- if e - span_start - length - min_space > keep_length:
136
- new_parts.append((span_start + length + min_space, e))
137
- return new_parts
138
-
139
- parts = [(0, sz)]
140
- min_length = min(lengths)
141
- for length in sorted(lengths, reverse=True):
142
- lens = np.fromiter(
143
- (e - s if e - s >= length + min_space else 0 for s, e in parts),
144
- np.int,
145
- )
146
- l_sum = np.sum(lens)
147
- if l_sum == 0:
148
- break
149
- probs = lens / np.sum(lens)
150
- c = rng.choice(len(parts), p=probs)
151
- s, e = parts.pop(c)
152
- parts.extend(arrange(s, e, length, min_length))
153
- mask_idc = np.asarray(mask_idc)
154
- else:
155
- if idc_select_ver == 1:
156
- min_len = min(lengths)
157
- if sz - min_len <= num_mask:
158
- min_len = sz - num_mask - 1
159
- mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
160
- elif idc_select_ver == 2:
161
- mask_idc = rng.choice(sz, num_mask, replace=False)
162
- else:
163
- raise ValueError()
164
-
165
- mask_idc = np.asarray(
166
- [
167
- mask_idc[j] + offset
168
- for j in range(len(mask_idc))
169
- for offset in range(lengths[j])
170
- ]
171
- )
172
-
173
- mask_idc = np.unique(mask_idc[mask_idc < sz])
174
- if len(mask_idc) >= sz:
175
- raise ValueError(
176
- (
177
- f"the entire sequence is masked. "
178
- f"sz={sz}; mask_idc[mask_idc]; "
179
- f"index={indices[i] if indices is not None else None}"
180
- )
181
- )
182
- mask_idcs.append(mask_idc)
183
-
184
- target_len = None
185
- if require_same_masks:
186
- if add_masks:
187
- target_len = max([len(m) for m in mask_idcs])
188
- else:
189
- target_len = min([len(m) for m in mask_idcs])
190
-
191
- for i, mask_idc in enumerate(mask_idcs):
192
- if target_len is not None and len(mask_idc) > target_len:
193
- mask_idc = rng.choice(mask_idc, target_len, replace=False)
194
-
195
- mask[i, mask_idc] = True
196
-
197
- if target_len is not None and len(mask_idc) < target_len:
198
- unmasked = np.flatnonzero(~mask[i])
199
- to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False)
200
- mask[i, to_mask] = True
201
-
202
- if mask_dropout > 0:
203
- masked = np.flatnonzero(mask[i])
204
- num_holes = np.rint(len(masked) * mask_dropout).astype(int)
205
- to_drop = rng.choice(masked, num_holes, replace=False)
206
- mask[i, to_drop] = False
207
-
208
- return mask
209
-
210
-
211
- def compute_block_mask_2d(
212
- shape: Tuple[int, int],
213
- mask_prob: float,
214
- mask_length: int,
215
- mask_prob_adjust: float = 0,
216
- inverse_mask: bool = False,
217
- require_same_masks: bool = True,
218
- expand_adjcent: bool = False,
219
- mask_dropout: float = 0,
220
- non_overlapping: bool = False,
221
- img_shape: tuple = None, # For the situation when d[0] != d[1], especially in audio spce ways
222
- flexible_mask: bool = False,
223
- ) -> torch.Tensor:
224
-
225
- assert mask_length > 1
226
-
227
- B, L = shape
228
-
229
- d = (int(L**0.5),int(L**0.5))
230
-
231
- if img_shape:
232
- d = (img_shape[0],img_shape[1])
233
-
234
- if flexible_mask:
235
- index = np.random.randint(0,3)
236
- block_size_options = np.array([(6, 4), (5, 5), (8, 3)])
237
- block_size = block_size_options[index]
238
-
239
- if inverse_mask:
240
- mask_prob = 1 - mask_prob
241
-
242
- if flexible_mask:
243
- mask = torch.zeros((B, d[0], d[1]))
244
- mask_inds = torch.randint(
245
- 0,
246
- L,
247
- size=(
248
- B,
249
- int(
250
- L
251
- * ((mask_prob + mask_prob_adjust) / (block_size[0]*block_size[1]))
252
- * (1 + mask_dropout)
253
- ),
254
- ),
255
- )
256
- mask.view(B, -1).scatter_(1, mask_inds, 1)
257
- centers = mask.nonzero(as_tuple=True)
258
-
259
- inds = ([], [], [])
260
-
261
- offset = mask_length // 2
262
- for i in range(block_size[0]):
263
- for j in range(block_size[1]):
264
- k1 = i - offset
265
- k2 = j - offset
266
- inds[0].append(centers[0])
267
- inds[1].append(centers[1] + k1)
268
- inds[2].append(centers[2] + k2)
269
-
270
- i0 = torch.cat(inds[0])
271
- i1 = torch.cat(inds[1]).clamp_(min=0, max=d[0] - 1)
272
- i2 = torch.cat(inds[2]).clamp_(min=0, max=d[1] - 1)
273
-
274
- mask[(i0, i1, i2)] = 1
275
-
276
- elif non_overlapping:
277
- sz = math.ceil(d[0] / mask_length)
278
- inp_len = sz * sz
279
-
280
- inp = torch.zeros((B, 1, sz, sz))
281
- w = torch.ones((1, 1, mask_length, mask_length))
282
-
283
- mask_inds = torch.multinomial(
284
- 1 - inp.view(B, -1),
285
- int(inp_len * (mask_prob + mask_prob_adjust) * (1 + mask_dropout)),
286
- replacement=False,
287
- )
288
- inp.view(B, -1).scatter_(1, mask_inds, 1)
289
-
290
- mask = torch.nn.functional.conv_transpose2d(inp, w, stride=mask_length).squeeze(
291
- 1
292
- )
293
- if mask.size(-1) > d[0]:
294
- mask = mask[..., :d, :d]
295
- else:
296
- mask = torch.zeros((B, d[0], d[1]))
297
- mask_inds = torch.randint(
298
- 0,
299
- L,
300
- size=(
301
- B,
302
- int(
303
- L
304
- * ((mask_prob + mask_prob_adjust) / mask_length**2)
305
- * (1 + mask_dropout)
306
- ),
307
- ),
308
- )
309
- mask.view(B, -1).scatter_(1, mask_inds, 1)
310
- centers = mask.nonzero(as_tuple=True)
311
-
312
- inds = ([], [], [])
313
-
314
- offset = mask_length // 2
315
- for i in range(mask_length):
316
- for j in range(mask_length):
317
- k1 = i - offset
318
- k2 = j - offset
319
- inds[0].append(centers[0])
320
- inds[1].append(centers[1] + k1)
321
- inds[2].append(centers[2] + k2)
322
-
323
- i0 = torch.cat(inds[0])
324
- i1 = torch.cat(inds[1]).clamp_(min=0, max=d[0] - 1)
325
- i2 = torch.cat(inds[2]).clamp_(min=0, max=d[1] - 1)
326
-
327
- mask[(i0, i1, i2)] = 1
328
-
329
- def get_nbs(b, m, w):
330
- all_nbs = torch.nn.functional.conv2d(m.unsqueeze(1), w, padding="same")
331
- all_nbs = all_nbs.clamp_max_(1).view(b, -1)
332
- return all_nbs
333
-
334
- if require_same_masks and expand_adjcent:
335
- w = torch.zeros((1, 1, 3, 3))
336
- w[..., 0, 1] = 1
337
- w[..., 2, 1] = 1
338
- w[..., 1, 0] = 1
339
- w[..., 1, 2] = 1
340
-
341
- all_nbs = get_nbs(B, mask, w)
342
-
343
- mask = mask.reshape(B, -1)
344
-
345
- if require_same_masks:
346
- n_masks = mask.sum(dim=-1)
347
- final_target_len = int(L * (mask_prob))
348
- target_len = int(final_target_len * (1 + mask_dropout))
349
-
350
- for i in range(len(mask)):
351
- n = n_masks[i]
352
- m = mask[i]
353
- r = 0
354
- while expand_adjcent and n < target_len:
355
- if r == 0:
356
- nbs = all_nbs[i]
357
- else:
358
- nbs = get_nbs(1, m.view(1, d[0], d[1]), w).flatten()
359
-
360
- cands = (1 - m + nbs) > 1
361
- cand_sz = int(cands.sum().item())
362
-
363
- assert cand_sz > 0, f"{nbs} {cand_sz}"
364
-
365
- to_mask = torch.multinomial(
366
- cands.float(), min(cand_sz, int(target_len - n)), replacement=False
367
- )
368
- m[to_mask] = 1
369
- assert to_mask.numel() > 0
370
- n += to_mask.numel()
371
- r += 1
372
-
373
- if n > final_target_len:
374
- to_unmask = torch.multinomial(
375
- m, int(n - final_target_len), replacement=False
376
- )
377
- m[to_unmask] = 0
378
- elif n < final_target_len:
379
- to_mask = torch.multinomial(
380
- (1 - m), int(final_target_len - n), replacement=False
381
- )
382
- m[to_mask] = 1
383
-
384
- if inverse_mask:
385
- mask = 1 - mask
386
-
387
- return mask
388
-
389
-
390
- def compute_block_mask_1d(
391
- shape: Tuple[int, int],
392
- mask_prob: float,
393
- mask_length: int,
394
- mask_prob_adjust: float = 0,
395
- inverse_mask: bool = False,
396
- require_same_masks: bool = True,
397
- expand_adjcent: bool = False,
398
- mask_dropout: float = 0,
399
- non_overlapping: bool = False,
400
- ) -> torch.Tensor:
401
-
402
- B, L = shape
403
-
404
- if inverse_mask:
405
- mask_prob = 1 - mask_prob
406
-
407
- if non_overlapping:
408
- sz = math.ceil(L / mask_length)
409
-
410
- inp = torch.zeros((B, 1, sz))
411
- w = torch.ones((1, 1, mask_length))
412
-
413
- mask_inds = torch.multinomial(
414
- 1 - inp.view(B, -1),
415
- int(sz * (mask_prob + mask_prob_adjust) * (1 + mask_dropout)),
416
- replacement=False,
417
- )
418
- inp.view(B, -1).scatter_(1, mask_inds, 1)
419
-
420
- mask = torch.nn.functional.conv_transpose1d(inp, w, stride=mask_length).squeeze(
421
- 1
422
- )
423
- if mask.size(-1) > L:
424
- mask = mask[..., :L]
425
-
426
- else:
427
- mask = torch.zeros((B, L))
428
- mask_inds = torch.randint(
429
- 0,
430
- L,
431
- size=(
432
- B,
433
- int(
434
- L
435
- * ((mask_prob + mask_prob_adjust) / mask_length)
436
- * (1 + mask_dropout)
437
- ),
438
- ),
439
- )
440
-
441
- mask.view(B, -1).scatter_(1, mask_inds, 1)
442
- centers = mask.nonzero(as_tuple=True)
443
-
444
- inds = ([], [])
445
-
446
- offset = mask_length // 2
447
- for i in range(mask_length):
448
- k1 = i - offset
449
- inds[0].append(centers[0])
450
- inds[1].append(centers[1] + k1)
451
-
452
- i0 = torch.cat(inds[0])
453
- i1 = torch.cat(inds[1]).clamp_(min=0, max=L - 1)
454
-
455
- mask[(i0, i1)] = 1
456
-
457
- def get_nbs(b, m, w):
458
- all_nbs = torch.nn.functional.conv1d(m.unsqueeze(1), w, padding="same")
459
- all_nbs = all_nbs.clamp_max_(1).view(b, -1)
460
- return all_nbs
461
-
462
- if require_same_masks and expand_adjcent:
463
- w = torch.ones((1, 1, 3))
464
- w[..., 1] = 0
465
- all_nbs = get_nbs(B, mask, w)
466
-
467
- mask = mask.view(B, -1)
468
-
469
- if require_same_masks:
470
- n_masks = mask.sum(dim=-1)
471
- final_target_len = int(L * (mask_prob))
472
- target_len = int(final_target_len * (1 + mask_dropout))
473
-
474
- for i in range(len(mask)):
475
- n = n_masks[i]
476
- m = mask[i]
477
- r = 0
478
- while expand_adjcent and n < target_len:
479
- if r == 0:
480
- nbs = all_nbs[i]
481
- else:
482
- nbs = get_nbs(1, m.unsqueeze(0), w).squeeze(0)
483
-
484
- cands = (1 - m + nbs) > 1
485
- cand_sz = int(cands.sum().item())
486
-
487
- assert cand_sz > 0, f"{nbs} {cand_sz}"
488
-
489
- to_mask = torch.multinomial(
490
- cands.float(), min(cand_sz, int(target_len - n)), replacement=False
491
- )
492
- m[to_mask] = 1
493
- assert to_mask.numel() > 0
494
- n += to_mask.numel()
495
- r += 1
496
-
497
- if n > final_target_len:
498
- to_unmask = torch.multinomial(
499
- m, int(n - final_target_len), replacement=False
500
- )
501
- m[to_unmask] = 0
502
- elif n < final_target_len:
503
- to_mask = torch.multinomial(
504
- (1 - m), int(final_target_len - n), replacement=False
505
- )
506
- m[to_mask] = 1
507
-
508
- if inverse_mask:
509
- mask = 1 - mask
510
-
511
- return mask
512
-
513
-
514
- def get_buckets(sizes, num_buckets):
515
- buckets = np.unique(
516
- np.percentile(
517
- sizes,
518
- np.linspace(0, 100, num_buckets + 1),
519
- interpolation="lower",
520
- )[1:]
521
- )
522
- return buckets
523
-
524
-
525
- def get_bucketed_sizes(orig_sizes, buckets):
526
- sizes = np.copy(orig_sizes)
527
- assert np.min(sizes) >= 0
528
- start_val = -1
529
- for end_val in buckets:
530
- mask = (sizes > start_val) & (sizes <= end_val)
531
- sizes[mask] = end_val
532
- start_val = end_val
533
- return sizes
534
-
535
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MuCodec/muq_dev/muq_fairseq/models/muq/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .muq_model import *
 
 
MuCodec/muq_dev/muq_fairseq/models/muq/model/__init__.py DELETED
@@ -1,2 +0,0 @@
1
-
2
-
 
 
 
MuCodec/muq_dev/muq_fairseq/models/muq/model/muq.py DELETED
@@ -1,520 +0,0 @@
1
- import json
2
- import random
3
- import torch
4
- from torch import nn
5
- from einops import rearrange
6
- import os
7
- from fairseq.data.data_utils import compute_mask_indices
8
- from fairseq.models.wav2vec.wav2vec2 import ConvFeatureExtractionModel
9
- from fairseq.modules import LayerNorm
10
-
11
- try:
12
- from ..modules.random_quantizer import RandomProjectionQuantizer
13
- from ..modules.features import MelSTFT
14
- from ..modules.conv import Conv2dSubsampling
15
- except:
16
- import sys, os
17
- sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
18
- from modules.random_quantizer import RandomProjectionQuantizer
19
- from modules.features import MelSTFT
20
- from modules.conv import Conv2dSubsampling
21
-
22
-
23
- class MuQ(nn.Module):
24
- """
25
- MuQ
26
-
27
- Input: 128-band mel spectrogram
28
- Frontend: 2-layer Residual convolution
29
- Backend: 12-layer Conformer
30
- Quantizer: a codebook for mel spectrogram
31
- """
32
-
33
- def __init__(
34
- self,
35
- num_codebooks=1,
36
- codebook_dim=16,
37
- codebook_size=4096,
38
- features=["melspec_2048"],
39
- hop_length=240,
40
- n_mels=128,
41
- conv_dim=512,
42
- encoder_dim=1024,
43
- encoder_depth=12,
44
- mask_hop=0.4,
45
- mask_prob=0.6,
46
- is_flash=False,
47
- stat_path=None, #"./data/fma_stats.json",
48
- model_path=None, #"./data/pretrained_fma.pt",
49
- w2v2_config_path=None, #"facebook/wav2vec2-conformer-rope-large-960h-ft",
50
- use_rvq_target=False,
51
- use_vq_target=False,
52
- rvq_ckpt_path=None,
53
- recon_loss_ratio=None,
54
- label_rate=25,
55
- use_hubert_masking_strategy=False,
56
- use_hubert_featurizer=False,
57
- hubert_conv_feature_layers="[(512,10,5)] + [(512,3,2)] * 3 + [(512,3,3)] + [(512,2,2)] * 2",
58
- use_hubert_nce_loss=False,
59
- hubert_final_dim=256,
60
- rvq_n_codebooks=8,
61
- rvq_multi_layer_num=1,
62
- use_encodec_target=False,
63
- ):
64
- super(MuQ, self).__init__()
65
-
66
- # global variables
67
- self.hop_length = hop_length
68
- self.mask_hop = mask_hop
69
- self.mask_prob = mask_prob
70
- self.num_codebooks = num_codebooks
71
- self.codebook_size = codebook_size
72
- self.features = features
73
- self.recon_loss_ratio = recon_loss_ratio
74
- self.n_fold = int(100//label_rate)
75
- self.label_rate = label_rate
76
- self.use_hubert_masking_strategy = use_hubert_masking_strategy
77
- self.use_hubert_featurizer = use_hubert_featurizer
78
- self.use_hubert_nce_loss = use_hubert_nce_loss
79
-
80
- # load feature mean / std stats
81
- import os
82
- if stat_path is not None and os.path.exists(stat_path):
83
- with open(stat_path, "r") as f:
84
- self.stat = json.load(f)
85
- else:
86
- # print("No stats file found at `{}`, use default from msd.".format(stat_path))
87
- self.stat = {"spec_256_cnt": 14394344256, "spec_256_mean": -23.34296658431829, "spec_256_std": 26.189295587132637, "spec_512_cnt": 28677104448, "spec_512_mean": -21.31267396860235, "spec_512_std": 26.52644536245769, "spec_1024_cnt": 57242624832, "spec_1024_mean": -18.852271129208273, "spec_1024_std": 26.443154583585663, "spec_2048_cnt": 114373665600, "spec_2048_mean": -15.638743433896792, "spec_2048_std": 26.115825961611545, "spec_4096_cnt": 228635747136, "spec_4096_mean": -11.715532502794836, "spec_4096_std": 25.763972210234062, "melspec_256_cnt": 14282760192, "melspec_256_mean": -26.962600400166156, "melspec_256_std": 36.13614100912126, "melspec_512_cnt": 14282760192, "melspec_512_mean": -9.108344167718862, "melspec_512_std": 24.71910937988429, "melspec_1024_cnt": 14282760192, "melspec_1024_mean": 0.37302579246531126, "melspec_1024_std": 18.684082325919388, "melspec_2048_cnt": 14282760192, "melspec_2048_mean": 6.768444971712967, "melspec_2048_std": 18.417922652295623, "melspec_4096_cnt": 14282760192, "melspec_4096_mean": 13.617164614990036, "melspec_4096_std": 18.08552130124525, "cqt_cnt": 9373061376, "cqt_mean": 0.46341379757927165, "cqt_std": 0.9543998080910191, "mfcc_256_cnt": 1339008768, "mfcc_256_mean": -11.681755459447485, "mfcc_256_std": 29.183186444668316, "mfcc_512_cnt": 1339008768, "mfcc_512_mean": -2.540581461792183, "mfcc_512_std": 31.93752185832081, "mfcc_1024_cnt": 1339008768, "mfcc_1024_mean": 6.606636263169779, "mfcc_1024_std": 34.151644801729624, "mfcc_2048_cnt": 1339008768, "mfcc_2048_mean": 5.281600844245184, "mfcc_2048_std": 33.12784541220003, "mfcc_4096_cnt": 1339008768, "mfcc_4096_mean": 4.7616569480166095, "mfcc_4096_std": 32.61458906894133, "chromagram_256_cnt": 1339008768, "chromagram_256_mean": 55.15596556703181, "chromagram_256_std": 73.91858278719991, "chromagram_512_cnt": 1339008768, "chromagram_512_mean": 175.73092252759895, "chromagram_512_std": 248.48485148525953, "chromagram_1024_cnt": 1339008768, "chromagram_1024_mean": 589.2947481634608, "chromagram_1024_std": 913.857929063196, "chromagram_2048_cnt": 1339008768, "chromagram_2048_mean": 2062.286388327397, "chromagram_2048_std": 3458.92657915397, "chromagram_4096_cnt": 1339008768, "chromagram_4096_mean": 7673.039107997085, "chromagram_4096_std": 13009.883158267234}
88
-
89
- # feature extractor
90
- self.preprocessor_melspec_2048 = MelSTFT(
91
- n_fft=2048, hop_length=hop_length, is_db=True
92
- )
93
-
94
- # random quantizer
95
- self.use_rvq_target = use_rvq_target
96
- self.use_vq_target = use_vq_target
97
- self.use_encodec_target = use_encodec_target
98
-
99
- seed = 142
100
- if self.use_rvq_like_target:
101
- if use_rvq_target:
102
- try:
103
- from .rvq_muq import ResidualVectorQuantize
104
- except:
105
- import sys, os
106
- sys.path.append(os.path.dirname(os.path.abspath(__file__)))
107
- from rvq_muq import ResidualVectorQuantize
108
-
109
- inp_dim = 128*self.n_fold
110
- self.rvq = ResidualVectorQuantize(
111
- input_dim = inp_dim,
112
- n_codebooks = rvq_n_codebooks,
113
- codebook_size = 1024,
114
- codebook_dim = 16,
115
- quantizer_dropout = 0.0,
116
- use_multi_layer_num = rvq_multi_layer_num,
117
- )
118
- elif use_vq_target:
119
- try:
120
- from .rvq_muq import VectorQuantize
121
- except:
122
- import sys, os
123
- sys.path.append(os.path.dirname(os.path.abspath(__file__)))
124
- from rvq_muq import VectorQuantize
125
-
126
- self.rvq = VectorQuantize(
127
- input_dim = 128*self.n_fold,
128
- codebook_size = 1024,
129
- codebook_dim = 8,
130
- stale_tolerance = 1000,
131
- mfcc_clustering = False
132
- )
133
- elif use_encodec_target:
134
- from encodec import EncodecModel
135
- self.rvq = EncodecModel.encodec_model_24khz()
136
- self.rvq.set_target_bandwidth(6.0)
137
- for param in self.rvq.parameters():
138
- param.requires_grad = False
139
-
140
- import os
141
- if rvq_ckpt_path is not None and os.path.exists(rvq_ckpt_path):
142
- state_dict = torch.load(rvq_ckpt_path, map_location="cpu")
143
- self.rvq.load_state_dict(state_dict)
144
- else:
145
- print(f'Checkpoint for rvq `{rvq_ckpt_path}` not found. Using random initialization.')
146
- else:
147
- for feature in self.features:
148
- for i in range(num_codebooks):
149
- setattr(
150
- self,
151
- f"quantizer_{feature}", # _{i}
152
- RandomProjectionQuantizer(
153
- n_mels * self.n_fold, codebook_dim, codebook_size, seed=seed + i
154
- ),
155
- )
156
-
157
- if use_hubert_masking_strategy:
158
- self.mask_emb = nn.Parameter(
159
- torch.FloatTensor(encoder_dim).uniform_()
160
- )
161
-
162
- if use_hubert_featurizer:
163
- feature_enc_layers = eval(hubert_conv_feature_layers) # noqa
164
- hubert_feat_embed = feature_enc_layers[-1][0]
165
- self.hubert_feature_extractor = ConvFeatureExtractionModel(
166
- conv_layers=feature_enc_layers,
167
- dropout=0.0,
168
- mode='default', #cfg.extractor_mode,
169
- conv_bias=False, #cfg.conv_bias,
170
- )
171
- self.post_extract_proj = (
172
- nn.Linear(hubert_feat_embed, encoder_dim)
173
- if hubert_feat_embed != encoder_dim
174
- else None
175
- )
176
- self.layer_norm = LayerNorm(hubert_feat_embed)
177
- else:
178
- # two residual convolution layers + one projection layer
179
- strides_factory = {
180
- 4: [2, 2],
181
- 2: [2, 1]
182
- }
183
- self.conv = Conv2dSubsampling(
184
- 1, conv_dim, encoder_dim, strides=strides_factory.get(self.n_fold), n_bands=n_mels
185
- )
186
-
187
- # Conformer
188
- if is_flash:
189
- from modules.flash_conformer import (
190
- Wav2Vec2ConformerEncoder,
191
- Wav2Vec2ConformerConfig,
192
- )
193
- else:
194
- from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
195
- Wav2Vec2ConformerEncoder,
196
- Wav2Vec2ConformerConfig,
197
- )
198
- import os
199
- if w2v2_config_path is None or not os.path.exists(w2v2_config_path):
200
- w2v2_config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "w2v2_config.json")
201
- print("load w2v2 config from:", w2v2_config_path)
202
- config = Wav2Vec2ConformerConfig.from_pretrained(
203
- w2v2_config_path
204
- )
205
- config.num_hidden_layers = encoder_depth
206
- config.hidden_size = encoder_dim
207
-
208
- self.conformer = Wav2Vec2ConformerEncoder(config)
209
-
210
- if self.use_hubert_nce_loss:
211
- self.label_embs_concat = nn.Parameter(
212
- torch.FloatTensor(codebook_size, hubert_final_dim)
213
- ) # embeddings of codes
214
- nn.init.uniform_(self.label_embs_concat)
215
- self.linear = nn.Linear(encoder_dim, hubert_final_dim) # final_proj
216
- else:
217
- # projection
218
- self.linear = nn.Linear(encoder_dim, codebook_size) # N_SubSpec=8
219
-
220
- # reconstruct melspec
221
- if self.recon_loss_ratio is not None and self.recon_loss_ratio > 0:
222
- self.recon_proj = nn.Linear(encoder_dim, n_mels * self.n_fold)
223
- self.recon_loss = nn.MSELoss()
224
-
225
- # loss function
226
- self.loss = nn.CrossEntropyLoss()
227
-
228
- # cls token (used for sequence classification)
229
- random.seed(seed)
230
- self.cls_token = nn.Parameter(torch.randn(encoder_dim))
231
-
232
- # load model
233
- if model_path:
234
- S = torch.load(model_path)["state_dict"]
235
- SS = {k[6:]: v for k, v in S.items()}
236
- SS['quantizer_melspec_2048.random_projection'] = SS['quantizer_melspec_2048_0.random_projection']
237
- SS['quantizer_melspec_2048.codebook'] = SS['quantizer_melspec_2048_0.codebook']
238
- del SS['quantizer_melspec_2048_0.random_projection']
239
- del SS['quantizer_melspec_2048_0.codebook']
240
- unmatch = self.load_state_dict(SS, strict=False)
241
- if len(unmatch.missing_keys) > 0:
242
- print(f'Missing keys: {unmatch.missing_keys}')
243
-
244
- @property
245
- def use_rvq_like_target(self):
246
- return self.use_rvq_target or self.use_vq_target or self.use_encodec_target
247
-
248
-
249
- def apply_hubert_mask(self, x, padding_mask=None, target_list=None):
250
- B, T, C = x.shape
251
- if self.mask_prob > 0:
252
- mask_length = int(self.mask_hop / (1/self.label_rate))
253
- mask_indices = compute_mask_indices(
254
- (B, T),
255
- padding_mask,
256
- self.mask_prob,
257
- mask_length, # self.mask_length,
258
- "static", #self.mask_selection,
259
- 0, #self.mask_other,
260
- min_masks=2,
261
- no_overlap=False, #self.no_mask_overlap,
262
- min_space=1, #self.mask_min_space,
263
- )
264
- mask_indices = torch.from_numpy(mask_indices).to(x.device)
265
- x[mask_indices] = self.mask_emb
266
- mask_indices = torch.nonzero(mask_indices)
267
- else:
268
- mask_indices = None
269
-
270
- return x, mask_indices
271
-
272
- def masking(self, x, attention_mask=None):
273
- """random masking of 400ms with given probability"""
274
- if self.use_hubert_masking_strategy:
275
- return x, None
276
- mx = x.clone()
277
- b, t = mx.shape
278
- len_masking_raw = int(24000 * self.mask_hop) # 9600 = 24000 * 0.4
279
- len_masking_token = int(24000 / self.hop_length / 2 / 2 * self.mask_hop) # 10 = 25Hz * 0.4
280
-
281
- # get random mask indices
282
- start_indices = torch.rand(b, t // len_masking_raw) < self.mask_prob
283
- time_domain_masked_indices = torch.nonzero(
284
- start_indices.repeat_interleave(len_masking_raw, dim=1)
285
- )
286
- token_domain_masked_indices = torch.nonzero(
287
- start_indices.repeat_interleave(len_masking_token, dim=1)
288
- )
289
-
290
- # mask with random values
291
- masking_noise = (
292
- torch.randn(time_domain_masked_indices.shape[0], dtype=x.dtype) * 0.1
293
- ) # 0 mean 0.1 std
294
- mx[tuple(time_domain_masked_indices.t())] = masking_noise.to(x.device)
295
-
296
- return mx, token_domain_masked_indices
297
-
298
-
299
- @torch.no_grad()
300
- def preprocessing(self, x, features):
301
- """extract classic audio features"""
302
- # check precision
303
- if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
304
- precision = 16
305
- else:
306
- precision = 32
307
-
308
- out = {}
309
- for key in features:
310
- layer = getattr(self, "preprocessor_%s" % key)
311
- layer.to(x.device)
312
- dtype = x.dtype
313
- out[key] = layer.float()(x.float())[..., :-1]
314
- if precision == 16:
315
- out[key] = out[key].half()
316
- if out[key].dtype != dtype:
317
- out[key].to(dtype=dtype)
318
- return out
319
-
320
- def encoder(self, x, *, attention_mask=None, is_features_only=False):
321
- """2-layer conv + w2v-conformer"""
322
- if not self.use_hubert_featurizer:
323
- x = self.conv(x) # [3, 128, 3000] -> [3, 750, 1024]
324
- if self.training and self.use_hubert_masking_strategy and not is_features_only:
325
- x, mask_indices = self.apply_hubert_mask(x)
326
- else:
327
- mask_indices = None
328
- if attention_mask is None:
329
- out = self.conformer(x, output_hidden_states=True)
330
- else:
331
- attention_mask = attention_mask.bool()
332
- skip_n = int(attention_mask.size(-1) / x.size(1))
333
- attention_mask = attention_mask[:, ::skip_n]
334
- attention_mask = attention_mask[:, :x.size(1)]
335
- out = self.conformer(x, attention_mask=attention_mask, output_hidden_states=True)
336
- hidden_emb = out["hidden_states"]
337
- last_emb = out["last_hidden_state"]
338
- logits = self.linear(last_emb)
339
- interval = self.codebook_size
340
- logits = {
341
- key: logits[:, :, i * interval : (i + 1) * interval]
342
- for i, key in enumerate(self.features)
343
- }
344
- return logits, hidden_emb, mask_indices
345
-
346
- @torch.no_grad()
347
- def normalize(self, x):
348
- """normalize the input audio to have zero mean unit variance"""
349
- for key in x.keys():
350
- x[key] = (x[key] - self.stat["%s_mean" % key]) / self.stat["%s_std" % key] # {'melspec_2048_cnt': 14282760192, 'melspec_2048_mean': 6.768444971712967}
351
- return x
352
-
353
- @torch.no_grad()
354
- def rearrange(self, x):
355
- """rearrange the batch to flatten every 4 steps"""
356
- for key in x.keys():
357
- if key == "chromagram":
358
- x[key] = rearrange(x[key], "b f t -> b t f")
359
- else:
360
- x[key] = rearrange(x[key], "b f (t s) -> b t (s f)", s=self.n_fold)
361
- return x
362
-
363
- def get_rvq_codes(self, inp, raw_wav):
364
- if self.use_rvq_target:
365
- quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = self.rvq(inp)
366
- return codes
367
- if self.use_vq_target:
368
- quantized_prompt_embeds, commitment_loss, codebook_loss, codes, _ = self.rvq(inp)
369
- return codes.unsqueeze(1)
370
- if self.use_encodec_target:
371
- encoded_frames = self.rvq.encode(raw_wav.unsqueeze(1)) #list, B,[ 8,T ]
372
- codes = torch.cat([encoded[0].detach() for encoded in encoded_frames], dim=-1)
373
- if self.label_rate == 25:
374
- codes = codes[:, :, ::3]
375
- return codes
376
-
377
- @torch.no_grad()
378
- def tokenize(self, x, raw_wav):
379
- out = {}
380
- for key in x.keys():
381
- if self.use_rvq_like_target:
382
- self.rvq.eval()
383
- inp = x[key].permute((0, 2, 1))
384
- codes = self.get_rvq_codes(inp, raw_wav)
385
- out[key] = torch.cat([codes[:, idx, ...] for idx in range(int(self.codebook_size//1024))], dim=-1) # (when use freq mask)->[Batch, N_SubSpec, SeqLen=8*750]
386
- else:
387
- layer = getattr(self, "quantizer_%s" % key)
388
- out[key] = layer(x[key])
389
- return out
390
-
391
- def to_spec_wise_quad(self, x):
392
- Batch, QuadSpec, Time = x.shape
393
- SubSpec, N_SubSpec = 16, 8
394
- assert 4 * SubSpec * N_SubSpec == QuadSpec == 4*128
395
- x = rearrange(x, "b (q n s) t -> b (q s) (n t)", q=4, n=N_SubSpec, s=SubSpec)
396
- return x # [Batch, SubSpec=16, N_SubSpec*Time=8*100Hz]
397
-
398
- def get_targets(self, x, label=None):
399
- if self.use_encodec_target:
400
- raw_x = x.clone()
401
- else:
402
- raw_x = None
403
- x = self.preprocessing(x, features=self.features) # -> {'melspec_2048': Tensor{Size([3, 128, 3000]) cuda:0 f32}}
404
- x = self.normalize(x)
405
- x = self.rearrange(x) # -> {'melspec_2048': Tensor{Size([3, 750, 512]) cuda:0 f32}}
406
- melspec = x['melspec_2048']
407
- if label is None:
408
- target_tokens = self.tokenize(x, raw_x) # -> {'melspec_2048': Tensor{Size([3, 750]) cuda:0 i64}}
409
- else:
410
- # print("use_target from label")
411
- target_tokens = {'melspec_2048': rearrange(label, "b n s -> b (n s)").long()}
412
- return target_tokens, melspec
413
-
414
- def get_predictions(self, x, *, mask=None, attention_mask=None, return_new_mask=False, is_features_only=False):
415
- # preprocessing
416
- if not self.use_hubert_featurizer:
417
- x = self.preprocessing(x, features=["melspec_2048"])
418
- x = self.normalize(x) # -> {'melspec_2048': Tensor{Size([3, 128, 3000]) cuda:0 f32}}
419
- else:
420
- features = self.hubert_feature_extractor(x)
421
- features = self.layer_norm(features.transpose(1, 2))
422
- if self.post_extract_proj is not None:
423
- features = self.post_extract_proj(features)
424
- x = {"melspec_2048": features}
425
-
426
- # encoding
427
- logits, hidden_emb, new_mask = self.encoder(x["melspec_2048"], attention_mask=attention_mask, is_features_only=is_features_only)
428
-
429
- if return_new_mask:
430
- return logits, hidden_emb, mask if new_mask is None else new_mask
431
- else:
432
- return logits, hidden_emb
433
-
434
- def get_latent(self, x, layer_ix=12):
435
- _, hidden_states = self.get_predictions(x)
436
- emb = hidden_states[layer_ix]
437
- return emb
438
-
439
- def compute_nce(self, x, pos, negs):
440
- neg_is_pos = (pos == negs).all(-1)
441
- pos = pos.unsqueeze(0)
442
- targets = torch.cat([pos, negs], dim=0)
443
-
444
- logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x)
445
- logits /= 0.1
446
- if neg_is_pos.any():
447
- logits[1:][neg_is_pos] = float("-inf")
448
- logits = logits.transpose(0, 1) # (num_x, num_cls+1)
449
- return logits
450
-
451
- def compute_hubert_nce_loss(self, proj_xs, targets):
452
-
453
- label_embs_list = self.label_embs_concat.split(self.codebook_size, 0) # (self.num_classes, 0)
454
-
455
- def compute_pred(proj_x, target, label_embs):
456
- # compute logits for the i-th label set
457
- y = torch.index_select(label_embs, 0, target.long())
458
- negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1)
459
- return self.compute_nce(proj_x, y, negs)
460
-
461
- logit_list = [
462
- compute_pred(proj_x, t, label_embs_list[i])
463
- for i, (proj_x, t) in enumerate(zip(proj_xs, targets))
464
- ]
465
-
466
- return sum(logit_list)
467
-
468
-
469
- def get_loss(self, logits, target_tokens, masked_indices):
470
- losses = {}
471
- accuracies = {}
472
- for key in logits.keys():
473
- if not self.use_rvq_like_target:
474
- masked_logits = logits[key][tuple(masked_indices.t())]
475
- masked_tokens = target_tokens[key][tuple(masked_indices.t())]
476
- else:
477
- Batch, SeqLen, N_Codebook_x_CodebookSize = logits[key].shape # CodebookSize=4096
478
- Batch, N_Codebook_x_SeqLen = target_tokens[key].shape # N_Codebook*SeqLen=4*750
479
- N_Codebook = int(N_Codebook_x_SeqLen // SeqLen)
480
- # print("not use_virtual, n codebook = ", N_Codebook)
481
- target_tokens[key] = rearrange(target_tokens[key], "b (n s) -> b s n", n=N_Codebook) # Batch, SeqLen=750, N_Codebook=4
482
- masked_logits = logits[key][tuple(masked_indices.t())]
483
- masked_tokens = target_tokens[key][tuple(masked_indices.t())]
484
- masked_logits = rearrange(masked_logits, "b (n c) -> (b n) c", n=N_Codebook)
485
- masked_tokens = rearrange(masked_tokens, "b n -> (b n)", n=N_Codebook)
486
-
487
- if self.use_hubert_nce_loss:
488
- losses[key] = self.compute_hubert_nce_loss(masked_logits, masked_tokens)
489
- else:
490
- losses[key] = self.loss(masked_logits, masked_tokens)
491
- accuracies[key] = (
492
- torch.sum(masked_logits.argmax(-1) == masked_tokens)
493
- / masked_tokens.numel()
494
- )
495
- return losses, accuracies
496
-
497
- def get_recon_loss(self, last_hidden_emb, melspec, masked_indices):
498
- pred_melspec = self.recon_proj(last_hidden_emb[tuple(masked_indices.t())])
499
- target_melspec = melspec[tuple(masked_indices.t())]
500
- recon_loss = self.recon_loss(pred_melspec, target_melspec)
501
- return recon_loss
502
-
503
- def forward(self, x, attention_mask=None, label=None):
504
- dtype = x.dtype
505
- # get target feature tokens
506
- target_tokens, melspec = self.get_targets(x, label=label)
507
-
508
- # masking
509
- x, masked_indices = self.masking(x, attention_mask=attention_mask)
510
-
511
- # forward
512
- logits, hidden_emb, masked_indices = self.get_predictions(x, mask=masked_indices, attention_mask=attention_mask, return_new_mask=True)
513
-
514
- # get loss
515
- losses, accuracies = self.get_loss(logits, target_tokens, masked_indices)
516
-
517
- if self.recon_loss_ratio:
518
- losses["recon_loss"] = self.get_recon_loss(hidden_emb[-1], melspec, masked_indices) * self.recon_loss_ratio
519
-
520
- return logits, hidden_emb, losses, accuracies
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MuCodec/muq_dev/muq_fairseq/models/muq/model/pred_ark_target_with_model.py DELETED
@@ -1,151 +0,0 @@
1
- import sys
2
- import torch.nn as nn
3
- import torch
4
- import sys, os
5
- sys.path.append(os.path.dirname(os.path.abspath(__file__)))
6
- from rvq_musicfm import PreprocessorWithModel, ResidualVectorQuantize
7
-
8
- class RVQ(nn.Module):
9
- def __init__(self,
10
- model_config,
11
- rvq_ckpt_path,
12
- preprocess,
13
- ):
14
- super().__init__()
15
- self.rvq = ResidualVectorQuantize(**model_config)
16
- if rvq_ckpt_path is not None:
17
- self.rvq.load_state_dict(torch.load(rvq_ckpt_path, map_location='cpu'))
18
- self.preprocess = preprocess
19
-
20
- def get_targets(self, x):
21
- self.rvq.eval()
22
- x = self.preprocess(x)
23
- quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = self.rvq(x)
24
- return codes.permute(1,0,2)
25
-
26
- @torch.no_grad()
27
- def encode_wavs(self, wavs):
28
- wavs = wavs[..., :int((wavs.shape[-1]//320)*320)]
29
- return self.get_targets(wavs)
30
-
31
- def This_Music_ModelTarget_Config():
32
- config = dict(
33
- model = dict(
34
- input_dim = 1024,
35
- n_codebooks = 8,
36
- codebook_size = 1024,
37
- codebook_dim = 16,
38
- quantizer_dropout = 0.0,
39
- ),
40
- train = dict(
41
- batch_size = 32,
42
- num_workers = 6,
43
- valid_interval = 10,
44
- save_interval = 100,
45
- max_updates = 500000,
46
- lr = 1e-4,
47
- # device = 'cuda:1',
48
- loss = 'commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean()',
49
- preprocess = PreprocessorWithModel(
50
- model_dir= 'path/to/muq_fairseq',
51
- checkpoint_dir='path/to/muq_m4a_75K.pt',
52
- use_layer_idx=9,
53
- )
54
- ),
55
- pred = dict(
56
- rvq_ckpt_path='path/to/runs/Aug07_18-09-24_ts-828fa13e58384d0bba4144fda78ecc92-launcher/ckpt/RVQ_8100.pth',
57
- sr=24000,
58
- data_jsonl_path='path/to/data/music4all/train.json',
59
- save_target_dir= 'path/to/data/music4all_ark/reiter_musicssl_m4a',
60
- ),
61
- )
62
- return config
63
-
64
-
65
- CLEN = 30
66
- N_GPU_PER = 8
67
- N_NODE = 4
68
-
69
- def parse_lr(wave_length, sr):
70
- n_step = int( wave_length // (sr*CLEN) )
71
- if n_step == 0:
72
- n_step = 1
73
- print('wave_length: ', wave_length, 'sr: ', sr, 'n_step: ', n_step)
74
- starts = torch.arange(n_step) * CLEN * sr
75
- left_rights = torch.stack((starts, starts+CLEN*sr)).T
76
- return left_rights[:10, ...]
77
-
78
- @torch.no_grad()
79
- def main(index, rank):
80
- device = f'cuda:{rank}'
81
- config = This_Music_ModelTarget_Config()
82
- preprocess = config['train']['preprocess']
83
- model = RVQ(
84
- model_config = config['model'],
85
- rvq_ckpt_path = config['pred']['rvq_ckpt_path'],
86
- preprocess = preprocess
87
- ).to(device)
88
- model.eval()
89
- sr = config['pred']['sr']
90
-
91
- fname_nobase = os.path.basename(config['pred']['data_jsonl_path']).split('.')[0]
92
- scp_dir = os.path.join(config['pred']['save_target_dir'], 'scp')
93
- ark_dir = os.path.join(config['pred']['save_target_dir'], 'ark')
94
- os.makedirs(scp_dir, exist_ok=True)
95
- os.makedirs(ark_dir, exist_ok=True)
96
-
97
- scp_path = os.path.join(scp_dir, f'{fname_nobase}.{index}_{rank}.scp')
98
- ark_path = os.path.join(ark_dir, f'{fname_nobase}.{index}_{rank}.ark')
99
-
100
- from kaldiio import WriteHelper
101
-
102
- with open(config['pred']['data_jsonl_path']) as f:
103
- lines = f.readlines()
104
-
105
- print("Total:", len(lines))
106
-
107
- from tqdm import tqdm
108
- import json
109
- import librosa
110
- import time
111
- from einops import rearrange
112
- import numpy as np
113
-
114
- # lines = lines[(index*N_GPU_PER+rank)::(N_GPU_PER*N_NODE)]
115
-
116
- with WriteHelper(f'ark,scp:{ark_path},{scp_path}') as writer:
117
- for idx, line in tqdm(enumerate(lines)):
118
- try:
119
- if idx % (N_GPU_PER*N_NODE) != (index*N_GPU_PER+rank):
120
- continue
121
- item = json.loads(line)
122
- path = item['path']
123
- wave, _ = librosa.load(path, sr=sr)
124
- wave = torch.from_numpy(wave)
125
- wave_length = wave.shape[-1]
126
- if wave_length < sr*CLEN:
127
- continue
128
- left_rights = parse_lr(wave_length, sr)
129
- lr = left_rights.tolist()
130
- wavs = torch.stack(
131
- [wave[l:r] for l,r in lr]
132
- ).to(device)
133
- targets = model.encode_wavs(wavs) # [Codebook=8, N_Steps, Feature]
134
-
135
- final_target = rearrange(targets, "c n f -> n (c f)").cpu().numpy().astype(np.int32)
136
- for j in range(final_target.shape[0]):
137
- writer(f'{idx}:{j}', final_target[j])
138
- except Exception as e:
139
- print(e)
140
-
141
-
142
- if __name__ == '__main__':
143
- import sys
144
- index = int(sys.argv[1])
145
- import multiprocessing
146
- pool = multiprocessing.Pool(processes=N_GPU_PER)
147
- for rank in range(8):
148
- pool.apply_async(main, (index, rank))
149
- pool.close()
150
- pool.join()
151
- print("Done.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MuCodec/muq_dev/muq_fairseq/models/muq/model/rvq.py DELETED
@@ -1,459 +0,0 @@
1
-
2
- from typing import Union
3
-
4
- import numpy as np
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
- from einops import rearrange
9
- from torch.nn.utils import weight_norm
10
-
11
- def WNConv1d(*args, **kwargs):
12
- return weight_norm(nn.Conv1d(*args, **kwargs))
13
-
14
-
15
- class VectorQuantize(nn.Module):
16
- """
17
- Implementation of VQ similar to Karpathy's repo:
18
- https://github.com/karpathy/deep-vector-quantization
19
- Additionally uses following tricks from Improved VQGAN
20
- (https://arxiv.org/pdf/2110.04627.pdf):
21
- 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
22
- for improved codebook usage
23
- 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
24
- improves training stability
25
- """
26
-
27
- def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int, stale_tolerance: int = 1000, mfcc_clustering=False, n_layer=1):
28
- super().__init__()
29
- self.codebook_size = codebook_size
30
- self.codebook_dim = codebook_dim
31
- self.mfcc_clustering = mfcc_clustering
32
-
33
- ProjClass = nn.Identity if mfcc_clustering else WNConv1d
34
- if n_layer==1:
35
- self.in_proj = ProjClass(input_dim, codebook_dim, kernel_size=1)
36
- self.out_proj = ProjClass(codebook_dim, input_dim, kernel_size=1)
37
- elif n_layer >= 2:
38
- ndim_hidden = 128
39
- self.in_proj = nn.Sequential(
40
- ProjClass(input_dim, ndim_hidden, kernel_size=1),
41
- *[nn.Sequential(nn.ReLU(), ProjClass(ndim_hidden, ndim_hidden, kernel_size=1),) for _ in range(n_layer-2)],
42
- nn.ReLU(),
43
- ProjClass(ndim_hidden, codebook_dim, kernel_size=1)
44
- )
45
- self.out_proj = nn.Sequential(
46
- ProjClass(codebook_dim, ndim_hidden, kernel_size=1),
47
- nn.ReLU(),
48
- *[nn.Sequential(ProjClass(ndim_hidden, ndim_hidden, kernel_size=1), nn.ReLU()) for _ in range(n_layer-2)],
49
- ProjClass(ndim_hidden, input_dim, kernel_size=1),
50
- )
51
- self.codebook = nn.Embedding(codebook_size, codebook_dim)
52
- self.register_buffer("stale_counter", torch.zeros(self.codebook_size,))
53
- self.stale_tolerance = stale_tolerance
54
-
55
- def forward(self, z):
56
- """Quantized the input tensor using a fixed codebook and returns
57
- the corresponding codebook vectors
58
-
59
- Parameters
60
- ----------
61
- z : Tensor[B x D x T]
62
-
63
- Returns
64
- -------
65
- Tensor[B x D x T]
66
- Quantized continuous representation of input
67
- Tensor[1]
68
- Commitment loss to train encoder to predict vectors closer to codebook
69
- entries
70
- Tensor[1]
71
- Codebook loss to update the codebook
72
- Tensor[B x T]
73
- Codebook indices (quantized discrete representation of input)
74
- Tensor[B x D x T]
75
- Projected latents (continuous representation of input before quantization)
76
- """
77
-
78
- # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
79
-
80
- z_e = self.in_proj(z) # z_e : (B x D x T)
81
- z_q, indices = self.decode_latents(z_e)
82
-
83
- commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
84
- codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
85
-
86
- z_q = (
87
- z_e + (z_q - z_e).detach()
88
- ) # noop in forward pass, straight-through gradient estimator in backward pass
89
-
90
- z_q = self.out_proj(z_q)
91
-
92
- return z_q, commitment_loss, codebook_loss, indices, z_e
93
-
94
- def embed_code(self, embed_id):
95
- return F.embedding(embed_id, self.codebook.weight)
96
-
97
- def decode_code(self, embed_id):
98
- return self.embed_code(embed_id).transpose(1, 2)
99
-
100
- def decode_latents(self, latents):
101
- encodings = rearrange(latents, "b d t -> (b t) d")
102
- codebook = self.codebook.weight # codebook: (N x D)
103
-
104
- # L2 normalize encodings and codebook (ViT-VQGAN)
105
- encodings = F.normalize(encodings)
106
- codebook = F.normalize(codebook)
107
-
108
- # Compute euclidean distance with codebook
109
- dist = (
110
- encodings.pow(2).sum(1, keepdim=True)
111
- - 2 * encodings @ codebook.t()
112
- + codebook.pow(2).sum(1, keepdim=True).t()
113
- )
114
- indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
115
- z_q = self.decode_code(indices)
116
-
117
- if(self.training):
118
- onehots = torch.nn.functional.one_hot(indices, self.codebook_size).float() # B, T, codebook_size
119
- stale_codes = (onehots.sum(0).sum(0) == 0).float()
120
- self.stale_counter = self.stale_counter * stale_codes + stale_codes
121
-
122
- # random replace codes that haven't been used for a while
123
- replace_code = (self.stale_counter == self.stale_tolerance).float() # codebook_size
124
- if replace_code.sum(-1) > 0:
125
- print("Replace {} codes".format(replace_code.sum(-1)))
126
- random_input_idx = torch.randperm(encodings.shape[0])
127
- random_input = encodings[random_input_idx].view(encodings.shape)
128
- if random_input.shape[0] < self.codebook_size:
129
- random_input = torch.cat([random_input]*(self.codebook_size // random_input.shape[0] + 1), 0)
130
- random_input = random_input[:self.codebook_size,:].contiguous() # codebook_size, dim
131
-
132
- self.codebook.weight.data = self.codebook.weight.data * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1)
133
- self.stale_counter = self.stale_counter * (1 - replace_code)
134
-
135
- return z_q, indices
136
-
137
-
138
- class ResidualVectorQuantize(nn.Module):
139
- """
140
- Introduced in SoundStream: An end2end neural audio codec
141
- https://arxiv.org/abs/2107.03312
142
- """
143
-
144
- def __init__(
145
- self,
146
- input_dim: int = 512,
147
- n_codebooks: int = 9,
148
- codebook_size: int = 1024,
149
- codebook_dim: Union[int, list] = 8,
150
- quantizer_dropout: float = 0.0,
151
- stale_tolerance: int = 100,
152
- use_multi_layer_num:int = 1,
153
- ):
154
- super().__init__()
155
- if isinstance(codebook_dim, int):
156
- codebook_dim = [codebook_dim for _ in range(n_codebooks)]
157
-
158
- self.n_codebooks = n_codebooks
159
- self.codebook_dim = codebook_dim
160
- self.codebook_size = codebook_size
161
-
162
- self.quantizers = nn.ModuleList(
163
- [
164
- VectorQuantize(input_dim, codebook_size, codebook_dim[i], stale_tolerance=stale_tolerance, n_layer=use_multi_layer_num)
165
- for i in range(n_codebooks)
166
- ]
167
- )
168
- self.quantizer_dropout = quantizer_dropout
169
-
170
- def forward(self, z, n_quantizers: int = None):
171
- """Quantized the input tensor using a fixed set of `n` codebooks and returns
172
- the corresponding codebook vectors
173
- Parameters
174
- ----------
175
- z : Tensor[B x D x T]
176
- n_quantizers : int, optional
177
- No. of quantizers to use
178
- (n_quantizers < self.n_codebooks ex: for quantizer dropout)
179
- Note: if `self.quantizer_dropout` is True, this argument is ignored
180
- when in training mode, and a random number of quantizers is used.
181
- Returns
182
- -------
183
- dict
184
- A dictionary with the following keys:
185
-
186
- "z" : Tensor[B x D x T]
187
- Quantized continuous representation of input
188
- "codes" : Tensor[B x N x T]
189
- Codebook indices for each codebook
190
- (quantized discrete representation of input)
191
- "latents" : Tensor[B x N*D x T]
192
- Projected latents (continuous representation of input before quantization)
193
- "vq/commitment_loss" : Tensor[1]
194
- Commitment loss to train encoder to predict vectors closer to codebook
195
- entries
196
- "vq/codebook_loss" : Tensor[1]
197
- Codebook loss to update the codebook
198
- """
199
- z_q = 0
200
- residual = z
201
- commitment_loss = 0
202
- codebook_loss = 0
203
-
204
- codebook_indices = []
205
- latents = []
206
-
207
- if n_quantizers is None:
208
- n_quantizers = self.n_codebooks
209
- if self.training:
210
- n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
211
- dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
212
- n_dropout = int(z.shape[0] * self.quantizer_dropout)
213
- n_quantizers[:n_dropout] = dropout[:n_dropout]
214
- n_quantizers = n_quantizers.to(z.device)
215
- else:
216
- n_quantizers = torch.ones((z.shape[0],)) * n_quantizers + 1
217
- n_quantizers = n_quantizers.to(z.device)
218
-
219
- for i, quantizer in enumerate(self.quantizers):
220
- # if self.training is False and i >= n_quantizers:
221
- # break
222
-
223
- z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
224
- residual
225
- )
226
-
227
- # Create mask to apply quantizer dropout
228
- mask = (
229
- torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
230
- )
231
- z_q = z_q + z_q_i * mask[:, None, None]
232
- residual = residual - z_q_i
233
-
234
- # Sum losses
235
- commitment_loss += (commitment_loss_i * mask).mean()
236
- codebook_loss += (codebook_loss_i * mask).mean()
237
-
238
- codebook_indices.append(indices_i)
239
- latents.append(z_e_i)
240
-
241
- codes = torch.stack(codebook_indices, dim=1)
242
- latents = torch.cat(latents, dim=1)
243
-
244
- encodings = F.one_hot(codes, self.codebook_size).float() # B N T 1024
245
-
246
- return z_q, codes, latents, commitment_loss, codebook_loss, n_quantizers.clamp(max=self.n_codebooks).long() - 1
247
-
248
- def from_codes(self, codes: torch.Tensor):
249
- """Given the quantized codes, reconstruct the continuous representation
250
- Parameters
251
- ----------
252
- codes : Tensor[B x N x T]
253
- Quantized discrete representation of input
254
- Returns
255
- -------
256
- Tensor[B x D x T]
257
- Quantized continuous representation of input
258
- """
259
- z_q = 0.0
260
- z_p = []
261
- n_codebooks = codes.shape[1]
262
- for i in range(n_codebooks):
263
- z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
264
- z_p.append(z_p_i)
265
-
266
- z_q_i = self.quantizers[i].out_proj(z_p_i)
267
- z_q = z_q + z_q_i
268
- return z_q, torch.cat(z_p, dim=1), codes
269
-
270
- def from_latents(self, latents: torch.Tensor):
271
- """Given the unquantized latents, reconstruct the
272
- continuous representation after quantization.
273
-
274
- Parameters
275
- ----------
276
- latents : Tensor[B x N x T]
277
- Continuous representation of input after projection
278
-
279
- Returns
280
- -------
281
- Tensor[B x D x T]
282
- Quantized representation of full-projected space
283
- Tensor[B x D x T]
284
- Quantized representation of latent space
285
- """
286
- z_q = 0
287
- z_p = []
288
- codes = []
289
- dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
290
-
291
- n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
292
- 0
293
- ]
294
- for i in range(n_codebooks):
295
- j, k = dims[i], dims[i + 1]
296
- z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
297
- z_p.append(z_p_i)
298
- codes.append(codes_i)
299
-
300
- z_q_i = self.quantizers[i].out_proj(z_p_i)
301
- z_q = z_q + z_q_i
302
-
303
- return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
304
-
305
- from torch.utils.data import Dataset, DataLoader
306
- import json, traceback
307
- import torchaudio
308
- import math
309
-
310
- from typing import List, Tuple, Dict, Any
311
-
312
- CLIPSECS = 5
313
- def load_audio_by_json(json_path, max_keep, min_keep, tgt_sample_rate):
314
- # read json file
315
- print(json_path)
316
- datas = []
317
- inds = []
318
- sizes = []
319
- with open(json_path) as fp:
320
- for ind,line in enumerate(fp):
321
- data = json.loads(line)
322
- datas.append(data)
323
- inds.append(ind)
324
- # sz = int(data['duration'] * data['sample_rate'])
325
- sz = int(tgt_sample_rate * CLIPSECS)
326
- sizes.append(sz)
327
- tot = ind + 1
328
- return datas,inds,tot,sizes
329
-
330
- class Read_and_PadCrop_Normalized_T(torch.nn.Module):
331
- def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True):
332
-
333
- super().__init__()
334
-
335
- self.n_samples = n_samples
336
- self.sample_rate = sample_rate
337
- self.randomize = randomize
338
-
339
-
340
- def __call__(self, filename: str, duration: float, cur_sample_rate: int) -> Tuple[torch.Tensor, float, float, int, int]:
341
- if(duration<(float(self.n_samples)/self.sample_rate+1)):
342
- # print(duration,(float(self.n_samples)/self.sample_rate+1))
343
- chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1)
344
- t_start = 0.
345
- t_end = min(1.0, float(self.n_samples) / float(self.sample_rate) / duration)
346
- offset = 0
347
- # print('c1:',chunk.shape)
348
- else:
349
- offset = np.random.randint(0,int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
350
- t_start = offset / float(cur_sample_rate) / duration
351
- t_end = t_start + float(self.n_samples) / float(self.sample_rate) / duration
352
- chunk, _ = torchaudio.load(filename, frame_offset=offset, num_frames=int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
353
- # print('offset:',offset)
354
- # print('c0:',chunk.shape)
355
- # Pad with silence if necessary.
356
- if(chunk.shape[0]>1):
357
- chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
358
- else:
359
- chunk = chunk[[0],:].float()
360
- if(cur_sample_rate!=self.sample_rate):
361
- # print('a:',cur_sample_rate,chunk.shape)
362
- chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sample_rate)
363
- # print('b:',self.sample_rate,chunk.shape)
364
- if chunk.shape[-1] < self.n_samples:
365
- chunk = torch.cat([chunk, torch.zeros((1, self.n_samples - chunk.shape[-1],))],-1)
366
- else:
367
- chunk = chunk[:,0:self.n_samples]
368
- seconds_start = math.floor(offset / cur_sample_rate)
369
- seconds_total = math.floor(duration)
370
-
371
- return (
372
- chunk,
373
- t_start,
374
- t_end,
375
- seconds_start,
376
- seconds_total
377
- )
378
-
379
- class RVQDataset(Dataset):
380
- def __init__(
381
- self,
382
- manifest_path: str,
383
- sample_rate: float,
384
- normalize: bool = False,
385
- ):
386
- self.sample_rate = sample_rate
387
- self.datas,inds,tot,self.sizes = load_audio_by_json(manifest_path, None, None, self.sample_rate)
388
- self.dataset_len = len(self.datas)
389
-
390
- self.reader = Read_and_PadCrop_Normalized_T(n_samples=CLIPSECS*sample_rate,sample_rate = self.sample_rate)
391
- self.normalize = normalize
392
-
393
-
394
- def __getitem__(self, i):
395
- # WORLD_SIZE = int(torch.distributed.get_world_size())
396
- # WORLD_RANK = int(torch.distributed.get_rank())
397
- # np.random.seed(1337 + self.epoch * WORLD_SIZE + WORLD_RANK + i)
398
- # index = random.randint(0,len(self.sizes) - 1)
399
- index = i
400
- item = None
401
- while item is None:
402
- try:
403
- wav = self.get_audio_by_slice(index)
404
- # labels = self.get_labels(index)
405
- # labels = None
406
- # item = {"id": index, "source": wav, "label_list": labels}
407
- item = {"id": index, "source": wav}
408
- except Exception as e:
409
- # print(e)
410
- traceback.print_exc()
411
- print(f'skip damaged data {index}')
412
- index = np.random.randint(0,len(self.sizes)-1)
413
- return item
414
-
415
- def __len__(self):
416
- return self.dataset_len
417
-
418
- def get_audio_by_slice(self,index):
419
-
420
- wav_path = self.datas[index]['path']
421
- # print(wav_path)
422
- audio_info = torchaudio.info(wav_path)
423
- origin_sample_rate = audio_info.sample_rate
424
- origin_duration = audio_info.num_frames / origin_sample_rate
425
-
426
- wav, *ignored = self.reader(wav_path, origin_duration,origin_sample_rate)
427
- wav = wav.float()
428
-
429
- # _path, slice_ptr = parse_path(wav_path)
430
- # original way
431
- # if len(slice_ptr) == 0:
432
- # wav, cur_sample_rate = sf.read(_path)
433
- # else:
434
- # assert _path.endswith(".zip")
435
- # data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
436
- # f = io.BytesIO(data)
437
- # wav, cur_sample_rate = sf.read(f)
438
- # wav = torch.from_numpy(wav).float()
439
- # print(wav.shape)
440
- wav = wav.permute(1,0)
441
- wav = self.postprocess(wav, self.sample_rate)
442
- # print(wav.shape)
443
-
444
- # wav = wav.squeeze(0)
445
- return wav
446
-
447
- def postprocess(self, wav, cur_sample_rate):
448
- if wav.dim() == 2:
449
- wav = wav.mean(-1)
450
- assert wav.dim() == 1, wav.dim()
451
-
452
- if cur_sample_rate != self.sample_rate:
453
- raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
454
-
455
- if self.normalize:
456
- with torch.no_grad():
457
- wav = F.layer_norm(wav, wav.shape)
458
- return wav
459
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MuCodec/muq_dev/muq_fairseq/models/muq/model/rvq_muq.py DELETED
@@ -1,394 +0,0 @@
1
- try:
2
- from .rvq import *
3
- except:
4
- import sys, os
5
- sys.path.append(os.path.dirname(os.path.abspath(__file__)))
6
- from rvq import *
7
-
8
- try:
9
- from ..modules.random_quantizer import RandomProjectionQuantizer
10
- from ..modules.features import MelSTFT
11
- from ..modules.conv import Conv2dSubsampling
12
- except:
13
- import sys, os
14
- sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
15
- from modules.random_quantizer import RandomProjectionQuantizer
16
- from modules.features import MelSTFT
17
- from modules.conv import Conv2dSubsampling
18
-
19
- import fairseq
20
-
21
- CLIPSECS = 5 # 5 for rvq, 30 for model
22
-
23
- class RVQDataset(Dataset):
24
- def __init__(
25
- self,
26
- manifest_path: str,
27
- sample_rate: float,
28
- normalize: bool = False,
29
- ):
30
- self.sample_rate = sample_rate
31
- self.datas,inds,tot,self.sizes = load_audio_by_json(manifest_path, None, None, self.sample_rate)
32
- self.dataset_len = len(self.datas)
33
-
34
- self.reader = Read_and_PadCrop_Normalized_T(n_samples=CLIPSECS*sample_rate,sample_rate = self.sample_rate)
35
- self.normalize = normalize
36
-
37
-
38
- def __getitem__(self, i):
39
- # WORLD_SIZE = int(torch.distributed.get_world_size())
40
- # WORLD_RANK = int(torch.distributed.get_rank())
41
- # np.random.seed(1337 + self.epoch * WORLD_SIZE + WORLD_RANK + i)
42
- # index = random.randint(0,len(self.sizes) - 1)
43
- index = i
44
- item = None
45
- while item is None:
46
- try:
47
- wav = self.get_audio_by_slice(index)
48
- item = {"id": index, "source": wav}
49
- except Exception as e:
50
- # print(e)
51
- traceback.print_exc()
52
- print(f'skip damaged data {index}')
53
- index = np.random.randint(0,len(self.sizes)-1)
54
- return item
55
-
56
- def __len__(self):
57
- return self.dataset_len
58
-
59
- def get_audio_by_slice(self,index):
60
-
61
- wav_path = self.datas[index]['path']
62
- audio_info = torchaudio.info(wav_path)
63
- origin_sample_rate = audio_info.sample_rate
64
- origin_duration = audio_info.num_frames / origin_sample_rate
65
-
66
- wav, *ignored = self.reader(wav_path, origin_duration,origin_sample_rate)
67
- wav = wav.float()
68
-
69
- # _path, slice_ptr = parse_path(wav_path)
70
- # original way
71
- # if len(slice_ptr) == 0:
72
- # wav, cur_sample_rate = sf.read(_path)
73
- # else:
74
- # assert _path.endswith(".zip")
75
- # data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
76
- # f = io.BytesIO(data)
77
- # wav, cur_sample_rate = sf.read(f)
78
- # wav = torch.from_numpy(wav).float()
79
- # print(wav.shape)
80
- wav = wav.permute(1,0)
81
- wav = self.postprocess(wav, self.sample_rate)
82
- # print(wav.shape)
83
-
84
- # wav = wav.squeeze(0)
85
- return wav
86
-
87
- def postprocess(self, wav, cur_sample_rate):
88
- if wav.dim() == 2:
89
- wav = wav.mean(-1)
90
- assert wav.dim() == 1, wav.dim()
91
-
92
- if cur_sample_rate != self.sample_rate:
93
- raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
94
-
95
- if self.normalize:
96
- with torch.no_grad():
97
- wav = F.layer_norm(wav, wav.shape)
98
- return wav
99
-
100
- class Preprocessor(nn.Module):
101
- def __init__(self,
102
- codebook_dim=16,
103
- codebook_size=4096,
104
- hop_length=240,
105
- n_mels=128,
106
- stat_path=None,
107
- is_spec_wise=False,
108
- s=4,
109
- ) -> None:
110
- super().__init__()
111
-
112
- self.features=["melspec_2048"]
113
- self.s = s
114
-
115
- # load feature mean / std stats
116
- import os
117
- if stat_path is not None and os.path.exists(stat_path):
118
- with open(stat_path, "r") as f:
119
- self.stat = json.load(f)
120
- else:
121
- # print("No stats file found at `{}`, use default from msd.".format(stat_path))
122
- self.stat = {"spec_256_cnt": 14394344256, "spec_256_mean": -23.34296658431829, "spec_256_std": 26.189295587132637, "spec_512_cnt": 28677104448, "spec_512_mean": -21.31267396860235, "spec_512_std": 26.52644536245769, "spec_1024_cnt": 57242624832, "spec_1024_mean": -18.852271129208273, "spec_1024_std": 26.443154583585663, "spec_2048_cnt": 114373665600, "spec_2048_mean": -15.638743433896792, "spec_2048_std": 26.115825961611545, "spec_4096_cnt": 228635747136, "spec_4096_mean": -11.715532502794836, "spec_4096_std": 25.763972210234062, "melspec_256_cnt": 14282760192, "melspec_256_mean": -26.962600400166156, "melspec_256_std": 36.13614100912126, "melspec_512_cnt": 14282760192, "melspec_512_mean": -9.108344167718862, "melspec_512_std": 24.71910937988429, "melspec_1024_cnt": 14282760192, "melspec_1024_mean": 0.37302579246531126, "melspec_1024_std": 18.684082325919388, "melspec_2048_cnt": 14282760192, "melspec_2048_mean": 6.768444971712967, "melspec_2048_std": 18.417922652295623, "melspec_4096_cnt": 14282760192, "melspec_4096_mean": 13.617164614990036, "melspec_4096_std": 18.08552130124525, "cqt_cnt": 9373061376, "cqt_mean": 0.46341379757927165, "cqt_std": 0.9543998080910191, "mfcc_256_cnt": 1339008768, "mfcc_256_mean": -11.681755459447485, "mfcc_256_std": 29.183186444668316, "mfcc_512_cnt": 1339008768, "mfcc_512_mean": -2.540581461792183, "mfcc_512_std": 31.93752185832081, "mfcc_1024_cnt": 1339008768, "mfcc_1024_mean": 6.606636263169779, "mfcc_1024_std": 34.151644801729624, "mfcc_2048_cnt": 1339008768, "mfcc_2048_mean": 5.281600844245184, "mfcc_2048_std": 33.12784541220003, "mfcc_4096_cnt": 1339008768, "mfcc_4096_mean": 4.7616569480166095, "mfcc_4096_std": 32.61458906894133, "chromagram_256_cnt": 1339008768, "chromagram_256_mean": 55.15596556703181, "chromagram_256_std": 73.91858278719991, "chromagram_512_cnt": 1339008768, "chromagram_512_mean": 175.73092252759895, "chromagram_512_std": 248.48485148525953, "chromagram_1024_cnt": 1339008768, "chromagram_1024_mean": 589.2947481634608, "chromagram_1024_std": 913.857929063196, "chromagram_2048_cnt": 1339008768, "chromagram_2048_mean": 2062.286388327397, "chromagram_2048_std": 3458.92657915397, "chromagram_4096_cnt": 1339008768, "chromagram_4096_mean": 7673.039107997085, "chromagram_4096_std": 13009.883158267234}
123
-
124
- # feature extractor
125
- self.preprocessor_melspec_2048 = MelSTFT(
126
- n_fft=2048, hop_length=hop_length, is_db=True
127
- )
128
-
129
- self.is_spec_wise = is_spec_wise
130
-
131
-
132
- @torch.no_grad()
133
- def normalize(self, x):
134
- """normalize the input audio to have zero mean unit variance"""
135
- for key in x.keys():
136
- x[key] = (x[key] - self.stat["%s_mean" % key]) / self.stat["%s_std" % key] # {'melspec_2048_cnt': 14282760192, 'melspec_2048_mean': 6.768444971712967}
137
- return x
138
-
139
- @torch.no_grad()
140
- def rearrange(self, x):
141
- """rearrange the batch to flatten every 4 steps"""
142
- for key in x.keys():
143
- if key == "chromagram":
144
- x[key] = rearrange(x[key], "b f t -> b t f")
145
- else:
146
- x[key] = rearrange(x[key], "b f (t s) -> b t (s f)", s=self.s)
147
- return x
148
-
149
- @torch.no_grad()
150
- def preprocessing(self, x, features):
151
- """extract classic audio features"""
152
- # check precision
153
- if x.dtype == torch.float16:
154
- precision = 16
155
- else:
156
- precision = 32
157
-
158
- out = {}
159
- for key in features:
160
- layer = getattr(self, "preprocessor_%s" % key)
161
- out[key] = layer.float()(x.float())[..., :-1]
162
- if precision == 16:
163
- out[key] = out[key].half()
164
- return out
165
-
166
- @torch.no_grad()
167
- def tokenize(self, x):
168
- out = {}
169
- for key in x.keys():
170
- layer = getattr(self, "quantizer_%s" % key)
171
- out[key] = layer(x[key])
172
- return out
173
-
174
- def to_spec_wise(self, x):
175
- Batch, Spec, Time = x.shape
176
- SubSpec, N_SubSpec = 16, 8
177
- assert SubSpec * N_SubSpec == Spec == 128
178
- x = rearrange(x, "b (n s) t -> b s (n t)", n=N_SubSpec, s=SubSpec)
179
- return x # [Batch, SubSpec=16, N_SubSpec*Time=8*100Hz]
180
-
181
- @torch.no_grad()
182
- def __call__(self, x):
183
- x = self.preprocessing(x, features=self.features) # -> {'melspec_2048': Tensor{Size([3, 128, 3000]) cuda:0 f32}}
184
- x = self.normalize(x)
185
- if self.is_spec_wise:
186
- x = {k:self.to_spec_wise(v) for k,v in x.items()}
187
- x = self.rearrange(x) # -> {'melspec_2048': Tensor{Size([3, 750, 512]) cuda:0 f32}}
188
- return x['melspec_2048'].permute((0, 2, 1))
189
-
190
-
191
- class CQTPreprocessor(nn.Module):
192
- def __init__(self,
193
- sr=24000,
194
- hop=960,
195
- nb=84,
196
- to_db = True,
197
- ) -> None:
198
- super().__init__()
199
-
200
- from nnAudio.features.cqt import CQT
201
- import torchaudio
202
- self.cqt_fn = CQT(
203
- sr=sr,
204
- hop_length=hop,
205
- n_bins=nb,
206
- fmin=32.7 if nb == 84 else 27.5, # 84 or 88
207
- bins_per_octave=12,
208
- filter_scale=1,
209
- norm=1,
210
- window='hann',
211
- center=True,
212
- pad_mode='constant',
213
- trainable=False,
214
- output_format='Magnitude',
215
- verbose=True,
216
- )
217
- if to_db:
218
- self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
219
- else:
220
- self.amplitude_to_db = lambda x:x
221
-
222
- @torch.no_grad()
223
- def __call__(self, x):
224
- return self.amplitude_to_db(self.cqt_fn(x))
225
-
226
-
227
- from dataclasses import dataclass
228
-
229
- @dataclass
230
- class UserDirModule:
231
- user_dir: str
232
-
233
- def load_model(model_dir, checkpoint_dir):
234
- '''Load Fairseq SSL model'''
235
-
236
- if model_dir is not None:
237
- model_path = UserDirModule(model_dir)
238
- fairseq.utils.import_user_module(model_path)
239
-
240
- model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_dir], strict=False)
241
- model = model[0]
242
-
243
- return model
244
-
245
-
246
-
247
- class PreprocessorWithModel(nn.Module):
248
- def __init__(self, model_dir, checkpoint_dir, use_layer_idx=9) -> None:
249
- super().__init__()
250
- self.model = load_model(model_dir=model_dir, checkpoint_dir=checkpoint_dir)
251
- self.model.eval()
252
- self.use_layer_idx = use_layer_idx
253
-
254
- def forward(self, x):
255
- with torch.no_grad():
256
- self.model.eval()
257
- res = self.model(x, features_only = True)
258
- layer_results = res['layer_results']
259
- return layer_results[self.use_layer_idx].permute(0,2,1)
260
-
261
-
262
-
263
- def Music_Mel_Target_Config():
264
- config = dict(
265
- train_dataset = dict(
266
- manifest_path = 'path/to/data/music4all/train.json',
267
- sample_rate = 24000,
268
- normalize = False,
269
- ),
270
- valid_dataset = dict(
271
- manifest_path = 'path/to/data/music4all/valid.json',
272
- sample_rate = 24000,
273
- normalize = False,
274
- ),
275
- model = dict(
276
- input_dim = 128*4,
277
- n_codebooks = 8,
278
- codebook_size = 1024,
279
- codebook_dim = 16,
280
- quantizer_dropout = 0.0,
281
- ),
282
- train = dict(
283
- batch_size = 32,
284
- num_workers = 6,
285
- valid_interval = 10,
286
- save_interval = 100,
287
- max_updates = 500000,
288
- lr = 1e-4,
289
- device = 'cuda:0',
290
- loss = 'commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean()',
291
- preprocess = Preprocessor()
292
- )
293
- )
294
- return config
295
-
296
-
297
- def main(config):
298
- train_dataset = RVQDataset(**config['train_dataset'])
299
- if config['valid_dataset']['manifest_path'] is None:
300
- # split train and valid dataset
301
- from torch.utils.data import random_split
302
- train_dataset, valid_dataset = random_split(
303
- train_dataset, lengths=[len(train_dataset) - 500, 500]
304
- )
305
- else:
306
- valid_dataset = RVQDataset(**config['valid_dataset'])
307
- train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=config['train']['batch_size'], drop_last=True, num_workers=config['train']['num_workers'])
308
- valid_dataloader = DataLoader(valid_dataset, shuffle=False, batch_size=config['train']['batch_size'], drop_last=True, num_workers=config['train']['num_workers'])
309
- model = ResidualVectorQuantize(**config['model'])
310
-
311
- device = config['train']['device']
312
- preprocess = config['train']['preprocess'].to(device)
313
- model = model.to(device)
314
-
315
- optimizer = torch.optim.Adam(model.parameters(), lr=config['train']['lr'])
316
- cur_updates = 0
317
- is_running = True
318
- result = {}
319
- from tqdm import tqdm
320
- from tensorboardX import SummaryWriter
321
- writer = SummaryWriter()
322
- from collections import defaultdict
323
- import os
324
- from logging import getLogger
325
- logger = getLogger()
326
-
327
- while is_running:
328
- results = defaultdict(lambda:0)
329
- for item in tqdm(train_dataloader, desc='train'):
330
- wavs = item['source']
331
- optimizer.zero_grad()
332
- wavs = wavs.to(device)
333
- x = preprocess(wavs)
334
- model.train()
335
- quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = model(x)
336
- loss = eval(config['train']['loss'])
337
- loss.backward()
338
- optimizer.step()
339
-
340
- results['loss/train'] += loss.item()
341
- results['commitment_loss/train'] += commitment_loss.item()
342
- results['codebook_loss/train'] += codebook_loss.item()
343
- results['rvq_usage/train'] += rvq_usage.float().mean().item()
344
-
345
- if cur_updates % config['train']['valid_interval'] == 0:
346
- model.eval()
347
- with torch.no_grad():
348
- for item in tqdm(valid_dataloader, desc='valid'):
349
- wavs = item['source']
350
- wavs = wavs.to(device)
351
- x = preprocess(wavs)
352
- quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = model(x)
353
- valid_loss = eval(config['train']['loss'])
354
-
355
- results['loss/valid'] += valid_loss.item()
356
- results['commitment_loss/valid'] += commitment_loss.item()
357
- results['codebook_loss/valid'] += codebook_loss.item()
358
- results['rvq_usage/valid'] += rvq_usage.float().mean().item()
359
-
360
- results['cur_updates'] = cur_updates
361
- results['loss/train'] /= config['train']['valid_interval']
362
- results['commitment_loss/train'] /= config['train']['valid_interval']
363
- results['codebook_loss/train'] /= config['train']['valid_interval']
364
- results['rvq_usage/train'] /= config['train']['valid_interval']
365
-
366
- results['loss/valid'] /= len(valid_dataloader)
367
- results['commitment_loss/valid'] /= len(valid_dataloader)
368
- results['codebook_loss/valid'] /= len(valid_dataloader)
369
- results['rvq_usage/valid'] /= len(valid_dataloader)
370
-
371
- print('')
372
- logger.info(str(results))
373
- for k,v in results.items():
374
- writer.add_scalar(k, v, cur_updates)
375
-
376
- results.clear()
377
-
378
- if cur_updates % config['train']['save_interval'] == 0:
379
- os.makedirs(f'{writer.logdir}/ckpt/', exist_ok=True)
380
- logger.info(f'saving checkpoint to {writer.logdir}/ckpt/RVQ_{cur_updates}.pth')
381
- torch.save(model.state_dict(), f'{writer.logdir}/ckpt/RVQ_{cur_updates}.pth')
382
-
383
-
384
- if cur_updates < config['train']['max_updates']:
385
- cur_updates += 1
386
- else:
387
- is_running = False
388
- break
389
-
390
-
391
-
392
- if __name__ == '__main__':
393
- config = Music_Mel_Target_Config()
394
- main(config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MuCodec/muq_dev/muq_fairseq/models/muq/model/w2v2_config.json DELETED
@@ -1,113 +0,0 @@
1
- {
2
- "activation_dropout": 0.1,
3
- "adapter_kernel_size": 3,
4
- "adapter_stride": 2,
5
- "add_adapter": false,
6
- "apply_spec_augment": true,
7
- "architectures": [
8
- "Wav2Vec2ConformerForCTC"
9
- ],
10
- "attention_dropout": 0.1,
11
- "bos_token_id": 1,
12
- "classifier_proj_size": 256,
13
- "codevector_dim": 768,
14
- "conformer_conv_dropout": 0.1,
15
- "contrastive_logits_temperature": 0.1,
16
- "conv_bias": true,
17
- "conv_depthwise_kernel_size": 31,
18
- "conv_dim": [
19
- 512,
20
- 512,
21
- 512,
22
- 512,
23
- 512,
24
- 512,
25
- 512
26
- ],
27
- "conv_kernel": [
28
- 10,
29
- 3,
30
- 3,
31
- 3,
32
- 3,
33
- 2,
34
- 2
35
- ],
36
- "conv_stride": [
37
- 5,
38
- 2,
39
- 2,
40
- 2,
41
- 2,
42
- 2,
43
- 2
44
- ],
45
- "ctc_loss_reduction": "sum",
46
- "ctc_zero_infinity": false,
47
- "diversity_loss_weight": 0.1,
48
- "do_stable_layer_norm": true,
49
- "eos_token_id": 2,
50
- "feat_extract_activation": "gelu",
51
- "feat_extract_dropout": 0.0,
52
- "feat_extract_norm": "layer",
53
- "feat_proj_dropout": 0.1,
54
- "feat_quantizer_dropout": 0.0,
55
- "final_dropout": 0.1,
56
- "gradient_checkpointing": false,
57
- "hidden_act": "swish",
58
- "hidden_dropout": 0.1,
59
- "hidden_dropout_prob": 0.1,
60
- "hidden_size": 1024,
61
- "initializer_range": 0.02,
62
- "intermediate_size": 4096,
63
- "layer_norm_eps": 1e-05,
64
- "layerdrop": 0.0,
65
- "mask_feature_length": 10,
66
- "mask_feature_min_masks": 0,
67
- "mask_feature_prob": 0.0,
68
- "mask_time_length": 10,
69
- "mask_time_min_masks": 2,
70
- "mask_time_prob": 0.05,
71
- "max_source_positions": 5000,
72
- "model_type": "wav2vec2-conformer",
73
- "num_adapter_layers": 3,
74
- "num_attention_heads": 16,
75
- "num_codevector_groups": 2,
76
- "num_codevectors_per_group": 320,
77
- "num_conv_pos_embedding_groups": 16,
78
- "num_conv_pos_embeddings": 128,
79
- "num_feat_extract_layers": 7,
80
- "num_hidden_layers": 24,
81
- "num_negatives": 100,
82
- "output_hidden_size": 1024,
83
- "pad_token_id": 0,
84
- "position_embeddings_type": "rotary",
85
- "proj_codevector_dim": 768,
86
- "rotary_embedding_base": 10000,
87
- "tdnn_dilation": [
88
- 1,
89
- 2,
90
- 3,
91
- 1,
92
- 1
93
- ],
94
- "tdnn_dim": [
95
- 512,
96
- 512,
97
- 512,
98
- 512,
99
- 1500
100
- ],
101
- "tdnn_kernel": [
102
- 5,
103
- 3,
104
- 3,
105
- 1,
106
- 1
107
- ],
108
- "torch_dtype": "float32",
109
- "transformers_version": "4.19.0.dev0",
110
- "use_weighted_layer_sum": false,
111
- "vocab_size": 32,
112
- "xvector_output_dim": 512
113
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MuCodec/muq_dev/muq_fairseq/models/muq/modules/__init__.py DELETED
@@ -1,2 +0,0 @@
1
-
2
-
 
 
 
MuCodec/muq_dev/muq_fairseq/models/muq/modules/conv.py DELETED
@@ -1,77 +0,0 @@
1
- from torch import nn
2
- from einops import rearrange
3
-
4
-
5
- class Res2dModule(nn.Module):
6
- def __init__(self, idim, odim, stride=(2, 2)):
7
- super(Res2dModule, self).__init__()
8
- self.conv1 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride)
9
- self.bn1 = nn.BatchNorm2d(odim)
10
- self.conv2 = nn.Conv2d(odim, odim, 3, padding=1)
11
- self.bn2 = nn.BatchNorm2d(odim)
12
- self.relu = nn.ReLU()
13
-
14
- # residual
15
- self.diff = False
16
- if (idim != odim) or (stride[0] > 1):
17
- self.conv3 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride)
18
- self.bn3 = nn.BatchNorm2d(odim)
19
- self.diff = True
20
-
21
- def forward(self, x):
22
- out = self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x)))))
23
- if self.diff:
24
- x = self.bn3(self.conv3(x))
25
- out = x + out
26
- out = self.relu(out)
27
- return out
28
-
29
-
30
- class Conv2dSubsampling(nn.Module):
31
- """Convolutional 2D subsampling (to 1/4 length).
32
-
33
- Args:
34
- idim (int): Input dimension.
35
- hdim (int): Hidden dimension.
36
- odim (int): Output dimension.
37
- strides (list): Sizes of strides.
38
- n_bands (int): Number of frequency bands.
39
- """
40
-
41
- def __init__(self, idim, hdim, odim, strides=[2, 2], n_bands=64):
42
- """Construct an Conv2dSubsampling object."""
43
- super(Conv2dSubsampling, self).__init__()
44
-
45
- self.conv = nn.Sequential(
46
- Res2dModule(idim, hdim, (2, strides[0])),
47
- Res2dModule(hdim, hdim, (2, strides[1])),
48
- )
49
- self.linear = nn.Linear(hdim * n_bands // 2 // 2, odim)
50
-
51
- def forward(self, x):
52
- """Subsample x.
53
-
54
- Args:
55
- x (torch.Tensor): Input tensor (#batch, idim, time).
56
-
57
- Returns:
58
- torch.Tensor: Subsampled tensor (#batch, time', odim),
59
- where time' = time // 4.
60
- """
61
-
62
- if x.dim() == 3:
63
- x = x.unsqueeze(1) # (b, c, f, t)
64
- x = self.conv(x)
65
- x = rearrange(x, "b c f t -> b t (c f)")
66
- x = self.linear(x)
67
- return x
68
-
69
- if __name__ == '__main__':
70
- import torch
71
- conv_dim, encoder_dim = 512, 1024
72
- conv = Conv2dSubsampling(
73
- 1, conv_dim, encoder_dim, strides=[2, 1], n_bands=128
74
- )
75
- inp = torch.randn((1, 128, 3000))
76
- out = conv(inp)
77
- print(out.shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MuCodec/muq_dev/muq_fairseq/models/muq/modules/features.py DELETED
@@ -1,67 +0,0 @@
1
- import torchaudio
2
- from torch import nn
3
- import torch
4
-
5
-
6
- class MelSTFT(nn.Module):
7
- def __init__(
8
- self,
9
- sample_rate=24000,
10
- n_fft=2048,
11
- hop_length=240,
12
- n_mels=128,
13
- is_db=False,
14
- ):
15
- super(MelSTFT, self).__init__()
16
-
17
- # spectrogram
18
- self.mel_stft = torchaudio.transforms.MelSpectrogram(
19
- sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels
20
- )
21
-
22
- # amplitude to decibel
23
- self.is_db = is_db
24
- if is_db:
25
- self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
26
-
27
- def forward(self, waveform):
28
- if self.is_db:
29
- return self.amplitude_to_db(self.mel_stft(waveform))
30
- else:
31
- return self.mel_stft(waveform)
32
-
33
-
34
- class CQTPreprocessor(nn.Module):
35
- def __init__(self,
36
- sr=24000,
37
- hop=960,
38
- nb=84,
39
- to_db = True,
40
- ) -> None:
41
- super().__init__()
42
-
43
- from nnAudio.features.cqt import CQT
44
- import torchaudio
45
- self.cqt_fn = CQT(
46
- sr=sr,
47
- hop_length=hop,
48
- n_bins=nb,
49
- fmin=32.7 if nb == 84 else 27.5, # 84 or 88
50
- bins_per_octave=12,
51
- filter_scale=1,
52
- norm=1,
53
- window='hann',
54
- center=True,
55
- pad_mode='constant',
56
- trainable=False,
57
- output_format='Magnitude',
58
- verbose=True,
59
- )
60
- if to_db:
61
- self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
62
- else:
63
- self.amplitude_to_db = lambda x:x
64
-
65
- @torch.no_grad()
66
- def __call__(self, x):
67
- return self.amplitude_to_db(self.cqt_fn(x))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MuCodec/muq_dev/muq_fairseq/models/muq/modules/flash_conformer.py DELETED
@@ -1,2114 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2022 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """ PyTorch Wav2Vec2-Conformer model."""
16
-
17
- import math
18
- from dataclasses import dataclass
19
- from typing import Optional, Tuple, Union
20
-
21
- import numpy as np
22
- import torch
23
- import torch.utils.checkpoint
24
- from torch import nn
25
- from torch.nn import CrossEntropyLoss
26
- from torch.nn import functional as F
27
-
28
- from transformers.activations import ACT2FN
29
- from transformers.deepspeed import is_deepspeed_zero3_enabled
30
- from transformers.modeling_outputs import (
31
- BaseModelOutput,
32
- CausalLMOutput,
33
- SequenceClassifierOutput,
34
- TokenClassifierOutput,
35
- Wav2Vec2BaseModelOutput,
36
- XVectorOutput,
37
- )
38
- from transformers.modeling_utils import PreTrainedModel
39
- from transformers.utils import (
40
- ModelOutput,
41
- add_code_sample_docstrings,
42
- add_start_docstrings,
43
- add_start_docstrings_to_model_forward,
44
- logging,
45
- replace_return_docstrings,
46
- )
47
- from transformers.models.wav2vec2_conformer.configuration_wav2vec2_conformer import Wav2Vec2ConformerConfig
48
-
49
-
50
- logger = logging.get_logger(__name__)
51
-
52
-
53
- _HIDDEN_STATES_START_POSITION = 2
54
-
55
- # General docstring
56
- _CONFIG_FOR_DOC = "Wav2Vec2ConformerConfig"
57
-
58
- # Base docstring
59
- _CHECKPOINT_FOR_DOC = "facebook/wav2vec2-conformer-rope-large-960h-ft"
60
- _EXPECTED_OUTPUT_SHAPE = [1, 292, 1024]
61
-
62
- # CTC docstring
63
- _CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
64
- _CTC_EXPECTED_LOSS = 64.21
65
-
66
-
67
- WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
68
- "facebook/wav2vec2-conformer-rel-pos-large",
69
- # See all Wav2Vec2Conformer models at https://huggingface.co/models?filter=wav2vec2-conformer
70
- ]
71
-
72
-
73
- @dataclass
74
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTrainingOutput with Wav2Vec2->Wav2Vec2Conformer
75
- class Wav2Vec2ConformerForPreTrainingOutput(ModelOutput):
76
- """
77
- Output type of [`Wav2Vec2ConformerForPreTraining`], with potential hidden states and attentions.
78
-
79
- Args:
80
- loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
81
- Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
82
- paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss.
83
- projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
84
- Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
85
- projected quantized states.
86
- projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
87
- Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
88
- target vectors for contrastive loss.
89
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
90
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
91
- shape `(batch_size, sequence_length, hidden_size)`.
92
-
93
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
94
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
95
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
96
- sequence_length)`.
97
-
98
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
99
- heads.
100
- contrastive_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
101
- The contrastive loss (L_m) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
102
- diversity_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
103
- The diversity loss (L_d) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
104
- """
105
-
106
- loss: Optional[torch.FloatTensor] = None
107
- projected_states: torch.FloatTensor = None
108
- projected_quantized_states: torch.FloatTensor = None
109
- codevector_perplexity: torch.FloatTensor = None
110
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
111
- attentions: Optional[Tuple[torch.FloatTensor]] = None
112
- contrastive_loss: Optional[torch.FloatTensor] = None
113
- diversity_loss: Optional[torch.FloatTensor] = None
114
-
115
-
116
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
117
- def _compute_mask_indices(
118
- shape: Tuple[int, int],
119
- mask_prob: float,
120
- mask_length: int,
121
- attention_mask: Optional[torch.LongTensor] = None,
122
- min_masks: int = 0,
123
- ) -> np.ndarray:
124
- """
125
- Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
126
- ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
127
- CPU as part of the preprocessing during training.
128
-
129
- Args:
130
- shape: The shape for which to compute masks. This should be of a tuple of size 2 where
131
- the first element is the batch size and the second element is the length of the axis to span.
132
- mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
133
- independently generated mask spans of length `mask_length` is computed by
134
- `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
135
- actual percentage will be smaller.
136
- mask_length: size of the mask
137
- min_masks: minimum number of masked spans
138
- attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
139
- each batch dimension.
140
- """
141
- batch_size, sequence_length = shape
142
-
143
- if mask_length < 1:
144
- raise ValueError("`mask_length` has to be bigger than 0.")
145
-
146
- if mask_length > sequence_length:
147
- raise ValueError(
148
- f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
149
- f" and `sequence_length`: {sequence_length}`"
150
- )
151
-
152
- # epsilon is used for probabilistic rounding
153
- epsilon = np.random.rand(1).item()
154
-
155
- def compute_num_masked_span(input_length):
156
- """Given input length, compute how many spans should be masked"""
157
- num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
158
- num_masked_span = max(num_masked_span, min_masks)
159
-
160
- # make sure num masked span <= sequence_length
161
- if num_masked_span * mask_length > sequence_length:
162
- num_masked_span = sequence_length // mask_length
163
-
164
- # make sure num_masked span is also <= input_length - (mask_length - 1)
165
- if input_length - (mask_length - 1) < num_masked_span:
166
- num_masked_span = max(input_length - (mask_length - 1), 0)
167
-
168
- return num_masked_span
169
-
170
- # compute number of masked spans in batch
171
- input_lengths = (
172
- attention_mask.sum(-1).detach().tolist()
173
- if attention_mask is not None
174
- else [sequence_length for _ in range(batch_size)]
175
- )
176
-
177
- # SpecAugment mask to fill
178
- spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
179
- spec_aug_mask_idxs = []
180
-
181
- max_num_masked_span = compute_num_masked_span(sequence_length)
182
-
183
- if max_num_masked_span == 0:
184
- return spec_aug_mask
185
-
186
- for input_length in input_lengths:
187
- # compute num of masked spans for this input
188
- num_masked_span = compute_num_masked_span(input_length)
189
-
190
- # get random indices to mask
191
- spec_aug_mask_idx = np.random.choice(
192
- np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
193
- )
194
-
195
- # pick first sampled index that will serve as a dummy index to pad vector
196
- # to ensure same dimension for all batches due to probabilistic rounding
197
- # Picking first sample just pads those vectors twice.
198
- if len(spec_aug_mask_idx) == 0:
199
- # this case can only happen if `input_length` is strictly smaller then
200
- # `sequence_length` in which case the last token has to be a padding
201
- # token which we can use as a dummy mask id
202
- dummy_mask_idx = sequence_length - 1
203
- else:
204
- dummy_mask_idx = spec_aug_mask_idx[0]
205
-
206
- spec_aug_mask_idx = np.concatenate(
207
- [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
208
- )
209
- spec_aug_mask_idxs.append(spec_aug_mask_idx)
210
-
211
- spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
212
-
213
- # expand masked indices to masked spans
214
- spec_aug_mask_idxs = np.broadcast_to(
215
- spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
216
- )
217
- spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
218
-
219
- # add offset to the starting indexes so that indexes now create a span
220
- offsets = np.arange(mask_length)[None, None, :]
221
- offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
222
- batch_size, max_num_masked_span * mask_length
223
- )
224
- spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
225
-
226
- # ensure that we cannot have indices larger than sequence_length
227
- if spec_aug_mask_idxs.max() > sequence_length - 1:
228
- spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
229
-
230
- # scatter indices to mask
231
- np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
232
-
233
- return spec_aug_mask
234
-
235
-
236
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2._sample_negative_indices
237
- def _sample_negative_indices(
238
- features_shape: Tuple, num_negatives: int, mask_time_indices: Optional[np.ndarray] = None
239
- ):
240
- """
241
- Sample `num_negatives` vectors from feature vectors.
242
- """
243
- batch_size, sequence_length = features_shape
244
-
245
- # generate indices of the positive vectors themselves, repeat them `num_negatives` times
246
- sequence_length_range = np.arange(sequence_length)
247
-
248
- # get `num_negatives` random vector indices from the same utterance
249
- sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32)
250
-
251
- mask_time_indices = (
252
- mask_time_indices.astype(bool) if mask_time_indices is not None else np.ones(features_shape, dtype=bool)
253
- )
254
-
255
- for batch_idx in range(batch_size):
256
- high = mask_time_indices[batch_idx].sum() - 1
257
- mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]]
258
-
259
- feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives))
260
- sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives))
261
- # avoid sampling the same positive vector, but keep the distribution uniform
262
- sampled_indices[sampled_indices >= feature_indices] += 1
263
-
264
- # remap to actual indices
265
- sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices]
266
-
267
- # correct for batch size
268
- sampled_negative_indices[batch_idx] += batch_idx * sequence_length
269
-
270
- return sampled_negative_indices
271
-
272
-
273
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
274
- class Wav2Vec2ConformerNoLayerNormConvLayer(nn.Module):
275
- def __init__(self, config, layer_id=0):
276
- super().__init__()
277
- self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
278
- self.out_conv_dim = config.conv_dim[layer_id]
279
-
280
- self.conv = nn.Conv1d(
281
- self.in_conv_dim,
282
- self.out_conv_dim,
283
- kernel_size=config.conv_kernel[layer_id],
284
- stride=config.conv_stride[layer_id],
285
- bias=config.conv_bias,
286
- )
287
- self.activation = ACT2FN[config.feat_extract_activation]
288
-
289
- def forward(self, hidden_states):
290
- hidden_states = self.conv(hidden_states)
291
- hidden_states = self.activation(hidden_states)
292
- return hidden_states
293
-
294
-
295
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
296
- class Wav2Vec2ConformerLayerNormConvLayer(nn.Module):
297
- def __init__(self, config, layer_id=0):
298
- super().__init__()
299
- self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
300
- self.out_conv_dim = config.conv_dim[layer_id]
301
-
302
- self.conv = nn.Conv1d(
303
- self.in_conv_dim,
304
- self.out_conv_dim,
305
- kernel_size=config.conv_kernel[layer_id],
306
- stride=config.conv_stride[layer_id],
307
- bias=config.conv_bias,
308
- )
309
- self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
310
- self.activation = ACT2FN[config.feat_extract_activation]
311
-
312
- def forward(self, hidden_states):
313
- hidden_states = self.conv(hidden_states)
314
-
315
- hidden_states = hidden_states.transpose(-2, -1)
316
- hidden_states = self.layer_norm(hidden_states)
317
- hidden_states = hidden_states.transpose(-2, -1)
318
-
319
- hidden_states = self.activation(hidden_states)
320
- return hidden_states
321
-
322
-
323
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
324
- class Wav2Vec2ConformerGroupNormConvLayer(nn.Module):
325
- def __init__(self, config, layer_id=0):
326
- super().__init__()
327
- self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
328
- self.out_conv_dim = config.conv_dim[layer_id]
329
-
330
- self.conv = nn.Conv1d(
331
- self.in_conv_dim,
332
- self.out_conv_dim,
333
- kernel_size=config.conv_kernel[layer_id],
334
- stride=config.conv_stride[layer_id],
335
- bias=config.conv_bias,
336
- )
337
- self.activation = ACT2FN[config.feat_extract_activation]
338
-
339
- self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
340
-
341
- def forward(self, hidden_states):
342
- hidden_states = self.conv(hidden_states)
343
- hidden_states = self.layer_norm(hidden_states)
344
- hidden_states = self.activation(hidden_states)
345
- return hidden_states
346
-
347
-
348
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->Wav2Vec2Conformer
349
- class Wav2Vec2ConformerPositionalConvEmbedding(nn.Module):
350
- def __init__(self, config):
351
- super().__init__()
352
- self.conv = nn.Conv1d(
353
- config.hidden_size,
354
- config.hidden_size,
355
- kernel_size=config.num_conv_pos_embeddings,
356
- padding=config.num_conv_pos_embeddings // 2,
357
- groups=config.num_conv_pos_embedding_groups,
358
- )
359
-
360
- if is_deepspeed_zero3_enabled():
361
- import deepspeed
362
-
363
- with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
364
- self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
365
- deepspeed.zero.register_external_parameter(self, self.conv.weight_v)
366
- deepspeed.zero.register_external_parameter(self, self.conv.weight_g)
367
- else:
368
- self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
369
-
370
- self.padding = Wav2Vec2ConformerSamePadLayer(config.num_conv_pos_embeddings)
371
- self.activation = ACT2FN[config.feat_extract_activation]
372
-
373
- def forward(self, hidden_states):
374
- hidden_states = hidden_states.transpose(1, 2)
375
-
376
- hidden_states = self.conv(hidden_states)
377
- hidden_states = self.padding(hidden_states)
378
- hidden_states = self.activation(hidden_states)
379
-
380
- hidden_states = hidden_states.transpose(1, 2)
381
- return hidden_states
382
-
383
-
384
- class Wav2Vec2ConformerRotaryPositionalEmbedding(nn.Module):
385
- """Rotary positional embedding
386
- Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf
387
- """
388
-
389
- def __init__(self, config):
390
- super().__init__()
391
- dim = config.hidden_size // config.num_attention_heads
392
- base = config.rotary_embedding_base
393
-
394
- inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
395
- self.register_buffer("inv_freq", inv_freq)
396
- self.cached_sequence_length = None
397
- self.cached_rotary_positional_embedding = None
398
-
399
- def forward(self, hidden_states):
400
- sequence_length = hidden_states.shape[1]
401
-
402
- if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None:
403
- return self.cached_rotary_positional_embedding
404
-
405
- self.cached_sequence_length = sequence_length
406
- time_stamps = torch.arange(sequence_length).type_as(self.inv_freq)
407
- freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq)
408
- embeddings = torch.cat((freqs, freqs), dim=-1)
409
-
410
- cos_embeddings = embeddings.cos()[:, None, None, :]
411
- sin_embeddings = embeddings.sin()[:, None, None, :]
412
- self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings])
413
- return self.cached_rotary_positional_embedding
414
-
415
-
416
- class Wav2Vec2ConformerRelPositionalEmbedding(nn.Module):
417
- """Relative positional encoding module."""
418
-
419
- def __init__(self, config):
420
- super().__init__()
421
- self.max_len = config.max_source_positions
422
- self.d_model = config.hidden_size
423
- self.pe = None
424
- self.extend_pe(torch.tensor(0.0).expand(1, self.max_len))
425
-
426
- def extend_pe(self, x):
427
- # Reset the positional encodings
428
- if self.pe is not None:
429
- # self.pe contains both positive and negative parts
430
- # the length of self.pe is 2 * input_len - 1
431
- if self.pe.size(1) >= x.size(1) * 2 - 1:
432
- if self.pe.dtype != x.dtype or self.pe.device != x.device:
433
- self.pe = self.pe.to(dtype=x.dtype, device=x.device)
434
- return
435
- # Suppose `i` is the position of query vector and `j` is the
436
- # position of key vector. We use positive relative positions when keys
437
- # are to the left (i>j) and negative relative positions otherwise (i<j).
438
- pe_positive = torch.zeros(x.size(1), self.d_model)
439
- pe_negative = torch.zeros(x.size(1), self.d_model)
440
- position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
441
- div_term = torch.exp(
442
- torch.arange(0, self.d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.d_model)
443
- )
444
- pe_positive[:, 0::2] = torch.sin(position * div_term)
445
- pe_positive[:, 1::2] = torch.cos(position * div_term)
446
- pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
447
- pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
448
-
449
- # Reverse the order of positive indices and concat both positive and
450
- # negative indices. This is used to support the shifting trick
451
- # as in https://arxiv.org/abs/1901.02860
452
- pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
453
- pe_negative = pe_negative[1:].unsqueeze(0)
454
- pe = torch.cat([pe_positive, pe_negative], dim=1)
455
- self.pe = pe.to(device=x.device, dtype=x.dtype)
456
-
457
- def forward(self, hidden_states: torch.Tensor):
458
- self.extend_pe(hidden_states)
459
- start_idx = self.pe.size(1) // 2 - hidden_states.size(1) + 1
460
- end_idx = self.pe.size(1) // 2 + hidden_states.size(1)
461
- relative_position_embeddings = self.pe[:, start_idx:end_idx]
462
-
463
- return relative_position_embeddings
464
-
465
-
466
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->Wav2Vec2Conformer
467
- class Wav2Vec2ConformerSamePadLayer(nn.Module):
468
- def __init__(self, num_conv_pos_embeddings):
469
- super().__init__()
470
- self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
471
-
472
- def forward(self, hidden_states):
473
- if self.num_pad_remove > 0:
474
- hidden_states = hidden_states[:, :, : -self.num_pad_remove]
475
- return hidden_states
476
-
477
-
478
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->Wav2Vec2Conformer
479
- class Wav2Vec2ConformerFeatureEncoder(nn.Module):
480
- """Construct the features from raw audio waveform"""
481
-
482
- def __init__(self, config):
483
- super().__init__()
484
-
485
- if config.feat_extract_norm == "group":
486
- conv_layers = [Wav2Vec2ConformerGroupNormConvLayer(config, layer_id=0)] + [
487
- Wav2Vec2ConformerNoLayerNormConvLayer(config, layer_id=i + 1)
488
- for i in range(config.num_feat_extract_layers - 1)
489
- ]
490
- elif config.feat_extract_norm == "layer":
491
- conv_layers = [
492
- Wav2Vec2ConformerLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)
493
- ]
494
- else:
495
- raise ValueError(
496
- f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
497
- )
498
- self.conv_layers = nn.ModuleList(conv_layers)
499
- self.gradient_checkpointing = False
500
- self._requires_grad = True
501
-
502
- def _freeze_parameters(self):
503
- for param in self.parameters():
504
- param.requires_grad = False
505
- self._requires_grad = False
506
-
507
- def forward(self, input_values):
508
- hidden_states = input_values[:, None]
509
-
510
- # make sure hidden_states require grad for gradient_checkpointing
511
- if self._requires_grad and self.training:
512
- hidden_states.requires_grad = True
513
-
514
- for conv_layer in self.conv_layers:
515
- if self._requires_grad and self.gradient_checkpointing and self.training:
516
-
517
- def create_custom_forward(module):
518
- def custom_forward(*inputs):
519
- return module(*inputs)
520
-
521
- return custom_forward
522
-
523
- hidden_states = torch.utils.checkpoint.checkpoint(
524
- create_custom_forward(conv_layer),
525
- hidden_states,
526
- )
527
- else:
528
- hidden_states = conv_layer(hidden_states)
529
-
530
- return hidden_states
531
-
532
-
533
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->Wav2Vec2Conformer
534
- class Wav2Vec2ConformerFeatureProjection(nn.Module):
535
- def __init__(self, config):
536
- super().__init__()
537
- self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
538
- self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
539
- self.dropout = nn.Dropout(config.feat_proj_dropout)
540
-
541
- def forward(self, hidden_states):
542
- # non-projected hidden states are needed for quantization
543
- norm_hidden_states = self.layer_norm(hidden_states)
544
- hidden_states = self.projection(norm_hidden_states)
545
- hidden_states = self.dropout(hidden_states)
546
- return hidden_states, norm_hidden_states
547
-
548
-
549
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->Wav2Vec2Conformer
550
- class Wav2Vec2ConformerFeedForward(nn.Module):
551
- def __init__(self, config):
552
- super().__init__()
553
- self.intermediate_dropout = nn.Dropout(config.activation_dropout)
554
-
555
- self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
556
- if isinstance(config.hidden_act, str):
557
- self.intermediate_act_fn = ACT2FN[config.hidden_act]
558
- else:
559
- self.intermediate_act_fn = config.hidden_act
560
-
561
- self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
562
- self.output_dropout = nn.Dropout(config.hidden_dropout)
563
-
564
- def forward(self, hidden_states):
565
- hidden_states = self.intermediate_dense(hidden_states)
566
- hidden_states = self.intermediate_act_fn(hidden_states)
567
- hidden_states = self.intermediate_dropout(hidden_states)
568
-
569
- hidden_states = self.output_dense(hidden_states)
570
- hidden_states = self.output_dropout(hidden_states)
571
- return hidden_states
572
-
573
-
574
- class Wav2Vec2ConformerConvolutionModule(nn.Module):
575
- """Convolution block used in the conformer block"""
576
-
577
- def __init__(self, config):
578
- super().__init__()
579
- if (config.conv_depthwise_kernel_size - 1) % 2 == 1:
580
- raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding")
581
- self.layer_norm = nn.LayerNorm(config.hidden_size)
582
- self.pointwise_conv1 = torch.nn.Conv1d(
583
- config.hidden_size,
584
- 2 * config.hidden_size,
585
- kernel_size=1,
586
- stride=1,
587
- padding=0,
588
- bias=False,
589
- )
590
- self.glu = torch.nn.GLU(dim=1)
591
- self.depthwise_conv = torch.nn.Conv1d(
592
- config.hidden_size,
593
- config.hidden_size,
594
- config.conv_depthwise_kernel_size,
595
- stride=1,
596
- padding=(config.conv_depthwise_kernel_size - 1) // 2,
597
- groups=config.hidden_size,
598
- bias=False,
599
- )
600
- self.batch_norm = torch.nn.BatchNorm1d(config.hidden_size)
601
- self.activation = ACT2FN[config.hidden_act]
602
- self.pointwise_conv2 = torch.nn.Conv1d(
603
- config.hidden_size,
604
- config.hidden_size,
605
- kernel_size=1,
606
- stride=1,
607
- padding=0,
608
- bias=False,
609
- )
610
- self.dropout = torch.nn.Dropout(config.conformer_conv_dropout)
611
-
612
- def forward(self, hidden_states):
613
- hidden_states = self.layer_norm(hidden_states)
614
- # exchange the temporal dimension and the feature dimension
615
- hidden_states = hidden_states.transpose(1, 2)
616
-
617
- # GLU mechanism
618
- # => (batch, 2*channel, dim)
619
- hidden_states = self.pointwise_conv1(hidden_states)
620
- # => (batch, channel, dim)
621
- hidden_states = self.glu(hidden_states)
622
-
623
- # 1D Depthwise Conv
624
- hidden_states = self.depthwise_conv(hidden_states)
625
- hidden_states = self.batch_norm(hidden_states)
626
- hidden_states = self.activation(hidden_states)
627
-
628
- hidden_states = self.pointwise_conv2(hidden_states)
629
- hidden_states = self.dropout(hidden_states)
630
- hidden_states = hidden_states.transpose(1, 2)
631
- return hidden_states
632
-
633
-
634
- class Wav2Vec2ConformerSelfAttention(nn.Module):
635
- """Construct an Wav2Vec2ConformerSelfAttention object.
636
- Can be enhanced with rotary or relative position embeddings.
637
- """
638
-
639
- def __init__(self, config):
640
- super().__init__()
641
-
642
- self.head_size = config.hidden_size // config.num_attention_heads
643
- self.num_heads = config.num_attention_heads
644
- self.position_embeddings_type = config.position_embeddings_type
645
-
646
- self.linear_q = nn.Linear(config.hidden_size, config.hidden_size)
647
- self.linear_k = nn.Linear(config.hidden_size, config.hidden_size)
648
- self.linear_v = nn.Linear(config.hidden_size, config.hidden_size)
649
- self.linear_out = nn.Linear(config.hidden_size, config.hidden_size)
650
-
651
- self.dropout = nn.Dropout(p=config.attention_dropout)
652
- self.dropout_p = config.attention_dropout
653
-
654
- self.is_causal = config.is_causal
655
-
656
- if self.position_embeddings_type == "relative":
657
- # linear transformation for positional encoding
658
- self.linear_pos = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
659
- # these two learnable bias are used in matrix c and matrix d
660
- # as described in https://arxiv.org/abs/1901.02860 Section 3.3
661
- self.pos_bias_u = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
662
- self.pos_bias_v = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
663
-
664
- def forward(
665
- self,
666
- hidden_states: torch.Tensor,
667
- attention_mask: Optional[torch.Tensor] = None,
668
- relative_position_embeddings: Optional[torch.Tensor] = None,
669
- output_attentions: bool = False,
670
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
671
- # self-attention mechanism
672
- batch_size, sequence_length, hidden_size = hidden_states.size()
673
-
674
- # make sure query/key states can be != value states
675
- query_key_states = hidden_states
676
- value_states = hidden_states
677
-
678
- if self.position_embeddings_type == "rotary":
679
- if relative_position_embeddings is None:
680
- raise ValueError(
681
- "`relative_position_embeddings` has to be defined when `self.position_embeddings_type == 'rotary'"
682
- )
683
- query_key_states = self._apply_rotary_embedding(query_key_states, relative_position_embeddings)
684
-
685
- # project query_key_states and value_states
686
- query = self.linear_q(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
687
- key = self.linear_k(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
688
- value = self.linear_v(value_states).view(batch_size, -1, self.num_heads, self.head_size)
689
-
690
- # => (batch, head, time1, d_k)
691
- query = query.transpose(1, 2)
692
- key = key.transpose(1, 2)
693
- value = value.transpose(1, 2)
694
-
695
- with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False):
696
- hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=self.dropout_p, is_causal=self.is_causal)
697
- probs = None
698
-
699
- # # apply attention_mask if necessary
700
- # if attention_mask is not None:
701
- # scores = scores + attention_mask
702
-
703
- # # => (batch, head, time1, time2)
704
- # probs = torch.softmax(scores, dim=-1)
705
- # probs = self.dropout(probs)
706
-
707
- # # => (batch, head, time1, d_k)
708
- # hidden_states = torch.matmul(probs, value)
709
-
710
- # => (batch, time1, hidden_size)
711
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_size)
712
- hidden_states = self.linear_out(hidden_states)
713
-
714
- return hidden_states, probs
715
-
716
- def _apply_rotary_embedding(self, hidden_states, relative_position_embeddings):
717
- batch_size, sequence_length, hidden_size = hidden_states.size()
718
- hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads, self.head_size)
719
-
720
- cos = relative_position_embeddings[0, :sequence_length, ...]
721
- sin = relative_position_embeddings[1, :sequence_length, ...]
722
-
723
- # rotate hidden_states with rotary embeddings
724
- hidden_states = hidden_states.transpose(0, 1)
725
- rotated_states_begin = hidden_states[..., : self.head_size // 2]
726
- rotated_states_end = hidden_states[..., self.head_size // 2 :]
727
- rotated_states = torch.cat((-rotated_states_end, rotated_states_begin), dim=rotated_states_begin.ndim - 1)
728
- hidden_states = (hidden_states * cos) + (rotated_states * sin)
729
- hidden_states = hidden_states.transpose(0, 1)
730
-
731
- hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads * self.head_size)
732
-
733
- return hidden_states
734
-
735
- def _apply_relative_embeddings(self, query, key, relative_position_embeddings):
736
- # 1. project positional embeddings
737
- # => (batch, head, 2*time1-1, d_k)
738
- proj_relative_position_embeddings = self.linear_pos(relative_position_embeddings)
739
- proj_relative_position_embeddings = proj_relative_position_embeddings.view(
740
- relative_position_embeddings.size(0), -1, self.num_heads, self.head_size
741
- )
742
- proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(1, 2)
743
- proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(2, 3)
744
-
745
- # 2. Add bias to query
746
- # => (batch, head, time1, d_k)
747
- query = query.transpose(1, 2)
748
- q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
749
- q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
750
-
751
- # 3. attention score: first compute matrix a and matrix c
752
- # as described in https://arxiv.org/abs/1901.02860 Section 3.3
753
- # => (batch, head, time1, time2)
754
- scores_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
755
-
756
- # 4. then compute matrix b and matrix d
757
- # => (batch, head, time1, 2*time1-1)
758
- scores_bd = torch.matmul(q_with_bias_v, proj_relative_position_embeddings)
759
-
760
- # 5. shift matrix b and matrix d
761
- zero_pad = torch.zeros((*scores_bd.size()[:3], 1), device=scores_bd.device, dtype=scores_bd.dtype)
762
- scores_bd_padded = torch.cat([zero_pad, scores_bd], dim=-1)
763
- scores_bd_padded_shape = scores_bd.size()[:2] + (scores_bd.shape[3] + 1, scores_bd.shape[2])
764
- scores_bd_padded = scores_bd_padded.view(*scores_bd_padded_shape)
765
- scores_bd = scores_bd_padded[:, :, 1:].view_as(scores_bd)
766
- scores_bd = scores_bd[:, :, :, : scores_bd.size(-1) // 2 + 1]
767
-
768
- # 6. sum matrices
769
- # => (batch, head, time1, time2)
770
- scores = (scores_ac + scores_bd) / math.sqrt(self.head_size)
771
-
772
- return scores
773
-
774
-
775
- class Wav2Vec2ConformerEncoderLayer(nn.Module):
776
- """Conformer block based on https://arxiv.org/abs/2005.08100."""
777
-
778
- def __init__(self, config):
779
- super().__init__()
780
- embed_dim = config.hidden_size
781
- dropout = config.attention_dropout
782
-
783
- # Feed-forward 1
784
- self.ffn1_layer_norm = nn.LayerNorm(embed_dim)
785
- self.ffn1 = Wav2Vec2ConformerFeedForward(config)
786
-
787
- # Self-Attention
788
- self.self_attn_layer_norm = nn.LayerNorm(embed_dim)
789
- self.self_attn_dropout = torch.nn.Dropout(dropout)
790
- self.self_attn = Wav2Vec2ConformerSelfAttention(config)
791
-
792
- # Conformer Convolution
793
- self.conv_module = Wav2Vec2ConformerConvolutionModule(config)
794
-
795
- # Feed-forward 2
796
- self.ffn2_layer_norm = nn.LayerNorm(embed_dim)
797
- self.ffn2 = Wav2Vec2ConformerFeedForward(config)
798
- self.final_layer_norm = nn.LayerNorm(embed_dim)
799
-
800
- def forward(
801
- self,
802
- hidden_states,
803
- attention_mask: Optional[torch.Tensor] = None,
804
- relative_position_embeddings: Optional[torch.Tensor] = None,
805
- output_attentions: bool = False,
806
- ):
807
- hidden_states = hidden_states
808
-
809
- # 1. Feed-Forward 1 layer
810
- residual = hidden_states
811
- hidden_states = self.ffn1_layer_norm(hidden_states)
812
- hidden_states = self.ffn1(hidden_states)
813
- hidden_states = hidden_states * 0.5 + residual
814
- residual = hidden_states
815
-
816
- # 2. Self-Attention layer
817
- hidden_states = self.self_attn_layer_norm(hidden_states)
818
- hidden_states, attn_weigts = self.self_attn(
819
- hidden_states=hidden_states,
820
- attention_mask=attention_mask,
821
- relative_position_embeddings=relative_position_embeddings,
822
- output_attentions=output_attentions,
823
- )
824
- hidden_states = self.self_attn_dropout(hidden_states)
825
- hidden_states = hidden_states + residual
826
-
827
- # 3. Convolutional Layer
828
- residual = hidden_states
829
- hidden_states = self.conv_module(hidden_states)
830
- hidden_states = residual + hidden_states
831
-
832
- # 4. Feed-Forward 2 Layer
833
- residual = hidden_states
834
- hidden_states = self.ffn2_layer_norm(hidden_states)
835
- hidden_states = self.ffn2(hidden_states)
836
- hidden_states = hidden_states * 0.5 + residual
837
- hidden_states = self.final_layer_norm(hidden_states)
838
-
839
- return hidden_states, attn_weigts
840
-
841
-
842
- class Wav2Vec2ConformerEncoder(nn.Module):
843
- def __init__(self, config, is_causal=False):
844
- super().__init__()
845
- config.is_causal = is_causal
846
- self.config = config
847
-
848
- if config.position_embeddings_type == "relative":
849
- self.embed_positions = Wav2Vec2ConformerRelPositionalEmbedding(config)
850
- elif config.position_embeddings_type == "rotary":
851
- self.embed_positions = Wav2Vec2ConformerRotaryPositionalEmbedding(config)
852
- else:
853
- self.embed_positions = None
854
-
855
- self.pos_conv_embed = Wav2Vec2ConformerPositionalConvEmbedding(config)
856
- self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
857
- self.dropout = nn.Dropout(config.hidden_dropout)
858
- self.layers = nn.ModuleList([Wav2Vec2ConformerEncoderLayer(config) for _ in range(config.num_hidden_layers)])
859
- self.gradient_checkpointing = False
860
-
861
- def forward(
862
- self,
863
- hidden_states,
864
- attention_mask=None,
865
- output_attentions=False,
866
- output_hidden_states=False,
867
- return_dict=True,
868
- ):
869
- all_hidden_states = () if output_hidden_states else None
870
- all_self_attentions = () if output_attentions else None
871
-
872
- if attention_mask is not None:
873
- # make sure padded tokens output 0
874
- hidden_states[~attention_mask] = 0.0
875
-
876
- # extend attention_mask
877
- attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
878
- attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
879
- attention_mask = attention_mask.expand(
880
- attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
881
- )
882
-
883
- hidden_states = self.dropout(hidden_states)
884
-
885
- if self.embed_positions is not None:
886
- relative_position_embeddings = self.embed_positions(hidden_states)
887
- else:
888
- relative_position_embeddings = None
889
-
890
- deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
891
-
892
- for i, layer in enumerate(self.layers):
893
- if output_hidden_states:
894
- all_hidden_states = all_hidden_states + (hidden_states,)
895
-
896
- # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
897
- dropout_probability = np.random.uniform(0, 1)
898
-
899
- skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
900
- if not skip_the_layer or deepspeed_zero3_is_enabled:
901
- # under deepspeed zero3 all gpus must run in sync
902
- if self.gradient_checkpointing and self.training:
903
- # create gradient checkpointing function
904
- def create_custom_forward(module):
905
- def custom_forward(*inputs):
906
- return module(*inputs, output_attentions)
907
-
908
- return custom_forward
909
-
910
- layer_outputs = torch.utils.checkpoint.checkpoint(
911
- create_custom_forward(layer),
912
- hidden_states,
913
- attention_mask,
914
- relative_position_embeddings,
915
- )
916
- else:
917
- layer_outputs = layer(
918
- hidden_states,
919
- attention_mask=attention_mask,
920
- relative_position_embeddings=relative_position_embeddings,
921
- output_attentions=output_attentions,
922
- )
923
- hidden_states = layer_outputs[0]
924
-
925
- if skip_the_layer:
926
- layer_outputs = (None, None)
927
-
928
- if output_attentions:
929
- all_self_attentions = all_self_attentions + (layer_outputs[1],)
930
-
931
- hidden_states = self.layer_norm(hidden_states)
932
- if output_hidden_states:
933
- all_hidden_states = all_hidden_states + (hidden_states,)
934
-
935
- if not return_dict:
936
- return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
937
- return BaseModelOutput(
938
- last_hidden_state=hidden_states,
939
- hidden_states=all_hidden_states,
940
- attentions=all_self_attentions,
941
- )
942
-
943
-
944
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GumbelVectorQuantizer with Wav2Vec2->Wav2Vec2Conformer
945
- class Wav2Vec2ConformerGumbelVectorQuantizer(nn.Module):
946
- """
947
- Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH
948
- GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information.
949
- """
950
-
951
- def __init__(self, config):
952
- super().__init__()
953
- self.num_groups = config.num_codevector_groups
954
- self.num_vars = config.num_codevectors_per_group
955
-
956
- if config.codevector_dim % self.num_groups != 0:
957
- raise ValueError(
958
- f"`config.codevector_dim {config.codevector_dim} must be divisible "
959
- f"by `config.num_codevector_groups` {self.num_groups} for concatenation"
960
- )
961
-
962
- # storage for codebook variables (codewords)
963
- self.codevectors = nn.Parameter(
964
- torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups)
965
- )
966
- self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
967
-
968
- # can be decayed for training
969
- self.temperature = 2
970
-
971
- @staticmethod
972
- def _compute_perplexity(probs, mask=None):
973
- if mask is not None:
974
- mask_extended = mask.flatten()[:, None, None].expand(probs.shape)
975
- probs = torch.where(mask_extended, probs, torch.zeros_like(probs))
976
- marginal_probs = probs.sum(dim=0) / mask.sum()
977
- else:
978
- marginal_probs = probs.mean(dim=0)
979
-
980
- perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
981
- return perplexity
982
-
983
- def forward(self, hidden_states, mask_time_indices=None):
984
- batch_size, sequence_length, hidden_size = hidden_states.shape
985
-
986
- # project to codevector dim
987
- hidden_states = self.weight_proj(hidden_states)
988
- hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
989
-
990
- if self.training:
991
- # sample code vector probs via gumbel in differentiateable way
992
- codevector_probs = nn.functional.gumbel_softmax(
993
- hidden_states.float(), tau=self.temperature, hard=True
994
- ).type_as(hidden_states)
995
-
996
- # compute perplexity
997
- codevector_soft_dist = torch.softmax(
998
- hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
999
- )
1000
- perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
1001
- else:
1002
- # take argmax in non-differentiable way
1003
- # comptute hard codevector distribution (one hot)
1004
- codevector_idx = hidden_states.argmax(dim=-1)
1005
- codevector_probs = hidden_states.new_zeros(hidden_states.shape).scatter_(
1006
- -1, codevector_idx.view(-1, 1), 1.0
1007
- )
1008
- codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
1009
-
1010
- perplexity = self._compute_perplexity(codevector_probs, mask_time_indices)
1011
-
1012
- codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
1013
- # use probs to retrieve codevectors
1014
- codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
1015
- codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
1016
- codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
1017
-
1018
- return codevectors, perplexity
1019
-
1020
-
1021
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Adapter with Wav2Vec2->Wav2Vec2Conformer
1022
- class Wav2Vec2ConformerAdapter(nn.Module):
1023
- def __init__(self, config):
1024
- super().__init__()
1025
-
1026
- # feature dim might need to be down-projected
1027
- if config.output_hidden_size != config.hidden_size:
1028
- self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
1029
- self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size)
1030
- else:
1031
- self.proj = self.proj_layer_norm = None
1032
-
1033
- self.layers = nn.ModuleList(Wav2Vec2ConformerAdapterLayer(config) for _ in range(config.num_adapter_layers))
1034
- self.layerdrop = config.layerdrop
1035
-
1036
- def forward(self, hidden_states):
1037
- # down project hidden_states if necessary
1038
- if self.proj is not None and self.proj_layer_norm is not None:
1039
- hidden_states = self.proj(hidden_states)
1040
- hidden_states = self.proj_layer_norm(hidden_states)
1041
-
1042
- hidden_states = hidden_states.transpose(1, 2)
1043
-
1044
- for layer in self.layers:
1045
- layerdrop_prob = np.random.random()
1046
- if not self.training or (layerdrop_prob > self.layerdrop):
1047
- hidden_states = layer(hidden_states)
1048
-
1049
- hidden_states = hidden_states.transpose(1, 2)
1050
- return hidden_states
1051
-
1052
-
1053
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AdapterLayer with Wav2Vec2->Wav2Vec2Conformer
1054
- class Wav2Vec2ConformerAdapterLayer(nn.Module):
1055
- def __init__(self, config):
1056
- super().__init__()
1057
- self.conv = nn.Conv1d(
1058
- config.output_hidden_size,
1059
- 2 * config.output_hidden_size,
1060
- config.adapter_kernel_size,
1061
- stride=config.adapter_stride,
1062
- padding=1,
1063
- )
1064
-
1065
- def forward(self, hidden_states):
1066
- hidden_states = self.conv(hidden_states)
1067
- hidden_states = nn.functional.glu(hidden_states, dim=1)
1068
-
1069
- return hidden_states
1070
-
1071
-
1072
- class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel):
1073
- """
1074
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
1075
- models.
1076
- """
1077
-
1078
- config_class = Wav2Vec2ConformerConfig
1079
- base_model_prefix = "wav2vec2_conformer"
1080
- main_input_name = "input_values"
1081
- _keys_to_ignore_on_load_missing = [r"position_ids"]
1082
- supports_gradient_checkpointing = True
1083
-
1084
- def _init_weights(self, module):
1085
- """Initialize the weights"""
1086
- # Wav2Vec2ForPreTraining last 2 linear layers need standard Linear init.
1087
- if isinstance(module, Wav2Vec2ConformerForPreTraining):
1088
- module.project_hid.reset_parameters()
1089
- module.project_q.reset_parameters()
1090
- module.project_hid._is_hf_initialized = True
1091
- module.project_q._is_hf_initialized = True
1092
- # gumbel softmax requires special init
1093
- elif isinstance(module, Wav2Vec2ConformerGumbelVectorQuantizer):
1094
- module.weight_proj.weight.data.normal_(mean=0.0, std=1)
1095
- module.weight_proj.bias.data.zero_()
1096
- nn.init.uniform_(module.codevectors)
1097
- elif isinstance(module, Wav2Vec2ConformerSelfAttention):
1098
- if hasattr(module, "pos_bias_u"):
1099
- nn.init.xavier_uniform_(module.pos_bias_u)
1100
- if hasattr(module, "pos_bias_v"):
1101
- nn.init.xavier_uniform_(module.pos_bias_v)
1102
- elif isinstance(module, Wav2Vec2ConformerPositionalConvEmbedding):
1103
- nn.init.normal_(
1104
- module.conv.weight,
1105
- mean=0,
1106
- std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
1107
- )
1108
- nn.init.constant_(module.conv.bias, 0)
1109
- elif isinstance(module, Wav2Vec2ConformerFeatureProjection):
1110
- k = math.sqrt(1 / module.projection.in_features)
1111
- nn.init.uniform_(module.projection.weight, a=-k, b=k)
1112
- nn.init.uniform_(module.projection.bias, a=-k, b=k)
1113
- elif isinstance(module, nn.Linear):
1114
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1115
-
1116
- if module.bias is not None:
1117
- module.bias.data.zero_()
1118
- elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
1119
- module.bias.data.zero_()
1120
- module.weight.data.fill_(1.0)
1121
- elif isinstance(module, nn.Conv1d):
1122
- nn.init.kaiming_normal_(module.weight)
1123
-
1124
- if module.bias is not None:
1125
- k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
1126
- nn.init.uniform_(module.bias, a=-k, b=k)
1127
-
1128
- def _get_feat_extract_output_lengths(
1129
- self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
1130
- ):
1131
- """
1132
- Computes the output length of the convolutional layers
1133
- """
1134
-
1135
- add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
1136
-
1137
- def _conv_out_length(input_length, kernel_size, stride):
1138
- # 1D convolutional layer output length formula taken
1139
- # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
1140
- return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
1141
-
1142
- for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
1143
- input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
1144
-
1145
- if add_adapter:
1146
- for _ in range(self.config.num_adapter_layers):
1147
- input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
1148
-
1149
- return input_lengths
1150
-
1151
- def _get_feature_vector_attention_mask(
1152
- self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
1153
- ):
1154
- # Effectively attention_mask.sum(-1), but not inplace to be able to run
1155
- # on inference mode.
1156
- non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
1157
-
1158
- output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
1159
- output_lengths = output_lengths.to(torch.long)
1160
-
1161
- batch_size = attention_mask.shape[0]
1162
-
1163
- attention_mask = torch.zeros(
1164
- (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
1165
- )
1166
- # these two operations makes sure that all values before the output lengths idxs are attended to
1167
- attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
1168
- attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
1169
- return attention_mask
1170
-
1171
- def _set_gradient_checkpointing(self, module, value=False):
1172
- if isinstance(module, (Wav2Vec2ConformerEncoder, Wav2Vec2ConformerFeatureEncoder)):
1173
- module.gradient_checkpointing = value
1174
-
1175
-
1176
- WAV2VEC2_CONFORMER_START_DOCSTRING = r"""
1177
- Wav2Vec2Conformer was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
1178
- Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
1179
- Auli.
1180
-
1181
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1182
- library implements for all its model (such as downloading or saving etc.).
1183
-
1184
- This model is a PyTorch [nn.Module](https://pytorch.org/docs/stable/nn.html#nn.Module) sub-class. Use it as a
1185
- regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
1186
-
1187
- Parameters:
1188
- config ([`Wav2Vec2ConformerConfig`]): Model configuration class with all the parameters of the model.
1189
- Initializing with a config file does not load the weights associated with the model, only the
1190
- configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1191
- """
1192
-
1193
-
1194
- WAV2VEC2_CONFORMER_INPUTS_DOCSTRING = r"""
1195
- Args:
1196
- input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
1197
- Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
1198
- into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
1199
- soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
1200
- conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
1201
- attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1202
- Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
1203
- 1]`:
1204
-
1205
- - 1 for tokens that are **not masked**,
1206
- - 0 for tokens that are **masked**.
1207
-
1208
- [What are attention masks?](../glossary#attention-mask)
1209
-
1210
- <Tip warning={true}>
1211
-
1212
- `attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask ==
1213
- True`. For all models whose processor has `config.return_attention_mask == False`, such as
1214
- [wav2vec2-conformer-rel-pos-large](https://huggingface.co/facebook/wav2vec2-conformer-rel-pos-large),
1215
- `attention_mask` should **not** be passed to avoid degraded performance when doing batched inference. For
1216
- such models `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware
1217
- that these models also yield slightly different results depending on whether `input_values` is padded or
1218
- not.
1219
-
1220
- </Tip>
1221
-
1222
- output_attentions (`bool`, *optional*):
1223
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1224
- tensors for more detail.
1225
- output_hidden_states (`bool`, *optional*):
1226
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1227
- more detail.
1228
- return_dict (`bool`, *optional*):
1229
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1230
- """
1231
-
1232
-
1233
- @add_start_docstrings(
1234
- "The bare Wav2Vec2Conformer Model transformer outputting raw hidden-states without any specific head on top.",
1235
- WAV2VEC2_CONFORMER_START_DOCSTRING,
1236
- )
1237
- class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
1238
- def __init__(self, config: Wav2Vec2ConformerConfig):
1239
- super().__init__(config)
1240
- self.config = config
1241
- self.feature_extractor = Wav2Vec2ConformerFeatureEncoder(config)
1242
- self.feature_projection = Wav2Vec2ConformerFeatureProjection(config)
1243
-
1244
- # model only needs masking vector if mask prob is > 0.0
1245
- if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
1246
- self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
1247
-
1248
- self.encoder = Wav2Vec2ConformerEncoder(config)
1249
-
1250
- self.adapter = Wav2Vec2ConformerAdapter(config) if config.add_adapter else None
1251
-
1252
- # Initialize weights and apply final processing
1253
- self.post_init()
1254
-
1255
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.freeze_feature_encoder
1256
- def freeze_feature_encoder(self):
1257
- """
1258
- Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1259
- not be updated during training.
1260
- """
1261
- self.feature_extractor._freeze_parameters()
1262
-
1263
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
1264
- def _mask_hidden_states(
1265
- self,
1266
- hidden_states: torch.FloatTensor,
1267
- mask_time_indices: Optional[torch.FloatTensor] = None,
1268
- attention_mask: Optional[torch.LongTensor] = None,
1269
- ):
1270
- """
1271
- Masks extracted features along time axis and/or along feature axis according to
1272
- [SpecAugment](https://arxiv.org/abs/1904.08779).
1273
- """
1274
-
1275
- # `config.apply_spec_augment` can set masking to False
1276
- if not getattr(self.config, "apply_spec_augment", True):
1277
- return hidden_states
1278
-
1279
- # generate indices & apply SpecAugment along time axis
1280
- batch_size, sequence_length, hidden_size = hidden_states.size()
1281
-
1282
- if mask_time_indices is not None:
1283
- # apply SpecAugment along time axis with given mask_time_indices
1284
- hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
1285
- elif self.config.mask_time_prob > 0 and self.training:
1286
- mask_time_indices = _compute_mask_indices(
1287
- (batch_size, sequence_length),
1288
- mask_prob=self.config.mask_time_prob,
1289
- mask_length=self.config.mask_time_length,
1290
- attention_mask=attention_mask,
1291
- min_masks=self.config.mask_time_min_masks,
1292
- )
1293
- mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
1294
- hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
1295
-
1296
- if self.config.mask_feature_prob > 0 and self.training:
1297
- # generate indices & apply SpecAugment along feature axis
1298
- mask_feature_indices = _compute_mask_indices(
1299
- (batch_size, hidden_size),
1300
- mask_prob=self.config.mask_feature_prob,
1301
- mask_length=self.config.mask_feature_length,
1302
- min_masks=self.config.mask_feature_min_masks,
1303
- )
1304
- mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
1305
- mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
1306
- hidden_states[mask_feature_indices] = 0
1307
-
1308
- return hidden_states
1309
-
1310
- @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
1311
- @add_code_sample_docstrings(
1312
- checkpoint=_CHECKPOINT_FOR_DOC,
1313
- output_type=Wav2Vec2BaseModelOutput,
1314
- config_class=_CONFIG_FOR_DOC,
1315
- modality="audio",
1316
- expected_output=_EXPECTED_OUTPUT_SHAPE,
1317
- )
1318
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.forward with wav2vec2->wav2vec2_conformer
1319
- def forward(
1320
- self,
1321
- input_values: Optional[torch.Tensor],
1322
- attention_mask: Optional[torch.Tensor] = None,
1323
- mask_time_indices: Optional[torch.FloatTensor] = None,
1324
- output_attentions: Optional[bool] = None,
1325
- output_hidden_states: Optional[bool] = None,
1326
- return_dict: Optional[bool] = None,
1327
- ) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
1328
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1329
- output_hidden_states = (
1330
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1331
- )
1332
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1333
-
1334
- extract_features = self.feature_extractor(input_values)
1335
- extract_features = extract_features.transpose(1, 2)
1336
-
1337
- if attention_mask is not None:
1338
- # compute reduced attention_mask corresponding to feature vectors
1339
- attention_mask = self._get_feature_vector_attention_mask(
1340
- extract_features.shape[1], attention_mask, add_adapter=False
1341
- )
1342
-
1343
- hidden_states, extract_features = self.feature_projection(extract_features)
1344
- hidden_states = self._mask_hidden_states(
1345
- hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
1346
- )
1347
-
1348
- encoder_outputs = self.encoder(
1349
- hidden_states,
1350
- attention_mask=attention_mask,
1351
- output_attentions=output_attentions,
1352
- output_hidden_states=output_hidden_states,
1353
- return_dict=return_dict,
1354
- )
1355
-
1356
- hidden_states = encoder_outputs[0]
1357
-
1358
- if self.adapter is not None:
1359
- hidden_states = self.adapter(hidden_states)
1360
-
1361
- if not return_dict:
1362
- return (hidden_states, extract_features) + encoder_outputs[1:]
1363
-
1364
- return Wav2Vec2BaseModelOutput(
1365
- last_hidden_state=hidden_states,
1366
- extract_features=extract_features,
1367
- hidden_states=encoder_outputs.hidden_states,
1368
- attentions=encoder_outputs.attentions,
1369
- )
1370
-
1371
-
1372
- @add_start_docstrings(
1373
- """Wav2Vec2Conformer Model with a quantizer and `VQ` head on top.""", WAV2VEC2_CONFORMER_START_DOCSTRING
1374
- )
1375
- class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
1376
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
1377
- def __init__(self, config: Wav2Vec2ConformerConfig):
1378
- super().__init__(config)
1379
- self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
1380
- self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)
1381
-
1382
- self.quantizer = Wav2Vec2ConformerGumbelVectorQuantizer(config)
1383
-
1384
- self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
1385
- self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
1386
-
1387
- # Initialize weights and apply final processing
1388
- self.post_init()
1389
-
1390
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.set_gumbel_temperature
1391
- def set_gumbel_temperature(self, temperature: int):
1392
- """
1393
- Set the Gumbel softmax temperature to a given value. Only necessary for training
1394
- """
1395
- self.quantizer.temperature = temperature
1396
-
1397
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
1398
- def freeze_feature_encoder(self):
1399
- """
1400
- Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1401
- not be updated during training.
1402
- """
1403
- self.wav2vec2_conformer.feature_extractor._freeze_parameters()
1404
-
1405
- @staticmethod
1406
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.compute_contrastive_logits
1407
- def compute_contrastive_logits(
1408
- target_features: torch.FloatTensor,
1409
- negative_features: torch.FloatTensor,
1410
- predicted_features: torch.FloatTensor,
1411
- temperature: int = 0.1,
1412
- ):
1413
- """
1414
- Compute logits for contrastive loss based using cosine similarity as the distance measure between
1415
- `[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied.
1416
- """
1417
- target_features = torch.cat([target_features, negative_features], dim=0)
1418
-
1419
- logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1).type_as(
1420
- target_features
1421
- )
1422
-
1423
- # apply temperature
1424
- logits = logits / temperature
1425
- return logits
1426
-
1427
- @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
1428
- @replace_return_docstrings(output_type=Wav2Vec2ConformerForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
1429
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,wav2vec2_conformer-base->wav2vec2-conformer-rel-pos-large
1430
- def forward(
1431
- self,
1432
- input_values: Optional[torch.Tensor],
1433
- attention_mask: Optional[torch.Tensor] = None,
1434
- mask_time_indices: Optional[torch.BoolTensor] = None,
1435
- sampled_negative_indices: Optional[torch.BoolTensor] = None,
1436
- output_attentions: Optional[bool] = None,
1437
- output_hidden_states: Optional[bool] = None,
1438
- return_dict: Optional[bool] = None,
1439
- ) -> Union[Tuple, Wav2Vec2ConformerForPreTrainingOutput]:
1440
- r"""
1441
- mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
1442
- Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
1443
- masked extracted features in *config.proj_codevector_dim* space.
1444
- sampled_negative_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_negatives)`, *optional*):
1445
- Indices indicating which quantized target vectors are used as negative sampled vectors in contrastive loss.
1446
- Required input for pre-training.
1447
-
1448
- Returns:
1449
-
1450
- Example:
1451
-
1452
- ```python
1453
- >>> import torch
1454
- >>> from transformers import AutoFeatureExtractor, Wav2Vec2ConformerForPreTraining
1455
- >>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
1456
- ... _compute_mask_indices,
1457
- ... _sample_negative_indices,
1458
- ... )
1459
- >>> from datasets import load_dataset
1460
-
1461
- >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
1462
- >>> model = Wav2Vec2ConformerForPreTraining.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
1463
-
1464
- >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
1465
- >>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values # Batch size 1
1466
-
1467
- >>> # compute masked indices
1468
- >>> batch_size, raw_sequence_length = input_values.shape
1469
- >>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length).item()
1470
- >>> mask_time_indices = _compute_mask_indices(
1471
- ... shape=(batch_size, sequence_length), mask_prob=0.2, mask_length=2
1472
- ... )
1473
- >>> sampled_negative_indices = _sample_negative_indices(
1474
- ... features_shape=(batch_size, sequence_length),
1475
- ... num_negatives=model.config.num_negatives,
1476
- ... mask_time_indices=mask_time_indices,
1477
- ... )
1478
- >>> mask_time_indices = torch.tensor(data=mask_time_indices, device=input_values.device, dtype=torch.long)
1479
- >>> sampled_negative_indices = torch.tensor(
1480
- ... data=sampled_negative_indices, device=input_values.device, dtype=torch.long
1481
- ... )
1482
-
1483
- >>> with torch.no_grad():
1484
- ... outputs = model(input_values, mask_time_indices=mask_time_indices)
1485
-
1486
- >>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states)
1487
- >>> cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
1488
-
1489
- >>> # show that cosine similarity is much higher than random
1490
- >>> cosine_sim[mask_time_indices.to(torch.bool)].mean() > 0.5
1491
- tensor(True)
1492
-
1493
- >>> # for contrastive loss training model should be put into train mode
1494
- >>> model = model.train()
1495
- >>> loss = model(
1496
- ... input_values, mask_time_indices=mask_time_indices, sampled_negative_indices=sampled_negative_indices
1497
- ... ).loss
1498
- ```"""
1499
-
1500
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1501
-
1502
- if mask_time_indices is not None:
1503
- mask_time_indices = mask_time_indices.to(torch.bool)
1504
-
1505
- outputs = self.wav2vec2_conformer(
1506
- input_values,
1507
- attention_mask=attention_mask,
1508
- output_attentions=output_attentions,
1509
- output_hidden_states=output_hidden_states,
1510
- mask_time_indices=mask_time_indices,
1511
- return_dict=return_dict,
1512
- )
1513
-
1514
- # 1. project all transformed features (including masked) to final vq dim
1515
- transformer_features = self.project_hid(outputs[0])
1516
-
1517
- # 2. quantize all (unmasked) extracted features and project to final vq dim
1518
- extract_features = self.dropout_features(outputs[1])
1519
-
1520
- if attention_mask is not None:
1521
- # compute reduced attention_mask correponding to feature vectors
1522
- attention_mask = self._get_feature_vector_attention_mask(
1523
- extract_features.shape[1], attention_mask, add_adapter=False
1524
- )
1525
-
1526
- quantized_features, codevector_perplexity = self.quantizer(
1527
- extract_features, mask_time_indices=mask_time_indices
1528
- )
1529
- quantized_features = self.project_q(quantized_features)
1530
-
1531
- loss = contrastive_loss = diversity_loss = None
1532
- if sampled_negative_indices is not None:
1533
- batch_size, sequence_length, hidden_size = quantized_features.shape
1534
-
1535
- # for training, we sample negatives
1536
- # 3. sample K negatives (distractors) quantized states for contrastive loss
1537
- # if attention_mask is passed, make sure that padded feature vectors cannot be sampled
1538
- # sample negative quantized vectors BTC => (BxT)C
1539
- negative_quantized_features = quantized_features.view(-1, hidden_size)[
1540
- sampled_negative_indices.long().view(-1)
1541
- ]
1542
- negative_quantized_features = negative_quantized_features.view(
1543
- batch_size, sequence_length, -1, hidden_size
1544
- ).permute(2, 0, 1, 3)
1545
-
1546
- # 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa`
1547
- # of equation (3) in https://arxiv.org/pdf/2006.11477.pdf
1548
- logits = self.compute_contrastive_logits(
1549
- quantized_features[None, :],
1550
- negative_quantized_features,
1551
- transformer_features,
1552
- self.config.contrastive_logits_temperature,
1553
- )
1554
-
1555
- # 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low),
1556
- # its cosine similarity will be masked
1557
- neg_is_pos = (quantized_features == negative_quantized_features).all(-1)
1558
-
1559
- if neg_is_pos.any():
1560
- logits[1:][neg_is_pos] = float("-inf")
1561
-
1562
- # 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) =
1563
- # -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa))
1564
- logits = logits.transpose(0, 2).reshape(-1, logits.size(0))
1565
- target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten()
1566
-
1567
- contrastive_loss = nn.functional.cross_entropy(logits.float(), target, reduction="sum")
1568
- # 7. compute diversity loss: \mathbf{L}_d
1569
- num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups
1570
- diversity_loss = ((num_codevectors - codevector_perplexity) / num_codevectors) * mask_time_indices.sum()
1571
-
1572
- # 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d
1573
- loss = contrastive_loss + self.config.diversity_loss_weight * diversity_loss
1574
-
1575
- if not return_dict:
1576
- if loss is not None:
1577
- return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
1578
- return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
1579
-
1580
- return Wav2Vec2ConformerForPreTrainingOutput(
1581
- loss=loss,
1582
- projected_states=transformer_features,
1583
- projected_quantized_states=quantized_features,
1584
- codevector_perplexity=codevector_perplexity,
1585
- hidden_states=outputs.hidden_states,
1586
- attentions=outputs.attentions,
1587
- contrastive_loss=contrastive_loss,
1588
- diversity_loss=diversity_loss,
1589
- )
1590
-
1591
-
1592
- @add_start_docstrings(
1593
- """Wav2Vec2Conformer Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
1594
- WAV2VEC2_CONFORMER_START_DOCSTRING,
1595
- )
1596
- class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
1597
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
1598
- def __init__(self, config):
1599
- super().__init__(config)
1600
-
1601
- self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
1602
- self.dropout = nn.Dropout(config.final_dropout)
1603
-
1604
- if config.vocab_size is None:
1605
- raise ValueError(
1606
- f"You are trying to instantiate {self.__class__} with a configuration that "
1607
- "does not define the vocabulary size of the language model head. Please "
1608
- "instantiate the model as follows: `Wav2Vec2ConformerForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
1609
- "or define `vocab_size` of your model's configuration."
1610
- )
1611
- output_hidden_size = (
1612
- config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
1613
- )
1614
- self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
1615
-
1616
- # Initialize weights and apply final processing
1617
- self.post_init()
1618
-
1619
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
1620
- def freeze_feature_encoder(self):
1621
- """
1622
- Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1623
- not be updated during training.
1624
- """
1625
- self.wav2vec2_conformer.feature_extractor._freeze_parameters()
1626
-
1627
- @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
1628
- @add_code_sample_docstrings(
1629
- checkpoint=_CHECKPOINT_FOR_DOC,
1630
- output_type=CausalLMOutput,
1631
- config_class=_CONFIG_FOR_DOC,
1632
- expected_output=_CTC_EXPECTED_OUTPUT,
1633
- expected_loss=_CTC_EXPECTED_LOSS,
1634
- )
1635
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
1636
- def forward(
1637
- self,
1638
- input_values: Optional[torch.Tensor],
1639
- attention_mask: Optional[torch.Tensor] = None,
1640
- output_attentions: Optional[bool] = None,
1641
- output_hidden_states: Optional[bool] = None,
1642
- return_dict: Optional[bool] = None,
1643
- labels: Optional[torch.Tensor] = None,
1644
- ) -> Union[Tuple, CausalLMOutput]:
1645
- r"""
1646
- labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
1647
- Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
1648
- the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
1649
- All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
1650
- config.vocab_size - 1]`.
1651
- """
1652
-
1653
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1654
-
1655
- outputs = self.wav2vec2_conformer(
1656
- input_values,
1657
- attention_mask=attention_mask,
1658
- output_attentions=output_attentions,
1659
- output_hidden_states=output_hidden_states,
1660
- return_dict=return_dict,
1661
- )
1662
-
1663
- hidden_states = outputs[0]
1664
- hidden_states = self.dropout(hidden_states)
1665
-
1666
- logits = self.lm_head(hidden_states)
1667
-
1668
- loss = None
1669
- if labels is not None:
1670
- if labels.max() >= self.config.vocab_size:
1671
- raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
1672
-
1673
- # retrieve loss input_lengths from attention_mask
1674
- attention_mask = (
1675
- attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
1676
- )
1677
- input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
1678
-
1679
- # assuming that padded tokens are filled with -100
1680
- # when not being attended to
1681
- labels_mask = labels >= 0
1682
- target_lengths = labels_mask.sum(-1)
1683
- flattened_targets = labels.masked_select(labels_mask)
1684
-
1685
- # ctc_loss doesn't support fp16
1686
- log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
1687
-
1688
- with torch.backends.cudnn.flags(enabled=False):
1689
- loss = nn.functional.ctc_loss(
1690
- log_probs,
1691
- flattened_targets,
1692
- input_lengths,
1693
- target_lengths,
1694
- blank=self.config.pad_token_id,
1695
- reduction=self.config.ctc_loss_reduction,
1696
- zero_infinity=self.config.ctc_zero_infinity,
1697
- )
1698
-
1699
- if not return_dict:
1700
- output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
1701
- return ((loss,) + output) if loss is not None else output
1702
-
1703
- return CausalLMOutput(
1704
- loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
1705
- )
1706
-
1707
-
1708
- @add_start_docstrings(
1709
- """
1710
- Wav2Vec2Conformer Model with a sequence classification head on top (a linear layer over the pooled output) for
1711
- tasks like SUPERB Keyword Spotting.
1712
- """,
1713
- WAV2VEC2_CONFORMER_START_DOCSTRING,
1714
- )
1715
- class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ConformerPreTrainedModel):
1716
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
1717
- def __init__(self, config):
1718
- super().__init__(config)
1719
-
1720
- if hasattr(config, "add_adapter") and config.add_adapter:
1721
- raise ValueError(
1722
- "Sequence classification does not support the use of Wav2Vec2Conformer adapters (config.add_adapter=True)"
1723
- )
1724
- self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
1725
- num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
1726
- if config.use_weighted_layer_sum:
1727
- self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
1728
- self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
1729
- self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
1730
-
1731
- # Initialize weights and apply final processing
1732
- self.post_init()
1733
-
1734
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
1735
- def freeze_feature_encoder(self):
1736
- """
1737
- Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1738
- not be updated during training.
1739
- """
1740
- self.wav2vec2_conformer.feature_extractor._freeze_parameters()
1741
-
1742
- def freeze_base_model(self):
1743
- """
1744
- Calling this function will disable the gradient computation for the base model so that its parameters will not
1745
- be updated during training. Only the classification head will be updated.
1746
- """
1747
- for param in self.wav2vec2_conformer.parameters():
1748
- param.requires_grad = False
1749
-
1750
- @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
1751
- @add_code_sample_docstrings(
1752
- checkpoint=_CHECKPOINT_FOR_DOC,
1753
- output_type=SequenceClassifierOutput,
1754
- config_class=_CONFIG_FOR_DOC,
1755
- modality="audio",
1756
- )
1757
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
1758
- def forward(
1759
- self,
1760
- input_values: Optional[torch.Tensor],
1761
- attention_mask: Optional[torch.Tensor] = None,
1762
- output_attentions: Optional[bool] = None,
1763
- output_hidden_states: Optional[bool] = None,
1764
- return_dict: Optional[bool] = None,
1765
- labels: Optional[torch.Tensor] = None,
1766
- ) -> Union[Tuple, SequenceClassifierOutput]:
1767
- r"""
1768
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1769
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1770
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1771
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1772
- """
1773
-
1774
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1775
- output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
1776
-
1777
- outputs = self.wav2vec2_conformer(
1778
- input_values,
1779
- attention_mask=attention_mask,
1780
- output_attentions=output_attentions,
1781
- output_hidden_states=output_hidden_states,
1782
- return_dict=return_dict,
1783
- )
1784
-
1785
- if self.config.use_weighted_layer_sum:
1786
- hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
1787
- hidden_states = torch.stack(hidden_states, dim=1)
1788
- norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
1789
- hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
1790
- else:
1791
- hidden_states = outputs[0]
1792
-
1793
- hidden_states = self.projector(hidden_states)
1794
- if attention_mask is None:
1795
- pooled_output = hidden_states.mean(dim=1)
1796
- else:
1797
- padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
1798
- hidden_states[~padding_mask] = 0.0
1799
- pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
1800
-
1801
- logits = self.classifier(pooled_output)
1802
-
1803
- loss = None
1804
- if labels is not None:
1805
- loss_fct = CrossEntropyLoss()
1806
- loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
1807
-
1808
- if not return_dict:
1809
- output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
1810
- return ((loss,) + output) if loss is not None else output
1811
-
1812
- return SequenceClassifierOutput(
1813
- loss=loss,
1814
- logits=logits,
1815
- hidden_states=outputs.hidden_states,
1816
- attentions=outputs.attentions,
1817
- )
1818
-
1819
-
1820
- @add_start_docstrings(
1821
- """
1822
- Wav2Vec2Conformer Model with a frame classification head on top for tasks like Speaker Diarization.
1823
- """,
1824
- WAV2VEC2_CONFORMER_START_DOCSTRING,
1825
- )
1826
- class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedModel):
1827
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
1828
- def __init__(self, config):
1829
- super().__init__(config)
1830
-
1831
- if hasattr(config, "add_adapter") and config.add_adapter:
1832
- raise ValueError(
1833
- "Audio frame classification does not support the use of Wav2Vec2Conformer adapters (config.add_adapter=True)"
1834
- )
1835
- self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
1836
- num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
1837
- if config.use_weighted_layer_sum:
1838
- self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
1839
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1840
- self.num_labels = config.num_labels
1841
-
1842
- self.init_weights()
1843
-
1844
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
1845
- def freeze_feature_encoder(self):
1846
- """
1847
- Calling this function will disable the gradient computation for the feature encoder so that its parameter will
1848
- not be updated during training.
1849
- """
1850
- self.wav2vec2_conformer.feature_extractor._freeze_parameters()
1851
-
1852
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_base_model with wav2vec2->wav2vec2_conformer
1853
- def freeze_base_model(self):
1854
- """
1855
- Calling this function will disable the gradient computation for the base model so that its parameters will not
1856
- be updated during training. Only the classification head will be updated.
1857
- """
1858
- for param in self.wav2vec2_conformer.parameters():
1859
- param.requires_grad = False
1860
-
1861
- @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
1862
- @add_code_sample_docstrings(
1863
- checkpoint=_CHECKPOINT_FOR_DOC,
1864
- output_type=TokenClassifierOutput,
1865
- config_class=_CONFIG_FOR_DOC,
1866
- modality="audio",
1867
- )
1868
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.forward with wav2vec2->wav2vec2_conformer
1869
- def forward(
1870
- self,
1871
- input_values: Optional[torch.Tensor],
1872
- attention_mask: Optional[torch.Tensor] = None,
1873
- labels: Optional[torch.Tensor] = None,
1874
- output_attentions: Optional[bool] = None,
1875
- output_hidden_states: Optional[bool] = None,
1876
- return_dict: Optional[bool] = None,
1877
- ) -> Union[Tuple, TokenClassifierOutput]:
1878
- r"""
1879
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1880
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1881
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1882
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1883
- """
1884
-
1885
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1886
- output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
1887
-
1888
- outputs = self.wav2vec2_conformer(
1889
- input_values,
1890
- attention_mask=attention_mask,
1891
- output_attentions=output_attentions,
1892
- output_hidden_states=output_hidden_states,
1893
- return_dict=return_dict,
1894
- )
1895
-
1896
- if self.config.use_weighted_layer_sum:
1897
- hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
1898
- hidden_states = torch.stack(hidden_states, dim=1)
1899
- norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
1900
- hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
1901
- else:
1902
- hidden_states = outputs[0]
1903
-
1904
- logits = self.classifier(hidden_states)
1905
-
1906
- loss = None
1907
- if labels is not None:
1908
- loss_fct = CrossEntropyLoss()
1909
- loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
1910
-
1911
- if not return_dict:
1912
- output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
1913
- return output
1914
-
1915
- return TokenClassifierOutput(
1916
- loss=loss,
1917
- logits=logits,
1918
- hidden_states=outputs.hidden_states,
1919
- attentions=outputs.attentions,
1920
- )
1921
-
1922
-
1923
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss
1924
- class AMSoftmaxLoss(nn.Module):
1925
- def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
1926
- super(AMSoftmaxLoss, self).__init__()
1927
- self.scale = scale
1928
- self.margin = margin
1929
- self.num_labels = num_labels
1930
- self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
1931
- self.loss = nn.CrossEntropyLoss()
1932
-
1933
- def forward(self, hidden_states, labels):
1934
- labels = labels.flatten()
1935
- weight = nn.functional.normalize(self.weight, dim=0)
1936
- hidden_states = nn.functional.normalize(hidden_states, dim=1)
1937
- cos_theta = torch.mm(hidden_states, weight)
1938
- psi = cos_theta - self.margin
1939
-
1940
- onehot = nn.functional.one_hot(labels, self.num_labels)
1941
- logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
1942
- loss = self.loss(logits, labels)
1943
-
1944
- return loss
1945
-
1946
-
1947
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer
1948
- class TDNNLayer(nn.Module):
1949
- def __init__(self, config, layer_id=0):
1950
- super().__init__()
1951
- self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
1952
- self.out_conv_dim = config.tdnn_dim[layer_id]
1953
- self.kernel_size = config.tdnn_kernel[layer_id]
1954
- self.dilation = config.tdnn_dilation[layer_id]
1955
-
1956
- self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
1957
- self.activation = nn.ReLU()
1958
-
1959
- def forward(self, hidden_states):
1960
- hidden_states = hidden_states.unsqueeze(1)
1961
- hidden_states = nn.functional.unfold(
1962
- hidden_states,
1963
- (self.kernel_size, self.in_conv_dim),
1964
- stride=(1, self.in_conv_dim),
1965
- dilation=(self.dilation, 1),
1966
- )
1967
- hidden_states = hidden_states.transpose(1, 2)
1968
- hidden_states = self.kernel(hidden_states)
1969
-
1970
- hidden_states = self.activation(hidden_states)
1971
- return hidden_states
1972
-
1973
-
1974
- @add_start_docstrings(
1975
- """
1976
- Wav2Vec2Conformer Model with an XVector feature extraction head on top for tasks like Speaker Verification.
1977
- """,
1978
- WAV2VEC2_CONFORMER_START_DOCSTRING,
1979
- )
1980
- class Wav2Vec2ConformerForXVector(Wav2Vec2ConformerPreTrainedModel):
1981
- def __init__(self, config):
1982
- super().__init__(config)
1983
-
1984
- self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
1985
- num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
1986
- if config.use_weighted_layer_sum:
1987
- self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
1988
- self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
1989
-
1990
- tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
1991
- self.tdnn = nn.ModuleList(tdnn_layers)
1992
-
1993
- self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
1994
- self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
1995
-
1996
- self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
1997
-
1998
- self.init_weights()
1999
-
2000
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
2001
- def freeze_feature_encoder(self):
2002
- """
2003
- Calling this function will disable the gradient computation for the feature encoder so that its parameter will
2004
- not be updated during training.
2005
- """
2006
- self.wav2vec2_conformer.feature_extractor._freeze_parameters()
2007
-
2008
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_base_model with wav2vec2->wav2vec2_conformer
2009
- def freeze_base_model(self):
2010
- """
2011
- Calling this function will disable the gradient computation for the base model so that its parameters will not
2012
- be updated during training. Only the classification head will be updated.
2013
- """
2014
- for param in self.wav2vec2_conformer.parameters():
2015
- param.requires_grad = False
2016
-
2017
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector._get_tdnn_output_lengths with wav2vec2->wav2vec2_conformer
2018
- def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
2019
- """
2020
- Computes the output length of the TDNN layers
2021
- """
2022
-
2023
- def _conv_out_length(input_length, kernel_size, stride):
2024
- # 1D convolutional layer output length formula taken
2025
- # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
2026
- return (input_length - kernel_size) // stride + 1
2027
-
2028
- for kernel_size in self.config.tdnn_kernel:
2029
- input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
2030
-
2031
- return input_lengths
2032
-
2033
- @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
2034
- @add_code_sample_docstrings(
2035
- checkpoint=_CHECKPOINT_FOR_DOC,
2036
- output_type=XVectorOutput,
2037
- config_class=_CONFIG_FOR_DOC,
2038
- modality="audio",
2039
- )
2040
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
2041
- def forward(
2042
- self,
2043
- input_values: Optional[torch.Tensor],
2044
- attention_mask: Optional[torch.Tensor] = None,
2045
- output_attentions: Optional[bool] = None,
2046
- output_hidden_states: Optional[bool] = None,
2047
- return_dict: Optional[bool] = None,
2048
- labels: Optional[torch.Tensor] = None,
2049
- ) -> Union[Tuple, XVectorOutput]:
2050
- r"""
2051
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
2052
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
2053
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
2054
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
2055
- """
2056
-
2057
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2058
- output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
2059
-
2060
- outputs = self.wav2vec2_conformer(
2061
- input_values,
2062
- attention_mask=attention_mask,
2063
- output_attentions=output_attentions,
2064
- output_hidden_states=output_hidden_states,
2065
- return_dict=return_dict,
2066
- )
2067
-
2068
- if self.config.use_weighted_layer_sum:
2069
- hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
2070
- hidden_states = torch.stack(hidden_states, dim=1)
2071
- norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
2072
- hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
2073
- else:
2074
- hidden_states = outputs[0]
2075
-
2076
- hidden_states = self.projector(hidden_states)
2077
-
2078
- for tdnn_layer in self.tdnn:
2079
- hidden_states = tdnn_layer(hidden_states)
2080
-
2081
- # Statistic Pooling
2082
- if attention_mask is None:
2083
- mean_features = hidden_states.mean(dim=1)
2084
- std_features = hidden_states.std(dim=1)
2085
- else:
2086
- feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
2087
- tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
2088
- mean_features = []
2089
- std_features = []
2090
- for i, length in enumerate(tdnn_output_lengths):
2091
- mean_features.append(hidden_states[i, :length].mean(dim=0))
2092
- std_features.append(hidden_states[i, :length].std(dim=0))
2093
- mean_features = torch.stack(mean_features)
2094
- std_features = torch.stack(std_features)
2095
- statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
2096
-
2097
- output_embeddings = self.feature_extractor(statistic_pooling)
2098
- logits = self.classifier(output_embeddings)
2099
-
2100
- loss = None
2101
- if labels is not None:
2102
- loss = self.objective(logits, labels)
2103
-
2104
- if not return_dict:
2105
- output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
2106
- return ((loss,) + output) if loss is not None else output
2107
-
2108
- return XVectorOutput(
2109
- loss=loss,
2110
- logits=logits,
2111
- embeddings=output_embeddings,
2112
- hidden_states=outputs.hidden_states,
2113
- attentions=outputs.attentions,
2114
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MuCodec/muq_dev/muq_fairseq/models/muq/modules/random_quantizer.py DELETED
@@ -1,68 +0,0 @@
1
- import torch
2
- from torch import nn, einsum
3
- from einops import rearrange
4
-
5
-
6
- class RandomProjectionQuantizer(nn.Module):
7
- """
8
- Random projection and codebook lookup module
9
-
10
- Some code is borrowed from:
11
- https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/random_projection_quantizer.py
12
- But I did normalization using pre-computed global mean & variance instead of using layer norm.
13
- """
14
-
15
- def __init__(
16
- self,
17
- input_dim,
18
- codebook_dim,
19
- codebook_size,
20
- seed=142,
21
- ):
22
- super().__init__()
23
-
24
- # random seed
25
- torch.manual_seed(seed)
26
-
27
- # randomly initialized projection
28
- random_projection = torch.empty(input_dim, codebook_dim)
29
- nn.init.xavier_normal_(random_projection)
30
- self.register_buffer("random_projection", random_projection)
31
-
32
- # randomly initialized codebook
33
- codebook = torch.empty(codebook_size, codebook_dim)
34
- nn.init.normal_(codebook)
35
- self.register_buffer("codebook", codebook)
36
-
37
- def codebook_lookup(self, x):
38
- # reshape
39
- b = x.shape[0]
40
- x = rearrange(x, "b n e -> (b n) e")
41
-
42
- # L2 normalization
43
- normalized_x = nn.functional.normalize(x, dim=1, p=2)
44
- normalized_codebook = nn.functional.normalize(self.codebook, dim=1, p=2)
45
-
46
- # compute distances
47
- distances = torch.cdist(normalized_codebook, normalized_x)
48
-
49
- # get nearest
50
- nearest_indices = torch.argmin(distances, dim=0)
51
-
52
- # reshape
53
- xq = rearrange(nearest_indices, "(b n) -> b n", b=b)
54
-
55
- return xq
56
-
57
- @torch.no_grad()
58
- def forward(self, x):
59
- # always eval
60
- self.eval()
61
-
62
- # random projection [batch, length, input_dim] -> [batch, length, codebook_dim]
63
- x = einsum("b n d, d e -> b n e", x, self.random_projection)
64
-
65
- # codebook lookup
66
- xq = self.codebook_lookup(x)
67
-
68
- return xq
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MuCodec/muq_dev/muq_fairseq/models/muq/muq_model.py DELETED
@@ -1,139 +0,0 @@
1
- try:
2
- from .model.muq import MuQ
3
- except:
4
- import sys, os
5
- sys.path.append(os.path.dirname(os.path.abspath(__file__)))
6
- from model.muq import MuQ
7
- try:
8
- from fairseq.fairseq.dataclass import FairseqDataclass
9
- from fairseq.fairseq.models import BaseFairseqModel, register_model
10
- from fairseq.fairseq.tasks.fairseq_task import FairseqTask
11
- except:
12
- from fairseq.dataclass import FairseqDataclass
13
- from fairseq.models import BaseFairseqModel, register_model
14
- from fairseq.tasks.fairseq_task import FairseqTask
15
-
16
- from dataclasses import dataclass, field
17
- from typing import List, Tuple, Optional
18
- import torch
19
-
20
- from logging import getLogger
21
-
22
- logger = getLogger(__name__)
23
-
24
- @dataclass
25
- class MuQConfig(FairseqDataclass):
26
- label_rate:int = field(default=25)
27
- num_codebooks:int = field(default=1)
28
- codebook_dim:int = field(default=16)
29
- codebook_size:int = field(default=4096)
30
- features:List[str] = field(default_factory=lambda:["melspec_2048"])
31
- hop_length:int = field(default=240)
32
- n_mels:int = field(default=128)
33
- conv_dim:int = field(default=512)
34
- encoder_dim:int = field(default=1024)
35
- encoder_depth:int = field(default=12)
36
- mask_hop:float = field(default=0.4)
37
- mask_prob:float = field(default=0.6)
38
- is_flash:bool = field(default=False)
39
- stat_path:Optional[str] = field(default=None)
40
- model_path:Optional[str] = field(default=None)
41
- w2v2_config_path:Optional[str] = field(default=None)
42
- use_rvq_target:bool = field(default=False)
43
- use_vq_target:bool = field(default=False)
44
- rvq_ckpt_path: Optional[str] = field(default=None)
45
- recon_loss_ratio: Optional[float] = field(default=None)
46
- resume_checkpoint: Optional[str] = None
47
- use_hubert_masking_strategy:bool = field(default=False)
48
- use_hubert_featurizer:bool = field(default=False)
49
- hubert_conv_feature_layers:str = field(default_factory=lambda:"[(512,10,5)] + [(512,3,2)] * 3 + [(512,3,3)] + [(512,2,2)] * 2")
50
- rvq_n_codebooks:int = field(default=8)
51
- rvq_multi_layer_num:int = field(default=1)
52
- use_encodec_target:bool = field(default=False)
53
-
54
- SAMPLE_RATE = 24_000
55
-
56
- @register_model("muq", dataclass=MuQConfig)
57
- class MuQModel(BaseFairseqModel):
58
- def __init__(self, cfg: MuQConfig, task_cfg: FairseqTask):
59
- super().__init__()
60
- self.cfg = cfg
61
- self.model = MuQ(
62
- num_codebooks=cfg.num_codebooks,
63
- codebook_dim=cfg.codebook_dim,
64
- codebook_size=cfg.codebook_size,
65
- features=cfg.features,
66
- n_mels=cfg.n_mels,
67
- conv_dim=cfg.conv_dim,
68
- encoder_dim=cfg.encoder_dim,
69
- encoder_depth=cfg.encoder_depth,
70
- mask_hop=cfg.mask_hop,
71
- mask_prob=cfg.mask_prob,
72
- is_flash=cfg.is_flash,
73
- stat_path=cfg.stat_path,
74
- model_path=cfg.model_path,
75
- w2v2_config_path=cfg.w2v2_config_path,
76
- use_rvq_target=cfg.use_rvq_target,
77
- use_vq_target=cfg.use_vq_target,
78
- rvq_ckpt_path=cfg.rvq_ckpt_path,
79
- recon_loss_ratio=cfg.recon_loss_ratio,
80
- label_rate=cfg.label_rate,
81
- use_hubert_masking_strategy=cfg.use_hubert_masking_strategy,
82
- use_hubert_featurizer=cfg.use_hubert_featurizer,
83
- hubert_conv_feature_layers=cfg.hubert_conv_feature_layers,
84
- rvq_n_codebooks=cfg.rvq_n_codebooks,
85
- rvq_multi_layer_num=cfg.rvq_multi_layer_num,
86
- use_encodec_target=cfg.use_encodec_target,
87
- )
88
-
89
- def forward(
90
- self,
91
- source: torch.Tensor, # B,L
92
- features_only: bool = False,
93
- label = None, # pre-extracted labeks, dim is [Batch, N_Codebook, SeqLen]
94
- **kwargs,
95
- ):
96
- source = source[..., :int((source.shape[-1]//(SAMPLE_RATE//self.cfg.label_rate))*(SAMPLE_RATE//self.cfg.label_rate)) ]
97
- if features_only:
98
- if 'attention_mask' in kwargs:
99
- attention_mask = kwargs['attention_mask']
100
- elif 'padding_mask' in kwargs:
101
- attention_mask = ~kwargs['padding_mask'].bool()
102
- else:
103
- attention_mask = None
104
- _, hidden_states = self.model.get_predictions(source, attention_mask=attention_mask, is_features_only=True)
105
- result = {
106
- "layer_results": hidden_states
107
- }
108
- return result
109
- else:
110
- result = {}
111
- logits, hidden_emb, losses, accuracies = self.model(source, label=label)
112
- result["losses"] = losses
113
- result["accuracies"] = accuracies
114
- result["logits"] = logits
115
- result["hidden_emb"] = hidden_emb
116
- for k, v in losses.items():
117
- result[k] = v
118
- return result
119
-
120
- @classmethod
121
- def build_model(cls, cfg: MuQConfig, task: FairseqTask):
122
- """Build a new model instance."""
123
-
124
- model = MuQModel(cfg, task.cfg)
125
- import numpy as np
126
- s = 0
127
- for param in model.parameters():
128
- s += np.product(param.size())
129
- # print('# of parameters: '+str(s/1024.0/1024.0))
130
-
131
- if cfg.get("resume_checkpoint", None):
132
- print("Loading checkpoint from {}".format(cfg.resume_checkpoint))
133
- model.load_state_dict(torch.load(cfg.resume_checkpoint)['model'], strict=False)
134
-
135
- return model
136
-
137
- def get_losses(self, result, batch):
138
- return result['losses']
139
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MuCodec/muq_dev/muq_fairseq/tasks/muq_pretraining.py DELETED
@@ -1,354 +0,0 @@
1
- # Copyright (c) 2017-present, Facebook, Inc.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the LICENSE file in
5
- # the root directory of this source tree. An additional grant of patent rights
6
- # can be found in the PATENTS file in the same directory.
7
-
8
- import logging
9
- import os
10
- import sys
11
- from typing import Dict, List, Optional, Tuple
12
-
13
- import numpy as np
14
- import torch
15
-
16
- from dataclasses import dataclass, field
17
- from fairseq.data import Dictionary, HubertDataset
18
- from fairseq.dataclass.configs import FairseqDataclass
19
- from fairseq.tasks import register_task
20
- from fairseq.tasks.fairseq_task import FairseqTask
21
- from omegaconf import MISSING
22
-
23
- from ..data.mert_dataset import MERTDataset
24
- from ..data.ark_dataset import ArkDataset
25
-
26
- logger = logging.getLogger(__name__)
27
-
28
-
29
- class LabelEncoder(object):
30
- def __init__(self, dictionary: Dictionary) -> None:
31
- self.dictionary = dictionary
32
-
33
- def __call__(self, label: str) -> List[str]:
34
- # encode_line return a torch.IntTensor, should be all 1 for vanila HuBERT
35
- return self.dictionary.encode_line(
36
- label,
37
- append_eos=False,
38
- add_if_not_exist=False,
39
- )
40
- class PaddedNumpyLabelEncoder(object):
41
- def __init__(self):
42
- # self.dictionary = dictionary
43
- pass
44
-
45
- def __call__(self, label):
46
- t = torch.IntTensor(np.asarray(label))
47
- t = t[t>=0] # remove padded -1 values at the end
48
- return t
49
-
50
- @dataclass
51
- class MuQPretrainingConfig(FairseqDataclass):
52
- data: str = field(default=MISSING, metadata={"help": "path to data directory"})
53
- sharding_data: int = field(
54
- default=-1,
55
- metadata={
56
- "help": "set this para >1 to use sharding dataset to prevent OOM"
57
- "prepare data tsv and label files by adding postfix for sharding 64 like:"
58
- "train_28_64.tsv and train_28_64.encodec_6"
59
- },
60
- )
61
- load_random_data_shard: bool = field(
62
- default=True,
63
- metadata={
64
- "help": "whether to laod shards randomly or in order when use sharding_data"
65
- },
66
- )
67
- fine_tuning: bool = field(
68
- default=False, metadata={"help": "set to true if fine-tuning Hubert"}
69
- )
70
- labels: List[str] = field(
71
- default_factory=lambda: ["ltr"],
72
- metadata={
73
- "help": (
74
- "extension of the label files to load, frame-level labels for"
75
- " pre-training, and sequence-level label for fine-tuning"
76
- )
77
- },
78
- )
79
- label_dir: Optional[str] = field(
80
- default=None,
81
- metadata={
82
- "help": "if set, looks for labels in this directory instead",
83
- },
84
- )
85
- label_scp_path: Optional[str] = field(
86
- default=None,
87
- metadata={
88
- 'help': 'if set, load label from scp file'
89
- }
90
- )
91
- label_scp_clip_duration: float = field(
92
- default=-1,
93
- metadata={
94
- 'help': 'clip duration for loading scp label. if set to -1, this will not make effect.'
95
- }
96
- )
97
- label_rate: float = field(
98
- default=-1.0,
99
- metadata={"help": "label frame rate. -1.0 for sequence label"},
100
- )
101
- sample_rate: int = field(
102
- default=16_000,
103
- metadata={
104
- "help": "target sample rate. audio files will be up/down "
105
- "sampled to this rate"
106
- },
107
- )
108
- normalize: bool = field(
109
- default=False,
110
- metadata={"help": "if set, normalizes input to have 0 mean and unit variance"},
111
- )
112
- enable_padding: bool = field(
113
- default=False,
114
- metadata={"help": "pad shorter samples instead of cropping"},
115
- )
116
- max_keep_size: Optional[int] = field(
117
- default=None,
118
- metadata={"help": "exclude sample longer than this"},
119
- )
120
- max_sample_size: Optional[int] = field(
121
- default=None,
122
- metadata={"help": "max sample size to crop to for batching"},
123
- )
124
- min_sample_size: Optional[int] = field(
125
- default=None,
126
- metadata={"help": "min sample size to crop to for batching"},
127
- )
128
- single_target: Optional[bool] = field(
129
- default=False,
130
- metadata={
131
- "help": "if set, AddTargetDatasets outputs same keys " "as AddTargetDataset"
132
- },
133
- )
134
- random_crop: Optional[bool] = field(
135
- default=True,
136
- metadata={"help": "always crop from the beginning if false"},
137
- )
138
- pad_audio: Optional[bool] = field(
139
- default=False,
140
- metadata={"help": "pad audio to the longest one in the batch if true"},
141
- )
142
-
143
- store_labels: Optional[bool] = field(
144
- default=False,
145
- metadata={"help": "whether to load all of the label into memory"},
146
- )
147
-
148
- numpy_memmap_label: Optional[bool] = field(
149
- default=False,
150
- metadata={"help": "whether the label file is saved as a numpy file, each line is ended with padding -1"},
151
- )
152
-
153
- augmentation_effects: Optional[str] = field(
154
- default="[]",
155
- metadata={
156
- "help": (
157
- "a list of effects that might apply to the audios"
158
- "example: \"['random_mute', 'random_Gaussian', 'reverse_polarity']\" "
159
- "supported: random_mute,"
160
- "todo: "
161
- )
162
- },
163
- )
164
- augmentation_probs: Optional[str] = field(
165
- default="[]",
166
- metadata={
167
- "help": (
168
- "the corresponding probabilities for the data augmentation effects"
169
- "example: \"[0.1, 0.5, 0.8]\" "
170
- "the sum is not necessarily need to be 1.0, and multiple effects can be applied to the same audio"
171
- )
172
- },
173
- )
174
-
175
- # inbatch_noise_augment_len_range: Optional[List[int]] = field(
176
- # default_factory=lambda: [8000, 24000],
177
- # default = [8000, 24000],
178
- inbatch_noise_augment_len_range: Optional[str] = field(
179
- default = "[8000, 24000]",
180
- metadata={
181
- "help": (
182
- "the range of length of the mix-up noise augmentation, unit in smaples"
183
- )
184
- },
185
- )
186
- # inbatch_noise_augment_number_range: Optional[List[int]] = field(
187
- # default_factory=lambda: [1, 3],
188
- # default = [1, 3],
189
- inbatch_noise_augment_number_range: Optional[str] = field(
190
- default = "[1, 3]",
191
- metadata={
192
- "help": (
193
- "the range of numbers of the mix-up noise augmentation"
194
- )
195
- },
196
- )
197
- inbatch_noise_augment_volume: float = field(
198
- default = 1.0,
199
- metadata={
200
- "help": (
201
- "the coefficient used to modify the volume of the noise audios wavs"
202
- )
203
- },
204
- )
205
- dynamic_crops: Optional[str] = field(
206
- default="[]",
207
- metadata={
208
- "help": (
209
- "used to set the maximum audio length setting, for training"
210
- "example: \"[1, 2, 3, 4, 5, 10]\" "
211
- )
212
- },
213
- )
214
- dynamic_crops_epoches: Optional[str] = field(
215
- default="[]",
216
- metadata={
217
- "help": (
218
- "used to set training epoches of changing the maximum audio length"
219
- "example: \"[1, 10, 20, 40, 80, 160,]\" "
220
- "then len need to be equal to len(dynamic_crops)"
221
- )
222
- },
223
- )
224
-
225
- cqt_loss_bin_dataloader: Optional[int] = field(
226
- default=-1,
227
- metadata={
228
- "help": (
229
- "use this parameter to prepare cqt prediction objective in dataloader"
230
- )
231
- },
232
- )
233
-
234
- clip_secs: int = field(
235
- default=5,
236
- metadata={
237
- "help": "clip secs for each audio"
238
- }
239
- )
240
-
241
- dataset_shuffle: bool = field(
242
- default=True,
243
- metadata={
244
- "help": (
245
- "dataset shuffle when sample a batch"
246
- )
247
- },
248
- )
249
-
250
-
251
- @register_task("muq_pretraining", dataclass=MuQPretrainingConfig)
252
- class MuQPretrainingTask(FairseqTask):
253
-
254
- cfg: MuQPretrainingConfig
255
-
256
- def __init__(
257
- self,
258
- cfg: MuQPretrainingConfig,
259
- ) -> None:
260
- super().__init__(cfg)
261
-
262
- logger.info(f"current directory is {os.getcwd()}")
263
- logger.info(f"MuQPretrainingTask Config {cfg}")
264
-
265
- self.cfg = cfg
266
- self.fine_tuning = cfg.fine_tuning
267
-
268
- if cfg.fine_tuning:
269
- self.state.add_factory("target_dictionary", self.load_dictionaries)
270
- else:
271
- self.state.add_factory("dictionaries", self.load_dictionaries)
272
-
273
- self.blank_symbol = "<s>"
274
-
275
- # use eval() to pass list parameters, skirt the fairseq/torch error: Can't pickle <enum 'Choices'>: attribute lookup Choices on fairseq.dataclass.constants failed
276
- self.augmentation_effects = eval(self.cfg.augmentation_effects)
277
- self.augmentation_probs = eval(self.cfg.augmentation_probs)
278
- if len(self.augmentation_effects) > 0:
279
- assert len(self.augmentation_effects) == len(self.augmentation_probs)
280
- logger.info(f"Applying audio augmentation {self.augmentation_effects}, probabilities: {self.augmentation_probs}")
281
-
282
- self.inbatch_noise_augment_number_range = eval(self.cfg.inbatch_noise_augment_number_range)
283
- self.inbatch_noise_augment_len_range = eval(self.cfg.inbatch_noise_augment_len_range)
284
-
285
- self.max_sample_size = self.cfg.max_sample_size
286
-
287
- self.dynamic_crops = eval(self.cfg.dynamic_crops)
288
- self.dynamic_crops_epoches = eval(self.cfg.dynamic_crops_epoches)
289
- assert len(self.dynamic_crops) == len(self.dynamic_crops_epoches)
290
- if len(self.dynamic_crops) > 0:
291
- assert self.dynamic_crops_epoches[0] == 1
292
-
293
- self.cqt_loss_bin_dataloader = self.cfg.cqt_loss_bin_dataloader
294
-
295
- self.numpy_memmap_label = self.cfg.numpy_memmap_label
296
- self.store_labels = self.cfg.store_labels
297
- if self.numpy_memmap_label:
298
- assert self.store_labels
299
-
300
- @property
301
- def source_dictionary(self) -> Optional[Dictionary]:
302
- return None
303
-
304
- @property
305
- def target_dictionary(self) -> Optional[Dictionary]:
306
- return self.state.target_dictionary
307
-
308
- @property
309
- def dictionaries(self) -> List[Dictionary]:
310
- return self.state.dictionaries
311
-
312
- @classmethod
313
- def setup_task(
314
- cls, cfg: MuQPretrainingConfig, **kwargs
315
- ) -> "MuQPretrainingTask":
316
- return cls(cfg)
317
-
318
- def load_dictionaries(self):
319
- label_dir = self.cfg.data if (self.cfg.label_dir is None or self.cfg.label_dir == '') else self.cfg.label_dir
320
- print(label_dir)
321
- dictionaries = [
322
- Dictionary.load(f"{label_dir}/dict.{label}.txt")
323
- for label in self.cfg.labels
324
- ]
325
- return dictionaries[0] if self.cfg.fine_tuning else dictionaries
326
-
327
- def get_label_dir(self) -> str:
328
- if self.cfg.label_dir is None or self.cfg.label_dir=='':
329
- return self.cfg.data
330
- return self.cfg.label_dir
331
-
332
-
333
- def is_force_load_dataset(self, epoch, training_restore=False):
334
- # find the threshold that holds epoch \in [threshold, next_threshold)
335
- return (epoch in self.dynamic_crops_epoches) or training_restore or (self.cfg.sharding_data > 1)
336
-
337
-
338
- def set_dynamic_crop_max_sample(self, epoch):
339
- pass
340
-
341
- def load_dataset(self, split: str, **kwargs) -> None:
342
- pass
343
-
344
- def load_dataset_ark(self, split, **kwargs):
345
- pass
346
-
347
- def load_dataset_mert(self, split: str, **kwargs) -> None:
348
- pass
349
-
350
- def max_positions(self) -> Tuple[int, int]:
351
- return (sys.maxsize, sys.maxsize)
352
-
353
- def filter_indices_by_size(self, indices: np.array, *args, **kwargs) -> np.array:
354
- return indices
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MuCodec/muq_dev/test.py DELETED
@@ -1,22 +0,0 @@
1
- import torch
2
- from dataclasses import dataclass
3
- import fairseq
4
- import os.path as op
5
-
6
- root = op.dirname(op.abspath(__file__))
7
-
8
-
9
- @dataclass
10
- class UserDirModule:
11
- user_dir: str
12
-
13
- def load_model(model_dir, checkpoint_dir):
14
- '''Load Fairseq SSL model'''
15
-
16
- model_path = UserDirModule(model_dir)
17
- fairseq.utils.import_user_module(model_path)
18
-
19
- model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_dir], strict=False)
20
- model = model[0]
21
-
22
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MuCodec/readme.md DELETED
@@ -1,67 +0,0 @@
1
- # MuCodec: Ultra Low-Bitrate Music Codec
2
-
3
- This repository is the official code repository for MuCodec: Ultra Low-Bitrate Music Codec. You can find our paper on [arXiv] (https://arxiv.org/pdf/2409.13216). The demo page is available [here](https://xuyaoxun.github.io/MuCodec_demo/).
4
-
5
- In this repository, we provide the Mucodec model, inference scripts, and the checkpoint that has been trained on the Million Song Dataset. Specifically, we have released the model and inference code corresponding to the lowest bitrate of 0.35 kbps as mentioned in the paper, to demonstrate the effectiveness of our work.
6
-
7
-
8
- MuCodec supports 48kHz, dual-channel (stereo) audio reconstruction. If the original audio is in a different format, it will first be converted to 48kHz, dual-channel audio.
9
-
10
- ## Installation
11
-
12
- You can install the necessary dependencies using the `requirements.txt` file with Python 3.8.12:
13
-
14
- ```bash
15
- pip install -r requirements.txt
16
- ```
17
-
18
- Due to storage limitations, we have saved the model checkpoints on Hugging Face at https://huggingface.co/yaoxunxu/mucodec. You can easily download the models from Hugging Face and save them in the following directories:
19
-
20
- - Save `audioldm_48k.pth` in the `tools` folder.
21
- - Save `muq.pt` in the `muq_dev` folder.
22
- - Save `mucodec.pt` in the `ckpt` folder.
23
-
24
- Please note that all three checkpoints must be downloaded completely for the model to load correctly. The final file paths should be:
25
-
26
- ```
27
- tools/audioldm_48k.pth
28
- muq_dev/muq.pt
29
- ckpt/mucodec.pt
30
- ```
31
-
32
- The file `audioldm_48k.pth` is sourced from https://huggingface.co/haoheliu/audioldm_48k/blob/main/audioldm_48k.pth.
33
-
34
- ## Inference
35
-
36
- To run inference, use the following command:
37
-
38
- ```bash
39
- python3 generate.py
40
- ```
41
-
42
- We have provided a sample song `test.wav`, randomly sampled from the Million Song Dataset, in the `test_wav` folder. The default input path is `test_wav/test.wav`, and the output path for the reconstructed audio is `reconstruct/test.wav`.
43
-
44
- In the `generate.py` file, we have implemented several functions to facilitate the music compression and reconstruction process. You can easily obtain compressed tokens from audio using the `sound2code` function, and reconstruct the audio from tokens using the `code2sound` function.
45
-
46
- ## Note
47
-
48
- Please note that the open-sourced model was trained solely on the Million Song Dataset. Considering the quality issues of this dataset, the open-sourced model may not achieve the same performance as demonstrated in the demo. Unfortunately, due to copyright restrictions, we are unable to release the checkpoints trained on additional datasets. However, you can use your own dataset to further train the model and achieve better results.
49
-
50
- ## License
51
-
52
- The code in this repository is released under the MIT license as found in the [LICENSE](LICENSE) file.
53
-
54
- The model weights (muq.pt, mucodec.pt) in this repository are released under the CC-BY-NC 4.0 license, as detailed in the [LICENSE_weights](LICENSE_weights) file.
55
-
56
- ## Citation
57
-
58
- If you find our work useful, please cite our paper:
59
-
60
- ```bibtex
61
- @article{xu2024mucodec,
62
- title={MuCodec: Ultra Low-Bitrate Music Codec},
63
- author={Xu, Yaoxun and Chen, Hangting and Yu, Jianwei and Tan, Wei and Gu, Rongzhi and Lei, Shun and Lin, Zhiwei and Wu, Zhiyong},
64
- journal={arXiv preprint arXiv:2409.13216},
65
- year={2024}
66
- }
67
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MuCodec/reconstructed/test.wav DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:946e5815c7c3b8cab9f8eb6ca6707e821498fd59233d3ee356f6bb6f2fd2296b
3
- size 99367376
 
 
 
 
MuCodec/requirements.txt DELETED
@@ -1,335 +0,0 @@
1
- absl-py==2.0.0
2
- accelerate==0.30.1
3
- aeiou==0.0.20
4
- aiobotocore==2.13.1
5
- aiofiles==23.2.1
6
- aiohttp==3.9.3
7
- aioitertools==0.11.0
8
- aiosignal==1.3.1
9
- alias-free-torch==0.0.6
10
- altair==5.3.0
11
- annotated-types==0.6.0
12
- antlr4-python3-runtime==4.8
13
- anyio==4.3.0
14
- appdirs==1.4.4
15
- argbind==0.3.9
16
- asttokens==2.4.1
17
- astunparse==1.6.3
18
- async-timeout==4.0.3
19
- attrs==23.1.0
20
- audioread==3.0.1
21
- auraloss==0.4.0
22
- av==11.0.0
23
- backcall==0.2.0
24
- beartype==0.18.5
25
- bitarray==2.9.2
26
- bleach==6.1.0
27
- blis==0.7.11
28
- bokeh==3.1.1
29
- botocore==1.34.131
30
- braceexpand==0.1.7
31
- cachetools==5.3.2
32
- catalogue==2.0.10
33
- certifi==2023.11.17
34
- cffi==1.16.0
35
- charset-normalizer==3.3.2
36
- clean-fid==0.1.35
37
- click==8.1.7
38
- clip-anytorch==2.6.0
39
- cloudpathlib==0.16.0
40
- cloudpickle==3.0.0
41
- cn2an==0.5.22
42
- colorama==0.4.6
43
- colorcet==3.1.0
44
- colorlog==6.8.2
45
- confection==0.1.4
46
- configparser==7.0.0
47
- contourpy==1.1.1
48
- cycler==0.12.1
49
- cymem==2.0.8
50
- Cython==3.0.10
51
- dataclasses==0.6
52
- datasets
53
- dctorch==0.1.2
54
- decorator==5.1.1
55
- decord==0.6.0
56
- deepspeed==0.14.0
57
- demucs==4.0.1
58
- descript-audio-codec==1.0.0
59
- descript-audiotools==0.7.2
60
- diffusers==0.27.2
61
- dill==0.3.8
62
- Distance==0.1.3
63
- docker-pycreds==0.4.0
64
- docopt==0.6.2
65
- docstring_parser==0.16
66
- dora_search==0.1.12
67
- einops==0.7.0
68
- einops-exts==0.0.4
69
- einx==0.3.0
70
- ema-pytorch==0.2.3
71
- encodec==0.1.1
72
- exceptiongroup==1.2.0
73
- executing==2.0.1
74
- expecttest==0.1.6
75
- fairseq==0.12.2
76
- fastapi==0.110.3
77
- fastcore==1.6.3
78
- ffmpy==0.3.2
79
- filelock==3.13.1
80
- fire==0.6.0
81
- flashy==0.0.2
82
- flatten-dict==0.4.2
83
- fonttools==4.49.0
84
- frozendict==2.4.4
85
- frozenlist==1.4.1
86
- fsspec==2024.6.1
87
- ftfy==6.1.3
88
- future==1.0.0
89
- g2p-en==2.1.0
90
- gin-config==0.5.0
91
- gitdb==4.0.11
92
- GitPython==3.1.43
93
- google-auth==2.23.4
94
- google-auth-oauthlib==1.0.0
95
- gradio==4.26.0
96
- gradio_client==0.15.1
97
- grpcio==1.59.3
98
- h11==0.14.0
99
- h5py==3.11.0
100
- hjson==3.1.0
101
- holoviews==1.17.1
102
- httpcore==1.0.5
103
- httpx==0.27.0
104
- huggingface-hub==0.23.5
105
- hydra-colorlog==1.2.0
106
- hydra-core==1.0.7
107
- hypothesis==6.90.0
108
- idna==3.4
109
- imageio==2.34.2
110
- importlib-metadata==6.8.0
111
- importlib-resources==5.12.0
112
- inflect==7.0.0
113
- ipython==8.12.3
114
- jedi==0.19.1
115
- jieba-fast==0.53
116
- Jinja2==3.1.2
117
- jmespath==1.0.1
118
- joblib==1.3.2
119
- json5==0.9.25
120
- jsonlines==4.0.0
121
- jsonmerge==1.9.2
122
- jsonschema==4.22.0
123
- jsonschema-specifications==2023.12.1
124
- julius==0.2.7
125
- k-diffusion==0.1.1
126
- kaldiio==2.18.0
127
- kiwisolver==1.4.5
128
- kornia==0.7.3
129
- kornia_rs==0.1.5
130
- laion-clap==1.1.4
131
- lameenc==1.7.0
132
- langcodes==3.4.0
133
- language_data==1.2.0
134
- lazy_loader==0.3
135
- librosa==0.9.2
136
- lightning==2.2.1
137
- lightning-utilities==0.10.1
138
- linkify-it-py==2.0.3
139
- lion-pytorch==0.2.2
140
- llvmlite==0.41.1
141
- local-attention==1.8.6
142
- loguru==0.7.2
143
- lxml==5.2.2
144
- marisa-trie==1.1.1
145
- Markdown==3.5.1
146
- markdown-it-py==3.0.0
147
- markdown2==2.5.0
148
- MarkupSafe==2.1.3
149
- matplotlib==3.7.5
150
- matplotlib-inline==0.1.7
151
- mdit-py-plugins==0.4.1
152
- mdurl==0.1.2
153
- mpmath==1.3.0
154
- msgpack==1.0.8
155
- multidict==6.0.5
156
- multiprocess==0.70.16
157
- murmurhash==1.0.10
158
- mypy-extensions==1.0.0
159
- networkx==3.1
160
- ninja==1.11.1.1
161
- nltk==3.8.1
162
- nnAudio==0.3.3
163
- num2words==0.5.13
164
- numba==0.58.1
165
- numpy==1.23.5
166
- nvidia-cublas-cu11==11.11.3.6
167
- nvidia-cuda-cupti-cu11==11.8.87
168
- nvidia-cuda-nvrtc-cu11==11.8.89
169
- nvidia-cuda-runtime-cu11==11.8.89
170
- nvidia-cudnn-cu11==8.7.0.84
171
- nvidia-cufft-cu11==10.9.0.58
172
- nvidia-curand-cu11==10.3.0.86
173
- nvidia-cusolver-cu11==11.4.1.48
174
- nvidia-cusparse-cu11==11.7.5.86
175
- nvidia-nccl-cu11==2.19.3
176
- nvidia-nvtx-cu11==11.8.86
177
- oauthlib==3.2.2
178
- omegaconf
179
- opencv-contrib-python==4.8.1.78
180
- opencv-python==4.8.1.78
181
- openunmix==1.2.1
182
- orjson==3.10.3
183
- packaging==23.2
184
- pandas==2.0.2
185
- panel==1.2.3
186
- param==2.1.1
187
- parso==0.8.4
188
- pathtools==0.1.2
189
- pedalboard==0.7.4
190
- peft==0.10.0
191
- pexpect==4.9.0
192
- pickleshare==0.7.5
193
- Pillow==10.1.0
194
- pkgutil_resolve_name==1.3.10
195
- platformdirs==4.2.0
196
- plotly==5.23.0
197
- pooch==1.8.1
198
- portalocker==2.10.1
199
- prefigure==0.0.9
200
- preshed==3.0.9
201
- proces==0.1.7
202
- prodict==0.8.18
203
- progressbar==2.5
204
- prompt_toolkit==3.0.47
205
- protobuf==3.19.6
206
- psutil==5.9.6
207
- ptyprocess==0.7.0
208
- pure_eval==0.2.3
209
- py-cpuinfo==9.0.0
210
- pyarrow==17.0.0
211
- pyarrow-hotfix==0.6
212
- pyasn1==0.5.1
213
- pyasn1-modules==0.3.0
214
- pybind11==2.11.1
215
- pycparser==2.21
216
- pydantic==2.6.3
217
- pydantic_core==2.16.3
218
- pydub==0.25.1
219
- Pygments==2.18.0
220
- pyloudnorm==0.1.1
221
- pynndescent==0.5.13
222
- pynvml==11.5.0
223
- pyparsing==3.1.2
224
- pypinyin==0.51.0
225
- pyre-extensions==0.0.29
226
- pyreaper==0.0.10
227
- pystoi==0.4.1
228
- python-dateutil==2.8.2
229
- python-multipart==0.0.9
230
- pytorch-lightning==2.1.0
231
- pytz==2023.3.post1
232
- pyviz_comms==3.0.3
233
- PyWavelets==1.4.1
234
- PyYAML==6.0.1
235
- randomname==0.2.1
236
- referencing==0.35.1
237
- regex==2023.10.3
238
- requests==2.32.3
239
- requests-oauthlib==1.3.1
240
- resampy==0.4.3
241
- retrying==1.3.4
242
- rich==13.7.1
243
- rpds-py==0.18.1
244
- rsa==4.9
245
- ruamel.yaml==0.18.5
246
- ruamel.yaml.clib==0.2.8
247
- ruff==0.4.4
248
- s3fs==2024.6.1
249
- s3transfer==0.7.0
250
- sacrebleu==2.4.2
251
- safetensors==0.4.3
252
- scikit-image==0.21.0
253
- scikit-learn==1.3.2
254
- scipy==1.10.1
255
- semantic-version==2.10.0
256
- sentencepiece==0.1.99
257
- sentry-sdk==2.10.0
258
- setproctitle==1.3.3
259
- shellingham==1.5.4
260
- six==1.16.0
261
- smart-open==6.4.0
262
- smmap==5.0.1
263
- sniffio==1.3.1
264
- sortedcontainers==2.4.0
265
- SoundFile==0.10.2
266
- sox==1.4.1
267
- soxr==0.3.7
268
- spacy==3.7.4
269
- spacy-legacy==3.0.12
270
- spacy-loggers==1.0.5
271
- srsly==2.4.8
272
- stack-data==0.6.3
273
- starlette==0.37.2
274
- submitit==1.5.1
275
- sympy==1.12
276
- tabulate==0.9.0
277
- tenacity==9.0.0
278
- tensorboard==2.14.0
279
- tensorboard-data-server==0.7.2
280
- termcolor==2.3.0
281
- thinc==8.2.3
282
- threadpoolctl==3.3.0
283
- tifffile==2023.7.10
284
- timm==0.9.11
285
- tokenizers==0.19.1
286
- tomlkit==0.12.0
287
- toolz==0.12.1
288
- torch==2.2.0+cu118
289
- torch-stoi==0.2.1
290
- torchaudio==2.2.0+cu118
291
- torchdata==0.7.1
292
- torchdiffeq==0.2.4
293
- torchlibrosa==0.1.0
294
- torchmetrics==0.11.4
295
- torchsde==0.2.6
296
- torchtext==0.17.0
297
- torchvision==0.17.0+cu118
298
- tornado==6.4.1
299
- tqdm==4.66.4
300
- traitlets==5.14.3
301
- trampoline==0.1.2
302
- transformers==4.42.4
303
- treetable==0.2.5
304
- triton==2.2.0
305
- typeguard==2.13.0
306
- typer==0.9.4
307
- types-dataclasses==0.6.6
308
- typing-inspect==0.9.0
309
- typing_extensions==4.8.0
310
- tzdata==2023.3
311
- uc-micro-py==1.0.3
312
- umap-learn==0.5.6
313
- Unidecode==1.3.8
314
- urllib3==1.26.18
315
- uvicorn==0.29.0
316
- v-diffusion-pytorch==0.0.2
317
- vector-quantize-pytorch==1.9.14
318
- wandb==0.15.4
319
- wasabi==1.1.2
320
- wcwidth==0.2.12
321
- weasel==0.3.4
322
- webdataset==0.2.48
323
- webencodings==0.5.1
324
- websockets==11.0.3
325
- Werkzeug==3.0.1
326
- wget==3.2
327
- wordsegment==1.3.1
328
- wrapt==1.16.0
329
- x-clip==0.14.4
330
- x-transformers==1.26.6
331
- xformers==0.0.24+cu118
332
- xxhash==3.4.1
333
- xyzservices==2024.6.0
334
- yarl==1.9.4
335
- zipp==3.17.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MuCodec/test_wav/test.wav DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:8cd28fa4fc1e8695be47602407088fcc9c486ac27b0ac6712ad30b7c7bcef4f8
3
- size 22823468
 
 
 
 
MuCodec/tools/get_melvaehifigan48k.py DELETED
@@ -1,1551 +0,0 @@
1
-
2
- import soundfile as sf
3
- import os
4
- from librosa.filters import mel as librosa_mel_fn
5
- import sys
6
- sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
7
- import tools.torch_tools as torch_tools
8
- import torch.nn as nn
9
- import torch
10
- import numpy as np
11
- from einops import rearrange
12
- from scipy.signal import get_window
13
- from librosa.util import pad_center, tiny
14
- import librosa.util as librosa_util
15
-
16
- class AttrDict(dict):
17
- def __init__(self, *args, **kwargs):
18
- super(AttrDict, self).__init__(*args, **kwargs)
19
- self.__dict__ = self
20
-
21
- def init_weights(m, mean=0.0, std=0.01):
22
- classname = m.__class__.__name__
23
- if classname.find("Conv") != -1:
24
- m.weight.data.normal_(mean, std)
25
-
26
-
27
- def get_padding(kernel_size, dilation=1):
28
- return int((kernel_size * dilation - dilation) / 2)
29
-
30
- LRELU_SLOPE = 0.1
31
-
32
- class ResBlock(torch.nn.Module):
33
- def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
34
- super(ResBlock, self).__init__()
35
- self.h = h
36
- self.convs1 = nn.ModuleList(
37
- [
38
- torch.nn.utils.weight_norm(
39
- nn.Conv1d(
40
- channels,
41
- channels,
42
- kernel_size,
43
- 1,
44
- dilation=dilation[0],
45
- padding=get_padding(kernel_size, dilation[0]),
46
- )
47
- ),
48
- torch.nn.utils.weight_norm(
49
- nn.Conv1d(
50
- channels,
51
- channels,
52
- kernel_size,
53
- 1,
54
- dilation=dilation[1],
55
- padding=get_padding(kernel_size, dilation[1]),
56
- )
57
- ),
58
- torch.nn.utils.weight_norm(
59
- nn.Conv1d(
60
- channels,
61
- channels,
62
- kernel_size,
63
- 1,
64
- dilation=dilation[2],
65
- padding=get_padding(kernel_size, dilation[2]),
66
- )
67
- ),
68
- ]
69
- )
70
- self.convs1.apply(init_weights)
71
-
72
- self.convs2 = nn.ModuleList(
73
- [
74
- torch.nn.utils.weight_norm(
75
- nn.Conv1d(
76
- channels,
77
- channels,
78
- kernel_size,
79
- 1,
80
- dilation=1,
81
- padding=get_padding(kernel_size, 1),
82
- )
83
- ),
84
- torch.nn.utils.weight_norm(
85
- nn.Conv1d(
86
- channels,
87
- channels,
88
- kernel_size,
89
- 1,
90
- dilation=1,
91
- padding=get_padding(kernel_size, 1),
92
- )
93
- ),
94
- torch.nn.utils.weight_norm(
95
- nn.Conv1d(
96
- channels,
97
- channels,
98
- kernel_size,
99
- 1,
100
- dilation=1,
101
- padding=get_padding(kernel_size, 1),
102
- )
103
- ),
104
- ]
105
- )
106
- self.convs2.apply(init_weights)
107
-
108
- def forward(self, x):
109
- for c1, c2 in zip(self.convs1, self.convs2):
110
- xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
111
- xt = c1(xt)
112
- xt = torch.nn.functional.leaky_relu(xt, LRELU_SLOPE)
113
- xt = c2(xt)
114
- x = xt + x
115
- return x
116
-
117
- def remove_weight_norm(self):
118
- for l in self.convs1:
119
- torch.nn.utils.remove_weight_norm(l)
120
- for l in self.convs2:
121
- torch.nn.utils.remove_weight_norm(l)
122
-
123
-
124
- class Generator_old(torch.nn.Module):
125
- def __init__(self, h):
126
- super(Generator_old, self).__init__()
127
- self.h = h
128
- self.num_kernels = len(h.resblock_kernel_sizes)
129
- self.num_upsamples = len(h.upsample_rates)
130
- self.conv_pre = torch.nn.utils.weight_norm(
131
- nn.Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
132
- )
133
- resblock = ResBlock
134
-
135
- self.ups = nn.ModuleList()
136
- for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
137
- self.ups.append(
138
- torch.nn.utils.weight_norm(
139
- nn.ConvTranspose1d(
140
- h.upsample_initial_channel // (2**i),
141
- h.upsample_initial_channel // (2 ** (i + 1)),
142
- k,
143
- u,
144
- padding=(k - u) // 2,
145
- )
146
- )
147
- )
148
-
149
- self.resblocks = nn.ModuleList()
150
- for i in range(len(self.ups)):
151
- ch = h.upsample_initial_channel // (2 ** (i + 1))
152
- for j, (k, d) in enumerate(
153
- zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
154
- ):
155
- self.resblocks.append(resblock(h, ch, k, d))
156
-
157
- self.conv_post = torch.nn.utils.weight_norm(nn.Conv1d(ch, 1, 7, 1, padding=3))
158
- self.ups.apply(init_weights)
159
- self.conv_post.apply(init_weights)
160
-
161
- def forward(self, x):
162
- x = self.conv_pre(x)
163
- for i in range(self.num_upsamples):
164
- x = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
165
- x = self.ups[i](x)
166
- xs = None
167
- for j in range(self.num_kernels):
168
- if xs is None:
169
- xs = self.resblocks[i * self.num_kernels + j](x)
170
- else:
171
- xs += self.resblocks[i * self.num_kernels + j](x)
172
- x = xs / self.num_kernels
173
- x = torch.nn.functional.leaky_relu(x)
174
- x = self.conv_post(x)
175
- x = torch.tanh(x)
176
-
177
- return x
178
-
179
- def remove_weight_norm(self):
180
- # print("Removing weight norm...")
181
- for l in self.ups:
182
- torch.nn.utils.remove_weight_norm(l)
183
- for l in self.resblocks:
184
- l.remove_weight_norm()
185
- torch.nn.utils.remove_weight_norm(self.conv_pre)
186
- torch.nn.utils.remove_weight_norm(self.conv_post)
187
-
188
-
189
-
190
- def nonlinearity(x):
191
- # swish
192
- return x * torch.sigmoid(x)
193
-
194
-
195
- def Normalize(in_channels, num_groups=32):
196
- return torch.nn.GroupNorm(
197
- num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
198
- )
199
-
200
- class Downsample(nn.Module):
201
- def __init__(self, in_channels, with_conv):
202
- super().__init__()
203
- self.with_conv = with_conv
204
- if self.with_conv:
205
- # Do time downsampling here
206
- # no asymmetric padding in torch conv, must do it ourselves
207
- self.conv = torch.nn.Conv2d(
208
- in_channels, in_channels, kernel_size=3, stride=2, padding=0
209
- )
210
-
211
- def forward(self, x):
212
- if self.with_conv:
213
- pad = (0, 1, 0, 1)
214
- x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
215
- x = self.conv(x)
216
- else:
217
- x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
218
- return x
219
-
220
-
221
- class DownsampleTimeStride4(nn.Module):
222
- def __init__(self, in_channels, with_conv):
223
- super().__init__()
224
- self.with_conv = with_conv
225
- if self.with_conv:
226
- # Do time downsampling here
227
- # no asymmetric padding in torch conv, must do it ourselves
228
- self.conv = torch.nn.Conv2d(
229
- in_channels, in_channels, kernel_size=5, stride=(4, 2), padding=1
230
- )
231
-
232
- def forward(self, x):
233
- if self.with_conv:
234
- pad = (0, 1, 0, 1)
235
- x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
236
- x = self.conv(x)
237
- else:
238
- x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2))
239
- return x
240
-
241
- class Upsample(nn.Module):
242
- def __init__(self, in_channels, with_conv):
243
- super().__init__()
244
- self.with_conv = with_conv
245
- if self.with_conv:
246
- self.conv = torch.nn.Conv2d(
247
- in_channels, in_channels, kernel_size=3, stride=1, padding=1
248
- )
249
-
250
- def forward(self, x):
251
- x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
252
- if self.with_conv:
253
- x = self.conv(x)
254
- return x
255
-
256
-
257
- class UpsampleTimeStride4(nn.Module):
258
- def __init__(self, in_channels, with_conv):
259
- super().__init__()
260
- self.with_conv = with_conv
261
- if self.with_conv:
262
- self.conv = torch.nn.Conv2d(
263
- in_channels, in_channels, kernel_size=5, stride=1, padding=2
264
- )
265
-
266
- def forward(self, x):
267
- x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode="nearest")
268
- if self.with_conv:
269
- x = self.conv(x)
270
- return x
271
-
272
- class AttnBlock(nn.Module):
273
- def __init__(self, in_channels):
274
- super().__init__()
275
- self.in_channels = in_channels
276
-
277
- self.norm = Normalize(in_channels)
278
- self.q = torch.nn.Conv2d(
279
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
280
- )
281
- self.k = torch.nn.Conv2d(
282
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
283
- )
284
- self.v = torch.nn.Conv2d(
285
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
286
- )
287
- self.proj_out = torch.nn.Conv2d(
288
- in_channels, in_channels, kernel_size=1, stride=1, padding=0
289
- )
290
-
291
- def forward(self, x):
292
- h_ = x
293
- h_ = self.norm(h_)
294
- q = self.q(h_)
295
- k = self.k(h_)
296
- v = self.v(h_)
297
-
298
- # compute attention
299
- b, c, h, w = q.shape
300
- q = q.reshape(b, c, h * w).contiguous()
301
- q = q.permute(0, 2, 1).contiguous() # b,hw,c
302
- k = k.reshape(b, c, h * w).contiguous() # b,c,hw
303
- w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
304
- w_ = w_ * (int(c) ** (-0.5))
305
- w_ = torch.nn.functional.softmax(w_, dim=2)
306
-
307
- # attend to values
308
- v = v.reshape(b, c, h * w).contiguous()
309
- w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)
310
- h_ = torch.bmm(
311
- v, w_
312
- ).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
313
- h_ = h_.reshape(b, c, h, w).contiguous()
314
-
315
- h_ = self.proj_out(h_)
316
-
317
- return x + h_
318
-
319
-
320
- def make_attn(in_channels, attn_type="vanilla"):
321
- assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown"
322
- # print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
323
- if attn_type == "vanilla":
324
- return AttnBlock(in_channels)
325
- elif attn_type == "none":
326
- return nn.Identity(in_channels)
327
- else:
328
- raise ValueError(attn_type)
329
-
330
-
331
- class ResnetBlock(nn.Module):
332
- def __init__(
333
- self,
334
- *,
335
- in_channels,
336
- out_channels=None,
337
- conv_shortcut=False,
338
- dropout,
339
- temb_channels=512,
340
- ):
341
- super().__init__()
342
- self.in_channels = in_channels
343
- out_channels = in_channels if out_channels is None else out_channels
344
- self.out_channels = out_channels
345
- self.use_conv_shortcut = conv_shortcut
346
-
347
- self.norm1 = Normalize(in_channels)
348
- self.conv1 = torch.nn.Conv2d(
349
- in_channels, out_channels, kernel_size=3, stride=1, padding=1
350
- )
351
- if temb_channels > 0:
352
- self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
353
- self.norm2 = Normalize(out_channels)
354
- self.dropout = torch.nn.Dropout(dropout)
355
- self.conv2 = torch.nn.Conv2d(
356
- out_channels, out_channels, kernel_size=3, stride=1, padding=1
357
- )
358
- if self.in_channels != self.out_channels:
359
- if self.use_conv_shortcut:
360
- self.conv_shortcut = torch.nn.Conv2d(
361
- in_channels, out_channels, kernel_size=3, stride=1, padding=1
362
- )
363
- else:
364
- self.nin_shortcut = torch.nn.Conv2d(
365
- in_channels, out_channels, kernel_size=1, stride=1, padding=0
366
- )
367
-
368
- def forward(self, x, temb):
369
- h = x
370
- h = self.norm1(h)
371
- h = nonlinearity(h)
372
- h = self.conv1(h)
373
-
374
- if temb is not None:
375
- h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
376
-
377
- h = self.norm2(h)
378
- h = nonlinearity(h)
379
- h = self.dropout(h)
380
- h = self.conv2(h)
381
-
382
- if self.in_channels != self.out_channels:
383
- if self.use_conv_shortcut:
384
- x = self.conv_shortcut(x)
385
- else:
386
- x = self.nin_shortcut(x)
387
-
388
- return x + h
389
-
390
-
391
- class Encoder(nn.Module):
392
- def __init__(
393
- self,
394
- *,
395
- ch,
396
- out_ch,
397
- ch_mult=(1, 2, 4, 8),
398
- num_res_blocks,
399
- attn_resolutions,
400
- dropout=0.0,
401
- resamp_with_conv=True,
402
- in_channels,
403
- resolution,
404
- z_channels,
405
- double_z=True,
406
- use_linear_attn=False,
407
- attn_type="vanilla",
408
- downsample_time_stride4_levels=[],
409
- **ignore_kwargs,
410
- ):
411
- super().__init__()
412
- if use_linear_attn:
413
- attn_type = "linear"
414
- self.ch = ch
415
- self.temb_ch = 0
416
- self.num_resolutions = len(ch_mult)
417
- self.num_res_blocks = num_res_blocks
418
- self.resolution = resolution
419
- self.in_channels = in_channels
420
- self.downsample_time_stride4_levels = downsample_time_stride4_levels
421
-
422
- if len(self.downsample_time_stride4_levels) > 0:
423
- assert max(self.downsample_time_stride4_levels) < self.num_resolutions, (
424
- "The level to perform downsample 4 operation need to be smaller than the total resolution number %s"
425
- % str(self.num_resolutions)
426
- )
427
-
428
- # downsampling
429
- self.conv_in = torch.nn.Conv2d(
430
- in_channels, self.ch, kernel_size=3, stride=1, padding=1
431
- )
432
-
433
- curr_res = resolution
434
- in_ch_mult = (1,) + tuple(ch_mult)
435
- self.in_ch_mult = in_ch_mult
436
- self.down = nn.ModuleList()
437
- for i_level in range(self.num_resolutions):
438
- block = nn.ModuleList()
439
- attn = nn.ModuleList()
440
- block_in = ch * in_ch_mult[i_level]
441
- block_out = ch * ch_mult[i_level]
442
- for i_block in range(self.num_res_blocks):
443
- block.append(
444
- ResnetBlock(
445
- in_channels=block_in,
446
- out_channels=block_out,
447
- temb_channels=self.temb_ch,
448
- dropout=dropout,
449
- )
450
- )
451
- block_in = block_out
452
- if curr_res in attn_resolutions:
453
- attn.append(make_attn(block_in, attn_type=attn_type))
454
- down = nn.Module()
455
- down.block = block
456
- down.attn = attn
457
- if i_level != self.num_resolutions - 1:
458
- if i_level in self.downsample_time_stride4_levels:
459
- down.downsample = DownsampleTimeStride4(block_in, resamp_with_conv)
460
- else:
461
- down.downsample = Downsample(block_in, resamp_with_conv)
462
- curr_res = curr_res // 2
463
- self.down.append(down)
464
-
465
- # middle
466
- self.mid = nn.Module()
467
- self.mid.block_1 = ResnetBlock(
468
- in_channels=block_in,
469
- out_channels=block_in,
470
- temb_channels=self.temb_ch,
471
- dropout=dropout,
472
- )
473
- self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
474
- self.mid.block_2 = ResnetBlock(
475
- in_channels=block_in,
476
- out_channels=block_in,
477
- temb_channels=self.temb_ch,
478
- dropout=dropout,
479
- )
480
-
481
- # end
482
- self.norm_out = Normalize(block_in)
483
- self.conv_out = torch.nn.Conv2d(
484
- block_in,
485
- 2 * z_channels if double_z else z_channels,
486
- kernel_size=3,
487
- stride=1,
488
- padding=1,
489
- )
490
-
491
- def forward(self, x):
492
- # timestep embedding
493
- temb = None
494
- # downsampling
495
- hs = [self.conv_in(x)]
496
- for i_level in range(self.num_resolutions):
497
- for i_block in range(self.num_res_blocks):
498
- h = self.down[i_level].block[i_block](hs[-1], temb)
499
- if len(self.down[i_level].attn) > 0:
500
- h = self.down[i_level].attn[i_block](h)
501
- hs.append(h)
502
- if i_level != self.num_resolutions - 1:
503
- hs.append(self.down[i_level].downsample(hs[-1]))
504
-
505
- # middle
506
- h = hs[-1]
507
- h = self.mid.block_1(h, temb)
508
- h = self.mid.attn_1(h)
509
- h = self.mid.block_2(h, temb)
510
-
511
- # end
512
- h = self.norm_out(h)
513
- h = nonlinearity(h)
514
- h = self.conv_out(h)
515
- return h
516
-
517
-
518
- class Decoder(nn.Module):
519
- def __init__(
520
- self,
521
- *,
522
- ch,
523
- out_ch,
524
- ch_mult=(1, 2, 4, 8),
525
- num_res_blocks,
526
- attn_resolutions,
527
- dropout=0.0,
528
- resamp_with_conv=True,
529
- in_channels,
530
- resolution,
531
- z_channels,
532
- give_pre_end=False,
533
- tanh_out=False,
534
- use_linear_attn=False,
535
- downsample_time_stride4_levels=[],
536
- attn_type="vanilla",
537
- **ignorekwargs,
538
- ):
539
- super().__init__()
540
- if use_linear_attn:
541
- attn_type = "linear"
542
- self.ch = ch
543
- self.temb_ch = 0
544
- self.num_resolutions = len(ch_mult)
545
- self.num_res_blocks = num_res_blocks
546
- self.resolution = resolution
547
- self.in_channels = in_channels
548
- self.give_pre_end = give_pre_end
549
- self.tanh_out = tanh_out
550
- self.downsample_time_stride4_levels = downsample_time_stride4_levels
551
-
552
- if len(self.downsample_time_stride4_levels) > 0:
553
- assert max(self.downsample_time_stride4_levels) < self.num_resolutions, (
554
- "The level to perform downsample 4 operation need to be smaller than the total resolution number %s"
555
- % str(self.num_resolutions)
556
- )
557
-
558
- # compute in_ch_mult, block_in and curr_res at lowest res
559
- (1,) + tuple(ch_mult)
560
- block_in = ch * ch_mult[self.num_resolutions - 1]
561
- curr_res = resolution // 2 ** (self.num_resolutions - 1)
562
- self.z_shape = (1, z_channels, curr_res, curr_res)
563
- # print(
564
- # "Working with z of shape {} = {} dimensions.".format(
565
- # self.z_shape, np.prod(self.z_shape)
566
- # )
567
- # )
568
-
569
- # z to block_in
570
- self.conv_in = torch.nn.Conv2d(
571
- z_channels, block_in, kernel_size=3, stride=1, padding=1
572
- )
573
-
574
- # middle
575
- self.mid = nn.Module()
576
- self.mid.block_1 = ResnetBlock(
577
- in_channels=block_in,
578
- out_channels=block_in,
579
- temb_channels=self.temb_ch,
580
- dropout=dropout,
581
- )
582
- self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
583
- self.mid.block_2 = ResnetBlock(
584
- in_channels=block_in,
585
- out_channels=block_in,
586
- temb_channels=self.temb_ch,
587
- dropout=dropout,
588
- )
589
-
590
- # upsampling
591
- self.up = nn.ModuleList()
592
- for i_level in reversed(range(self.num_resolutions)):
593
- block = nn.ModuleList()
594
- attn = nn.ModuleList()
595
- block_out = ch * ch_mult[i_level]
596
- for i_block in range(self.num_res_blocks + 1):
597
- block.append(
598
- ResnetBlock(
599
- in_channels=block_in,
600
- out_channels=block_out,
601
- temb_channels=self.temb_ch,
602
- dropout=dropout,
603
- )
604
- )
605
- block_in = block_out
606
- if curr_res in attn_resolutions:
607
- attn.append(make_attn(block_in, attn_type=attn_type))
608
- up = nn.Module()
609
- up.block = block
610
- up.attn = attn
611
- if i_level != 0:
612
- if i_level - 1 in self.downsample_time_stride4_levels:
613
- up.upsample = UpsampleTimeStride4(block_in, resamp_with_conv)
614
- else:
615
- up.upsample = Upsample(block_in, resamp_with_conv)
616
- curr_res = curr_res * 2
617
- self.up.insert(0, up) # prepend to get consistent order
618
-
619
- # end
620
- self.norm_out = Normalize(block_in)
621
- self.conv_out = torch.nn.Conv2d(
622
- block_in, out_ch, kernel_size=3, stride=1, padding=1
623
- )
624
-
625
- def forward(self, z):
626
- # assert z.shape[1:] == self.z_shape[1:]
627
- self.last_z_shape = z.shape
628
-
629
- # timestep embedding
630
- temb = None
631
-
632
- # z to block_in
633
- h = self.conv_in(z)
634
-
635
- # middle
636
- h = self.mid.block_1(h, temb)
637
- h = self.mid.attn_1(h)
638
- h = self.mid.block_2(h, temb)
639
-
640
- # upsampling
641
- for i_level in reversed(range(self.num_resolutions)):
642
- for i_block in range(self.num_res_blocks + 1):
643
- h = self.up[i_level].block[i_block](h, temb)
644
- if len(self.up[i_level].attn) > 0:
645
- h = self.up[i_level].attn[i_block](h)
646
- if i_level != 0:
647
- h = self.up[i_level].upsample(h)
648
-
649
- # end
650
- if self.give_pre_end:
651
- return h
652
-
653
- h = self.norm_out(h)
654
- h = nonlinearity(h)
655
- h = self.conv_out(h)
656
- if self.tanh_out:
657
- h = torch.tanh(h)
658
- return h
659
-
660
-
661
- class DiagonalGaussianDistribution(object):
662
- def __init__(self, parameters, deterministic=False):
663
- self.parameters = parameters
664
- self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
665
- self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
666
- self.deterministic = deterministic
667
- self.std = torch.exp(0.5 * self.logvar)
668
- self.var = torch.exp(self.logvar)
669
- if self.deterministic:
670
- self.var = self.std = torch.zeros_like(self.mean).to(
671
- device=self.parameters.device
672
- )
673
-
674
- def sample(self):
675
- x = self.mean + self.std * torch.randn(self.mean.shape).to(
676
- device=self.parameters.device
677
- )
678
- return x
679
-
680
- def kl(self, other=None):
681
- if self.deterministic:
682
- return torch.Tensor([0.0])
683
- else:
684
- if other is None:
685
- return 0.5 * torch.mean(
686
- torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
687
- dim=[1, 2, 3],
688
- )
689
- else:
690
- return 0.5 * torch.mean(
691
- torch.pow(self.mean - other.mean, 2) / other.var
692
- + self.var / other.var
693
- - 1.0
694
- - self.logvar
695
- + other.logvar,
696
- dim=[1, 2, 3],
697
- )
698
-
699
- def nll(self, sample, dims=[1, 2, 3]):
700
- if self.deterministic:
701
- return torch.Tensor([0.0])
702
- logtwopi = np.log(2.0 * np.pi)
703
- return 0.5 * torch.sum(
704
- logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
705
- dim=dims,
706
- )
707
-
708
- def mode(self):
709
- return self.mean
710
-
711
- def get_vocoder_config_48k():
712
- return {
713
- "resblock": "1",
714
- "num_gpus": 8,
715
- "batch_size": 128,
716
- "learning_rate": 0.0001,
717
- "adam_b1": 0.8,
718
- "adam_b2": 0.99,
719
- "lr_decay": 0.999,
720
- "seed": 1234,
721
-
722
- "upsample_rates": [6,5,4,2,2],
723
- "upsample_kernel_sizes": [12,10,8,4,4],
724
- "upsample_initial_channel": 1536,
725
- "resblock_kernel_sizes": [3,7,11,15],
726
- "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5], [1,3,5]],
727
-
728
- "segment_size": 15360,
729
- "num_mels": 256,
730
- "n_fft": 2048,
731
- "hop_size": 480,
732
- "win_size": 2048,
733
-
734
- "sampling_rate": 48000,
735
-
736
- "fmin": 20,
737
- "fmax": 24000,
738
- "fmax_for_loss": None,
739
-
740
- "num_workers": 8,
741
-
742
- "dist_config": {
743
- "dist_backend": "nccl",
744
- "dist_url": "tcp://localhost:18273",
745
- "world_size": 1
746
- }
747
- }
748
-
749
- def get_vocoder(config, device, mel_bins):
750
- name = "HiFi-GAN"
751
- speaker = ""
752
- if name == "MelGAN":
753
- if speaker == "LJSpeech":
754
- vocoder = torch.hub.load(
755
- "descriptinc/melgan-neurips", "load_melgan", "linda_johnson"
756
- )
757
- elif speaker == "universal":
758
- vocoder = torch.hub.load(
759
- "descriptinc/melgan-neurips", "load_melgan", "multi_speaker"
760
- )
761
- vocoder.mel2wav.eval()
762
- vocoder.mel2wav.to(device)
763
- elif name == "HiFi-GAN":
764
- if(mel_bins == 256):
765
- config = get_vocoder_config_48k()
766
- config = AttrDict(config)
767
- vocoder = Generator_old(config)
768
- # print("Load hifigan/g_01080000")
769
- # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000"))
770
- # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000"))
771
- # ckpt = torch_version_orig_mod_remove(ckpt)
772
- # vocoder.load_state_dict(ckpt["generator"])
773
- vocoder.eval()
774
- vocoder.remove_weight_norm()
775
- vocoder.to(device)
776
- else:
777
- raise ValueError(mel_bins)
778
- return vocoder
779
-
780
- def vocoder_infer(mels, vocoder, lengths=None):
781
- with torch.no_grad():
782
- wavs = vocoder(mels).squeeze(1)
783
-
784
- #wavs = (wavs.cpu().numpy() * 32768).astype("int16")
785
- wavs = (wavs.cpu().numpy())
786
-
787
- if lengths is not None:
788
- wavs = wavs[:, :lengths]
789
-
790
- # wavs = [wav for wav in wavs]
791
-
792
- # for i in range(len(mels)):
793
- # if lengths is not None:
794
- # wavs[i] = wavs[i][: lengths[i]]
795
-
796
- return wavs
797
-
798
- @torch.no_grad()
799
- def vocoder_chunk_infer(mels, vocoder, lengths=None):
800
- chunk_size = 256*4
801
- shift_size = 256*1
802
- ov_size = chunk_size-shift_size
803
- # import pdb;pdb.set_trace()
804
-
805
- for cinx in range(0, mels.shape[2], shift_size):
806
- if(cinx==0):
807
- wavs = vocoder(mels[:,:,cinx:cinx+chunk_size]).squeeze(1).cpu()
808
- num_samples = int(wavs.shape[-1]/chunk_size)*chunk_size
809
- wavs = wavs[:,0:num_samples]
810
- ov_sample = int(float(wavs.shape[-1]) * ov_size / chunk_size)
811
- ov_win = torch.from_numpy(np.linspace(0,1,ov_sample)[None,:])
812
- ov_win = torch.cat([ov_win,1-ov_win],-1)
813
- if(cinx+chunk_size>=mels.shape[2]):
814
- break
815
- else:
816
- cur_wav = vocoder(mels[:,:,cinx:cinx+chunk_size]).squeeze(1).cpu()[:,0:num_samples]
817
- wavs[:,-ov_sample:] = wavs[:,-ov_sample:] * ov_win[:,-ov_sample:] + cur_wav[:,0:ov_sample] * ov_win[:,0:ov_sample]
818
- # wavs[:,-ov_sample:] = wavs[:,-ov_sample:] * 1.0 + cur_wav[:,0:ov_sample] * 0.0
819
- wavs = torch.cat([wavs, cur_wav[:,ov_sample:]],-1)
820
- if(cinx+chunk_size>=mels.shape[2]):
821
- break
822
- # print(wavs.shape)
823
-
824
- wavs = (wavs.cpu().numpy())
825
-
826
- if lengths is not None:
827
- wavs = wavs[:, :lengths]
828
- # print(wavs.shape)
829
- return wavs
830
-
831
- def synth_one_sample(mel_input, mel_prediction, labels, vocoder):
832
- if vocoder is not None:
833
-
834
- wav_reconstruction = vocoder_infer(
835
- mel_input.permute(0, 2, 1),
836
- vocoder,
837
- )
838
- wav_prediction = vocoder_infer(
839
- mel_prediction.permute(0, 2, 1),
840
- vocoder,
841
- )
842
- else:
843
- wav_reconstruction = wav_prediction = None
844
-
845
- return wav_reconstruction, wav_prediction
846
-
847
-
848
- class AutoencoderKL(nn.Module):
849
- def __init__(
850
- self,
851
- ddconfig=None,
852
- lossconfig=None,
853
- batchsize=None,
854
- embed_dim=None,
855
- time_shuffle=1,
856
- subband=1,
857
- sampling_rate=16000,
858
- ckpt_path=None,
859
- reload_from_ckpt=None,
860
- ignore_keys=[],
861
- image_key="fbank",
862
- colorize_nlabels=None,
863
- monitor=None,
864
- base_learning_rate=1e-5,
865
- scale_factor=1
866
- ):
867
- super().__init__()
868
- self.automatic_optimization = False
869
- assert (
870
- "mel_bins" in ddconfig.keys()
871
- ), "mel_bins is not specified in the Autoencoder config"
872
- num_mel = ddconfig["mel_bins"]
873
- self.image_key = image_key
874
- self.sampling_rate = sampling_rate
875
- self.encoder = Encoder(**ddconfig)
876
- self.decoder = Decoder(**ddconfig)
877
-
878
- self.loss = None
879
- self.subband = int(subband)
880
-
881
- if self.subband > 1:
882
- print("Use subband decomposition %s" % self.subband)
883
-
884
- assert ddconfig["double_z"]
885
- self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
886
- self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
887
-
888
- if self.image_key == "fbank":
889
- self.vocoder = get_vocoder(None, "cpu", num_mel)
890
- self.embed_dim = embed_dim
891
- if colorize_nlabels is not None:
892
- assert type(colorize_nlabels) == int
893
- self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
894
- if monitor is not None:
895
- self.monitor = monitor
896
- if ckpt_path is not None:
897
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
898
- self.learning_rate = float(base_learning_rate)
899
- # print("Initial learning rate %s" % self.learning_rate)
900
-
901
- self.time_shuffle = time_shuffle
902
- self.reload_from_ckpt = reload_from_ckpt
903
- self.reloaded = False
904
- self.mean, self.std = None, None
905
-
906
- self.feature_cache = None
907
- self.flag_first_run = True
908
- self.train_step = 0
909
-
910
- self.logger_save_dir = None
911
- self.logger_exp_name = None
912
- self.scale_factor = scale_factor
913
-
914
- print("Num parameters:")
915
- print("Encoder : ", sum(p.numel() for p in self.encoder.parameters()))
916
- print("Decoder : ", sum(p.numel() for p in self.decoder.parameters()))
917
- print("Vocoder : ", sum(p.numel() for p in self.vocoder.parameters()))
918
-
919
- def get_log_dir(self):
920
- if self.logger_save_dir is None and self.logger_exp_name is None:
921
- return os.path.join(self.logger.save_dir, self.logger._project)
922
- else:
923
- return os.path.join(self.logger_save_dir, self.logger_exp_name)
924
-
925
- def set_log_dir(self, save_dir, exp_name):
926
- self.logger_save_dir = save_dir
927
- self.logger_exp_name = exp_name
928
-
929
- def init_from_ckpt(self, path, ignore_keys=list()):
930
- sd = torch.load(path, map_location="cpu")["state_dict"]
931
- keys = list(sd.keys())
932
- for k in keys:
933
- for ik in ignore_keys:
934
- if k.startswith(ik):
935
- print("Deleting key {} from state_dict.".format(k))
936
- del sd[k]
937
- self.load_state_dict(sd, strict=False)
938
- print(f"Restored from {path}")
939
-
940
- def encode(self, x):
941
- # x = self.time_shuffle_operation(x)
942
- # x = self.freq_split_subband(x)
943
- h = self.encoder(x)
944
- moments = self.quant_conv(h)
945
- posterior = DiagonalGaussianDistribution(moments)
946
- return posterior
947
-
948
- def decode(self, z):
949
- z = self.post_quant_conv(z)
950
- dec = self.decoder(z)
951
- # bs, ch, shuffled_timesteps, fbins = dec.size()
952
- # dec = self.time_unshuffle_operation(dec, bs, int(ch*shuffled_timesteps), fbins)
953
- # dec = self.freq_merge_subband(dec)
954
- return dec
955
-
956
- def decode_to_waveform(self, dec):
957
-
958
- if self.image_key == "fbank":
959
- dec = dec.squeeze(1).permute(0, 2, 1)
960
- wav_reconstruction = vocoder_chunk_infer(dec, self.vocoder)
961
- elif self.image_key == "stft":
962
- dec = dec.squeeze(1).permute(0, 2, 1)
963
- wav_reconstruction = self.wave_decoder(dec)
964
- return wav_reconstruction
965
-
966
- def mel_spectrogram_to_waveform(
967
- self, mel, savepath=".", bs=None, name="outwav", save=True
968
- ):
969
- # Mel: [bs, 1, t-steps, fbins]
970
- if len(mel.size()) == 4:
971
- mel = mel.squeeze(1)
972
- mel = mel.permute(0, 2, 1)
973
- waveform = self.vocoder(mel)
974
- waveform = waveform.cpu().detach().numpy()
975
- #if save:
976
- # self.save_waveform(waveform, savepath, name)
977
- return waveform
978
-
979
- @torch.no_grad()
980
- def encode_first_stage(self, x):
981
- return self.encode(x)
982
-
983
- @torch.no_grad()
984
- def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
985
- if predict_cids:
986
- if z.dim() == 4:
987
- z = torch.argmax(z.exp(), dim=1).long()
988
- z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
989
- z = rearrange(z, "b h w c -> b c h w").contiguous()
990
-
991
- z = 1.0 / self.scale_factor * z
992
- return self.decode(z)
993
-
994
- def decode_first_stage_withgrad(self, z):
995
- z = 1.0 / self.scale_factor * z
996
- return self.decode(z)
997
-
998
- def get_first_stage_encoding(self, encoder_posterior, use_mode=False):
999
- if isinstance(encoder_posterior, DiagonalGaussianDistribution) and not use_mode:
1000
- z = encoder_posterior.sample()
1001
- elif isinstance(encoder_posterior, DiagonalGaussianDistribution) and use_mode:
1002
- z = encoder_posterior.mode()
1003
- elif isinstance(encoder_posterior, torch.Tensor):
1004
- z = encoder_posterior
1005
- else:
1006
- raise NotImplementedError(
1007
- f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
1008
- )
1009
- return self.scale_factor * z
1010
-
1011
- def visualize_latent(self, input):
1012
- import matplotlib.pyplot as plt
1013
-
1014
- # for i in range(10):
1015
- # zero_input = torch.zeros_like(input) - 11.59
1016
- # zero_input[:,:,i * 16: i * 16 + 16,:16] += 13.59
1017
-
1018
- # posterior = self.encode(zero_input)
1019
- # latent = posterior.sample()
1020
- # avg_latent = torch.mean(latent, dim=1)[0]
1021
- # plt.imshow(avg_latent.cpu().detach().numpy().T)
1022
- # plt.savefig("%s.png" % i)
1023
- # plt.close()
1024
-
1025
- np.save("input.npy", input.cpu().detach().numpy())
1026
- # zero_input = torch.zeros_like(input) - 11.59
1027
- time_input = input.clone()
1028
- time_input[:, :, :, :32] *= 0
1029
- time_input[:, :, :, :32] -= 11.59
1030
-
1031
- np.save("time_input.npy", time_input.cpu().detach().numpy())
1032
-
1033
- posterior = self.encode(time_input)
1034
- latent = posterior.sample()
1035
- np.save("time_latent.npy", latent.cpu().detach().numpy())
1036
- avg_latent = torch.mean(latent, dim=1)
1037
- for i in range(avg_latent.size(0)):
1038
- plt.imshow(avg_latent[i].cpu().detach().numpy().T)
1039
- plt.savefig("freq_%s.png" % i)
1040
- plt.close()
1041
-
1042
- freq_input = input.clone()
1043
- freq_input[:, :, :512, :] *= 0
1044
- freq_input[:, :, :512, :] -= 11.59
1045
-
1046
- np.save("freq_input.npy", freq_input.cpu().detach().numpy())
1047
-
1048
- posterior = self.encode(freq_input)
1049
- latent = posterior.sample()
1050
- np.save("freq_latent.npy", latent.cpu().detach().numpy())
1051
- avg_latent = torch.mean(latent, dim=1)
1052
- for i in range(avg_latent.size(0)):
1053
- plt.imshow(avg_latent[i].cpu().detach().numpy().T)
1054
- plt.savefig("time_%s.png" % i)
1055
- plt.close()
1056
-
1057
- def get_input(self, batch):
1058
- fname, text, label_indices, waveform, stft, fbank = (
1059
- batch["fname"],
1060
- batch["text"],
1061
- batch["label_vector"],
1062
- batch["waveform"],
1063
- batch["stft"],
1064
- batch["log_mel_spec"],
1065
- )
1066
- # if(self.time_shuffle != 1):
1067
- # if(fbank.size(1) % self.time_shuffle != 0):
1068
- # pad_len = self.time_shuffle - (fbank.size(1) % self.time_shuffle)
1069
- # fbank = torch.nn.functional.pad(fbank, (0,0,0,pad_len))
1070
-
1071
- ret = {}
1072
-
1073
- ret["fbank"], ret["stft"], ret["fname"], ret["waveform"] = (
1074
- fbank.unsqueeze(1),
1075
- stft.unsqueeze(1),
1076
- fname,
1077
- waveform.unsqueeze(1),
1078
- )
1079
-
1080
- return ret
1081
-
1082
- def save_wave(self, batch_wav, fname, save_dir):
1083
- os.makedirs(save_dir, exist_ok=True)
1084
-
1085
- for wav, name in zip(batch_wav, fname):
1086
- name = os.path.basename(name)
1087
-
1088
- sf.write(os.path.join(save_dir, name), wav, samplerate=self.sampling_rate)
1089
-
1090
- def get_last_layer(self):
1091
- return self.decoder.conv_out.weight
1092
-
1093
- @torch.no_grad()
1094
- def log_images(self, batch, train=True, only_inputs=False, waveform=None, **kwargs):
1095
- log = dict()
1096
- x = batch.to(self.device)
1097
- if not only_inputs:
1098
- xrec, posterior = self(x)
1099
- log["samples"] = self.decode(posterior.sample())
1100
- log["reconstructions"] = xrec
1101
-
1102
- log["inputs"] = x
1103
- wavs = self._log_img(log, train=train, index=0, waveform=waveform)
1104
- return wavs
1105
-
1106
- def _log_img(self, log, train=True, index=0, waveform=None):
1107
- images_input = self.tensor2numpy(log["inputs"][index, 0]).T
1108
- images_reconstruct = self.tensor2numpy(log["reconstructions"][index, 0]).T
1109
- images_samples = self.tensor2numpy(log["samples"][index, 0]).T
1110
-
1111
- if train:
1112
- name = "train"
1113
- else:
1114
- name = "val"
1115
-
1116
- if self.logger is not None:
1117
- self.logger.log_image(
1118
- "img_%s" % name,
1119
- [images_input, images_reconstruct, images_samples],
1120
- caption=["input", "reconstruct", "samples"],
1121
- )
1122
-
1123
- inputs, reconstructions, samples = (
1124
- log["inputs"],
1125
- log["reconstructions"],
1126
- log["samples"],
1127
- )
1128
-
1129
- if self.image_key == "fbank":
1130
- wav_original, wav_prediction = synth_one_sample(
1131
- inputs[index],
1132
- reconstructions[index],
1133
- labels="validation",
1134
- vocoder=self.vocoder,
1135
- )
1136
- wav_original, wav_samples = synth_one_sample(
1137
- inputs[index], samples[index], labels="validation", vocoder=self.vocoder
1138
- )
1139
- wav_original, wav_samples, wav_prediction = (
1140
- wav_original[0],
1141
- wav_samples[0],
1142
- wav_prediction[0],
1143
- )
1144
- elif self.image_key == "stft":
1145
- wav_prediction = (
1146
- self.decode_to_waveform(reconstructions)[index, 0]
1147
- .cpu()
1148
- .detach()
1149
- .numpy()
1150
- )
1151
- wav_samples = (
1152
- self.decode_to_waveform(samples)[index, 0].cpu().detach().numpy()
1153
- )
1154
- wav_original = waveform[index, 0].cpu().detach().numpy()
1155
-
1156
- if self.logger is not None:
1157
- self.logger.experiment.log(
1158
- {
1159
- "original_%s"
1160
- % name: wandb.Audio(
1161
- wav_original, caption="original", sample_rate=self.sampling_rate
1162
- ),
1163
- "reconstruct_%s"
1164
- % name: wandb.Audio(
1165
- wav_prediction,
1166
- caption="reconstruct",
1167
- sample_rate=self.sampling_rate,
1168
- ),
1169
- "samples_%s"
1170
- % name: wandb.Audio(
1171
- wav_samples, caption="samples", sample_rate=self.sampling_rate
1172
- ),
1173
- }
1174
- )
1175
-
1176
- return wav_original, wav_prediction, wav_samples
1177
-
1178
- def tensor2numpy(self, tensor):
1179
- return tensor.cpu().detach().numpy()
1180
-
1181
- def to_rgb(self, x):
1182
- assert self.image_key == "segmentation"
1183
- if not hasattr(self, "colorize"):
1184
- self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
1185
- x = torch.nn.functional.conv2d(x, weight=self.colorize)
1186
- x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
1187
- return x
1188
-
1189
-
1190
- class IdentityFirstStage(torch.nn.Module):
1191
- def __init__(self, *args, vq_interface=False, **kwargs):
1192
- self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
1193
- super().__init__()
1194
-
1195
- def encode(self, x, *args, **kwargs):
1196
- return x
1197
-
1198
- def decode(self, x, *args, **kwargs):
1199
- return x
1200
-
1201
- def quantize(self, x, *args, **kwargs):
1202
- if self.vq_interface:
1203
- return x, None, [None, None, None]
1204
- return x
1205
-
1206
- def forward(self, x, *args, **kwargs):
1207
- return x
1208
-
1209
-
1210
- def window_sumsquare(
1211
- window,
1212
- n_frames,
1213
- hop_length,
1214
- win_length,
1215
- n_fft,
1216
- dtype=np.float32,
1217
- norm=None,
1218
- ):
1219
- """
1220
- # from librosa 0.6
1221
- Compute the sum-square envelope of a window function at a given hop length.
1222
-
1223
- This is used to estimate modulation effects induced by windowing
1224
- observations in short-time fourier transforms.
1225
-
1226
- Parameters
1227
- ----------
1228
- window : string, tuple, number, callable, or list-like
1229
- Window specification, as in `get_window`
1230
-
1231
- n_frames : int > 0
1232
- The number of analysis frames
1233
-
1234
- hop_length : int > 0
1235
- The number of samples to advance between frames
1236
-
1237
- win_length : [optional]
1238
- The length of the window function. By default, this matches `n_fft`.
1239
-
1240
- n_fft : int > 0
1241
- The length of each analysis frame.
1242
-
1243
- dtype : np.dtype
1244
- The data type of the output
1245
-
1246
- Returns
1247
- -------
1248
- wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
1249
- The sum-squared envelope of the window function
1250
- """
1251
- if win_length is None:
1252
- win_length = n_fft
1253
-
1254
- n = n_fft + hop_length * (n_frames - 1)
1255
- x = np.zeros(n, dtype=dtype)
1256
-
1257
- # Compute the squared window at the desired length
1258
- win_sq = get_window(window, win_length, fftbins=True)
1259
- win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
1260
- win_sq = librosa_util.pad_center(win_sq, n_fft)
1261
-
1262
- # Fill the envelope
1263
- for i in range(n_frames):
1264
- sample = i * hop_length
1265
- x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
1266
- return x
1267
-
1268
- def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5):
1269
- """
1270
- PARAMS
1271
- ------
1272
- C: compression factor
1273
- """
1274
- return normalize_fun(torch.clamp(x, min=clip_val) * C)
1275
-
1276
-
1277
- def dynamic_range_decompression(x, C=1):
1278
- """
1279
- PARAMS
1280
- ------
1281
- C: compression factor used to compress
1282
- """
1283
- return torch.exp(x) / C
1284
-
1285
-
1286
- class STFT(torch.nn.Module):
1287
- """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
1288
-
1289
- def __init__(self, filter_length, hop_length, win_length, window="hann"):
1290
- super(STFT, self).__init__()
1291
- self.filter_length = filter_length
1292
- self.hop_length = hop_length
1293
- self.win_length = win_length
1294
- self.window = window
1295
- self.forward_transform = None
1296
- scale = self.filter_length / self.hop_length
1297
- fourier_basis = np.fft.fft(np.eye(self.filter_length))
1298
-
1299
- cutoff = int((self.filter_length / 2 + 1))
1300
- fourier_basis = np.vstack(
1301
- [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
1302
- )
1303
-
1304
- forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
1305
- inverse_basis = torch.FloatTensor(
1306
- np.linalg.pinv(scale * fourier_basis).T[:, None, :]
1307
- )
1308
-
1309
- if window is not None:
1310
- assert filter_length >= win_length
1311
- # get window and zero center pad it to filter_length
1312
- fft_window = get_window(window, win_length, fftbins=True)
1313
- fft_window = pad_center(fft_window, size=filter_length)
1314
- fft_window = torch.from_numpy(fft_window).float()
1315
-
1316
- # window the bases
1317
- forward_basis *= fft_window
1318
- inverse_basis *= fft_window
1319
-
1320
- self.register_buffer("forward_basis", forward_basis.float())
1321
- self.register_buffer("inverse_basis", inverse_basis.float())
1322
-
1323
- def transform(self, input_data):
1324
-
1325
- device = self.forward_basis.device
1326
- input_data = input_data.to(device)
1327
-
1328
- num_batches = input_data.size(0)
1329
- num_samples = input_data.size(1)
1330
-
1331
- self.num_samples = num_samples
1332
-
1333
- # similar to librosa, reflect-pad the input
1334
- input_data = input_data.view(num_batches, 1, num_samples)
1335
- input_data = torch.nn.functional.pad(
1336
- input_data.unsqueeze(1),
1337
- (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
1338
- mode="reflect",
1339
- )
1340
- input_data = input_data.squeeze(1)
1341
-
1342
- forward_transform = torch.nn.functional.conv1d(
1343
- input_data,
1344
- torch.autograd.Variable(self.forward_basis, requires_grad=False),
1345
- stride=self.hop_length,
1346
- padding=0,
1347
- )#.cpu()
1348
-
1349
- cutoff = int((self.filter_length / 2) + 1)
1350
- real_part = forward_transform[:, :cutoff, :]
1351
- imag_part = forward_transform[:, cutoff:, :]
1352
-
1353
- magnitude = torch.sqrt(real_part**2 + imag_part**2)
1354
- phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
1355
-
1356
- return magnitude, phase
1357
-
1358
- def inverse(self, magnitude, phase):
1359
-
1360
- device = self.forward_basis.device
1361
- magnitude, phase = magnitude.to(device), phase.to(device)
1362
-
1363
- recombine_magnitude_phase = torch.cat(
1364
- [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
1365
- )
1366
-
1367
- inverse_transform = torch.nn.functional.conv_transpose1d(
1368
- recombine_magnitude_phase,
1369
- torch.autograd.Variable(self.inverse_basis, requires_grad=False),
1370
- stride=self.hop_length,
1371
- padding=0,
1372
- )
1373
-
1374
- if self.window is not None:
1375
- window_sum = window_sumsquare(
1376
- self.window,
1377
- magnitude.size(-1),
1378
- hop_length=self.hop_length,
1379
- win_length=self.win_length,
1380
- n_fft=self.filter_length,
1381
- dtype=np.float32,
1382
- )
1383
- # remove modulation effects
1384
- approx_nonzero_indices = torch.from_numpy(
1385
- np.where(window_sum > tiny(window_sum))[0]
1386
- )
1387
- window_sum = torch.autograd.Variable(
1388
- torch.from_numpy(window_sum), requires_grad=False
1389
- )
1390
- window_sum = window_sum
1391
- inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
1392
- approx_nonzero_indices
1393
- ]
1394
-
1395
- # scale by hop ratio
1396
- inverse_transform *= float(self.filter_length) / self.hop_length
1397
-
1398
- inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
1399
- inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
1400
-
1401
- return inverse_transform
1402
-
1403
- def forward(self, input_data):
1404
- self.magnitude, self.phase = self.transform(input_data)
1405
- reconstruction = self.inverse(self.magnitude, self.phase)
1406
- return reconstruction
1407
-
1408
-
1409
- class TacotronSTFT(torch.nn.Module):
1410
- def __init__(
1411
- self,
1412
- filter_length,
1413
- hop_length,
1414
- win_length,
1415
- n_mel_channels,
1416
- sampling_rate,
1417
- mel_fmin,
1418
- mel_fmax,
1419
- ):
1420
- super(TacotronSTFT, self).__init__()
1421
- self.n_mel_channels = n_mel_channels
1422
- self.sampling_rate = sampling_rate
1423
- self.stft_fn = STFT(filter_length, hop_length, win_length)
1424
- mel_basis = librosa_mel_fn(
1425
- sr = sampling_rate, n_fft = filter_length, n_mels = n_mel_channels, fmin = mel_fmin, fmax = mel_fmax
1426
- )
1427
- mel_basis = torch.from_numpy(mel_basis).float()
1428
- self.register_buffer("mel_basis", mel_basis)
1429
-
1430
- def spectral_normalize(self, magnitudes, normalize_fun):
1431
- output = dynamic_range_compression(magnitudes, normalize_fun)
1432
- return output
1433
-
1434
- def spectral_de_normalize(self, magnitudes):
1435
- output = dynamic_range_decompression(magnitudes)
1436
- return output
1437
-
1438
- def mel_spectrogram(self, y, normalize_fun=torch.log):
1439
- """Computes mel-spectrograms from a batch of waves
1440
- PARAMS
1441
- ------
1442
- y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
1443
-
1444
- RETURNS
1445
- -------
1446
- mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
1447
- """
1448
- assert torch.min(y.data) >= -1, torch.min(y.data)
1449
- assert torch.max(y.data) <= 1, torch.max(y.data)
1450
-
1451
- magnitudes, phases = self.stft_fn.transform(y)
1452
- magnitudes = magnitudes.data
1453
- mel_output = torch.matmul(self.mel_basis, magnitudes)
1454
- mel_output = self.spectral_normalize(mel_output, normalize_fun)
1455
- energy = torch.norm(magnitudes, dim=1)
1456
-
1457
- log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun)
1458
-
1459
- return mel_output, log_magnitudes, energy
1460
-
1461
-
1462
- def build_pretrained_models(ckpt):
1463
- checkpoint = torch.load(ckpt, map_location="cpu")
1464
- scale_factor = checkpoint["state_dict"]["scale_factor"].item()
1465
- print("scale_factor: ", scale_factor)
1466
-
1467
- vae_state_dict = {k[18:]: v for k, v in checkpoint["state_dict"].items() if "first_stage_model." in k}
1468
-
1469
- config = {
1470
- "preprocessing": {
1471
- "audio": {
1472
- "sampling_rate": 48000,
1473
- "max_wav_value": 32768,
1474
- "duration": 10.24
1475
- },
1476
- "stft": {
1477
- "filter_length": 2048,
1478
- "hop_length": 480,
1479
- "win_length": 2048
1480
- },
1481
- "mel": {
1482
- "n_mel_channels": 256,
1483
- "mel_fmin": 20,
1484
- "mel_fmax": 24000
1485
- }
1486
- },
1487
- "model": {
1488
- "params": {
1489
- "first_stage_config": {
1490
- "params": {
1491
- "sampling_rate": 48000,
1492
- "batchsize": 4,
1493
- "monitor": "val/rec_loss",
1494
- "image_key": "fbank",
1495
- "subband": 1,
1496
- "embed_dim": 16,
1497
- "time_shuffle": 1,
1498
- "lossconfig": {
1499
- "target": "audioldm2.latent_diffusion.modules.losses.LPIPSWithDiscriminator",
1500
- "params": {
1501
- "disc_start": 50001,
1502
- "kl_weight": 1000,
1503
- "disc_weight": 0.5,
1504
- "disc_in_channels": 1
1505
- }
1506
- },
1507
- "ddconfig": {
1508
- "double_z": True,
1509
- "mel_bins": 256,
1510
- "z_channels": 16,
1511
- "resolution": 256,
1512
- "downsample_time": False,
1513
- "in_channels": 1,
1514
- "out_ch": 1,
1515
- "ch": 128,
1516
- "ch_mult": [
1517
- 1,
1518
- 2,
1519
- 4,
1520
- 8
1521
- ],
1522
- "num_res_blocks": 2,
1523
- "attn_resolutions": [],
1524
- "dropout": 0
1525
- }
1526
- }
1527
- },
1528
- }
1529
- }
1530
- }
1531
- vae_config = config["model"]["params"]["first_stage_config"]["params"]
1532
- vae_config["scale_factor"] = scale_factor
1533
-
1534
- vae = AutoencoderKL(**vae_config)
1535
- vae.load_state_dict(vae_state_dict)
1536
-
1537
- fn_STFT = TacotronSTFT(
1538
- config["preprocessing"]["stft"]["filter_length"],
1539
- config["preprocessing"]["stft"]["hop_length"],
1540
- config["preprocessing"]["stft"]["win_length"],
1541
- config["preprocessing"]["mel"]["n_mel_channels"],
1542
- config["preprocessing"]["audio"]["sampling_rate"],
1543
- config["preprocessing"]["mel"]["mel_fmin"],
1544
- config["preprocessing"]["mel"]["mel_fmax"],
1545
- )
1546
-
1547
- vae.eval()
1548
- fn_STFT.eval()
1549
- return vae, fn_STFT
1550
-
1551
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MuCodec/tools/torch_tools.py DELETED
@@ -1,100 +0,0 @@
1
- import torch
2
- import torchaudio
3
- import random
4
- import itertools
5
- import numpy as np
6
-
7
-
8
-
9
- def normalize_wav(waveform):
10
- waveform = waveform - torch.mean(waveform)
11
- waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8)
12
- return waveform * 0.5
13
-
14
-
15
- def pad_wav(waveform, segment_length):
16
- waveform_length = len(waveform)
17
-
18
- if segment_length is None or waveform_length == segment_length:
19
- return waveform
20
- elif waveform_length > segment_length:
21
- return waveform[:segment_length]
22
- else:
23
- pad_wav = torch.zeros(segment_length - waveform_length).to(waveform.device)
24
- waveform = torch.cat([waveform, pad_wav])
25
- return waveform
26
-
27
-
28
- def _pad_spec(fbank, target_length=1024):
29
- batch, n_frames, channels = fbank.shape
30
- p = target_length - n_frames
31
- if p > 0:
32
- pad = torch.zeros(batch, p, channels).to(fbank.device)
33
- fbank = torch.cat([fbank, pad], 1)
34
- elif p < 0:
35
- fbank = fbank[:, :target_length, :]
36
-
37
- if channels % 2 != 0:
38
- fbank = fbank[:, :, :-1]
39
-
40
- return fbank
41
-
42
-
43
- def read_wav_file(filename, segment_length):
44
- waveform, sr = torchaudio.load(filename) # Faster!!!
45
- waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)[0]
46
- try:
47
- waveform = normalize_wav(waveform)
48
- except:
49
- print ("Exception normalizing:", filename)
50
- waveform = torch.ones(160000)
51
- waveform = pad_wav(waveform, segment_length).unsqueeze(0)
52
- waveform = waveform / torch.max(torch.abs(waveform))
53
- waveform = 0.5 * waveform
54
- return waveform
55
-
56
-
57
- def get_mel_from_wav(audio, _stft):
58
- audio = torch.nan_to_num(torch.clip(audio, -1, 1))
59
- audio = torch.autograd.Variable(audio, requires_grad=False)
60
- melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio)
61
- return melspec, log_magnitudes_stft, energy
62
-
63
-
64
- def wav_to_fbank(paths, target_length=1024, fn_STFT=None):
65
- assert fn_STFT is not None
66
-
67
- waveform = torch.cat([read_wav_file(path, target_length * 160) for path in paths], 0) # hop size is 160
68
-
69
- fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT)
70
- fbank = fbank.transpose(1, 2)
71
- log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2)
72
-
73
- fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(
74
- log_magnitudes_stft, target_length
75
- )
76
-
77
- return fbank, log_magnitudes_stft, waveform
78
-
79
- def wav_to_fbank2(waveform, target_length=-1, fn_STFT=None):
80
- assert fn_STFT is not None
81
-
82
- fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT)
83
- fbank = fbank.transpose(1, 2)
84
- log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2)
85
- # print(fbank.shape, log_magnitudes_stft.shape)
86
-
87
- if(target_length>0):
88
- fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(
89
- log_magnitudes_stft, target_length
90
- )
91
-
92
- return fbank, log_magnitudes_stft, waveform
93
-
94
-
95
- def uncapitalize(s):
96
- if s:
97
- return s[:1].lower() + s[1:]
98
- else:
99
- return ""
100
-