mie237 commited on
Commit
bedfeec
·
verified ·
1 Parent(s): dfc0841

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [2026] [MiLM Plus, Xiaomi Inc.]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,4 +1,5 @@
1
  ---
 
2
  language:
3
  - en
4
  - es
@@ -8,40 +9,38 @@ language:
8
  - ja
9
  - ko
10
  - de
11
- license: apache-2.0
12
  tags:
13
  - audio-generation
14
  - text-to-audio
15
- - speech-synthesis
16
- - music-generation
17
  - sound-effects
18
- - flow-matching
19
- - diffusion-transformer
20
  - multilingual
 
21
  pipeline_tag: text-to-audio
22
  ---
23
 
24
  # Dasheng-AudioGen-Multilingual
25
 
26
- [**English**](./README.md) | [**中文**](./README_zh.md)
27
-
28
- **Dasheng-AudioGen-Multilingual** is the multilingual variant of [Dasheng-AudioGen](https://huggingface.co/mispeech/Dasheng-AudioGen). It replaces the text encoder with `google/mt5-large`, enabling text-to-audio generation from prompts in multiple languages.
 
29
 
30
- - GitHub: [https://github.com/xiaomi-research/dasheng-audiogen](https://github.com/xiaomi-research/dasheng-audiogen)
31
- - Demo: [https://huggingface.co/spaces/mispeech/Dasheng-AudioGen](https://huggingface.co/spaces/mispeech/Dasheng-AudioGen)
32
- - Web Demo: [https://nieeim.github.io/Dasheng-AudioGen-Web/](https://nieeim.github.io/Dasheng-AudioGen-Web/)
33
- - Base model: [mispeech/Dasheng-AudioGen](https://huggingface.co/mispeech/Dasheng-AudioGen)
34
 
35
- ## Differences from Base Model
36
 
37
- | | Dasheng-AudioGen | Dasheng-AudioGen-Multilingual |
38
- |---|---|---|
39
- | Text encoder | `google/flan-t5-large` | `google/mt5-large` |
40
- | Language support | English | Multilingual |
41
 
42
- ## Supported Languages
 
 
 
43
 
44
- Training data language distribution:
45
 
46
  | Language | Duration (h) | Proportion |
47
  |----------|------------:|----------:|
@@ -55,76 +54,108 @@ Training data language distribution:
55
  | German | 842.29 | 3.23% |
56
  | Other | 1,369.16 | 5.24% |
57
 
58
- > **Note:** The current multilingual model has notably higher synthesis error rates for all non-English languages. Languages outside the table above are even less reliable. For English-only use cases, the base model ([mispeech/Dasheng-AudioGen](https://huggingface.co/mispeech/Dasheng-AudioGen)) is recommended.
59
-
60
- ## Files
61
-
62
- | File | Description |
63
- |------|-------------|
64
- | `model.safetensors` | Model weights (~8.2 GB) |
65
- | `config.yaml` | Model architecture configuration |
66
-
67
- ## Usage
68
 
69
- ### Installation
70
 
71
  ```bash
72
- git clone https://github.com/xiaomi-research/dasheng-audiogen.git
73
- cd dasheng-audiogen
74
- conda create -n dasheng-audiogen python=3.10
75
- conda activate dasheng-audiogen
76
- pip install -r requirements.txt
77
  ```
78
 
79
- > torch 2.8.0+cu128 is recommended.
80
 
81
- ### Python API
 
 
82
 
83
  ```python
84
- from dasheng_audiogen.pipeline import DashengAudioGenPipeline
 
85
 
86
- pipe = DashengAudioGenPipeline(
87
- model_name_or_path="mispeech/Dasheng-AudioGen-Multilingual"
88
- )
 
 
 
 
89
 
90
- # Spanish speech example (only <asr> uses the target language)
91
- prompt = pipe.compose_prompt(
 
 
92
  caption="A conversation scene on a busy city street.",
93
  speech="A young woman speaking softly in Spanish.",
94
- asr="Creo que deberíamos irnos ya.",
95
  env="Rain and distant traffic noise.",
 
96
  )
97
- waveforms = pipe.generate(prompts=[prompt])
98
- pipe.save_waveform(waveforms[0], "output.wav")
99
  ```
100
 
101
- ### CLI
102
 
103
- ```bash
104
- python inference_cli.py infer \
105
- --model_name_or_path mispeech/Dasheng-AudioGen-Multilingual \
106
- --content "<|caption|> A conversation scene on a busy city street. <|speech|> A young woman speaking softly in Spanish. <|asr|> Creo que deberíamos irnos ya. <|env|> Rain and distant traffic noise." \
107
- --output_path ./outputs/multilingual.wav
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  ```
109
 
110
- ## Prompt Tags
 
 
111
 
112
  | Tag | Description |
113
  |-----|-------------|
114
- | `<\|caption\|>` | Overall audio scene |
115
- | `<\|speech\|>` | Speaker identity and style |
116
- | `<\|asr\|>` | Spoken transcript |
117
  | `<\|sfx\|>` | Sound effects |
118
  | `<\|music\|>` | Background music |
119
  | `<\|env\|>` | Environmental ambience |
120
 
121
- > **Prompt convention:** All descriptive tags (`caption`, `speech`, `sfx`, `music`, `env`) should be written in **English**. Only `<|asr|>` (the spoken content to synthesize) should use the target language.
122
 
123
- ## Dependencies
124
 
125
- - Audio tokenizer: [mispeech/dashengtokenizer](https://huggingface.co/mispeech/dashengtokenizer)
126
- - Text encoder: [google/mt5-large](https://huggingface.co/google/mt5-large)
127
 
128
- ## Acknowledgments
 
 
 
 
 
 
 
 
 
 
 
129
 
130
- Developed by **XIAOMI LLM PLUS** and **SJTU X-LANCE**.
 
1
  ---
2
+ license: apache-2.0
3
  language:
4
  - en
5
  - es
 
9
  - ja
10
  - ko
11
  - de
12
+ - multilingual
13
  tags:
14
  - audio-generation
15
  - text-to-audio
16
+ - text-to-speech
17
+ - text-to-music
18
  - sound-effects
19
+ - diffusion
 
20
  - multilingual
21
+ library_name: transformers
22
  pipeline_tag: text-to-audio
23
  ---
24
 
25
  # Dasheng-AudioGen-Multilingual
26
 
27
+ [![arXiv](https://img.shields.io/badge/arXiv-Paper-b31b1b?logo=arxiv)](https://arxiv.org/abs/2505.XXXXX)
28
+ [![Hugging Face Model](https://img.shields.io/badge/HuggingFace-Model-orange?logo=huggingface)](https://huggingface.co/mispeech/Dasheng-AudioGen-Multilingual)
29
+ [![Hugging Face Demo](https://img.shields.io/badge/HuggingFace-Demo-orange?logo=huggingface)](https://huggingface.co/spaces/mispeech/Dasheng-AudioGen)
30
+ [![Web Demo](https://img.shields.io/badge/Website-Demo-181717?logo=google-chrome)](https://nieeim.github.io/Dasheng-AudioGen-Web/)
31
 
32
+ [**English**](./README.md) | [**中文**](./README_zh.md)
 
 
 
33
 
34
+ **Dasheng-AudioGen-Multilingual** is the multilingual variant of Dasheng-AudioGen, a unified audio generation model that can jointly synthesize **intelligible speech, music, sound effects, and environmental acoustics** from text descriptions.
35
 
36
+ ## Models
 
 
 
37
 
38
+ | Model | HuggingFace | Text Encoder | Language |
39
+ |-------|-------------|-------------|:--------:|
40
+ | Dasheng-AudioGen | [mispeech/Dasheng-AudioGen](https://huggingface.co/mispeech/Dasheng-AudioGen) | `google/flan-t5-large` | English |
41
+ | Dasheng-AudioGen-Multilingual | [mispeech/Dasheng-AudioGen-Multilingual](https://huggingface.co/mispeech/Dasheng-AudioGen-Multilingual) | `google/mt5-large` | Multilingual |
42
 
43
+ ### Language Support
44
 
45
  | Language | Duration (h) | Proportion |
46
  |----------|------------:|----------:|
 
54
  | German | 842.29 | 3.23% |
55
  | Other | 1,369.16 | 5.24% |
56
 
57
+ > **Note:** The current multilingual model has notably higher synthesis error rates for all non-English languages. Languages outside the table above are even less reliable. For English-only use cases, the base model (`mispeech/Dasheng-AudioGen`) is recommended.
 
 
 
 
 
 
 
 
 
58
 
59
+ ## Installation
60
 
61
  ```bash
62
+ pip install torch torchaudio "transformers<5" einops
 
 
 
 
63
  ```
64
 
65
+ > Tested with Python 3.10, torch 2.8.0+cu128, transformers 4.57. Not compatible with transformers 5.x.
66
 
67
+ ## Quick Start
68
+
69
+ ### Basic Usage
70
 
71
  ```python
72
+ import torchaudio
73
+ from transformers import AutoModel
74
 
75
+ model = AutoModel.from_pretrained("mispeech/Dasheng-AudioGen-Multilingual", trust_remote_code=True).cuda()
76
+
77
+ audio = model.generate("A dog barking in a park")
78
+ torchaudio.save("output.wav", audio.cpu(), 16000)
79
+ ```
80
+
81
+ ### Aspect-wise Prompt
82
 
83
+ Use `compose_prompt` to describe different audio aspects separately:
84
+
85
+ ```python
86
+ prompt = model.compose_prompt(
87
  caption="A conversation scene on a busy city street.",
88
  speech="A young woman speaking softly in Spanish.",
 
89
  env="Rain and distant traffic noise.",
90
+ asr="Creo que deberíamos irnos ya.",
91
  )
92
+ audio = model.generate(prompt)
93
+ torchaudio.save("output.wav", audio.cpu(), 16000)
94
  ```
95
 
96
+ You can also pass a pre-formatted string with tags directly:
97
 
98
+ ```python
99
+ audio = model.generate(
100
+ "<|caption|> A helicopter passing overhead. <|sfx|> Rhythmic helicopter blade sounds. <|env|> Open sky ambience."
101
+ )
102
+ ```
103
+
104
+ ### Batch Inference
105
+
106
+ ```python
107
+ prompts = [
108
+ model.compose_prompt(caption="A cat meowing softly.", sfx="Soft cat meow."),
109
+ model.compose_prompt(caption="Thunder rolling in the distance.", env="Stormy night ambience."),
110
+ model.compose_prompt(caption="A piano playing a gentle melody.", music="Soft piano ballad."),
111
+ ]
112
+ audios = model.generate(prompts)
113
+
114
+ for i, audio in enumerate(audios):
115
+ torchaudio.save(f"output_{i}.wav", audio.unsqueeze(0).cpu(), 16000)
116
+ ```
117
+
118
+ ### Generation Parameters
119
+
120
+ ```python
121
+ audio = model.generate(
122
+ prompts="A dog barking in a park",
123
+ num_steps=25, # number of denoising steps (default: 25)
124
+ guidance_scale=5.0, # classifier-free guidance scale (default: 5.0)
125
+ sway_sampling_coef=-1.0, # sway sampling coefficient (default: -1.0, 0 for linear)
126
+ )
127
  ```
128
 
129
+ ## Prompt Format
130
+
131
+ Dasheng-AudioGen uses structured tags to describe different audio aspects:
132
 
133
  | Tag | Description |
134
  |-----|-------------|
135
+ | `<\|caption\|>` | Overall audio scene description |
136
+ | `<\|speech\|>` | Speaker identity and speaking style |
137
+ | `<\|asr\|>` | Spoken transcript / dialogue |
138
  | `<\|sfx\|>` | Sound effects |
139
  | `<\|music\|>` | Background music |
140
  | `<\|env\|>` | Environmental ambience |
141
 
142
+ > **Multilingual prompt convention:** All descriptive tags (`caption`, `speech`, `sfx`, `music`, `env`) should be written in **English**. Only the `<|asr|>` field (the actual spoken content to be synthesized) should use the target language.
143
 
144
+ ## Acknowledgments
145
 
146
+ Dasheng-AudioGen was developed with contributions from **XIAOMI LLM PLUS** and **SJTU X-LANCE**.
 
147
 
148
+ ## Citation
149
+
150
+ ```bibtex
151
+ @article{dasheng-audiogen,
152
+ title={Dasheng-AudioGen},
153
+ author={},
154
+ journal={arXiv preprint arXiv:2505.XXXXX},
155
+ year={2025}
156
+ }
157
+ ```
158
+
159
+ ## License
160
 
161
+ This project is released under the [Apache License 2.0](LICENSE).
README_zh.md CHANGED
@@ -1,24 +1,22 @@
1
  # Dasheng-AudioGen-Multilingual
2
 
3
- [**English**](./README.md) | [**中文**](./README_zh.md)
4
-
5
- **Dasheng-AudioGen-Multilingual** [Dasheng-AudioGen](https://huggingface.co/mispeech/Dasheng-AudioGen) 的多语言版本。它将文本编码器替换为 `google/mt5-large`,支持使用多种语言的 prompt 进行音频生成。
 
6
 
7
- - GitHub: [https://github.com/xiaomi-research/dasheng-audiogen](https://github.com/xiaomi-research/dasheng-audiogen)
8
- - Demo: [https://huggingface.co/spaces/mispeech/Dasheng-AudioGen](https://huggingface.co/spaces/mispeech/Dasheng-AudioGen)
9
- - Web Demo: [https://nieeim.github.io/Dasheng-AudioGen-Web/](https://nieeim.github.io/Dasheng-AudioGen-Web/)
10
- - 基础模型: [mispeech/Dasheng-AudioGen](https://huggingface.co/mispeech/Dasheng-AudioGen)
11
 
12
- ## 与基础模型的区别
13
 
14
- | | Dasheng-AudioGen | Dasheng-AudioGen-Multilingual |
15
- |---|---|---|
16
- | 文本编码器 | `google/flan-t5-large` | `google/mt5-large` |
17
- | 语言支持 | 英语 | 多语言 |
18
 
19
- ## 支持语言
 
 
 
20
 
21
- 训练数据语言分布:
22
 
23
  | 语言 | 时长 (h) | 占比 |
24
  |------|--------:|-----:|
@@ -32,76 +30,108 @@
32
  | 德语 (German) | 842.29 | 3.23% |
33
  | 其他 | 1,369.16 | 5.24% |
34
 
35
- > **注意:** 当前多语言模型在所有非英语语言上的合成错误率都明显偏高,表中未列出的语言更不稳定。如果仅需英语生成,建议使用基础模型 ([mispeech/Dasheng-AudioGen](https://huggingface.co/mispeech/Dasheng-AudioGen))
36
 
37
- ## 文件说明
38
-
39
- | 文件 | 描述 |
40
- |------|------|
41
- | `model.safetensors` | 模型权重 (~8.2 GB) |
42
- | `config.yaml` | 模型结构配置 |
43
-
44
- ## 使用方法
45
-
46
- ### 安装
47
 
48
  ```bash
49
- git clone https://github.com/xiaomi-research/dasheng-audiogen.git
50
- cd dasheng-audiogen
51
- conda create -n dasheng-audiogen python=3.10
52
- conda activate dasheng-audiogen
53
- pip install -r requirements.txt
54
  ```
55
 
56
- > 推荐使用 torch 2.8.0+cu128。
57
 
58
- ### Python API
 
 
59
 
60
  ```python
61
- from dasheng_audiogen.pipeline import DashengAudioGenPipeline
 
62
 
63
- pipe = DashengAudioGenPipeline(
64
- model_name_or_path="mispeech/Dasheng-AudioGen-Multilingual"
65
- )
66
 
67
- # 西班牙语语音示例(仅 <asr> 使用目标语言)
68
- prompt = pipe.compose_prompt(
 
 
 
 
 
 
 
 
69
  caption="A conversation scene on a busy city street.",
70
  speech="A young woman speaking softly in Spanish.",
71
- asr="Creo que deberíamos irnos ya.",
72
  env="Rain and distant traffic noise.",
 
73
  )
74
- waveforms = pipe.generate(prompts=[prompt])
75
- pipe.save_waveform(waveforms[0], "output.wav")
76
  ```
77
 
78
- ### 命令行
79
 
80
- ```bash
81
- python inference_cli.py infer \
82
- --model_name_or_path mispeech/Dasheng-AudioGen-Multilingual \
83
- --content "<|caption|> A conversation scene on a busy city street. <|speech|> A young woman speaking softly in Spanish. <|asr|> Creo que deberíamos irnos ya. <|env|> Rain and distant traffic noise." \
84
- --output_path ./outputs/multilingual.wav
85
  ```
86
 
87
- ## Prompt 标签
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  | 标签 | 描述 |
90
  |------|------|
91
- | `<\|caption\|>` | 整体音频场景 |
92
- | `<\|speech\|>` | 说话人身份和风格 |
93
- | `<\|asr\|>` | 语音转写内容 |
94
  | `<\|sfx\|>` | 音效 |
95
  | `<\|music\|>` | 背景音乐 |
96
  | `<\|env\|>` | 环境音 |
97
 
98
- > **Prompt 规范:** 所有描述性标签(`caption`、`speech`、`sfx`、`music`、`env`)应使用**英文**填写,仅 `<|asr|>`(实际要合成的语音内容)使用目标语言。
 
 
99
 
100
- ## 依赖资源
101
 
102
- - 音频分词器: [mispeech/dashengtokenizer](https://huggingface.co/mispeech/dashengtokenizer)
103
- - 文本编码器: [google/mt5-large](https://huggingface.co/google/mt5-large)
104
 
105
- ## 致谢
 
 
 
 
 
 
 
 
 
106
 
107
- 由**小米 LLM PLUS** 和**上海交通大学 X-LANCE** 联合开发。
 
1
  # Dasheng-AudioGen-Multilingual
2
 
3
+ [![arXiv](https://img.shields.io/badge/arXiv-Paper-b31b1b?logo=arxiv)](https://arxiv.org/abs/2505.XXXXX)
4
+ [![Hugging Face Model](https://img.shields.io/badge/HuggingFace-Model-orange?logo=huggingface)](https://huggingface.co/mispeech/Dasheng-AudioGen-Multilingual)
5
+ [![Hugging Face Demo](https://img.shields.io/badge/HuggingFace-Demo-orange?logo=huggingface)](https://huggingface.co/spaces/mispeech/Dasheng-AudioGen)
6
+ [![Web Demo](https://img.shields.io/badge/Website-Demo-181717?logo=google-chrome)](https://nieeim.github.io/Dasheng-AudioGen-Web/)
7
 
8
+ [**English**](./README.md) | [**中文**](./README_zh.md)
 
 
 
9
 
10
+ **Dasheng-AudioGen-Multilingual** 是 Dasheng-AudioGen 的多语言版本,是一个统一的音频生成模型,能够根据文本描述同时合成**语音、音乐、音效和环境声**。
11
 
12
+ ## 模型
 
 
 
13
 
14
+ | 模型 | HuggingFace | 文本编码器 | 语言支持 |
15
+ |------|-------------|-----------|:--------:|
16
+ | Dasheng-AudioGen | [mispeech/Dasheng-AudioGen](https://huggingface.co/mispeech/Dasheng-AudioGen) | `google/flan-t5-large` | 英语 |
17
+ | Dasheng-AudioGen-Multilingual | [mispeech/Dasheng-AudioGen-Multilingual](https://huggingface.co/mispeech/Dasheng-AudioGen-Multilingual) | `google/mt5-large` | 多语言 |
18
 
19
+ ### 多语言支持
20
 
21
  | 语言 | 时长 (h) | 占比 |
22
  |------|--------:|-----:|
 
30
  | 德语 (German) | 842.29 | 3.23% |
31
  | 其他 | 1,369.16 | 5.24% |
32
 
33
+ > **注意:** 当前多语言模型在所有非英语语言上的合成错误率都明显偏高,表中未列出的语言更不稳定。如果仅需英语生成,建议使用基础模型 (`mispeech/Dasheng-AudioGen`)。
34
 
35
+ ## 安装
 
 
 
 
 
 
 
 
 
36
 
37
  ```bash
38
+ pip install torch torchaudio "transformers<5" einops
 
 
 
 
39
  ```
40
 
41
+ > 已在 Python 3.10、torch 2.8.0+cu128、transformers 4.57 上测试通过已知不兼容 transformers 5.x。
42
 
43
+ ## 快速开始
44
+
45
+ ### 基本用法
46
 
47
  ```python
48
+ import torchaudio
49
+ from transformers import AutoModel
50
 
51
+ model = AutoModel.from_pretrained("mispeech/Dasheng-AudioGen-Multilingual", trust_remote_code=True).cuda()
 
 
52
 
53
+ audio = model.generate("A dog barking in a park")
54
+ torchaudio.save("output.wav", audio.cpu(), 16000)
55
+ ```
56
+
57
+ ### 分项 Prompt
58
+
59
+ 使用 `compose_prompt` 分别描述不同的音频维度:
60
+
61
+ ```python
62
+ prompt = model.compose_prompt(
63
  caption="A conversation scene on a busy city street.",
64
  speech="A young woman speaking softly in Spanish.",
 
65
  env="Rain and distant traffic noise.",
66
+ asr="Creo que deberíamos irnos ya.",
67
  )
68
+ audio = model.generate(prompt)
69
+ torchaudio.save("output.wav", audio.cpu(), 16000)
70
  ```
71
 
72
+ 也可以直接传入包含标签的完整字符串:
73
 
74
+ ```python
75
+ audio = model.generate(
76
+ "<|caption|> A helicopter passing overhead. <|sfx|> Rhythmic helicopter blade sounds. <|env|> Open sky ambience."
77
+ )
 
78
  ```
79
 
80
+ ### 批量推理
81
+
82
+ ```python
83
+ prompts = [
84
+ model.compose_prompt(caption="A cat meowing softly.", sfx="Soft cat meow."),
85
+ model.compose_prompt(caption="Thunder rolling in the distance.", env="Stormy night ambience."),
86
+ model.compose_prompt(caption="A piano playing a gentle melody.", music="Soft piano ballad."),
87
+ ]
88
+ audios = model.generate(prompts)
89
+
90
+ for i, audio in enumerate(audios):
91
+ torchaudio.save(f"output_{i}.wav", audio.unsqueeze(0).cpu(), 16000)
92
+ ```
93
+
94
+ ### 生成参数
95
+
96
+ ```python
97
+ audio = model.generate(
98
+ prompts="A dog barking in a park",
99
+ num_steps=25, # 去噪步数(默认:25)
100
+ guidance_scale=5.0, # 无分类器引导强度(默认:5.0)
101
+ sway_sampling_coef=-1.0, # sway 采样系数(默认:-1.0,设为 0 使用线性调度)
102
+ )
103
+ ```
104
+
105
+ ## Prompt 格式
106
+
107
+ Dasheng-AudioGen 使用结构化标签来描述不同的音频维度:
108
 
109
  | 标签 | 描述 |
110
  |------|------|
111
+ | `<\|caption\|>` | 整体音频场景描述 |
112
+ | `<\|speech\|>` | 说话人身份和说话风格 |
113
+ | `<\|asr\|>` | 语音转写内容 / 对话文本 |
114
  | `<\|sfx\|>` | 音效 |
115
  | `<\|music\|>` | 背景音乐 |
116
  | `<\|env\|>` | 环境音 |
117
 
118
+ > **多语言 prompt 规范:** 使用多语言模型时,所有描述性标签(`caption`、`speech`、`sfx`、`music`、`env`)应使用**英文**填写,仅 `<|asr|>` 字段(实际要合成的语音内容)使用目标语言。
119
+
120
+ ## 致谢
121
 
122
+ Dasheng-AudioGen 由**小米 LLM PLUS** 和 **上海交通大学 X-LANCE** 联合开发。
123
 
124
+ ## 引用
 
125
 
126
+ ```bibtex
127
+ @article{dasheng-audiogen,
128
+ title={Dasheng-AudioGen},
129
+ author={},
130
+ journal={arXiv preprint arXiv:2505.XXXXX},
131
+ year={2025}
132
+ }
133
+ ```
134
+
135
+ ## 许可证
136
 
137
+ 本项目基于 [Apache License 2.0](LICENSE)
attention.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import einops
5
+ from einops import rearrange, repeat
6
+ from inspect import isfunction
7
+
8
+ from .modules import RMSNorm
9
+
10
+
11
+ # --- Rotary Position Embeddings ---
12
+
13
+ def rotate_half(x):
14
+ x1, x2 = x.chunk(2, dim=-1)
15
+ return torch.cat((-x2, x1), dim=-1)
16
+
17
+
18
+ def apply_rotary_pos_emb(x, cos, sin):
19
+ cos = cos[:, :, : x.shape[-2], :]
20
+ sin = sin[:, :, : x.shape[-2], :]
21
+ return (x * cos) + (rotate_half(x) * sin)
22
+
23
+
24
+ class RotaryEmbedding(nn.Module):
25
+ def __init__(self, dim: int):
26
+ super().__init__()
27
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
28
+ self.register_buffer("inv_freq", inv_freq)
29
+ self._seq_len_cached = None
30
+ self._cos_cached = None
31
+ self._sin_cached = None
32
+
33
+ def _update_cos_sin_tables(self, x, seq_dimension=-2):
34
+ seq_len = x.shape[seq_dimension]
35
+ if (
36
+ seq_len != self._seq_len_cached
37
+ or self._cos_cached.device != x.device
38
+ or self._cos_cached.dtype != x.dtype
39
+ ):
40
+ self._seq_len_cached = seq_len
41
+ t = torch.arange(seq_len, device=x.device, dtype=torch.float32)
42
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype))
43
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
44
+ self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
45
+ self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
46
+ return self._cos_cached, self._sin_cached
47
+
48
+ def forward(self, q, k):
49
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
50
+ q.float(), seq_dimension=-2
51
+ )
52
+ if k is not None:
53
+ return (
54
+ apply_rotary_pos_emb(q.float(), self._cos_cached, self._sin_cached).type_as(q),
55
+ apply_rotary_pos_emb(k.float(), self._cos_cached, self._sin_cached).type_as(k),
56
+ )
57
+ else:
58
+ return (
59
+ apply_rotary_pos_emb(q.float(), self._cos_cached, self._sin_cached).type_as(q),
60
+ None,
61
+ )
62
+
63
+
64
+ # --- Attention Helpers ---
65
+
66
+ def add_mask(sim, mask):
67
+ b, ndim = sim.shape[0], mask.ndim
68
+ if ndim == 3:
69
+ mask = rearrange(mask, "b n m -> b 1 n m")
70
+ if ndim == 2:
71
+ mask = repeat(mask, "n m -> b 1 n m", b=b)
72
+ max_neg_value = -torch.finfo(sim.dtype).max
73
+ sim = sim.masked_fill(~mask, max_neg_value)
74
+ return sim
75
+
76
+
77
+ def create_mask(q_shape, k_shape, device, q_mask=None, k_mask=None):
78
+ def default(val, d):
79
+ return val if val is not None else (d() if isfunction(d) else d)
80
+
81
+ b, i, j = q_shape[0], q_shape[-2], k_shape[-2]
82
+ q_mask = default(q_mask, torch.ones((b, i), device=device, dtype=torch.bool))
83
+ k_mask = default(k_mask, torch.ones((b, j), device=device, dtype=torch.bool))
84
+ attn_mask = rearrange(q_mask, "b i -> b 1 i 1") * rearrange(k_mask, "b j -> b 1 1 j")
85
+ return attn_mask
86
+
87
+
88
+ # --- Main Attention Module ---
89
+
90
+ class Attention(nn.Module):
91
+ def __init__(
92
+ self,
93
+ dim,
94
+ context_dim=None,
95
+ num_heads=8,
96
+ qkv_bias=False,
97
+ qk_scale=None,
98
+ qk_norm=None,
99
+ attn_drop=0.0,
100
+ proj_drop=0.0,
101
+ rope_mode="none",
102
+ ):
103
+ super().__init__()
104
+ self.num_heads = num_heads
105
+ head_dim = dim // num_heads
106
+ self.scale = qk_scale or head_dim ** -0.5
107
+
108
+ self.cross_attn = context_dim is not None
109
+ context_dim = dim if context_dim is None else context_dim
110
+
111
+ self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
112
+ self.to_k = nn.Linear(context_dim, dim, bias=qkv_bias)
113
+ self.to_v = nn.Linear(context_dim, dim, bias=qkv_bias)
114
+
115
+ if qk_norm is None:
116
+ self.norm_q = nn.Identity()
117
+ self.norm_k = nn.Identity()
118
+ elif qk_norm == "layernorm":
119
+ self.norm_q = nn.LayerNorm(head_dim)
120
+ self.norm_k = nn.LayerNorm(head_dim)
121
+ elif qk_norm == "rmsnorm":
122
+ self.norm_q = RMSNorm(head_dim)
123
+ self.norm_k = RMSNorm(head_dim)
124
+ else:
125
+ raise NotImplementedError
126
+
127
+ self.attn_drop_p = attn_drop
128
+ self.attn_drop = nn.Dropout(attn_drop)
129
+ self.proj = nn.Linear(dim, dim)
130
+ self.proj_drop = nn.Dropout(proj_drop)
131
+
132
+ if self.cross_attn:
133
+ assert rope_mode == "none"
134
+ self.rope_mode = rope_mode
135
+ if self.rope_mode == "shared" or self.rope_mode == "x_only":
136
+ self.rotary = RotaryEmbedding(dim=head_dim)
137
+
138
+ def _rotary(self, q, k, extras):
139
+ if self.rope_mode == "shared":
140
+ q, k = self.rotary(q=q, k=k)
141
+ elif self.rope_mode == "x_only":
142
+ q_x, k_x = self.rotary(q=q[:, :, extras:, :], k=k[:, :, extras:, :])
143
+ q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :]
144
+ q = torch.cat((q_c, q_x), dim=2)
145
+ k = torch.cat((k_c, k_x), dim=2)
146
+ elif self.rope_mode == "none":
147
+ pass
148
+ else:
149
+ raise NotImplementedError
150
+ return q, k
151
+
152
+ def _attn(self, q, k, v, mask_binary):
153
+ x = F.scaled_dot_product_attention(
154
+ q, k, v, dropout_p=self.attn_drop_p if self.training else 0.0,
155
+ attn_mask=mask_binary,
156
+ )
157
+ x = einops.rearrange(x, "B H L D -> B L (H D)")
158
+ return x
159
+
160
+ def forward(self, x, context=None, context_mask=None, extras=0):
161
+ B, L, C = x.shape
162
+ if context is None:
163
+ context = x
164
+
165
+ q = self.to_q(x)
166
+ k = self.to_k(context)
167
+ v = self.to_v(context)
168
+
169
+ if context_mask is not None:
170
+ mask_binary = create_mask(x.shape, context.shape, x.device, None, context_mask)
171
+ else:
172
+ mask_binary = None
173
+
174
+ q = einops.rearrange(q, "B L (H D) -> B H L D", H=self.num_heads)
175
+ k = einops.rearrange(k, "B L (H D) -> B H L D", H=self.num_heads)
176
+ v = einops.rearrange(v, "B L (H D) -> B H L D", H=self.num_heads)
177
+
178
+ q = self.norm_q(q)
179
+ k = self.norm_k(k)
180
+
181
+ q, k = self._rotary(q, k, extras)
182
+
183
+ x = self._attn(q, k, v, mask_binary)
184
+
185
+ x = self.proj(x)
186
+ x = self.proj_drop(x)
187
+ return x
config.json ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "dasheng_audiogen",
3
+ "architectures": [
4
+ "DashengAudioGenModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_dasheng_audiogen.DashengAudioGenConfig",
8
+ "AutoModel": "modeling_dasheng_audiogen.DashengAudioGenModel"
9
+ },
10
+ "text_encoder_name": "google/mt5-large",
11
+ "tokenizer_name": "mispeech/dashengtokenizer",
12
+ "use_zero_instruction": true,
13
+ "task_instruction_dim": 1024,
14
+ "sample_rate": 16000,
15
+ "downsampling_ratio": 640,
16
+ "latent_dim": 1280,
17
+ "content_dim": 1024,
18
+ "frame_resolution": 0.005,
19
+ "duration_offset": 1.0,
20
+ "tokenizer_max_length": 512,
21
+ "dit_img_size": 1000,
22
+ "dit_patch_size": 1,
23
+ "dit_in_chans": 1280,
24
+ "dit_out_chans": 1280,
25
+ "dit_input_type": "1d",
26
+ "dit_embed_dim": 1536,
27
+ "dit_depth": 32,
28
+ "dit_num_heads": 24,
29
+ "dit_mlp_ratio": 4.0,
30
+ "dit_qk_norm": "layernorm",
31
+ "dit_norm_layer": "layernorm",
32
+ "dit_act_layer": "geglu",
33
+ "dit_context_norm": true,
34
+ "dit_time_fusion": "ada",
35
+ "dit_ada_sola_rank": 32,
36
+ "dit_ada_sola_alpha": 32,
37
+ "dit_ta_context_dim": 1024,
38
+ "dit_ta_context_fusion": "add",
39
+ "dit_ta_context_norm": true,
40
+ "dit_context_dim": 1024,
41
+ "dit_context_fusion": "cross",
42
+ "dit_context_pe_method": "none",
43
+ "dit_pe_method": "none",
44
+ "dit_rope_mode": "shared",
45
+ "adapter_num_heads": 16,
46
+ "adapter_dropout": 0.2,
47
+ "adapter_duration_grad_scale": 0.1,
48
+ "duration_predictor_filter_channels": 512,
49
+ "duration_predictor_n_layers": 5,
50
+ "duration_predictor_kernel_size": 3,
51
+ "duration_predictor_p_dropout": 0.5,
52
+ "special_tokens": [
53
+ "<|caption|>",
54
+ "<|speech|>",
55
+ "<|sfx|>",
56
+ "<|music|>",
57
+ "<|env|>",
58
+ "<|asr|>",
59
+ "<|speech_start|>",
60
+ "<|speech_end|>"
61
+ ],
62
+ "train_special_tokens": true
63
+ }
configuration_dasheng_audiogen.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class DashengAudioGenConfig(PretrainedConfig):
5
+ model_type = "dasheng_audiogen"
6
+
7
+ def __init__(
8
+ self,
9
+ text_encoder_name: str = "google/flan-t5-large",
10
+ tokenizer_name: str = "mispeech/dashengtokenizer",
11
+ use_zero_instruction: bool = False,
12
+ task_instruction_dim: int = 1024,
13
+ sample_rate: int = 16000,
14
+ downsampling_ratio: int = 640,
15
+ latent_dim: int = 1280,
16
+ content_dim: int = 1024,
17
+ frame_resolution: float = 0.005,
18
+ duration_offset: float = 1.0,
19
+ tokenizer_max_length: int = 512,
20
+ dit_img_size: int = 1000,
21
+ dit_patch_size: int = 1,
22
+ dit_in_chans: int = 1280,
23
+ dit_out_chans: int = 1280,
24
+ dit_input_type: str = "1d",
25
+ dit_embed_dim: int = 1536,
26
+ dit_depth: int = 32,
27
+ dit_num_heads: int = 24,
28
+ dit_mlp_ratio: float = 4.0,
29
+ dit_qk_norm: str = "layernorm",
30
+ dit_norm_layer: str = "layernorm",
31
+ dit_act_layer: str = "geglu",
32
+ dit_context_norm: bool = True,
33
+ dit_time_fusion: str = "ada",
34
+ dit_ada_sola_rank: int = 32,
35
+ dit_ada_sola_alpha: int = 32,
36
+ dit_ta_context_dim: int = 1024,
37
+ dit_ta_context_fusion: str = "add",
38
+ dit_ta_context_norm: bool = True,
39
+ dit_context_dim: int = 1024,
40
+ dit_context_fusion: str = "cross",
41
+ dit_context_pe_method: str = "none",
42
+ dit_pe_method: str = "none",
43
+ dit_rope_mode: str = "shared",
44
+ adapter_num_heads: int = 16,
45
+ adapter_dropout: float = 0.2,
46
+ adapter_duration_grad_scale: float = 0.1,
47
+ duration_predictor_filter_channels: int = 512,
48
+ duration_predictor_n_layers: int = 5,
49
+ duration_predictor_kernel_size: int = 3,
50
+ duration_predictor_p_dropout: float = 0.5,
51
+ special_tokens: list = None,
52
+ train_special_tokens: bool = False,
53
+ **kwargs,
54
+ ):
55
+ super().__init__(**kwargs)
56
+ self.text_encoder_name = text_encoder_name
57
+ self.tokenizer_name = tokenizer_name
58
+ self.use_zero_instruction = use_zero_instruction
59
+ self.task_instruction_dim = task_instruction_dim
60
+ self.sample_rate = sample_rate
61
+ self.downsampling_ratio = downsampling_ratio
62
+ self.latent_dim = latent_dim
63
+ self.content_dim = content_dim
64
+ self.frame_resolution = frame_resolution
65
+ self.duration_offset = duration_offset
66
+ self.tokenizer_max_length = tokenizer_max_length
67
+ self.dit_img_size = dit_img_size
68
+ self.dit_patch_size = dit_patch_size
69
+ self.dit_in_chans = dit_in_chans
70
+ self.dit_out_chans = dit_out_chans
71
+ self.dit_input_type = dit_input_type
72
+ self.dit_embed_dim = dit_embed_dim
73
+ self.dit_depth = dit_depth
74
+ self.dit_num_heads = dit_num_heads
75
+ self.dit_mlp_ratio = dit_mlp_ratio
76
+ self.dit_qk_norm = dit_qk_norm
77
+ self.dit_norm_layer = dit_norm_layer
78
+ self.dit_act_layer = dit_act_layer
79
+ self.dit_context_norm = dit_context_norm
80
+ self.dit_time_fusion = dit_time_fusion
81
+ self.dit_ada_sola_rank = dit_ada_sola_rank
82
+ self.dit_ada_sola_alpha = dit_ada_sola_alpha
83
+ self.dit_ta_context_dim = dit_ta_context_dim
84
+ self.dit_ta_context_fusion = dit_ta_context_fusion
85
+ self.dit_ta_context_norm = dit_ta_context_norm
86
+ self.dit_context_dim = dit_context_dim
87
+ self.dit_context_fusion = dit_context_fusion
88
+ self.dit_context_pe_method = dit_context_pe_method
89
+ self.dit_pe_method = dit_pe_method
90
+ self.dit_rope_mode = dit_rope_mode
91
+ self.adapter_num_heads = adapter_num_heads
92
+ self.adapter_dropout = adapter_dropout
93
+ self.adapter_duration_grad_scale = adapter_duration_grad_scale
94
+ self.duration_predictor_filter_channels = duration_predictor_filter_channels
95
+ self.duration_predictor_n_layers = duration_predictor_n_layers
96
+ self.duration_predictor_kernel_size = duration_predictor_kernel_size
97
+ self.duration_predictor_p_dropout = duration_predictor_p_dropout
98
+ self.special_tokens = special_tokens or []
99
+ self.train_special_tokens = train_special_tokens
content_adapter.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class LayerNorm(nn.LayerNorm):
6
+ def __init__(self, nout, dim=-1):
7
+ super().__init__(nout, eps=1e-12)
8
+ self.dim = dim
9
+
10
+ def forward(self, x):
11
+ if self.dim == -1:
12
+ return super().forward(x)
13
+ return super().forward(x.transpose(1, -1)).transpose(1, -1)
14
+
15
+
16
+ class DurationPredictor(nn.Module):
17
+ def __init__(
18
+ self,
19
+ in_channels: int,
20
+ filter_channels: int,
21
+ n_layers: int = 2,
22
+ kernel_size: int = 3,
23
+ p_dropout: float = 0.1,
24
+ padding: str = "SAME"
25
+ ):
26
+ super().__init__()
27
+ self.conv = nn.ModuleList()
28
+ self.kernel_size = kernel_size
29
+ self.padding = padding
30
+ for idx in range(n_layers):
31
+ in_chans = in_channels if idx == 0 else filter_channels
32
+ self.conv += [
33
+ nn.Sequential(
34
+ nn.ConstantPad1d(
35
+ ((kernel_size - 1) // 2, (kernel_size - 1) // 2)
36
+ if padding == 'SAME' else (kernel_size - 1, 0),
37
+ 0
38
+ ),
39
+ nn.Conv1d(
40
+ in_chans, filter_channels,
41
+ kernel_size, stride=1, padding=0
42
+ ),
43
+ nn.ReLU(),
44
+ LayerNorm(filter_channels, dim=1),
45
+ nn.Dropout(p_dropout)
46
+ )
47
+ ]
48
+ self.linear = nn.Linear(filter_channels, 1)
49
+
50
+ def forward(self, x: torch.Tensor, x_mask: torch.Tensor):
51
+ x = x.transpose(1, -1)
52
+ x_mask = x_mask.unsqueeze(1).to(x.device)
53
+ for f in self.conv:
54
+ x = f(x)
55
+ x = x * x_mask.float()
56
+ x = self.linear(x.transpose(1, -1)) * x_mask.transpose(1, -1).float()
57
+ return x
58
+
59
+
60
+ class ContentAdapterBase(nn.Module):
61
+ def __init__(self, d_out):
62
+ super().__init__()
63
+ self.d_out = d_out
64
+
65
+
66
+ class CrossAttentionAdapter(ContentAdapterBase):
67
+ def __init__(
68
+ self,
69
+ d_out: int,
70
+ content_dim: int,
71
+ prefix_dim: int,
72
+ num_heads: int,
73
+ duration_predictor: DurationPredictor,
74
+ dropout: float = 0.1,
75
+ duration_grad_scale: float = 0.1,
76
+ ):
77
+ super().__init__(d_out)
78
+ self.attn = nn.MultiheadAttention(
79
+ embed_dim=content_dim,
80
+ num_heads=num_heads,
81
+ dropout=dropout,
82
+ kdim=prefix_dim,
83
+ vdim=prefix_dim,
84
+ batch_first=True,
85
+ )
86
+ self.duration_grad_scale = duration_grad_scale
87
+ self.duration_predictor = duration_predictor
88
+ self.global_duration_mlp = nn.Sequential(
89
+ nn.Linear(content_dim, content_dim), nn.ReLU(),
90
+ nn.Dropout(dropout), nn.Linear(content_dim, 1)
91
+ )
92
+ self.norm = nn.LayerNorm(content_dim)
93
+ self.content_proj = nn.Conv1d(content_dim, d_out, 1)
94
+
95
+ def forward(self, content, content_mask, prefix, prefix_mask):
96
+ attn_output, attn_output_weights = self.attn(
97
+ query=content,
98
+ key=prefix,
99
+ value=prefix,
100
+ key_padding_mask=~prefix_mask.bool()
101
+ )
102
+ attn_output = attn_output * content_mask.unsqueeze(-1).float()
103
+ x = self.norm(attn_output + content)
104
+ x_grad_rescaled = x * self.duration_grad_scale + x.detach() * (
105
+ 1 - self.duration_grad_scale
106
+ )
107
+ x_aggregated = (
108
+ x_grad_rescaled * content_mask.unsqueeze(-1).float()
109
+ ).sum(dim=1) / content_mask.sum(dim=1, keepdim=True).float()
110
+ global_duration = self.global_duration_mlp(x_aggregated).squeeze(-1)
111
+ local_duration = self.duration_predictor(
112
+ x_grad_rescaled, content_mask
113
+ ).squeeze(-1)
114
+ content = self.content_proj(x.transpose(1, 2)).transpose(1, 2)
115
+ return content, content_mask, global_duration, local_duration
dit.py ADDED
@@ -0,0 +1,1153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from .modules import (
7
+ film_modulate,
8
+ unpatchify,
9
+ PatchEmbed,
10
+ PE_wrapper,
11
+ TimestepEmbedder,
12
+ FeedForward,
13
+ RMSNorm,
14
+ )
15
+ from .attention import Attention
16
+
17
+
18
+ class AdaLN(nn.Module):
19
+ def __init__(self, dim, ada_mode='ada', r=None, alpha=None):
20
+ super().__init__()
21
+ self.ada_mode = ada_mode
22
+ self.scale_shift_table = None
23
+ if ada_mode == 'ada':
24
+ self.time_ada = nn.Linear(dim, 6 * dim, bias=True)
25
+ elif ada_mode == 'ada_single':
26
+ self.scale_shift_table = nn.Parameter(torch.zeros(6, dim))
27
+ elif ada_mode in ['ada_sola', 'ada_sola_bias']:
28
+ self.lora_a = nn.Linear(dim, r * 6, bias=False)
29
+ self.lora_b = nn.Linear(r * 6, dim * 6, bias=False)
30
+ self.scaling = alpha / r
31
+ if ada_mode == 'ada_sola_bias':
32
+ self.scale_shift_table = nn.Parameter(torch.zeros(6, dim))
33
+ else:
34
+ raise NotImplementedError
35
+
36
+ def forward(self, time_token=None, time_ada=None):
37
+ if self.ada_mode == 'ada':
38
+ assert time_ada is None
39
+ B = time_token.shape[0]
40
+ time_ada = self.time_ada(time_token).reshape(B, 6, -1)
41
+ elif self.ada_mode == 'ada_single':
42
+ B = time_ada.shape[0]
43
+ time_ada = time_ada.reshape(B, 6, -1)
44
+ time_ada = self.scale_shift_table[None] + time_ada
45
+ elif self.ada_mode in ['ada_sola', 'ada_sola_bias']:
46
+ B = time_ada.shape[0]
47
+ time_ada_lora = self.lora_b(self.lora_a(time_token)) * self.scaling
48
+ time_ada = time_ada + time_ada_lora
49
+ time_ada = time_ada.reshape(B, 6, -1)
50
+ if self.scale_shift_table is not None:
51
+ time_ada = self.scale_shift_table[None] + time_ada
52
+ else:
53
+ raise NotImplementedError
54
+ return time_ada
55
+
56
+
57
+ class DiTBlock(nn.Module):
58
+ def __init__(
59
+ self,
60
+ dim,
61
+ context_dim=None,
62
+ num_heads=8,
63
+ mlp_ratio=4.,
64
+ qkv_bias=False,
65
+ qk_scale=None,
66
+ qk_norm=None,
67
+ act_layer='gelu',
68
+ norm_layer=nn.LayerNorm,
69
+ time_fusion='none',
70
+ ada_sola_rank=None,
71
+ ada_sola_alpha=None,
72
+ skip=False,
73
+ skip_norm=False,
74
+ rope_mode='none',
75
+ context_norm=False,
76
+ use_checkpoint=False
77
+ ):
78
+ super().__init__()
79
+ self.norm1 = norm_layer(dim)
80
+ self.attn = Attention(
81
+ dim=dim,
82
+ num_heads=num_heads,
83
+ qkv_bias=qkv_bias,
84
+ qk_scale=qk_scale,
85
+ qk_norm=qk_norm,
86
+ rope_mode=rope_mode
87
+ )
88
+
89
+ if context_dim is not None:
90
+ self.use_context = True
91
+ self.cross_attn = Attention(
92
+ dim=dim,
93
+ num_heads=num_heads,
94
+ context_dim=context_dim,
95
+ qkv_bias=qkv_bias,
96
+ qk_scale=qk_scale,
97
+ qk_norm=qk_norm,
98
+ rope_mode='none'
99
+ )
100
+ self.norm2 = norm_layer(dim)
101
+ if context_norm:
102
+ self.norm_context = norm_layer(context_dim)
103
+ else:
104
+ self.norm_context = nn.Identity()
105
+ else:
106
+ self.use_context = False
107
+
108
+ self.norm3 = norm_layer(dim)
109
+ self.mlp = FeedForward(
110
+ dim=dim, mult=mlp_ratio, activation_fn=act_layer, dropout=0
111
+ )
112
+
113
+ self.use_adanorm = True if time_fusion != 'token' else False
114
+ if self.use_adanorm:
115
+ self.adaln = AdaLN(
116
+ dim,
117
+ ada_mode=time_fusion,
118
+ r=ada_sola_rank,
119
+ alpha=ada_sola_alpha
120
+ )
121
+ if skip:
122
+ self.skip_norm = norm_layer(2 * dim) if skip_norm else nn.Identity()
123
+ self.skip_linear = nn.Linear(2 * dim, dim)
124
+ else:
125
+ self.skip_linear = None
126
+
127
+ self.use_checkpoint = use_checkpoint
128
+
129
+ def forward(
130
+ self,
131
+ x,
132
+ time_token=None,
133
+ time_ada=None,
134
+ skip=None,
135
+ context=None,
136
+ x_mask=None,
137
+ context_mask=None,
138
+ extras=None
139
+ ):
140
+ if self.use_checkpoint:
141
+ from torch.utils.checkpoint import checkpoint
142
+ return checkpoint(
143
+ self._forward,
144
+ x, time_token, time_ada, skip, context, x_mask, context_mask,
145
+ extras,
146
+ use_reentrant=False
147
+ )
148
+ else:
149
+ return self._forward(
150
+ x, time_token, time_ada, skip, context, x_mask, context_mask,
151
+ extras
152
+ )
153
+
154
+ def _forward(
155
+ self,
156
+ x,
157
+ time_token=None,
158
+ time_ada=None,
159
+ skip=None,
160
+ context=None,
161
+ x_mask=None,
162
+ context_mask=None,
163
+ extras=None
164
+ ):
165
+ B, T, C = x.shape
166
+ if self.skip_linear is not None:
167
+ assert skip is not None
168
+ cat = torch.cat([x, skip], dim=-1)
169
+ cat = self.skip_norm(cat)
170
+ x = self.skip_linear(cat)
171
+
172
+ if self.use_adanorm:
173
+ time_ada = self.adaln(time_token, time_ada)
174
+ (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp,
175
+ gate_mlp) = time_ada.chunk(6, dim=1)
176
+
177
+ if self.use_adanorm:
178
+ x_norm = film_modulate(
179
+ self.norm1(x), shift=shift_msa, scale=scale_msa
180
+ )
181
+ x = x + (1 - gate_msa) * self.attn(
182
+ x_norm, context=None, context_mask=x_mask, extras=extras
183
+ )
184
+ else:
185
+ x = x + self.attn(
186
+ self.norm1(x),
187
+ context=None,
188
+ context_mask=x_mask,
189
+ extras=extras
190
+ )
191
+
192
+ if self.use_context:
193
+ assert context is not None
194
+ x = x + self.cross_attn(
195
+ x=self.norm2(x),
196
+ context=self.norm_context(context),
197
+ context_mask=context_mask,
198
+ extras=extras
199
+ )
200
+
201
+ if self.use_adanorm:
202
+ x_norm = film_modulate(
203
+ self.norm3(x), shift=shift_mlp, scale=scale_mlp
204
+ )
205
+ x = x + (1 - gate_mlp) * self.mlp(x_norm)
206
+ else:
207
+ x = x + self.mlp(self.norm3(x))
208
+
209
+ return x
210
+
211
+
212
+ class FinalBlock(nn.Module):
213
+ def __init__(
214
+ self,
215
+ embed_dim,
216
+ patch_size,
217
+ in_chans,
218
+ img_size,
219
+ input_type='2d',
220
+ norm_layer=nn.LayerNorm,
221
+ use_conv=True,
222
+ use_adanorm=True
223
+ ):
224
+ super().__init__()
225
+ self.in_chans = in_chans
226
+ self.img_size = img_size
227
+ self.input_type = input_type
228
+
229
+ self.norm = norm_layer(embed_dim)
230
+ self.use_adanorm = use_adanorm
231
+
232
+ if input_type == '2d':
233
+ self.patch_dim = patch_size**2 * in_chans
234
+ self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True)
235
+ if use_conv:
236
+ self.final_layer = nn.Conv2d(
237
+ self.in_chans, self.in_chans, 3, padding=1
238
+ )
239
+ else:
240
+ self.final_layer = nn.Identity()
241
+
242
+ elif input_type == '1d':
243
+ self.patch_dim = patch_size * in_chans
244
+ self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True)
245
+ if use_conv:
246
+ self.final_layer = nn.Conv1d(
247
+ self.in_chans, self.in_chans, 3, padding=1
248
+ )
249
+ else:
250
+ self.final_layer = nn.Identity()
251
+
252
+ def forward(self, x, time_ada=None, extras=0):
253
+ B, T, C = x.shape
254
+ x = x[:, extras:, :]
255
+ if self.use_adanorm:
256
+ shift, scale = time_ada.reshape(B, 2, -1).chunk(2, dim=1)
257
+ x = film_modulate(self.norm(x), shift, scale)
258
+ else:
259
+ x = self.norm(x)
260
+ x = self.linear(x)
261
+ x = unpatchify(x, self.in_chans, self.input_type, self.img_size)
262
+ x = self.final_layer(x)
263
+ return x
264
+
265
+
266
+ class UDiT(nn.Module):
267
+ def __init__(
268
+ self,
269
+ img_size=224,
270
+ patch_size=16,
271
+ in_chans=3,
272
+ input_type='2d',
273
+ out_chans=None,
274
+ embed_dim=768,
275
+ depth=12,
276
+ num_heads=12,
277
+ mlp_ratio=4.,
278
+ qkv_bias=False,
279
+ qk_scale=None,
280
+ qk_norm=None,
281
+ act_layer='gelu',
282
+ norm_layer='layernorm',
283
+ context_norm=False,
284
+ use_checkpoint=False,
285
+ time_fusion='token',
286
+ ada_sola_rank=None,
287
+ ada_sola_alpha=None,
288
+ cls_dim=None,
289
+ context_dim=768,
290
+ context_fusion='concat',
291
+ context_max_length=128,
292
+ context_pe_method='sinu',
293
+ pe_method='abs',
294
+ rope_mode='none',
295
+ use_conv=True,
296
+ skip=True,
297
+ skip_norm=True
298
+ ):
299
+ super().__init__()
300
+ self.num_features = self.embed_dim = embed_dim
301
+
302
+ self.in_chans = in_chans
303
+ self.input_type = input_type
304
+ if self.input_type == '2d':
305
+ num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size)
306
+ elif self.input_type == '1d':
307
+ num_patches = img_size // patch_size
308
+ self.patch_embed = PatchEmbed(
309
+ patch_size=patch_size,
310
+ in_chans=in_chans,
311
+ embed_dim=embed_dim,
312
+ input_type=input_type
313
+ )
314
+ out_chans = in_chans if out_chans is None else out_chans
315
+ self.out_chans = out_chans
316
+
317
+ self.rope = rope_mode
318
+ self.x_pe = PE_wrapper(
319
+ dim=embed_dim, method=pe_method, length=num_patches
320
+ )
321
+
322
+ self.time_embed = TimestepEmbedder(embed_dim)
323
+ self.time_fusion = time_fusion
324
+ self.use_adanorm = False
325
+
326
+ if cls_dim is not None:
327
+ self.cls_embed = nn.Sequential(
328
+ nn.Linear(cls_dim, embed_dim, bias=True),
329
+ nn.SiLU(),
330
+ nn.Linear(embed_dim, embed_dim, bias=True),
331
+ )
332
+ else:
333
+ self.cls_embed = None
334
+
335
+ if time_fusion == 'token':
336
+ self.extras = 2 if self.cls_embed else 1
337
+ self.time_pe = PE_wrapper(
338
+ dim=embed_dim, method='abs', length=self.extras
339
+ )
340
+ elif time_fusion in ['ada', 'ada_single', 'ada_sola', 'ada_sola_bias']:
341
+ self.use_adanorm = True
342
+ self.time_act = nn.SiLU()
343
+ self.extras = 0
344
+ self.time_ada_final = nn.Linear(
345
+ embed_dim, 2 * embed_dim, bias=True
346
+ )
347
+ if time_fusion in ['ada_single', 'ada_sola', 'ada_sola_bias']:
348
+ self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True)
349
+ else:
350
+ self.time_ada = None
351
+ else:
352
+ raise NotImplementedError
353
+
354
+ self.use_context = False
355
+ self.context_cross = False
356
+ self.context_max_length = context_max_length
357
+ self.context_fusion = 'none'
358
+ if context_dim is not None:
359
+ self.use_context = True
360
+ self.context_embed = nn.Sequential(
361
+ nn.Linear(context_dim, embed_dim, bias=True),
362
+ nn.SiLU(),
363
+ nn.Linear(embed_dim, embed_dim, bias=True),
364
+ )
365
+ self.context_fusion = context_fusion
366
+ if context_fusion == 'concat' or context_fusion == 'joint':
367
+ self.extras += context_max_length
368
+ self.context_pe = PE_wrapper(
369
+ dim=embed_dim,
370
+ method=context_pe_method,
371
+ length=context_max_length
372
+ )
373
+ context_dim = None
374
+ elif context_fusion == 'cross':
375
+ self.context_pe = PE_wrapper(
376
+ dim=embed_dim,
377
+ method=context_pe_method,
378
+ length=context_max_length
379
+ )
380
+ self.context_cross = True
381
+ context_dim = embed_dim
382
+ else:
383
+ raise NotImplementedError
384
+
385
+ self.use_skip = skip
386
+
387
+ if norm_layer == 'layernorm':
388
+ norm_layer = nn.LayerNorm
389
+ elif norm_layer == 'rmsnorm':
390
+ norm_layer = RMSNorm
391
+ else:
392
+ raise NotImplementedError
393
+
394
+ self.in_blocks = nn.ModuleList([
395
+ DiTBlock(
396
+ dim=embed_dim,
397
+ context_dim=context_dim,
398
+ num_heads=num_heads,
399
+ mlp_ratio=mlp_ratio,
400
+ qkv_bias=qkv_bias,
401
+ qk_scale=qk_scale,
402
+ qk_norm=qk_norm,
403
+ act_layer=act_layer,
404
+ norm_layer=norm_layer,
405
+ time_fusion=time_fusion,
406
+ ada_sola_rank=ada_sola_rank,
407
+ ada_sola_alpha=ada_sola_alpha,
408
+ skip=False,
409
+ skip_norm=False,
410
+ rope_mode=self.rope,
411
+ context_norm=context_norm,
412
+ use_checkpoint=use_checkpoint
413
+ ) for _ in range(depth // 2)
414
+ ])
415
+
416
+ self.mid_block = DiTBlock(
417
+ dim=embed_dim,
418
+ context_dim=context_dim,
419
+ num_heads=num_heads,
420
+ mlp_ratio=mlp_ratio,
421
+ qkv_bias=qkv_bias,
422
+ qk_scale=qk_scale,
423
+ qk_norm=qk_norm,
424
+ act_layer=act_layer,
425
+ norm_layer=norm_layer,
426
+ time_fusion=time_fusion,
427
+ ada_sola_rank=ada_sola_rank,
428
+ ada_sola_alpha=ada_sola_alpha,
429
+ skip=False,
430
+ skip_norm=False,
431
+ rope_mode=self.rope,
432
+ context_norm=context_norm,
433
+ use_checkpoint=use_checkpoint
434
+ )
435
+
436
+ self.out_blocks = nn.ModuleList([
437
+ DiTBlock(
438
+ dim=embed_dim,
439
+ context_dim=context_dim,
440
+ num_heads=num_heads,
441
+ mlp_ratio=mlp_ratio,
442
+ qkv_bias=qkv_bias,
443
+ qk_scale=qk_scale,
444
+ qk_norm=qk_norm,
445
+ act_layer=act_layer,
446
+ norm_layer=norm_layer,
447
+ time_fusion=time_fusion,
448
+ ada_sola_rank=ada_sola_rank,
449
+ ada_sola_alpha=ada_sola_alpha,
450
+ skip=skip,
451
+ skip_norm=skip_norm,
452
+ rope_mode=self.rope,
453
+ context_norm=context_norm,
454
+ use_checkpoint=use_checkpoint
455
+ ) for _ in range(depth // 2)
456
+ ])
457
+
458
+ self.use_conv = use_conv
459
+ self.final_block = FinalBlock(
460
+ embed_dim=embed_dim,
461
+ patch_size=patch_size,
462
+ img_size=img_size,
463
+ in_chans=out_chans,
464
+ input_type=input_type,
465
+ norm_layer=norm_layer,
466
+ use_conv=use_conv,
467
+ use_adanorm=self.use_adanorm
468
+ )
469
+ self.initialize_weights()
470
+
471
+ def _init_ada(self):
472
+ if self.time_fusion == 'ada':
473
+ nn.init.constant_(self.time_ada_final.weight, 0)
474
+ nn.init.constant_(self.time_ada_final.bias, 0)
475
+ for block in self.in_blocks:
476
+ nn.init.constant_(block.adaln.time_ada.weight, 0)
477
+ nn.init.constant_(block.adaln.time_ada.bias, 0)
478
+ nn.init.constant_(self.mid_block.adaln.time_ada.weight, 0)
479
+ nn.init.constant_(self.mid_block.adaln.time_ada.bias, 0)
480
+ for block in self.out_blocks:
481
+ nn.init.constant_(block.adaln.time_ada.weight, 0)
482
+ nn.init.constant_(block.adaln.time_ada.bias, 0)
483
+ elif self.time_fusion == 'ada_single':
484
+ nn.init.constant_(self.time_ada.weight, 0)
485
+ nn.init.constant_(self.time_ada.bias, 0)
486
+ nn.init.constant_(self.time_ada_final.weight, 0)
487
+ nn.init.constant_(self.time_ada_final.bias, 0)
488
+ elif self.time_fusion in ['ada_sola', 'ada_sola_bias']:
489
+ nn.init.constant_(self.time_ada.weight, 0)
490
+ nn.init.constant_(self.time_ada.bias, 0)
491
+ nn.init.constant_(self.time_ada_final.weight, 0)
492
+ nn.init.constant_(self.time_ada_final.bias, 0)
493
+ for block in self.in_blocks:
494
+ nn.init.kaiming_uniform_(
495
+ block.adaln.lora_a.weight, a=math.sqrt(5)
496
+ )
497
+ nn.init.constant_(block.adaln.lora_b.weight, 0)
498
+ nn.init.kaiming_uniform_(
499
+ self.mid_block.adaln.lora_a.weight, a=math.sqrt(5)
500
+ )
501
+ nn.init.constant_(self.mid_block.adaln.lora_b.weight, 0)
502
+ for block in self.out_blocks:
503
+ nn.init.kaiming_uniform_(
504
+ block.adaln.lora_a.weight, a=math.sqrt(5)
505
+ )
506
+ nn.init.constant_(block.adaln.lora_b.weight, 0)
507
+
508
+ def initialize_weights(self):
509
+ def _basic_init(module):
510
+ if isinstance(module, nn.Linear):
511
+ nn.init.xavier_uniform_(module.weight)
512
+ if module.bias is not None:
513
+ nn.init.constant_(module.bias, 0)
514
+
515
+ self.apply(_basic_init)
516
+
517
+ w = self.patch_embed.proj.weight.data
518
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
519
+ nn.init.constant_(self.patch_embed.proj.bias, 0)
520
+
521
+ if self.use_adanorm:
522
+ self._init_ada()
523
+
524
+ if self.context_cross:
525
+ for block in self.in_blocks:
526
+ nn.init.constant_(block.cross_attn.proj.weight, 0)
527
+ nn.init.constant_(block.cross_attn.proj.bias, 0)
528
+ nn.init.constant_(self.mid_block.cross_attn.proj.weight, 0)
529
+ nn.init.constant_(self.mid_block.cross_attn.proj.bias, 0)
530
+ for block in self.out_blocks:
531
+ nn.init.constant_(block.cross_attn.proj.weight, 0)
532
+ nn.init.constant_(block.cross_attn.proj.bias, 0)
533
+
534
+ if self.cls_embed:
535
+ if self.use_adanorm:
536
+ nn.init.constant_(self.cls_embed[-1].weight, 0)
537
+ nn.init.constant_(self.cls_embed[-1].bias, 0)
538
+
539
+ if self.use_conv:
540
+ nn.init.xavier_uniform_(self.final_block.final_layer.weight)
541
+ nn.init.constant_(self.final_block.final_layer.bias, 0)
542
+
543
+ def _concat_x_context(self, x, context, x_mask=None, context_mask=None):
544
+ assert context.shape[-2] == self.context_max_length
545
+ B = x.shape[0]
546
+ if x_mask is None:
547
+ x_mask = torch.ones(B, x.shape[-2], device=x.device).bool()
548
+ if context_mask is None:
549
+ context_mask = torch.ones(
550
+ B, context.shape[-2], device=context.device
551
+ ).bool()
552
+ x_mask = torch.cat([context_mask, x_mask], dim=1)
553
+ x = torch.cat((context, x), dim=1)
554
+ return x, x_mask
555
+
556
+ def forward(
557
+ self,
558
+ x,
559
+ timesteps,
560
+ context,
561
+ x_mask=None,
562
+ context_mask=None,
563
+ cls_token=None,
564
+ controlnet_skips=None,
565
+ ):
566
+ if timesteps.dim() == 0:
567
+ timesteps = timesteps.expand(x.shape[0]).to(x.device, dtype=torch.long)
568
+
569
+ x = self.patch_embed(x)
570
+ x = self.x_pe(x)
571
+
572
+ B, L, D = x.shape
573
+
574
+ if self.use_context:
575
+ context_token = self.context_embed(context)
576
+ context_token = self.context_pe(context_token)
577
+ if self.context_fusion == 'concat' or self.context_fusion == 'joint':
578
+ x, x_mask = self._concat_x_context(
579
+ x=x,
580
+ context=context_token,
581
+ x_mask=x_mask,
582
+ context_mask=context_mask
583
+ )
584
+ context_token, context_mask = None, None
585
+ else:
586
+ context_token, context_mask = None, None
587
+
588
+ time_token = self.time_embed(timesteps)
589
+ if self.cls_embed:
590
+ cls_token = self.cls_embed(cls_token)
591
+ time_ada = None
592
+ time_ada_final = None
593
+ if self.use_adanorm:
594
+ if self.cls_embed:
595
+ time_token = time_token + cls_token
596
+ time_token = self.time_act(time_token)
597
+ time_ada_final = self.time_ada_final(time_token)
598
+ if self.time_ada is not None:
599
+ time_ada = self.time_ada(time_token)
600
+ else:
601
+ time_token = time_token.unsqueeze(dim=1)
602
+ if self.cls_embed:
603
+ cls_token = cls_token.unsqueeze(dim=1)
604
+ time_token = torch.cat([time_token, cls_token], dim=1)
605
+ time_token = self.time_pe(time_token)
606
+ x = torch.cat((time_token, x), dim=1)
607
+ if x_mask is not None:
608
+ x_mask = torch.cat([
609
+ torch.ones(B, time_token.shape[1],
610
+ device=x_mask.device).bool(), x_mask
611
+ ], dim=1)
612
+ time_token = None
613
+
614
+ skips = []
615
+ for blk in self.in_blocks:
616
+ x = blk(
617
+ x=x,
618
+ time_token=time_token,
619
+ time_ada=time_ada,
620
+ skip=None,
621
+ context=context_token,
622
+ x_mask=x_mask,
623
+ context_mask=context_mask,
624
+ extras=self.extras
625
+ )
626
+ if self.use_skip:
627
+ skips.append(x)
628
+
629
+ x = self.mid_block(
630
+ x=x,
631
+ time_token=time_token,
632
+ time_ada=time_ada,
633
+ skip=None,
634
+ context=context_token,
635
+ x_mask=x_mask,
636
+ context_mask=context_mask,
637
+ extras=self.extras
638
+ )
639
+ for blk in self.out_blocks:
640
+ if self.use_skip:
641
+ skip = skips.pop()
642
+ if controlnet_skips:
643
+ skip = skip + controlnet_skips.pop()
644
+ else:
645
+ skip = None
646
+ if controlnet_skips:
647
+ x = x + controlnet_skips.pop()
648
+
649
+ x = blk(
650
+ x=x,
651
+ time_token=time_token,
652
+ time_ada=time_ada,
653
+ skip=skip,
654
+ context=context_token,
655
+ x_mask=x_mask,
656
+ context_mask=context_mask,
657
+ extras=self.extras
658
+ )
659
+
660
+ x = self.final_block(x, time_ada=time_ada_final, extras=self.extras)
661
+ return x
662
+
663
+
664
+ class LayerFusionDiTBlock(DiTBlock):
665
+ def __init__(
666
+ self,
667
+ dim,
668
+ ta_context_dim,
669
+ ta_context_norm=False,
670
+ context_dim=None,
671
+ num_heads=8,
672
+ mlp_ratio=4.,
673
+ qkv_bias=False,
674
+ qk_scale=None,
675
+ qk_norm=None,
676
+ act_layer='gelu',
677
+ norm_layer=nn.LayerNorm,
678
+ ta_context_fusion='add',
679
+ time_fusion='none',
680
+ ada_sola_rank=None,
681
+ ada_sola_alpha=None,
682
+ skip=False,
683
+ skip_norm=False,
684
+ rope_mode='none',
685
+ context_norm=False,
686
+ use_checkpoint=False
687
+ ):
688
+ super().__init__(
689
+ dim=dim,
690
+ context_dim=context_dim,
691
+ num_heads=num_heads,
692
+ mlp_ratio=mlp_ratio,
693
+ qkv_bias=qkv_bias,
694
+ qk_scale=qk_scale,
695
+ qk_norm=qk_norm,
696
+ act_layer=act_layer,
697
+ norm_layer=norm_layer,
698
+ time_fusion=time_fusion,
699
+ ada_sola_rank=ada_sola_rank,
700
+ ada_sola_alpha=ada_sola_alpha,
701
+ skip=skip,
702
+ skip_norm=skip_norm,
703
+ rope_mode=rope_mode,
704
+ context_norm=context_norm,
705
+ use_checkpoint=use_checkpoint
706
+ )
707
+ self.ta_context_fusion = ta_context_fusion
708
+ self.ta_context_norm = ta_context_norm
709
+ if self.ta_context_fusion == "add":
710
+ self.ta_context_projection = nn.Linear(
711
+ ta_context_dim, dim, bias=False
712
+ )
713
+ self.ta_context_norm = norm_layer(
714
+ ta_context_dim
715
+ ) if self.ta_context_norm else nn.Identity()
716
+ elif self.ta_context_fusion == "concat":
717
+ self.ta_context_projection = nn.Linear(ta_context_dim + dim, dim)
718
+ self.ta_context_norm = norm_layer(
719
+ ta_context_dim + dim
720
+ ) if self.ta_context_norm else nn.Identity()
721
+
722
+ def forward(
723
+ self,
724
+ x,
725
+ time_aligned_context,
726
+ time_token=None,
727
+ time_ada=None,
728
+ skip=None,
729
+ context=None,
730
+ x_mask=None,
731
+ context_mask=None,
732
+ extras=None
733
+ ):
734
+ if self.use_checkpoint:
735
+ from torch.utils.checkpoint import checkpoint
736
+ return checkpoint(
737
+ self._forward,
738
+ x, time_aligned_context, time_token, time_ada, skip, context,
739
+ x_mask, context_mask, extras,
740
+ use_reentrant=False
741
+ )
742
+ else:
743
+ return self._forward(
744
+ x, time_aligned_context, time_token, time_ada, skip, context,
745
+ x_mask, context_mask, extras,
746
+ )
747
+
748
+ def _forward(
749
+ self,
750
+ x,
751
+ time_aligned_context,
752
+ time_token=None,
753
+ time_ada=None,
754
+ skip=None,
755
+ context=None,
756
+ x_mask=None,
757
+ context_mask=None,
758
+ extras=None
759
+ ):
760
+ B, T, C = x.shape
761
+
762
+ if self.skip_linear is not None:
763
+ assert skip is not None
764
+ cat = torch.cat([x, skip], dim=-1)
765
+ cat = self.skip_norm(cat)
766
+ x = self.skip_linear(cat)
767
+
768
+ if self.use_adanorm:
769
+ time_ada = self.adaln(time_token, time_ada)
770
+ (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp,
771
+ gate_mlp) = time_ada.chunk(6, dim=1)
772
+
773
+ if self.use_adanorm:
774
+ x_norm = film_modulate(
775
+ self.norm1(x), shift=shift_msa, scale=scale_msa
776
+ )
777
+ tanh_gate_msa = torch.tanh(1 - gate_msa)
778
+ x = x + tanh_gate_msa * self.attn(
779
+ x_norm, context=None, context_mask=x_mask, extras=extras
780
+ )
781
+ else:
782
+ x = x + self.attn(
783
+ self.norm1(x),
784
+ context=None,
785
+ context_mask=x_mask,
786
+ extras=extras
787
+ )
788
+
789
+ if self.ta_context_fusion == "add":
790
+ time_aligned_context = self.ta_context_projection(
791
+ self.ta_context_norm(time_aligned_context)
792
+ )
793
+ if time_aligned_context.size(1) < x.size(1):
794
+ time_aligned_context = nn.functional.pad(
795
+ time_aligned_context, (0, 0, 1, 0)
796
+ )
797
+ x = x + time_aligned_context
798
+ elif self.ta_context_fusion == "concat":
799
+ if time_aligned_context.size(1) < x.size(1):
800
+ time_aligned_context = nn.functional.pad(
801
+ time_aligned_context, (0, 0, 1, 0)
802
+ )
803
+ cat = torch.cat([x, time_aligned_context], dim=-1)
804
+ cat = self.ta_context_norm(cat)
805
+ x = self.ta_context_projection(cat)
806
+
807
+ if self.use_context:
808
+ assert context is not None
809
+ x = x + self.cross_attn(
810
+ x=self.norm2(x),
811
+ context=self.norm_context(context),
812
+ context_mask=context_mask,
813
+ extras=extras
814
+ )
815
+
816
+ if self.use_adanorm:
817
+ x_norm = film_modulate(
818
+ self.norm3(x), shift=shift_mlp, scale=scale_mlp
819
+ )
820
+ x = x + (1 - gate_mlp) * self.mlp(x_norm)
821
+ else:
822
+ x = x + self.mlp(self.norm3(x))
823
+
824
+ return x
825
+
826
+
827
+ class LayerFusionAudioDiT(UDiT):
828
+ def __init__(
829
+ self,
830
+ img_size=224,
831
+ patch_size=16,
832
+ in_chans=3,
833
+ input_type='2d',
834
+ out_chans=None,
835
+ embed_dim=768,
836
+ depth=12,
837
+ num_heads=12,
838
+ mlp_ratio=4,
839
+ qkv_bias=False,
840
+ qk_scale=None,
841
+ qk_norm=None,
842
+ act_layer='gelu',
843
+ norm_layer='layernorm',
844
+ context_norm=False,
845
+ use_checkpoint=False,
846
+ time_fusion='token',
847
+ ada_sola_rank=None,
848
+ ada_sola_alpha=None,
849
+ cls_dim=None,
850
+ ta_context_dim=768,
851
+ ta_context_fusion='concat',
852
+ ta_context_norm=True,
853
+ context_dim=768,
854
+ context_fusion='concat',
855
+ context_max_length=128,
856
+ context_pe_method='sinu',
857
+ pe_method='abs',
858
+ rope_mode='none',
859
+ use_conv=True,
860
+ skip=True,
861
+ skip_norm=True
862
+ ):
863
+ nn.Module.__init__(self)
864
+ self.num_features = self.embed_dim = embed_dim
865
+
866
+ self.in_chans = in_chans
867
+ self.input_type = input_type
868
+ if self.input_type == '2d':
869
+ num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size)
870
+ elif self.input_type == '1d':
871
+ num_patches = img_size // patch_size
872
+ self.patch_embed = PatchEmbed(
873
+ patch_size=patch_size,
874
+ in_chans=in_chans,
875
+ embed_dim=embed_dim,
876
+ input_type=input_type
877
+ )
878
+ out_chans = in_chans if out_chans is None else out_chans
879
+ self.out_chans = out_chans
880
+
881
+ self.rope = rope_mode
882
+ self.x_pe = PE_wrapper(
883
+ dim=embed_dim, method=pe_method, length=num_patches
884
+ )
885
+
886
+ self.time_embed = TimestepEmbedder(embed_dim)
887
+ self.time_fusion = time_fusion
888
+ self.use_adanorm = False
889
+
890
+ if cls_dim is not None:
891
+ self.cls_embed = nn.Sequential(
892
+ nn.Linear(cls_dim, embed_dim, bias=True),
893
+ nn.SiLU(),
894
+ nn.Linear(embed_dim, embed_dim, bias=True),
895
+ )
896
+ else:
897
+ self.cls_embed = None
898
+
899
+ if time_fusion == 'token':
900
+ self.extras = 2 if self.cls_embed else 1
901
+ self.time_pe = PE_wrapper(
902
+ dim=embed_dim, method='abs', length=self.extras
903
+ )
904
+ elif time_fusion in ['ada', 'ada_single', 'ada_sola', 'ada_sola_bias']:
905
+ self.use_adanorm = True
906
+ self.time_act = nn.SiLU()
907
+ self.extras = 0
908
+ self.time_ada_final = nn.Linear(
909
+ embed_dim, 2 * embed_dim, bias=True
910
+ )
911
+ if time_fusion in ['ada_single', 'ada_sola', 'ada_sola_bias']:
912
+ self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True)
913
+ else:
914
+ self.time_ada = None
915
+ else:
916
+ raise NotImplementedError
917
+
918
+ self.use_context = False
919
+ self.context_cross = False
920
+ self.context_max_length = context_max_length
921
+ self.context_fusion = 'none'
922
+ if context_dim is not None:
923
+ self.use_context = True
924
+ self.context_embed = nn.Sequential(
925
+ nn.Linear(context_dim, embed_dim, bias=True),
926
+ nn.SiLU(),
927
+ nn.Linear(embed_dim, embed_dim, bias=True),
928
+ )
929
+ self.context_fusion = context_fusion
930
+ if context_fusion == 'concat' or context_fusion == 'joint':
931
+ self.extras += context_max_length
932
+ self.context_pe = PE_wrapper(
933
+ dim=embed_dim,
934
+ method=context_pe_method,
935
+ length=context_max_length
936
+ )
937
+ context_dim = None
938
+ elif context_fusion == 'cross':
939
+ self.context_pe = PE_wrapper(
940
+ dim=embed_dim,
941
+ method=context_pe_method,
942
+ length=context_max_length
943
+ )
944
+ self.context_cross = True
945
+ context_dim = embed_dim
946
+ else:
947
+ raise NotImplementedError
948
+
949
+ self.use_skip = skip
950
+
951
+ if norm_layer == 'layernorm':
952
+ norm_layer = nn.LayerNorm
953
+ elif norm_layer == 'rmsnorm':
954
+ norm_layer = RMSNorm
955
+ else:
956
+ raise NotImplementedError
957
+
958
+ self.in_blocks = nn.ModuleList([
959
+ LayerFusionDiTBlock(
960
+ dim=embed_dim,
961
+ ta_context_dim=ta_context_dim,
962
+ ta_context_fusion=ta_context_fusion,
963
+ ta_context_norm=ta_context_norm,
964
+ context_dim=context_dim,
965
+ num_heads=num_heads,
966
+ mlp_ratio=mlp_ratio,
967
+ qkv_bias=qkv_bias,
968
+ qk_scale=qk_scale,
969
+ qk_norm=qk_norm,
970
+ act_layer=act_layer,
971
+ norm_layer=norm_layer,
972
+ time_fusion=time_fusion,
973
+ ada_sola_rank=ada_sola_rank,
974
+ ada_sola_alpha=ada_sola_alpha,
975
+ skip=False,
976
+ skip_norm=False,
977
+ rope_mode=self.rope,
978
+ context_norm=context_norm,
979
+ use_checkpoint=use_checkpoint
980
+ ) for i in range(depth // 2)
981
+ ])
982
+
983
+ self.mid_block = LayerFusionDiTBlock(
984
+ dim=embed_dim,
985
+ ta_context_dim=ta_context_dim,
986
+ context_dim=context_dim,
987
+ num_heads=num_heads,
988
+ mlp_ratio=mlp_ratio,
989
+ qkv_bias=qkv_bias,
990
+ qk_scale=qk_scale,
991
+ qk_norm=qk_norm,
992
+ act_layer=act_layer,
993
+ norm_layer=norm_layer,
994
+ time_fusion=time_fusion,
995
+ ada_sola_rank=ada_sola_rank,
996
+ ada_sola_alpha=ada_sola_alpha,
997
+ ta_context_fusion=ta_context_fusion,
998
+ ta_context_norm=ta_context_norm,
999
+ skip=False,
1000
+ skip_norm=False,
1001
+ rope_mode=self.rope,
1002
+ context_norm=context_norm,
1003
+ use_checkpoint=use_checkpoint
1004
+ )
1005
+
1006
+ self.out_blocks = nn.ModuleList([
1007
+ LayerFusionDiTBlock(
1008
+ dim=embed_dim,
1009
+ ta_context_dim=ta_context_dim,
1010
+ context_dim=context_dim,
1011
+ num_heads=num_heads,
1012
+ mlp_ratio=mlp_ratio,
1013
+ qkv_bias=qkv_bias,
1014
+ qk_scale=qk_scale,
1015
+ qk_norm=qk_norm,
1016
+ act_layer=act_layer,
1017
+ norm_layer=norm_layer,
1018
+ time_fusion=time_fusion,
1019
+ ada_sola_rank=ada_sola_rank,
1020
+ ada_sola_alpha=ada_sola_alpha,
1021
+ ta_context_fusion=ta_context_fusion,
1022
+ ta_context_norm=ta_context_norm,
1023
+ skip=skip,
1024
+ skip_norm=skip_norm,
1025
+ rope_mode=self.rope,
1026
+ context_norm=context_norm,
1027
+ use_checkpoint=use_checkpoint
1028
+ ) for i in range(depth // 2)
1029
+ ])
1030
+
1031
+ self.use_conv = use_conv
1032
+ self.final_block = FinalBlock(
1033
+ embed_dim=embed_dim,
1034
+ patch_size=patch_size,
1035
+ img_size=img_size,
1036
+ in_chans=out_chans,
1037
+ input_type=input_type,
1038
+ norm_layer=norm_layer,
1039
+ use_conv=use_conv,
1040
+ use_adanorm=self.use_adanorm
1041
+ )
1042
+ self.initialize_weights()
1043
+
1044
+ def forward(
1045
+ self,
1046
+ x,
1047
+ timesteps,
1048
+ time_aligned_context,
1049
+ context,
1050
+ x_mask=None,
1051
+ context_mask=None,
1052
+ cls_token=None,
1053
+ controlnet_skips=None,
1054
+ ):
1055
+ if timesteps.dim() == 0:
1056
+ timesteps = timesteps.expand(x.shape[0]).to(x.device, dtype=torch.long)
1057
+
1058
+ x = self.patch_embed(x)
1059
+ x = self.x_pe(x)
1060
+
1061
+ B, L, D = x.shape
1062
+
1063
+ if self.use_context:
1064
+ context_token = self.context_embed(context)
1065
+ context_token = self.context_pe(context_token)
1066
+ if self.context_fusion == 'concat' or self.context_fusion == 'joint':
1067
+ x, x_mask = self._concat_x_context(
1068
+ x=x,
1069
+ context=context_token,
1070
+ x_mask=x_mask,
1071
+ context_mask=context_mask
1072
+ )
1073
+ context_token, context_mask = None, None
1074
+ else:
1075
+ context_token, context_mask = None, None
1076
+
1077
+ time_token = self.time_embed(timesteps)
1078
+ if self.cls_embed:
1079
+ cls_token = self.cls_embed(cls_token)
1080
+ time_ada = None
1081
+ time_ada_final = None
1082
+ if self.use_adanorm:
1083
+ if self.cls_embed:
1084
+ time_token = time_token + cls_token
1085
+ time_token = self.time_act(time_token)
1086
+ time_ada_final = self.time_ada_final(time_token)
1087
+ if self.time_ada is not None:
1088
+ time_ada = self.time_ada(time_token)
1089
+ else:
1090
+ time_token = time_token.unsqueeze(dim=1)
1091
+ if self.cls_embed:
1092
+ cls_token = cls_token.unsqueeze(dim=1)
1093
+ time_token = torch.cat([time_token, cls_token], dim=1)
1094
+ time_token = self.time_pe(time_token)
1095
+ x = torch.cat((time_token, x), dim=1)
1096
+ if x_mask is not None:
1097
+ x_mask = torch.cat([
1098
+ torch.ones(B, time_token.shape[1],
1099
+ device=x_mask.device).bool(), x_mask
1100
+ ], dim=1)
1101
+ time_token = None
1102
+
1103
+ skips = []
1104
+ for blk in self.in_blocks:
1105
+ x = blk(
1106
+ x=x,
1107
+ time_aligned_context=time_aligned_context,
1108
+ time_token=time_token,
1109
+ time_ada=time_ada,
1110
+ skip=None,
1111
+ context=context_token,
1112
+ x_mask=x_mask,
1113
+ context_mask=context_mask,
1114
+ extras=self.extras
1115
+ )
1116
+ if self.use_skip:
1117
+ skips.append(x)
1118
+
1119
+ x = self.mid_block(
1120
+ x=x,
1121
+ time_aligned_context=time_aligned_context,
1122
+ time_token=time_token,
1123
+ time_ada=time_ada,
1124
+ skip=None,
1125
+ context=context_token,
1126
+ x_mask=x_mask,
1127
+ context_mask=context_mask,
1128
+ extras=self.extras
1129
+ )
1130
+ for blk in self.out_blocks:
1131
+ if self.use_skip:
1132
+ skip = skips.pop()
1133
+ if controlnet_skips:
1134
+ skip = skip + controlnet_skips.pop()
1135
+ else:
1136
+ skip = None
1137
+ if controlnet_skips:
1138
+ x = x + controlnet_skips.pop()
1139
+
1140
+ x = blk(
1141
+ x=x,
1142
+ time_aligned_context=time_aligned_context,
1143
+ time_token=time_token,
1144
+ time_ada=time_ada,
1145
+ skip=skip,
1146
+ context=context_token,
1147
+ x_mask=x_mask,
1148
+ context_mask=context_mask,
1149
+ extras=self.extras
1150
+ )
1151
+
1152
+ x = self.final_block(x, time_ada=time_ada_final, extras=self.extras)
1153
+ return x
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9e0ea86afdf5e73d8de8d65ca572e6ada7c59daa122176bcb56c8d27dbd018cd
3
- size 8742180416
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e5c17670507a4d658b7650aef24a3861c4145f4386105ccbc0cba18ab9e28acd
3
+ size 8742184656
modeling_dasheng_audiogen.py ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from collections import OrderedDict
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from transformers import AutoModel, AutoTokenizer, PreTrainedModel
8
+
9
+ from .configuration_dasheng_audiogen import DashengAudioGenConfig
10
+ from .modules import * # noqa: F401,F403 — ensures HF copies this file
11
+ from .attention import * # noqa: F401,F403 — ensures HF copies this file
12
+ from .dit import LayerFusionAudioDiT
13
+ from .content_adapter import CrossAttentionAdapter, DurationPredictor
14
+ from .scheduler import FlowMatchEulerScheduler, compute_sway_sigmas, compute_linear_sigmas
15
+ from .utils import create_mask_from_length, create_alignment_path, trim_or_pad_length
16
+
17
+
18
+ # ---------------------------------------------------------------------------
19
+ # Prompt formatting
20
+ # ---------------------------------------------------------------------------
21
+
22
+ TAG_ORDER = OrderedDict([
23
+ ("caption", "<|caption|>"),
24
+ ("speech", "<|speech|>"),
25
+ ("asr", "<|asr|>"),
26
+ ("sfx", "<|sfx|>"),
27
+ ("music", "<|music|>"),
28
+ ("env", "<|env|>"),
29
+ ])
30
+
31
+
32
+ def compose_prompt(
33
+ content: str | None = None,
34
+ caption: str | None = None,
35
+ speech: str | None = None,
36
+ asr: str | None = None,
37
+ sfx: str | None = None,
38
+ music: str | None = None,
39
+ env: str | None = None,
40
+ ) -> str:
41
+ if content is not None:
42
+ content = str(content).strip()
43
+ if content:
44
+ return content
45
+
46
+ values = {
47
+ "caption": caption, "speech": speech, "asr": asr,
48
+ "sfx": sfx, "music": music, "env": env,
49
+ }
50
+ chunks: list[str] = []
51
+ for key, tag in TAG_ORDER.items():
52
+ value = values[key]
53
+ if value is not None:
54
+ value = str(value).strip()
55
+ if value:
56
+ chunks.append(f"{tag} {value}")
57
+ if not chunks:
58
+ raise ValueError(
59
+ "No prompt content provided. Pass `content` or at least one aspect field."
60
+ )
61
+ return " ".join(chunks)
62
+
63
+
64
+ # ---------------------------------------------------------------------------
65
+ # Model
66
+ # ---------------------------------------------------------------------------
67
+
68
+ def _load_text_encoder_backbone(name: str, **kwargs):
69
+ name_lower = name.lower()
70
+ if "mt5" in name_lower:
71
+ from transformers import MT5EncoderModel
72
+ return MT5EncoderModel.from_pretrained(name, **kwargs)
73
+ else:
74
+ from transformers import T5EncoderModel
75
+ return T5EncoderModel.from_pretrained(name, **kwargs)
76
+
77
+
78
+ class DashengAudioGenModel(PreTrainedModel):
79
+ config_class = DashengAudioGenConfig
80
+
81
+ _DYNAMIC_BUFFERS = {"instruction_embedding", "instruction_lengths"}
82
+
83
+ def _load_from_state_dict(
84
+ self, state_dict, prefix, local_metadata, strict,
85
+ missing_keys, unexpected_keys, error_msgs,
86
+ ):
87
+ for name in self._DYNAMIC_BUFFERS:
88
+ key = prefix + name
89
+ if key in state_dict:
90
+ self.register_buffer(name, state_dict.pop(key))
91
+ super()._load_from_state_dict(
92
+ state_dict, prefix, local_metadata, strict,
93
+ missing_keys, unexpected_keys, error_msgs,
94
+ )
95
+
96
+ def __init__(self, config: DashengAudioGenConfig):
97
+ super().__init__(config)
98
+
99
+ # -- Backbone (DiT) --
100
+ self.backbone = LayerFusionAudioDiT(
101
+ img_size=config.dit_img_size,
102
+ patch_size=config.dit_patch_size,
103
+ in_chans=config.dit_in_chans,
104
+ out_chans=config.dit_out_chans,
105
+ input_type=config.dit_input_type,
106
+ embed_dim=config.dit_embed_dim,
107
+ depth=config.dit_depth,
108
+ num_heads=config.dit_num_heads,
109
+ mlp_ratio=config.dit_mlp_ratio,
110
+ qkv_bias=False,
111
+ qk_scale=None,
112
+ qk_norm=config.dit_qk_norm,
113
+ norm_layer=config.dit_norm_layer,
114
+ act_layer=config.dit_act_layer,
115
+ context_norm=config.dit_context_norm,
116
+ use_checkpoint=False,
117
+ time_fusion=config.dit_time_fusion,
118
+ ada_sola_rank=config.dit_ada_sola_rank,
119
+ ada_sola_alpha=config.dit_ada_sola_alpha,
120
+ cls_dim=None,
121
+ ta_context_dim=config.dit_ta_context_dim,
122
+ ta_context_fusion=config.dit_ta_context_fusion,
123
+ ta_context_norm=config.dit_ta_context_norm,
124
+ context_dim=config.dit_context_dim,
125
+ context_fusion=config.dit_context_fusion,
126
+ context_max_length=None,
127
+ context_pe_method=config.dit_context_pe_method,
128
+ pe_method=config.dit_pe_method,
129
+ rope_mode=config.dit_rope_mode,
130
+ use_conv=True,
131
+ skip=True,
132
+ skip_norm=True,
133
+ )
134
+
135
+ # -- Content adapter --
136
+ duration_predictor = DurationPredictor(
137
+ in_channels=config.content_dim,
138
+ filter_channels=config.duration_predictor_filter_channels,
139
+ n_layers=config.duration_predictor_n_layers,
140
+ kernel_size=config.duration_predictor_kernel_size,
141
+ p_dropout=config.duration_predictor_p_dropout,
142
+ )
143
+ self.content_adapter = CrossAttentionAdapter(
144
+ d_out=config.content_dim,
145
+ content_dim=config.content_dim,
146
+ prefix_dim=config.task_instruction_dim,
147
+ num_heads=config.adapter_num_heads,
148
+ duration_predictor=duration_predictor,
149
+ dropout=config.adapter_dropout,
150
+ duration_grad_scale=config.adapter_duration_grad_scale,
151
+ )
152
+
153
+ # -- Content encoder projection (matches safetensors key path) --
154
+ _text_enc = nn.Module()
155
+ _text_enc.proj = nn.Linear(config.content_dim, config.content_dim)
156
+ if config.special_tokens:
157
+ _text_enc.special_token_embedding = nn.Embedding(
158
+ len(config.special_tokens), config.content_dim
159
+ )
160
+ _content_enc = nn.Module()
161
+ _content_enc.text_encoder = _text_enc
162
+ self.content_encoder = _content_enc
163
+
164
+ # -- Dummy parameters (match safetensors keys) --
165
+ self.dummy_param = nn.Parameter(torch.empty(0))
166
+ self.dummy_nta_embed = nn.Parameter(torch.zeros(config.content_dim))
167
+ self.dummy_ta_embed = nn.Parameter(torch.zeros(config.content_dim))
168
+
169
+ # -- Instruction embedding (actual values loaded from safetensors) --
170
+ self.register_buffer(
171
+ "instruction_embedding",
172
+ torch.zeros(1, 1, config.task_instruction_dim),
173
+ )
174
+ self.register_buffer(
175
+ "instruction_lengths",
176
+ torch.ones(1, dtype=torch.long),
177
+ )
178
+
179
+ # -- Scheduler --
180
+ self.scheduler = FlowMatchEulerScheduler()
181
+
182
+ # -- Derived constants --
183
+ self.latent_token_rate = config.sample_rate // config.downsampling_ratio
184
+
185
+ # External models are loaded AFTER weight loading in from_pretrained
186
+ self.text_encoder_backbone = None
187
+ self.text_tokenizer = None
188
+ self.audio_tokenizer = None
189
+ self._special_token_ids = []
190
+ self._special_token_id_to_index = {}
191
+
192
+ self.post_init()
193
+
194
+ def _load_external_models(self, model_dir: str | None = None, **kwargs):
195
+ self.text_encoder_backbone = _load_text_encoder_backbone(
196
+ self.config.text_encoder_name, **kwargs
197
+ )
198
+ self.text_encoder_backbone.eval()
199
+ for p in self.text_encoder_backbone.parameters():
200
+ p.requires_grad = False
201
+
202
+ import os
203
+ tokenizer_local = (
204
+ model_dir
205
+ if model_dir and os.path.isfile(os.path.join(model_dir, "tokenizer.json"))
206
+ else None
207
+ )
208
+ self.text_tokenizer = AutoTokenizer.from_pretrained(
209
+ tokenizer_local or self.config.text_encoder_name, **kwargs
210
+ )
211
+ if self.config.special_tokens:
212
+ self.text_tokenizer.add_special_tokens(
213
+ {"additional_special_tokens": self.config.special_tokens}
214
+ )
215
+ old_vocab = self.text_encoder_backbone.get_input_embeddings().num_embeddings
216
+ new_vocab = len(self.text_tokenizer)
217
+ if new_vocab != old_vocab:
218
+ self.text_encoder_backbone.resize_token_embeddings(new_vocab)
219
+ self._special_token_ids = [
220
+ self.text_tokenizer.convert_tokens_to_ids(t)
221
+ for t in self.config.special_tokens
222
+ ]
223
+ self._special_token_id_to_index = {
224
+ tid: idx for idx, tid in enumerate(self._special_token_ids)
225
+ }
226
+
227
+ self.audio_tokenizer = AutoModel.from_pretrained(
228
+ self.config.tokenizer_name, trust_remote_code=True, **kwargs
229
+ )
230
+ self.audio_tokenizer.eval()
231
+ for p in self.audio_tokenizer.parameters():
232
+ p.requires_grad = False
233
+
234
+ def _load_dynamic_buffers(self, model_dir: str):
235
+ import os
236
+ from safetensors.torch import load_file
237
+ sf_path = os.path.join(model_dir, "model.safetensors")
238
+ if not os.path.isfile(sf_path):
239
+ return
240
+ state = load_file(sf_path)
241
+ for name in self._DYNAMIC_BUFFERS:
242
+ if name in state:
243
+ self.register_buffer(name, state[name])
244
+
245
+ @classmethod
246
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
247
+ kwargs.setdefault("ignore_mismatched_sizes", True)
248
+ model = super().from_pretrained(
249
+ pretrained_model_name_or_path, *model_args, **kwargs
250
+ )
251
+ model._load_dynamic_buffers(str(pretrained_model_name_or_path))
252
+ ext_kwargs = {}
253
+ if kwargs.get("local_files_only"):
254
+ ext_kwargs["local_files_only"] = True
255
+ model._load_external_models(
256
+ model_dir=str(pretrained_model_name_or_path), **ext_kwargs
257
+ )
258
+ return model
259
+
260
+ @staticmethod
261
+ def compose_prompt(
262
+ content: str | None = None,
263
+ caption: str | None = None,
264
+ speech: str | None = None,
265
+ asr: str | None = None,
266
+ sfx: str | None = None,
267
+ music: str | None = None,
268
+ env: str | None = None,
269
+ ) -> str:
270
+ return compose_prompt(
271
+ content=content, caption=caption, speech=speech,
272
+ asr=asr, sfx=sfx, music=music, env=env,
273
+ )
274
+
275
+ # ------------------------------------------------------------------
276
+ # Text encoding
277
+ # ------------------------------------------------------------------
278
+
279
+ def _get_model_inputs(self, input_ids: torch.Tensor):
280
+ if not self._special_token_ids:
281
+ return {"input_ids": input_ids}
282
+ special_emb = self.content_encoder.text_encoder.special_token_embedding
283
+ input_embeds = self.text_encoder_backbone.get_input_embeddings()(input_ids)
284
+ for token_id, token_idx in self._special_token_id_to_index.items():
285
+ mask = input_ids == token_id
286
+ if mask.any():
287
+ input_embeds[mask] = special_emb.weight[token_idx].to(
288
+ input_embeds.dtype
289
+ )
290
+ return {"inputs_embeds": input_embeds}
291
+
292
+ @torch.no_grad()
293
+ def encode_text(self, prompts: list[str]) -> tuple[torch.Tensor, torch.Tensor]:
294
+ device = self.dummy_param.device
295
+ batch = self.text_tokenizer(
296
+ prompts,
297
+ max_length=self.config.tokenizer_max_length,
298
+ padding=True,
299
+ truncation=True,
300
+ return_tensors="pt",
301
+ )
302
+ input_ids = batch.input_ids.to(device)
303
+ attention_mask = batch.attention_mask.to(device)
304
+ model_inputs = self._get_model_inputs(input_ids)
305
+ output = self.text_encoder_backbone(
306
+ **model_inputs, attention_mask=attention_mask
307
+ ).last_hidden_state
308
+ content = self.content_encoder.text_encoder.proj(output)
309
+ content_mask = attention_mask.bool()
310
+ return content, content_mask
311
+
312
+ # ------------------------------------------------------------------
313
+ # Duration helpers
314
+ # ------------------------------------------------------------------
315
+
316
+ def _prepare_local_duration(
317
+ self, pred: torch.Tensor, mask: torch.Tensor
318
+ ) -> torch.Tensor:
319
+ pred = torch.exp(pred) * mask
320
+ pred = torch.ceil(pred) - self.config.duration_offset
321
+ pred *= self.config.frame_resolution
322
+ pred = torch.round(pred * self.latent_token_rate)
323
+ return pred
324
+
325
+ def _prepare_global_duration(
326
+ self,
327
+ global_pred: torch.Tensor,
328
+ local_pred: torch.Tensor,
329
+ is_time_aligned: torch.Tensor,
330
+ ) -> torch.Tensor:
331
+ global_pred = torch.exp(global_pred) - self.config.duration_offset
332
+ result = torch.round(global_pred * self.latent_token_rate)
333
+ pred_from_local = local_pred.sum(1)
334
+ result[is_time_aligned] = pred_from_local[is_time_aligned]
335
+ return result.long()
336
+
337
+ def _expand_by_duration(
338
+ self,
339
+ x: torch.Tensor,
340
+ content_mask: torch.Tensor,
341
+ local_duration: torch.Tensor,
342
+ global_duration: torch.Tensor,
343
+ ) -> tuple[torch.Tensor, torch.Tensor]:
344
+ latent_length = global_duration
345
+ latent_mask = create_mask_from_length(latent_length).to(
346
+ content_mask.device
347
+ )
348
+ attn_mask = content_mask.unsqueeze(-1) * latent_mask.unsqueeze(1)
349
+ align_path = create_alignment_path(local_duration, attn_mask)
350
+ expanded_x = torch.matmul(
351
+ align_path.transpose(1, 2).to(x.dtype), x
352
+ )
353
+ return expanded_x, latent_mask
354
+
355
+ def _get_backbone_input(
356
+ self,
357
+ target_length: int,
358
+ content: torch.Tensor,
359
+ content_mask: torch.Tensor,
360
+ time_aligned_content: torch.Tensor,
361
+ is_time_aligned: torch.Tensor,
362
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
363
+ time_aligned_content = trim_or_pad_length(
364
+ time_aligned_content, target_length, 1
365
+ )
366
+ # For text_to_audio: length_aligned_content is zeros, so skip addition
367
+ # Replace non-time-aligned samples with dummy
368
+ time_aligned_content[~is_time_aligned] = self.dummy_ta_embed.to(
369
+ time_aligned_content.dtype
370
+ )
371
+
372
+ context = content.clone()
373
+ context[is_time_aligned] = self.dummy_nta_embed.to(context.dtype)
374
+ context_mask = content_mask.detach().clone()
375
+ context_mask[is_time_aligned, 1:] = False
376
+
377
+ if is_time_aligned.sum().item() < content.size(0):
378
+ trunc_nta_length = int(
379
+ content_mask[~is_time_aligned].sum(1).max().item()
380
+ )
381
+ else:
382
+ trunc_nta_length = content.size(1)
383
+ context = context[:, :trunc_nta_length]
384
+ context_mask = context_mask[:, :trunc_nta_length]
385
+
386
+ return context, context_mask, time_aligned_content
387
+
388
+ # ------------------------------------------------------------------
389
+ # Denoising loop
390
+ # ------------------------------------------------------------------
391
+
392
+ def _iterative_denoise(
393
+ self,
394
+ latent: torch.Tensor,
395
+ timesteps: torch.Tensor,
396
+ cfg: bool,
397
+ cfg_scale: float,
398
+ backbone_input: dict,
399
+ ) -> torch.Tensor:
400
+ for timestep in timesteps:
401
+ if cfg:
402
+ latent_input = torch.cat([latent, latent])
403
+ else:
404
+ latent_input = latent
405
+
406
+ noise_pred: torch.Tensor = self.backbone(
407
+ x=latent_input, timesteps=timestep, **backbone_input
408
+ )
409
+
410
+ if cfg:
411
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
412
+ noise_pred = noise_pred_uncond + cfg_scale * (
413
+ noise_pred_cond - noise_pred_uncond
414
+ )
415
+
416
+ latent = self.scheduler.step(
417
+ noise_pred, timestep, latent
418
+ ).prev_sample
419
+
420
+ return latent
421
+
422
+ # ------------------------------------------------------------------
423
+ # Main generation entry point
424
+ # ------------------------------------------------------------------
425
+
426
+ @torch.inference_mode()
427
+ def generate(
428
+ self,
429
+ prompts: str | list[str],
430
+ num_steps: int = 25,
431
+ guidance_scale: float = 5.0,
432
+ sway_sampling_coef: float = -1.0,
433
+ ) -> torch.Tensor:
434
+ if isinstance(prompts, str):
435
+ prompts = [prompts]
436
+
437
+ device = self.dummy_param.device
438
+ batch_size = len(prompts)
439
+ classifier_free_guidance = guidance_scale > 1.0
440
+
441
+ # 1. Encode text
442
+ content, content_mask = self.encode_text(prompts)
443
+
444
+ # 2. Get instruction embedding
445
+ if self.config.use_zero_instruction:
446
+ instruction = torch.zeros(
447
+ 1, 1, self.config.task_instruction_dim,
448
+ device=device, dtype=content.dtype,
449
+ ).expand(batch_size, -1, -1)
450
+ instruction_lengths = torch.ones(
451
+ batch_size, device=device, dtype=torch.long
452
+ )
453
+ else:
454
+ instruction = self.instruction_embedding.to(content.dtype).expand(
455
+ batch_size, -1, -1
456
+ )
457
+ instruction_lengths = self.instruction_lengths.expand(batch_size)
458
+
459
+ # 3. Content adapter
460
+ instruction_mask = create_mask_from_length(
461
+ instruction_lengths, max_length=instruction.size(1)
462
+ ).to(device)
463
+ (
464
+ content, content_mask, global_duration_pred, local_duration_pred,
465
+ ) = self.content_adapter(
466
+ content, content_mask, instruction, instruction_mask
467
+ )
468
+
469
+ # 4. Duration
470
+ is_time_aligned = torch.zeros(
471
+ batch_size, dtype=torch.bool, device=device
472
+ )
473
+
474
+ local_latent_duration = self._prepare_local_duration(
475
+ local_duration_pred, content_mask
476
+ )
477
+ global_latent_duration = self._prepare_global_duration(
478
+ global_duration_pred, local_latent_duration, is_time_aligned
479
+ )
480
+
481
+ time_aligned_content, latent_mask = self._expand_by_duration(
482
+ x=content,
483
+ content_mask=content_mask,
484
+ local_duration=local_latent_duration,
485
+ global_duration=global_latent_duration,
486
+ )
487
+
488
+ # 5. Prepare backbone input
489
+ context, context_mask, time_aligned_content = self._get_backbone_input(
490
+ target_length=time_aligned_content.size(1),
491
+ content=content,
492
+ content_mask=content_mask,
493
+ time_aligned_content=time_aligned_content,
494
+ is_time_aligned=is_time_aligned,
495
+ )
496
+
497
+ # 6. CFG: duplicate with unconditional
498
+ if classifier_free_guidance:
499
+ time_aligned_content = torch.cat([
500
+ torch.zeros_like(time_aligned_content),
501
+ time_aligned_content,
502
+ ])
503
+ context = torch.cat([
504
+ torch.zeros_like(context), context
505
+ ])
506
+ context_mask = torch.cat([
507
+ context_mask.detach().clone(), context_mask
508
+ ])
509
+ latent_mask = torch.cat([
510
+ latent_mask.detach().clone(), latent_mask
511
+ ])
512
+
513
+ # 7. Prepare latent noise
514
+ latent_length = int(latent_mask.sum(1).max().item())
515
+ latent = torch.randn(
516
+ batch_size, self.config.latent_dim, latent_length,
517
+ device=device, dtype=content.dtype,
518
+ )
519
+
520
+ # 8. Sigmas schedule
521
+ if sway_sampling_coef:
522
+ sigmas = compute_sway_sigmas(num_steps, sway_sampling_coef)
523
+ else:
524
+ sigmas = compute_linear_sigmas(num_steps)
525
+ self.scheduler.set_timesteps(sigmas, device=device)
526
+ timesteps = self.scheduler.timesteps
527
+
528
+ # 9. Denoise
529
+ latent = self._iterative_denoise(
530
+ latent=latent,
531
+ timesteps=timesteps,
532
+ cfg=classifier_free_guidance,
533
+ cfg_scale=guidance_scale,
534
+ backbone_input={
535
+ "x_mask": latent_mask,
536
+ "context": context,
537
+ "context_mask": context_mask,
538
+ "time_aligned_context": time_aligned_content,
539
+ },
540
+ )
541
+
542
+ # 10. Decode to waveform
543
+ waveform = self.audio_tokenizer.decode(
544
+ latent.transpose(1, 2)
545
+ )
546
+ if waveform.dim() == 3:
547
+ waveform = waveform.squeeze(1)
548
+
549
+ return waveform
modules.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import einops
8
+ from einops import rearrange
9
+
10
+
11
+ def trunc_normal_(tensor, mean, std, a, b):
12
+ def norm_cdf(x):
13
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
14
+
15
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
16
+ warnings.warn(
17
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
18
+ "The distribution of values may be incorrect.",
19
+ stacklevel=2,
20
+ )
21
+
22
+ with torch.no_grad():
23
+ l = norm_cdf((a - mean) / std)
24
+ u = norm_cdf((b - mean) / std)
25
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
26
+ tensor.erfinv_()
27
+ tensor.mul_(std * math.sqrt(2.0))
28
+ tensor.add_(mean)
29
+ tensor.clamp_(min=a, max=b)
30
+ return tensor
31
+
32
+
33
+ def film_modulate(x, shift, scale):
34
+ return x * (1 + scale) + shift
35
+
36
+
37
+ def timestep_embedding(timesteps, dim, max_period=10000):
38
+ half = dim // 2
39
+ freqs = torch.exp(
40
+ -math.log(max_period)
41
+ * torch.arange(start=0, end=half, dtype=torch.float32)
42
+ / half
43
+ ).to(device=timesteps.device)
44
+ args = timesteps[:, None].float() * freqs[None]
45
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
46
+ if dim % 2:
47
+ embedding = torch.cat(
48
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
49
+ )
50
+ return embedding
51
+
52
+
53
+ def unpatchify(x, channels=3, input_type="2d", img_size=None):
54
+ if input_type == "2d":
55
+ patch_size = int((x.shape[2] // channels) ** 0.5)
56
+ h, w = img_size[0] // patch_size, img_size[1] // patch_size
57
+ x = rearrange(
58
+ x,
59
+ "B (h w) (p1 p2 C) -> B C (h p1) (w p2)",
60
+ h=h,
61
+ p1=patch_size,
62
+ p2=patch_size,
63
+ )
64
+ elif input_type == "1d":
65
+ patch_size = int(x.shape[2] // channels)
66
+ h = x.shape[1]
67
+ x = rearrange(x, "B h (p1 C) -> B C (h p1)", h=h, p1=patch_size)
68
+ return x
69
+
70
+
71
+ class TimestepEmbedder(nn.Module):
72
+ def __init__(self, hidden_size, frequency_embedding_size=256, out_size=None):
73
+ super().__init__()
74
+ if out_size is None:
75
+ out_size = hidden_size
76
+ self.mlp = nn.Sequential(
77
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
78
+ nn.SiLU(),
79
+ nn.Linear(hidden_size, out_size, bias=True),
80
+ )
81
+ self.frequency_embedding_size = frequency_embedding_size
82
+
83
+ def forward(self, t):
84
+ t_freq = timestep_embedding(t, self.frequency_embedding_size).type(
85
+ self.mlp[0].weight.dtype
86
+ )
87
+ t_emb = self.mlp(t_freq)
88
+ return t_emb
89
+
90
+
91
+ class PatchEmbed(nn.Module):
92
+ def __init__(self, patch_size, in_chans=3, embed_dim=768, input_type="2d"):
93
+ super().__init__()
94
+ self.patch_size = patch_size
95
+ self.input_type = input_type
96
+ if input_type == "2d":
97
+ self.proj = nn.Conv2d(
98
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True
99
+ )
100
+ elif input_type == "1d":
101
+ self.proj = nn.Conv1d(
102
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True
103
+ )
104
+
105
+ def forward(self, x):
106
+ x = self.proj(x).flatten(2).transpose(1, 2)
107
+ return x
108
+
109
+
110
+ class PE_wrapper(nn.Module):
111
+ def __init__(self, dim=768, method="abs", length=None, **kwargs):
112
+ super().__init__()
113
+ self.method = method
114
+ if method == "abs":
115
+ self.length = length
116
+ self.abs_pe = nn.Parameter(torch.zeros(1, length, dim))
117
+ trunc_normal_(self.abs_pe, mean=0.0, std=0.02, a=-0.04, b=0.04)
118
+ elif method == "none":
119
+ self.id = nn.Identity()
120
+ else:
121
+ raise NotImplementedError
122
+
123
+ def forward(self, x):
124
+ if self.method == "abs":
125
+ _, L, _ = x.shape
126
+ assert L <= self.length
127
+ x = x + self.abs_pe[:, :L, :]
128
+ elif self.method == "none":
129
+ x = self.id(x)
130
+ return x
131
+
132
+
133
+ class RMSNorm(nn.Module):
134
+ def __init__(self, dim: int, eps: float = 1e-6):
135
+ super().__init__()
136
+ self.eps = eps
137
+ self.weight = nn.Parameter(torch.ones(dim))
138
+
139
+ def _norm(self, x):
140
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
141
+
142
+ def forward(self, x):
143
+ output = self._norm(x.float()).type_as(x)
144
+ return output * self.weight
145
+
146
+
147
+ class GELU(nn.Module):
148
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
149
+ super().__init__()
150
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
151
+ self.approximate = approximate
152
+
153
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
154
+ if gate.device.type != "mps":
155
+ return F.gelu(gate, approximate=self.approximate)
156
+ return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(
157
+ dtype=gate.dtype
158
+ )
159
+
160
+ def forward(self, hidden_states):
161
+ hidden_states = self.proj(hidden_states)
162
+ hidden_states = self.gelu(hidden_states)
163
+ return hidden_states
164
+
165
+
166
+ class GEGLU(nn.Module):
167
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
168
+ super().__init__()
169
+ self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
170
+
171
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
172
+ if gate.device.type != "mps":
173
+ return F.gelu(gate)
174
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
175
+
176
+ def forward(self, hidden_states):
177
+ hidden_states = self.proj(hidden_states)
178
+ hidden_states, gate = hidden_states.chunk(2, dim=-1)
179
+ return hidden_states * self.gelu(gate)
180
+
181
+
182
+ class FeedForward(nn.Module):
183
+ def __init__(
184
+ self,
185
+ dim,
186
+ dim_out=None,
187
+ mult=4,
188
+ dropout=0.0,
189
+ activation_fn="geglu",
190
+ final_dropout=False,
191
+ inner_dim=None,
192
+ bias=True,
193
+ ):
194
+ super().__init__()
195
+ if inner_dim is None:
196
+ inner_dim = int(dim * mult)
197
+ dim_out = dim_out if dim_out is not None else dim
198
+
199
+ if activation_fn == "gelu":
200
+ act_fn = GELU(dim, inner_dim, bias=bias)
201
+ elif activation_fn == "gelu-approximate":
202
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
203
+ elif activation_fn == "geglu":
204
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
205
+ else:
206
+ raise NotImplementedError
207
+
208
+ self.net = nn.ModuleList([])
209
+ self.net.append(act_fn)
210
+ self.net.append(nn.Dropout(dropout))
211
+ self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
212
+ if final_dropout:
213
+ self.net.append(nn.Dropout(dropout))
214
+
215
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
216
+ for module in self.net:
217
+ hidden_states = module(hidden_states)
218
+ return hidden_states
scheduler.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+
6
+
7
+ @dataclass
8
+ class SchedulerOutput:
9
+ prev_sample: torch.FloatTensor
10
+
11
+
12
+ class FlowMatchEulerScheduler:
13
+
14
+ def __init__(self, num_train_timesteps: int = 1000):
15
+ self.num_train_timesteps = num_train_timesteps
16
+ self.sigmas = None
17
+ self.timesteps = None
18
+ self._step_index = None
19
+
20
+ def set_timesteps(self, sigmas, device):
21
+ if isinstance(sigmas, (list, tuple)):
22
+ sigmas = torch.tensor(sigmas, dtype=torch.float32)
23
+ elif not isinstance(sigmas, torch.Tensor):
24
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
25
+
26
+ sigmas = sigmas.to(device=device)
27
+ self.timesteps = sigmas * self.num_train_timesteps
28
+ self.sigmas = torch.cat([sigmas, torch.zeros(1, device=device)])
29
+ self._step_index = None
30
+
31
+ def step(
32
+ self,
33
+ model_output: torch.FloatTensor,
34
+ timestep: torch.FloatTensor,
35
+ sample: torch.FloatTensor,
36
+ ) -> SchedulerOutput:
37
+ if self._step_index is None:
38
+ self._step_index = (self.timesteps == timestep).nonzero()
39
+ self._step_index = 0 if self._step_index.numel() == 0 else self._step_index[0].item()
40
+
41
+ sample = sample.to(torch.float32)
42
+
43
+ sigma = self.sigmas[self._step_index]
44
+ sigma_next = self.sigmas[self._step_index + 1]
45
+
46
+ prev_sample = sample + (sigma_next - sigma) * model_output
47
+ prev_sample = prev_sample.to(model_output.dtype)
48
+
49
+ self._step_index += 1
50
+ return SchedulerOutput(prev_sample=prev_sample)
51
+
52
+
53
+ def compute_sway_sigmas(num_steps: int, sway_sampling_coef: float = -1.0):
54
+ t = torch.linspace(0, 1, num_steps + 1)
55
+ t = t + sway_sampling_coef * (torch.cos(math.pi / 2.0 * t) - 1.0 + t)
56
+ sigmas = 1.0 - t
57
+ return sigmas
58
+
59
+
60
+ def compute_linear_sigmas(num_steps: int):
61
+ return torch.linspace(1.0, 1.0 / num_steps, num_steps)
special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "eos_token": {
3
+ "content": "</s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "pad_token": {
10
+ "content": "<pad>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "<unk>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef78f86560d809067d12bac6c09f19a462cb3af3f54d2b8acbba26e1433125d6
3
+ size 4309802
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65c2d7defb6472fada8a935bb364ae3433f7451780c8a59ab6b3cfbaadb32608
3
+ size 16349930
tokenizer_config.json ADDED
@@ -0,0 +1,840 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": null,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "<pad>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "1": {
13
+ "content": "</s>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "2": {
21
+ "content": "<unk>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "250000": {
29
+ "content": "▁<extra_id_99>",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": false
35
+ },
36
+ "250001": {
37
+ "content": "▁<extra_id_98>",
38
+ "lstrip": false,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": false
43
+ },
44
+ "250002": {
45
+ "content": "▁<extra_id_97>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false,
50
+ "special": false
51
+ },
52
+ "250003": {
53
+ "content": "▁<extra_id_96>",
54
+ "lstrip": false,
55
+ "normalized": false,
56
+ "rstrip": false,
57
+ "single_word": false,
58
+ "special": false
59
+ },
60
+ "250004": {
61
+ "content": "▁<extra_id_95>",
62
+ "lstrip": false,
63
+ "normalized": false,
64
+ "rstrip": false,
65
+ "single_word": false,
66
+ "special": false
67
+ },
68
+ "250005": {
69
+ "content": "▁<extra_id_94>",
70
+ "lstrip": false,
71
+ "normalized": false,
72
+ "rstrip": false,
73
+ "single_word": false,
74
+ "special": false
75
+ },
76
+ "250006": {
77
+ "content": "▁<extra_id_93>",
78
+ "lstrip": false,
79
+ "normalized": false,
80
+ "rstrip": false,
81
+ "single_word": false,
82
+ "special": false
83
+ },
84
+ "250007": {
85
+ "content": "▁<extra_id_92>",
86
+ "lstrip": false,
87
+ "normalized": false,
88
+ "rstrip": false,
89
+ "single_word": false,
90
+ "special": false
91
+ },
92
+ "250008": {
93
+ "content": "▁<extra_id_91>",
94
+ "lstrip": false,
95
+ "normalized": false,
96
+ "rstrip": false,
97
+ "single_word": false,
98
+ "special": false
99
+ },
100
+ "250009": {
101
+ "content": "▁<extra_id_90>",
102
+ "lstrip": false,
103
+ "normalized": false,
104
+ "rstrip": false,
105
+ "single_word": false,
106
+ "special": false
107
+ },
108
+ "250010": {
109
+ "content": "▁<extra_id_89>",
110
+ "lstrip": false,
111
+ "normalized": false,
112
+ "rstrip": false,
113
+ "single_word": false,
114
+ "special": false
115
+ },
116
+ "250011": {
117
+ "content": "▁<extra_id_88>",
118
+ "lstrip": false,
119
+ "normalized": false,
120
+ "rstrip": false,
121
+ "single_word": false,
122
+ "special": false
123
+ },
124
+ "250012": {
125
+ "content": "▁<extra_id_87>",
126
+ "lstrip": false,
127
+ "normalized": false,
128
+ "rstrip": false,
129
+ "single_word": false,
130
+ "special": false
131
+ },
132
+ "250013": {
133
+ "content": "▁<extra_id_86>",
134
+ "lstrip": false,
135
+ "normalized": false,
136
+ "rstrip": false,
137
+ "single_word": false,
138
+ "special": false
139
+ },
140
+ "250014": {
141
+ "content": "▁<extra_id_85>",
142
+ "lstrip": false,
143
+ "normalized": false,
144
+ "rstrip": false,
145
+ "single_word": false,
146
+ "special": false
147
+ },
148
+ "250015": {
149
+ "content": "▁<extra_id_84>",
150
+ "lstrip": false,
151
+ "normalized": false,
152
+ "rstrip": false,
153
+ "single_word": false,
154
+ "special": false
155
+ },
156
+ "250016": {
157
+ "content": "▁<extra_id_83>",
158
+ "lstrip": false,
159
+ "normalized": false,
160
+ "rstrip": false,
161
+ "single_word": false,
162
+ "special": false
163
+ },
164
+ "250017": {
165
+ "content": "▁<extra_id_82>",
166
+ "lstrip": false,
167
+ "normalized": false,
168
+ "rstrip": false,
169
+ "single_word": false,
170
+ "special": false
171
+ },
172
+ "250018": {
173
+ "content": "▁<extra_id_81>",
174
+ "lstrip": false,
175
+ "normalized": false,
176
+ "rstrip": false,
177
+ "single_word": false,
178
+ "special": false
179
+ },
180
+ "250019": {
181
+ "content": "▁<extra_id_80>",
182
+ "lstrip": false,
183
+ "normalized": false,
184
+ "rstrip": false,
185
+ "single_word": false,
186
+ "special": false
187
+ },
188
+ "250020": {
189
+ "content": "▁<extra_id_79>",
190
+ "lstrip": false,
191
+ "normalized": false,
192
+ "rstrip": false,
193
+ "single_word": false,
194
+ "special": false
195
+ },
196
+ "250021": {
197
+ "content": "▁<extra_id_78>",
198
+ "lstrip": false,
199
+ "normalized": false,
200
+ "rstrip": false,
201
+ "single_word": false,
202
+ "special": false
203
+ },
204
+ "250022": {
205
+ "content": "▁<extra_id_77>",
206
+ "lstrip": false,
207
+ "normalized": false,
208
+ "rstrip": false,
209
+ "single_word": false,
210
+ "special": false
211
+ },
212
+ "250023": {
213
+ "content": "▁<extra_id_76>",
214
+ "lstrip": false,
215
+ "normalized": false,
216
+ "rstrip": false,
217
+ "single_word": false,
218
+ "special": false
219
+ },
220
+ "250024": {
221
+ "content": "▁<extra_id_75>",
222
+ "lstrip": false,
223
+ "normalized": false,
224
+ "rstrip": false,
225
+ "single_word": false,
226
+ "special": false
227
+ },
228
+ "250025": {
229
+ "content": "▁<extra_id_74>",
230
+ "lstrip": false,
231
+ "normalized": false,
232
+ "rstrip": false,
233
+ "single_word": false,
234
+ "special": false
235
+ },
236
+ "250026": {
237
+ "content": "▁<extra_id_73>",
238
+ "lstrip": false,
239
+ "normalized": false,
240
+ "rstrip": false,
241
+ "single_word": false,
242
+ "special": false
243
+ },
244
+ "250027": {
245
+ "content": "▁<extra_id_72>",
246
+ "lstrip": false,
247
+ "normalized": false,
248
+ "rstrip": false,
249
+ "single_word": false,
250
+ "special": false
251
+ },
252
+ "250028": {
253
+ "content": "▁<extra_id_71>",
254
+ "lstrip": false,
255
+ "normalized": false,
256
+ "rstrip": false,
257
+ "single_word": false,
258
+ "special": false
259
+ },
260
+ "250029": {
261
+ "content": "▁<extra_id_70>",
262
+ "lstrip": false,
263
+ "normalized": false,
264
+ "rstrip": false,
265
+ "single_word": false,
266
+ "special": false
267
+ },
268
+ "250030": {
269
+ "content": "▁<extra_id_69>",
270
+ "lstrip": false,
271
+ "normalized": false,
272
+ "rstrip": false,
273
+ "single_word": false,
274
+ "special": false
275
+ },
276
+ "250031": {
277
+ "content": "▁<extra_id_68>",
278
+ "lstrip": false,
279
+ "normalized": false,
280
+ "rstrip": false,
281
+ "single_word": false,
282
+ "special": false
283
+ },
284
+ "250032": {
285
+ "content": "▁<extra_id_67>",
286
+ "lstrip": false,
287
+ "normalized": false,
288
+ "rstrip": false,
289
+ "single_word": false,
290
+ "special": false
291
+ },
292
+ "250033": {
293
+ "content": "▁<extra_id_66>",
294
+ "lstrip": false,
295
+ "normalized": false,
296
+ "rstrip": false,
297
+ "single_word": false,
298
+ "special": false
299
+ },
300
+ "250034": {
301
+ "content": "▁<extra_id_65>",
302
+ "lstrip": false,
303
+ "normalized": false,
304
+ "rstrip": false,
305
+ "single_word": false,
306
+ "special": false
307
+ },
308
+ "250035": {
309
+ "content": "▁<extra_id_64>",
310
+ "lstrip": false,
311
+ "normalized": false,
312
+ "rstrip": false,
313
+ "single_word": false,
314
+ "special": false
315
+ },
316
+ "250036": {
317
+ "content": "▁<extra_id_63>",
318
+ "lstrip": false,
319
+ "normalized": false,
320
+ "rstrip": false,
321
+ "single_word": false,
322
+ "special": false
323
+ },
324
+ "250037": {
325
+ "content": "▁<extra_id_62>",
326
+ "lstrip": false,
327
+ "normalized": false,
328
+ "rstrip": false,
329
+ "single_word": false,
330
+ "special": false
331
+ },
332
+ "250038": {
333
+ "content": "▁<extra_id_61>",
334
+ "lstrip": false,
335
+ "normalized": false,
336
+ "rstrip": false,
337
+ "single_word": false,
338
+ "special": false
339
+ },
340
+ "250039": {
341
+ "content": "▁<extra_id_60>",
342
+ "lstrip": false,
343
+ "normalized": false,
344
+ "rstrip": false,
345
+ "single_word": false,
346
+ "special": false
347
+ },
348
+ "250040": {
349
+ "content": "▁<extra_id_59>",
350
+ "lstrip": false,
351
+ "normalized": false,
352
+ "rstrip": false,
353
+ "single_word": false,
354
+ "special": false
355
+ },
356
+ "250041": {
357
+ "content": "▁<extra_id_58>",
358
+ "lstrip": false,
359
+ "normalized": false,
360
+ "rstrip": false,
361
+ "single_word": false,
362
+ "special": false
363
+ },
364
+ "250042": {
365
+ "content": "▁<extra_id_57>",
366
+ "lstrip": false,
367
+ "normalized": false,
368
+ "rstrip": false,
369
+ "single_word": false,
370
+ "special": false
371
+ },
372
+ "250043": {
373
+ "content": "▁<extra_id_56>",
374
+ "lstrip": false,
375
+ "normalized": false,
376
+ "rstrip": false,
377
+ "single_word": false,
378
+ "special": false
379
+ },
380
+ "250044": {
381
+ "content": "▁<extra_id_55>",
382
+ "lstrip": false,
383
+ "normalized": false,
384
+ "rstrip": false,
385
+ "single_word": false,
386
+ "special": false
387
+ },
388
+ "250045": {
389
+ "content": "▁<extra_id_54>",
390
+ "lstrip": false,
391
+ "normalized": false,
392
+ "rstrip": false,
393
+ "single_word": false,
394
+ "special": false
395
+ },
396
+ "250046": {
397
+ "content": "▁<extra_id_53>",
398
+ "lstrip": false,
399
+ "normalized": false,
400
+ "rstrip": false,
401
+ "single_word": false,
402
+ "special": false
403
+ },
404
+ "250047": {
405
+ "content": "▁<extra_id_52>",
406
+ "lstrip": false,
407
+ "normalized": false,
408
+ "rstrip": false,
409
+ "single_word": false,
410
+ "special": false
411
+ },
412
+ "250048": {
413
+ "content": "▁<extra_id_51>",
414
+ "lstrip": false,
415
+ "normalized": false,
416
+ "rstrip": false,
417
+ "single_word": false,
418
+ "special": false
419
+ },
420
+ "250049": {
421
+ "content": "▁<extra_id_50>",
422
+ "lstrip": false,
423
+ "normalized": false,
424
+ "rstrip": false,
425
+ "single_word": false,
426
+ "special": false
427
+ },
428
+ "250050": {
429
+ "content": "���<extra_id_49>",
430
+ "lstrip": false,
431
+ "normalized": false,
432
+ "rstrip": false,
433
+ "single_word": false,
434
+ "special": false
435
+ },
436
+ "250051": {
437
+ "content": "▁<extra_id_48>",
438
+ "lstrip": false,
439
+ "normalized": false,
440
+ "rstrip": false,
441
+ "single_word": false,
442
+ "special": false
443
+ },
444
+ "250052": {
445
+ "content": "▁<extra_id_47>",
446
+ "lstrip": false,
447
+ "normalized": false,
448
+ "rstrip": false,
449
+ "single_word": false,
450
+ "special": false
451
+ },
452
+ "250053": {
453
+ "content": "▁<extra_id_46>",
454
+ "lstrip": false,
455
+ "normalized": false,
456
+ "rstrip": false,
457
+ "single_word": false,
458
+ "special": false
459
+ },
460
+ "250054": {
461
+ "content": "▁<extra_id_45>",
462
+ "lstrip": false,
463
+ "normalized": false,
464
+ "rstrip": false,
465
+ "single_word": false,
466
+ "special": false
467
+ },
468
+ "250055": {
469
+ "content": "▁<extra_id_44>",
470
+ "lstrip": false,
471
+ "normalized": false,
472
+ "rstrip": false,
473
+ "single_word": false,
474
+ "special": false
475
+ },
476
+ "250056": {
477
+ "content": "▁<extra_id_43>",
478
+ "lstrip": false,
479
+ "normalized": false,
480
+ "rstrip": false,
481
+ "single_word": false,
482
+ "special": false
483
+ },
484
+ "250057": {
485
+ "content": "▁<extra_id_42>",
486
+ "lstrip": false,
487
+ "normalized": false,
488
+ "rstrip": false,
489
+ "single_word": false,
490
+ "special": false
491
+ },
492
+ "250058": {
493
+ "content": "▁<extra_id_41>",
494
+ "lstrip": false,
495
+ "normalized": false,
496
+ "rstrip": false,
497
+ "single_word": false,
498
+ "special": false
499
+ },
500
+ "250059": {
501
+ "content": "▁<extra_id_40>",
502
+ "lstrip": false,
503
+ "normalized": false,
504
+ "rstrip": false,
505
+ "single_word": false,
506
+ "special": false
507
+ },
508
+ "250060": {
509
+ "content": "▁<extra_id_39>",
510
+ "lstrip": false,
511
+ "normalized": false,
512
+ "rstrip": false,
513
+ "single_word": false,
514
+ "special": false
515
+ },
516
+ "250061": {
517
+ "content": "▁<extra_id_38>",
518
+ "lstrip": false,
519
+ "normalized": false,
520
+ "rstrip": false,
521
+ "single_word": false,
522
+ "special": false
523
+ },
524
+ "250062": {
525
+ "content": "▁<extra_id_37>",
526
+ "lstrip": false,
527
+ "normalized": false,
528
+ "rstrip": false,
529
+ "single_word": false,
530
+ "special": false
531
+ },
532
+ "250063": {
533
+ "content": "▁<extra_id_36>",
534
+ "lstrip": false,
535
+ "normalized": false,
536
+ "rstrip": false,
537
+ "single_word": false,
538
+ "special": false
539
+ },
540
+ "250064": {
541
+ "content": "▁<extra_id_35>",
542
+ "lstrip": false,
543
+ "normalized": false,
544
+ "rstrip": false,
545
+ "single_word": false,
546
+ "special": false
547
+ },
548
+ "250065": {
549
+ "content": "▁<extra_id_34>",
550
+ "lstrip": false,
551
+ "normalized": false,
552
+ "rstrip": false,
553
+ "single_word": false,
554
+ "special": false
555
+ },
556
+ "250066": {
557
+ "content": "▁<extra_id_33>",
558
+ "lstrip": false,
559
+ "normalized": false,
560
+ "rstrip": false,
561
+ "single_word": false,
562
+ "special": false
563
+ },
564
+ "250067": {
565
+ "content": "▁<extra_id_32>",
566
+ "lstrip": false,
567
+ "normalized": false,
568
+ "rstrip": false,
569
+ "single_word": false,
570
+ "special": false
571
+ },
572
+ "250068": {
573
+ "content": "▁<extra_id_31>",
574
+ "lstrip": false,
575
+ "normalized": false,
576
+ "rstrip": false,
577
+ "single_word": false,
578
+ "special": false
579
+ },
580
+ "250069": {
581
+ "content": "▁<extra_id_30>",
582
+ "lstrip": false,
583
+ "normalized": false,
584
+ "rstrip": false,
585
+ "single_word": false,
586
+ "special": false
587
+ },
588
+ "250070": {
589
+ "content": "▁<extra_id_29>",
590
+ "lstrip": false,
591
+ "normalized": false,
592
+ "rstrip": false,
593
+ "single_word": false,
594
+ "special": false
595
+ },
596
+ "250071": {
597
+ "content": "▁<extra_id_28>",
598
+ "lstrip": false,
599
+ "normalized": false,
600
+ "rstrip": false,
601
+ "single_word": false,
602
+ "special": false
603
+ },
604
+ "250072": {
605
+ "content": "▁<extra_id_27>",
606
+ "lstrip": false,
607
+ "normalized": false,
608
+ "rstrip": false,
609
+ "single_word": false,
610
+ "special": false
611
+ },
612
+ "250073": {
613
+ "content": "▁<extra_id_26>",
614
+ "lstrip": false,
615
+ "normalized": false,
616
+ "rstrip": false,
617
+ "single_word": false,
618
+ "special": false
619
+ },
620
+ "250074": {
621
+ "content": "▁<extra_id_25>",
622
+ "lstrip": false,
623
+ "normalized": false,
624
+ "rstrip": false,
625
+ "single_word": false,
626
+ "special": false
627
+ },
628
+ "250075": {
629
+ "content": "▁<extra_id_24>",
630
+ "lstrip": false,
631
+ "normalized": false,
632
+ "rstrip": false,
633
+ "single_word": false,
634
+ "special": false
635
+ },
636
+ "250076": {
637
+ "content": "▁<extra_id_23>",
638
+ "lstrip": false,
639
+ "normalized": false,
640
+ "rstrip": false,
641
+ "single_word": false,
642
+ "special": false
643
+ },
644
+ "250077": {
645
+ "content": "▁<extra_id_22>",
646
+ "lstrip": false,
647
+ "normalized": false,
648
+ "rstrip": false,
649
+ "single_word": false,
650
+ "special": false
651
+ },
652
+ "250078": {
653
+ "content": "▁<extra_id_21>",
654
+ "lstrip": false,
655
+ "normalized": false,
656
+ "rstrip": false,
657
+ "single_word": false,
658
+ "special": false
659
+ },
660
+ "250079": {
661
+ "content": "▁<extra_id_20>",
662
+ "lstrip": false,
663
+ "normalized": false,
664
+ "rstrip": false,
665
+ "single_word": false,
666
+ "special": false
667
+ },
668
+ "250080": {
669
+ "content": "▁<extra_id_19>",
670
+ "lstrip": false,
671
+ "normalized": false,
672
+ "rstrip": false,
673
+ "single_word": false,
674
+ "special": false
675
+ },
676
+ "250081": {
677
+ "content": "▁<extra_id_18>",
678
+ "lstrip": false,
679
+ "normalized": false,
680
+ "rstrip": false,
681
+ "single_word": false,
682
+ "special": false
683
+ },
684
+ "250082": {
685
+ "content": "▁<extra_id_17>",
686
+ "lstrip": false,
687
+ "normalized": false,
688
+ "rstrip": false,
689
+ "single_word": false,
690
+ "special": false
691
+ },
692
+ "250083": {
693
+ "content": "▁<extra_id_16>",
694
+ "lstrip": false,
695
+ "normalized": false,
696
+ "rstrip": false,
697
+ "single_word": false,
698
+ "special": false
699
+ },
700
+ "250084": {
701
+ "content": "▁<extra_id_15>",
702
+ "lstrip": false,
703
+ "normalized": false,
704
+ "rstrip": false,
705
+ "single_word": false,
706
+ "special": false
707
+ },
708
+ "250085": {
709
+ "content": "▁<extra_id_14>",
710
+ "lstrip": false,
711
+ "normalized": false,
712
+ "rstrip": false,
713
+ "single_word": false,
714
+ "special": false
715
+ },
716
+ "250086": {
717
+ "content": "▁<extra_id_13>",
718
+ "lstrip": false,
719
+ "normalized": false,
720
+ "rstrip": false,
721
+ "single_word": false,
722
+ "special": false
723
+ },
724
+ "250087": {
725
+ "content": "▁<extra_id_12>",
726
+ "lstrip": false,
727
+ "normalized": false,
728
+ "rstrip": false,
729
+ "single_word": false,
730
+ "special": false
731
+ },
732
+ "250088": {
733
+ "content": "▁<extra_id_11>",
734
+ "lstrip": false,
735
+ "normalized": false,
736
+ "rstrip": false,
737
+ "single_word": false,
738
+ "special": false
739
+ },
740
+ "250089": {
741
+ "content": "▁<extra_id_10>",
742
+ "lstrip": false,
743
+ "normalized": false,
744
+ "rstrip": false,
745
+ "single_word": false,
746
+ "special": false
747
+ },
748
+ "250090": {
749
+ "content": "▁<extra_id_9>",
750
+ "lstrip": false,
751
+ "normalized": false,
752
+ "rstrip": false,
753
+ "single_word": false,
754
+ "special": false
755
+ },
756
+ "250091": {
757
+ "content": "▁<extra_id_8>",
758
+ "lstrip": false,
759
+ "normalized": false,
760
+ "rstrip": false,
761
+ "single_word": false,
762
+ "special": false
763
+ },
764
+ "250092": {
765
+ "content": "▁<extra_id_7>",
766
+ "lstrip": false,
767
+ "normalized": false,
768
+ "rstrip": false,
769
+ "single_word": false,
770
+ "special": false
771
+ },
772
+ "250093": {
773
+ "content": "▁<extra_id_6>",
774
+ "lstrip": false,
775
+ "normalized": false,
776
+ "rstrip": false,
777
+ "single_word": false,
778
+ "special": false
779
+ },
780
+ "250094": {
781
+ "content": "▁<extra_id_5>",
782
+ "lstrip": false,
783
+ "normalized": false,
784
+ "rstrip": false,
785
+ "single_word": false,
786
+ "special": false
787
+ },
788
+ "250095": {
789
+ "content": "▁<extra_id_4>",
790
+ "lstrip": false,
791
+ "normalized": false,
792
+ "rstrip": false,
793
+ "single_word": false,
794
+ "special": false
795
+ },
796
+ "250096": {
797
+ "content": "▁<extra_id_3>",
798
+ "lstrip": false,
799
+ "normalized": false,
800
+ "rstrip": false,
801
+ "single_word": false,
802
+ "special": false
803
+ },
804
+ "250097": {
805
+ "content": "▁<extra_id_2>",
806
+ "lstrip": false,
807
+ "normalized": false,
808
+ "rstrip": false,
809
+ "single_word": false,
810
+ "special": false
811
+ },
812
+ "250098": {
813
+ "content": "▁<extra_id_1>",
814
+ "lstrip": false,
815
+ "normalized": false,
816
+ "rstrip": false,
817
+ "single_word": false,
818
+ "special": false
819
+ },
820
+ "250099": {
821
+ "content": "▁<extra_id_0>",
822
+ "lstrip": false,
823
+ "normalized": false,
824
+ "rstrip": false,
825
+ "single_word": false,
826
+ "special": false
827
+ }
828
+ },
829
+ "additional_special_tokens": [],
830
+ "clean_up_tokenization_spaces": false,
831
+ "eos_token": "</s>",
832
+ "extra_ids": 0,
833
+ "extra_special_tokens": {},
834
+ "legacy": true,
835
+ "model_max_length": 1000000000000000019884624838656,
836
+ "pad_token": "<pad>",
837
+ "sp_model_kwargs": {},
838
+ "tokenizer_class": "T5Tokenizer",
839
+ "unk_token": "<unk>"
840
+ }
utils.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def create_mask_from_length(lengths: torch.Tensor, max_length: int | None = None):
5
+ lengths = torch.as_tensor(lengths)
6
+ if lengths.ndim == 0:
7
+ lengths = lengths.unsqueeze(0)
8
+ lengths = lengths.long()
9
+ if max_length is None:
10
+ if lengths.numel() == 0:
11
+ max_length = 0
12
+ else:
13
+ max_length = int(lengths.max().item())
14
+ idxs = torch.arange(max_length, device=lengths.device).reshape(1, -1)
15
+ mask = idxs < lengths.view(-1, 1)
16
+ return mask
17
+
18
+
19
+ def convert_pad_shape(pad_shape: list[list[int]]):
20
+ l = pad_shape[::-1]
21
+ return [item for sublist in l for item in sublist]
22
+
23
+
24
+ def create_alignment_path(duration: torch.Tensor, mask: torch.Tensor):
25
+ device = duration.device
26
+ b, t_x, t_y = mask.shape
27
+ cum_duration = torch.cumsum(duration, 1)
28
+
29
+ cum_duration_flat = cum_duration.view(b * t_x)
30
+ path = create_mask_from_length(cum_duration_flat, t_y).float()
31
+ path = path.view(b, t_x, t_y)
32
+ path = path - torch.nn.functional.pad(
33
+ path, convert_pad_shape([[0, 0], [1, 0], [0, 0]])
34
+ )[:, :-1]
35
+ path = path * mask
36
+ return path
37
+
38
+
39
+ def trim_or_pad_length(x: torch.Tensor, target_length: int, length_dim: int):
40
+ current_length = x.shape[length_dim]
41
+ if current_length > target_length:
42
+ slices = [slice(None)] * x.ndim
43
+ slices[length_dim] = slice(0, target_length)
44
+ return x[tuple(slices)]
45
+ elif current_length < target_length:
46
+ pad_shape = list(x.shape)
47
+ pad_shape[length_dim] = target_length - current_length
48
+ padding = torch.zeros(pad_shape, dtype=x.dtype, device=x.device)
49
+ return torch.cat([x, padding], dim=length_dim)
50
+ return x