Locke commited on
Commit
3e934f1
·
1 Parent(s): dcc1d9d
.gitattributes CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
1
+ *.png filter=lfs diff=lfs merge=lfs -text
2
+ *.wav filter=lfs diff=lfs merge=lfs -text
3
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
4
+ *.jpg filter=lfs diff=lfs merge=lfs -text
5
  *.7z filter=lfs diff=lfs merge=lfs -text
6
  *.arrow filter=lfs diff=lfs merge=lfs -text
7
  *.bin filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Meituan
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in
13
+ all copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ library_name: LongCat-Next
4
+ pipeline_tag: any-to-any
5
+ tags:
6
+ - transformers
7
+ - multimodal
8
+ ---
9
+
10
+ # LongCat-Next
11
+
12
+ <div align="center">
13
+ <img src="https://raw.githubusercontent.com/meituan-longcat/LongCat-Flash-Chat/main/figures/longcat_logo.svg"
14
+ width="300"
15
+ alt="LongCat Logo"/>
16
+ </div>
17
+
18
+ <hr>
19
+
20
+ <div align="center" style="line-height: 1;">
21
+ <a href="https://longcat.chat/longcat-next/intro" target="_blank" style="margin: 2px;">
22
+ <img alt="Blog" src="https://img.shields.io/badge/Blog-LongCatNext-white?logo=safari&logoColor=white&color=purple" style="display: inline-block; vertical-align: middle;"/>
23
+ </a>
24
+ <a href="https://huggingface.co/meituan-longcat" target="_blank" style="margin: 2px;">
25
+ <img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-LongCatNext-ffc107?color=ffc107&logoColor=white" style="display: inline-block; vertical-align: middle;"/>
26
+ </a>
27
+ <a href="https://github.com/meituan-longcat/LongCat-Next" target="_blank" style="margin: 2px;">
28
+ <img alt="GitHub" src="https://img.shields.io/badge/GitHub-LongCatNext-white?logo=github&logoColor=white&color=a4b5d5" style="display: inline-block; vertical-align: middle;"/>
29
+ </a>
30
+ <a href="https://longcat.chat/longcat-next" target="_blank" style="margin: 2px;">
31
+ <img alt="Demo" src="https://img.shields.io/badge/Demo-LongCatNext-white?logo=googleplay&logoColor=white&color=eabcdd" style="display: inline-block; vertical-align: middle;"/>
32
+ </a>
33
+ </div>
34
+
35
+ <div align="center" style="line-height: 1;">
36
+ <a href="https://github.com/meituan-longcat/LongCat-Flash-Chat/blob/main/figures/wechat_official_accounts.png" target="_blank" style="margin: 2px;">
37
+ <img alt="Wechat" src="https://img.shields.io/badge/WeChat-LongCat-brightgreen?logo=wechat&logoColor=white" style="display: inline-block; vertical-align: middle;"/>
38
+ </a>
39
+ <a href="https://x.com/Meituan_LongCat" target="_blank" style="margin: 2px;">
40
+ <img alt="Twitter Follow" src="https://img.shields.io/badge/Twitter-LongCat-white?logo=x&logoColor=white" style="display: inline-block; vertical-align: middle;"/>
41
+ </a>
42
+ </div>
43
+
44
+ <div align="center" style="line-height: 1;">
45
+ <a href="https://huggingface.co/meituan-longcat/LongCat-Next/blob/main/LICENSE" style="margin: 2px;">
46
+ <img alt="License" src="https://img.shields.io/badge/License-MIT-f5de53?&color=f5de53" style="display: inline-block; vertical-align: middle;"/>
47
+ </a>
48
+ </div>
49
+
50
+ <p align="center">
51
+ <a href="https://github.com/meituan-longcat/LongCat-Next/blob/main/tech_report.pdf">
52
+ <b>Tech Report</b>&nbsp;📄
53
+ </a>
54
+ </p>
55
+
56
+
57
+
58
+
59
+
60
+ ## Model Introduction
61
+
62
+ ![evaluation](./assets/overview.png)
63
+
64
+
65
+ We develop **LongCat-Next**, a native multimodal model that processes text, vision, and audio under a single autoregressive objective with minimal inductive bias beyond the language paradigm. As an industrial-strength foundation model with A3B model size, it excels at seeing, creating, and talking, achieving strong performance across a wide range of multimodal benchmarks. In particular, leveraging semantically complete discrete representations, it surpasses the long-standing performance ceiling of discrete vision modeling on understanding tasks, and provides a unified solution for visual understanding and generation. This success demonstrates that discrete tokens can universally represent multimodal signals and be deeply internalized within a single discrete embedding space. We further provide extensive experiments to analyze this unified discrete training paradigm and uncover several interesting findings.
66
+
67
+ As a meaningful attempt toward native multimodality, we open-source the **LongCat-Next** and its tokenizers, hoping to foster further research and development in the community.
68
+
69
+
70
+ ### Key Features
71
+
72
+ This work primarily addresses the fundamental barrier to native multimodality through a design philosophy that prioritizes simplicity, treating vision and audio as intrinsic extensions of language. As a step toward this goal, we present LongCat-Next, a discrete native multimodal model that achieves industrial-strength performance within discrete frameworks while remaining highly competitive across a wide range of specialized domains. Built upon the LongCat-Flash-Lite MoE backbone (A3B) as a _multi-task_ learner, the model unifies language, vision, and audio within a single discrete framework. In this paper, we make the following principal contributions:
73
+
74
+ #### 🌟 Discrete Native Autoregression Paradigm (DiNA).
75
+ We introduce DiNA, a unified paradigm that extends next-token prediction from language to native multimodality, which internalizes diverse modalities into a shared token space. It simplifies multimodal modeling by creating modality-aware tokenizer-detokenizer pairs and leveraging the established training infrastructure of large language models.
76
+
77
+
78
+ #### 🌟 Semantic Completeness for Discrete Visual Representation.
79
+ We improve discrete visual modeling by combining Semantic-and-Aligned Encoders (SAE) with Residual Vector Quantization (RVQ). This integration creates hierarchical discrete tokens that preserve both semantic abstraction and fine-grained visual details, surpassing traditional representation limitations.
80
+
81
+
82
+ #### 🌟 Discrete Native-Resolution Vision Transformer (dNaViT).
83
+ Analogous to linguistic tokenizers, we propose dNaViT as a highly flexible, unified discrete interface for vision that extracts semantic features as "visual words", constructing a hierarchical representation space supporting dynamic tokenization and detokenization. dNaViT integrates seamlessly with large language models, ensuring high performance without degradation.
84
+
85
+ #### 🌟 Exceling in Seeing, Creating, and Talking in a Unified Model.
86
+ Within the framework of DiNA, visual understanding and generation are elegantly reformulated as two manifestations of the same predictive process without performance compromise. This formulation bridges the long-standing architectural divide while introducing minimal interference between these traditionally competing objectives and preserving core language capabilities. Remarkably, LongCat-Next achieves competitive performance with specialized understanding models, while maintaining strong generative quality even under a 28× compression ratio, particularly in text rendering, while also excelling in advanced speech comprehension, low-latency voice conversation, and customizable voice cloning.
87
+
88
+
89
+ Please refer to our [technical report](./tech_report.pdf) for details!
90
+
91
+
92
+
93
+ ## Evaluation Results
94
+
95
+ ![evaluation](./assets/evaluation.png)
96
+
97
+
98
+
99
+
100
+ ## Quick Start
101
+ To use LongCat-Next with transformers, we need at least 3 GPUs (80GB VRAM each, e.g., H100/A100 80GB), and we recommend the following environment:
102
+ * `python` >= 3.10
103
+ * `torch` >= 2.6
104
+ * `transformers` >= 4.57.6
105
+ * `accelerate` >= 1.10.0
106
+
107
+ ```shell
108
+ # (Install python=3.10, ffffmpeg<7, soundfile==0.13.1)
109
+ conda env create -f environment.yml -v
110
+
111
+ # (Install torch and other pip dependencies)
112
+ pip install -r requirements.txt && pip install -r requirements-post.txt --no-build-isolation
113
+ ```
114
+
115
+ Basic Usage Example:
116
+ - Remember to modify `WEIGHT_PATH_TO_LONGCAT_NEXT` in `./config.json`, because decoders use lazy loading.
117
+
118
+ ```python
119
+ import torch
120
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor
121
+
122
+ # Load model
123
+ model_name = "meituan-longcat/LongCat-Next"
124
+ model = AutoModelForCausalLM.from_pretrained(
125
+ model_name,
126
+ torch_dtype=torch.bfloat16,
127
+ device_map="auto",
128
+ trust_remote_code=True,
129
+ )
130
+ model.eval()
131
+
132
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, fix_mistral_regex=True)
133
+ model.text_tokenizer = tokenizer # Dynamic binding
134
+ processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
135
+
136
+ # Set messages
137
+ messages = [
138
+ {"role": "system", "content": "You are a helpful assistant."},
139
+ {"role": "user", "content": "What book is this?<longcat_img_start>./assets/book.png<longcat_img_end>"}
140
+ ]
141
+
142
+ # Apply chat-template
143
+ text_input = tokenizer.apply_chat_template(
144
+ messages,
145
+ tokenize=False,
146
+ add_generation_prompt=True,
147
+ )
148
+ print(f"{text_input=}")
149
+
150
+ # Preprocessing
151
+ text_inputs, visual_inputs, audio_inputs = processor(text=text_input, return_tensors="pt")
152
+ text_inputs = text_inputs.to(model.device)
153
+ if visual_inputs is not None:
154
+ visual_inputs = visual_inputs.to(model.device)
155
+ if audio_inputs is not None:
156
+ audio_inputs = audio_inputs.to(model.device)
157
+
158
+ # AR
159
+ with torch.no_grad():
160
+ outputs = model.generate(
161
+ input_ids=text_inputs["input_ids"],
162
+ visual_inputs=visual_inputs,
163
+ audio_inputs=audio_inputs,
164
+ return_dict_in_generate=True,
165
+ )
166
+
167
+ # Text decoding
168
+ output_input_ids = outputs.sequences
169
+ text_output = tokenizer.decode(output_input_ids[0][len(text_inputs["input_ids"][0]):], skip_special_tokens=True)
170
+ print(f"{text_output=}")
171
+
172
+ # Images decoding
173
+ output_visual_ids = outputs.visual_ids
174
+ if output_visual_ids.size(0) > 0:
175
+ image_path_list = model.model.decode_visual_ids_and_save(
176
+ output_visual_ids,
177
+ save_prefix="./output_image",
178
+ **model.generation_config.visual_generation_config["custom_params"],
179
+ )
180
+ print(f"{image_path_list=}")
181
+
182
+ # Audio decoding
183
+ output_audio_text_ids = outputs.audio_text_ids
184
+ output_audio_ids = outputs.audio_ids
185
+ if output_audio_text_ids.size(-1) > 0:
186
+ audio_text = tokenizer.decode(output_audio_text_ids[0], skip_special_tokens=True)
187
+ print(f"{audio_text=}")
188
+ if output_audio_ids.size(0) > 0:
189
+ audio_path_list = model.model.decode_audio_ids_and_save(
190
+ output_audio_ids,
191
+ save_prefix="./output_audio",
192
+ **model.generation_config.audio_generation_config["custom_params"],
193
+ )
194
+ print(f"{audio_path_list=}")
195
+ ```
196
+
197
+
198
+ <details>
199
+ <summary>Text - Tool Calling Example</summary>
200
+
201
+ ```python
202
+ from parse_model_response import parse_model_response
203
+
204
+ tools = [
205
+ {
206
+ "type": "function",
207
+ "function": {
208
+ "name": "func_add",
209
+ "description": "Calculate the sum of two numbers",
210
+ "parameters": {
211
+ "type": "object",
212
+ "properties": {
213
+ "x1": {"type": "number", "description": "The first addend"},
214
+ "x2": {"type": "number", "description": "The second addend"}
215
+ },
216
+ "required": ["x1", "x2"]
217
+ }
218
+ }
219
+ }
220
+ ]
221
+ messages = [
222
+ {"role": "system", "content": "You are a helpful assistant."},
223
+ {"role": "user", "content": "Please tell me what is $$125679 + 234519$$?"},
224
+ {
225
+ "role": "assistant",
226
+ "content": "I'll calculate the sum of 125679 and 234519 for you.",
227
+ "tool_calls": [{"type": "function", "function": {"name": "func_add", "arguments": {"x1": 125679, "x2": 234519}}}]
228
+ },
229
+ {"role": "tool", "name": "func_add", "content": '{"ans": 360198}'}
230
+ ]
231
+
232
+ text_input = tokenizer.apply_chat_template(
233
+ messages,
234
+ tools=tools, # add tools here
235
+ tokenize=False,
236
+ add_generation_prompt=True,
237
+ )
238
+ print(f"{text_input=}")
239
+
240
+
241
+ # Preprocessing - AR - Text decoding
242
+ ...
243
+
244
+ # Results parsing
245
+ parsed_message = parse_model_response(text_output.strip("\n"), tools)
246
+ print(f"{parsed_message=}")
247
+ ```
248
+ See [`parse_model_response.py`](./parse_model_response.py) for detailed implementation and examples.
249
+
250
+ </details>
251
+
252
+
253
+ <details>
254
+ <summary>Image - Understanding Example</summary>
255
+
256
+ ```python
257
+ # Simply replace the messages in the main example with the messages below.
258
+ messages = [
259
+ {"role": "user", "content": "What book is this?<longcat_img_start>./assets/book.png<longcat_img_end>"}
260
+ ]
261
+ ```
262
+
263
+ </details>
264
+
265
+
266
+ <details>
267
+ <summary>Image - Generation Example</summary>
268
+
269
+ ```python
270
+ # Simply replace the messages in the main example with the messages below.
271
+ messages = [
272
+ {"role": "system", "content": ""},
273
+ {"role": "user", "content": "A small kitten sitting naturally on a moss-covered forest floor, centered in the frame, holding a rectangular wooden sign gently with its front paws resting over the top edge. The kitten has soft, fluffy fur, a natural relaxed posture, and a calm, curious expression with a slightly open mouth (not exaggerated), looking directly at the camera.\n\nThe sign is positioned firmly in front of the kitten\'s chest, supported by its paws, with realistic contact and no floating effect. The board reads \"LongCat-Next: When Modalities Internalize as Multilingual Tokens\" in clean, sharp black text, perfectly legible.\n\nThe environment is a lush forest with tall trees, ferns, and soft green foliage. The ground is covered with moss and small plants. Background softly blurred with natural depth of field. Lighting is soft, diffused sunlight filtering through the trees, creating gentle highlights and shadows. Realistic photography style, natural colors, high detail, no cartoonish exaggeration.<longcat_img_start>"}
274
+ ]
275
+ ```
276
+
277
+ </details>
278
+
279
+
280
+ <details>
281
+ <summary>Audio - Audio-to-Text Example</summary>
282
+
283
+ ```python
284
+ # Simply replace the messages in the main example with the messages below.
285
+ messages = [
286
+ {"role": "user", "content": "<longcat_audio_start>./assets/math1.wav<longcat_audio_end>"}
287
+ ]
288
+
289
+ ```
290
+
291
+ </details>
292
+
293
+ <details>
294
+ <summary>Audio - Audio-to-Audio Example</summary>
295
+
296
+ ```python
297
+ # Simply replace the messages in the main example with the messages below.
298
+ messages = [
299
+ {"role": "system", "content": "Replicate the voice in the audio clip to formulate an answer:<longcat_audio_start>./assets/system_audio.wav<longcat_audio_end>"},
300
+ {"role": "user", "content": "<longcat_audio_start>./assets/math1.wav<longcat_audio_end><longcat_audiogen_start>"}
301
+ ]
302
+ ```
303
+
304
+ </details>
305
+
306
+ <details>
307
+ <summary>Audio - Speech Synthesis Example</summary>
308
+
309
+ ```python
310
+ # Simply replace the messages in the main example with the messages below.
311
+ messages = [
312
+ {"role": "system", "content": "Replicate the voice in the audio clip to formulate an answer:<longcat_audio_start>./assets/vc_zh3.wav<longcat_audio_end>"},
313
+ {"role": "user", "content": "用这个声音合成以下内容:明天的meeting在三楼的Conference Room举行。<longcat_audiogen_start>"}
314
+ ]
315
+ ```
316
+
317
+ </details>
318
+
319
+
320
+ <!-- > [!Tip] -->
321
+
322
+ > We recommend using the following set of sampling parameters for generation:
323
+ >
324
+ > - Text: `{"max_new_tokens":2048,"do_sample":false}`
325
+ > - Image - Understanding: `{"max_new_tokens":1024,"do_sample":true,"temperature":0.4,"top_k":40,"top_p":0.85,"repetition_penalty":1.1}`
326
+ > - Image - Generation: `{"max_new_tokens":2048,"do_sample":false,"visual_generation_config":{"do_sample":true,"temperature":0.5,"top_p":0.75,"top_k":1024,"custom_params":{"cfg_scale":3,"token_h":37,"token_w":37,"anyres_prefix":"<longcat_img_token_size>{h} {w}</longcat_img_token_size>"}}}`
327
+ > - Audio - Audio-to-Text: `{"max_new_tokens":1024,"do_sample":true,"temperature":0.2,"top_k":20,"top_p":0.85,"repetition_penalty":1.1}`
328
+ > - Audio - Audio-to-Audio/Speech Synthesis: `{"max_new_tokens":2048,"do_sample":true,"temperature":0.2,"top_k":20,"top_p":0.85,"repetition_penalty":1.1,"audio_generation_config":{"audio_parallel_decoding":false,"do_sample":true,"temperature":0.5,"top_k":5,"top_p":0.85,"repetition_penalty":1.3,"custom_params":{"sampling_rate":24000,"wave_concat_overlap":1200}}}`
329
+ >
330
+ > Please note that the support for sampling parameters varies according to inference frameworks(For transformers, the inference parameter configuration is located in `./generation_config.json`).
331
+
332
+
333
+
334
+ ## Deployment
335
+
336
+ We have implemented basic adaptations in SGLang(Code is being uploaded) to support the deployment of LongCat-Next.
337
+
338
+ ```shell
339
+ git clone [TBU]
340
+ cd nmm_infer
341
+ git checkout master
342
+ sh setup.sh
343
+ ```
344
+
345
+ ```shell
346
+ # Require CUDA >= 12.9
347
+
348
+ # Setup environment
349
+ source create_env.sh
350
+ source set_env.sh
351
+
352
+ # Run tests
353
+ python3 demo.py \
354
+ --model-path meituan-longcat/LongCat-Next \
355
+ --sequential \
356
+ --output-dir output \
357
+ --tasks vis_gen vis_und aud_qa spk_syn
358
+
359
+ ```
360
+
361
+
362
+ ## License Agreement
363
+ This repository, including both the model weights and the source code, is released under the **MIT License**.
364
+
365
+ Any contributions to this repository are licensed under the MIT License, unless otherwise stated. This license does not grant any rights to use Meituan trademarks or patents.
366
+
367
+ For details, see the [LICENSE](./LICENSE) file.
368
+
369
+ ## Usage Considerations
370
+ This model has not been specifically designed or comprehensively evaluated for every possible downstream application.
371
+
372
+ Developers should take into account the known limitations of large language models, including performance variations across different languages, and carefully assess accuracy, safety, and fairness before deploying the model in sensitive or high-risk scenarios.
373
+ It is the responsibility of developers and downstream users to understand and comply with all applicable laws and regulations relevant to their use case, including but not limited to data protection, privacy, and content safety requirements.
374
+
375
+ Nothing in this Model Card should be interpreted as altering or restricting the terms of the MIT License under which the model is released.
376
+
377
+
378
+ <!-- ## Citation
379
+
380
+ We kindly encourage citation of our work if you find it useful.
381
+
382
+ ```
383
+
384
+ ``` -->
385
+
386
+
387
+ ## Contact
388
+ Please contact us at <a href="mailto:longcat-team@meituan.com">longcat-team@meituan.com</a> or open an issue if you have any questions.
assets/book.png ADDED

Git LFS Details

  • SHA256: 973616383a25a76a71b18532452bc3d422516c0ce684895065cbdbaeb7c654e5
  • Pointer size: 131 Bytes
  • Size of remote file: 745 kB
assets/evaluation.png ADDED

Git LFS Details

  • SHA256: 82bc8ab1a053e71f1328241be1739f3b9d0f0c0f84501500070c2e8a49542759
  • Pointer size: 131 Bytes
  • Size of remote file: 267 kB
assets/longcat_logo.svg ADDED
assets/math1.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e88c12d17ba1b6d8a28fa6688311222673db0f958a3679347f03ba4afd4b78c2
3
+ size 1140560
assets/overview.png ADDED

Git LFS Details

  • SHA256: 945a0cf9961850f0db4e81dbf7a2ac588f6206e64b85336c21fed446cf99f8cd
  • Pointer size: 132 Bytes
  • Size of remote file: 1.51 MB
assets/system_audio.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bbb21a5cd57013406e1c18e8f267d05197bbbc3fdb8a65038d9c5a7799b9357a
3
+ size 254478
assets/vc_zh3.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c8313c738deac97e9c36cb861a85a896c9bbdaa22fe9f9f432feace766a75c65
3
+ size 1282618
config.json ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LongcatNextForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_longcat_next.LongcatNextConfig",
9
+ "AutoModel": "modeling_longcat_next.LongcatNextModel",
10
+ "AutoModelForCausalLM": "modeling_longcat_next.LongcatNextForCausalLM"
11
+ },
12
+ "vocab_size": 282624,
13
+ "hidden_size": 3072,
14
+ "ffn_hidden_size": 6144,
15
+ "expert_ffn_hidden_size": 1024,
16
+ "num_layers": 14,
17
+ "num_attention_heads": 32,
18
+ "kv_lora_rank": 512,
19
+ "q_lora_rank": 1536,
20
+ "qk_rope_head_dim": 64,
21
+ "v_head_dim": 128,
22
+ "qk_nope_head_dim": 128,
23
+ "mla_scale_q_lora": true,
24
+ "mla_scale_kv_lora": true,
25
+ "routed_scaling_factor": 6.0,
26
+ "n_routed_experts": 256,
27
+ "rms_norm_eps": 1e-5,
28
+ "use_cache": true,
29
+ "bos_token_id": 1,
30
+ "eos_token_id": 2,
31
+ "rope_theta": 10000000,
32
+ "max_position_embeddings": 131072,
33
+ "zero_expert_num": 128,
34
+ "zero_expert_type": "identity",
35
+ "moe_topk": 12,
36
+ "ngram_vocab_size_ratio": 78,
37
+ "emb_neighbor_num": 4,
38
+ "emb_split_num": 4,
39
+ "torch_dtype": "bfloat16",
40
+ "transformers_version": "4.57.6",
41
+
42
+ "text_vocab_size": 131072,
43
+ "text_vocab_plus_multimodal_special_token_size": 131125,
44
+ "visual_embedding_layer_intermediate_size": 8192,
45
+ "visual_embedding_layer_hidden_act": "silu",
46
+ "visual_offset": 150581,
47
+ "audio_offset": 131125,
48
+
49
+
50
+ "visual_config": {
51
+ "image_start_token_id": 131106,
52
+ "image_end_token_id": 131107,
53
+ "image_pad_token_id": 131108,
54
+ "image_newline_token_id": 131109,
55
+
56
+ "_attn_implementation": "flash_attention_2",
57
+ "hidden_size": 1280,
58
+
59
+ "image_head_transformer_dims": 2048,
60
+ "image_head_transformer_ffn_scale": 16,
61
+ "image_head_transformer_layers": 4,
62
+
63
+ "vq_config": {
64
+ "codebook_dim": 3584,
65
+ "codebook_size": 16384,
66
+ "codebook_sizes": [
67
+ 16384,
68
+ 16384,
69
+ 16384,
70
+ 16384,
71
+ 16384,
72
+ 16384,
73
+ 16384,
74
+ 16384
75
+ ],
76
+ "decay": 0.99,
77
+ "depth": 8,
78
+
79
+ "commit_loss_ratio": 0.25,
80
+ "entropy_loss_ratio": 0,
81
+
82
+ "in_channels": 3584,
83
+ "quant_conv": true,
84
+ "quantizer_type": "rq",
85
+ "restart_unused_codes": true,
86
+ "shared_codebook": true,
87
+
88
+ "vq_loss_ratio": 0
89
+ },
90
+
91
+ "visual_decoder_config": {
92
+ "codebook_dim": 3584,
93
+
94
+ "image_decoder_config": {
95
+ "attention_dropout": 0.0,
96
+ "codebook_dim": 3584,
97
+ "distill_taps": [
98
+ 3,
99
+ 7,
100
+ 15,
101
+ 23
102
+ ],
103
+ "hidden_act": "gelu",
104
+ "hidden_size": 1024,
105
+ "intermediate_size": 2730,
106
+ "k_bias": false,
107
+ "layer_norm_eps": 1e-06,
108
+ "num_attention_heads": 16,
109
+ "num_hidden_layers": 32,
110
+ "patch_size": 14,
111
+ "q_bias": true,
112
+ "spatial_merge_size": 2,
113
+ "subln": true,
114
+ "swiglu": true,
115
+ "teacher_dims": {
116
+ "15": 1280,
117
+ "23": 1280,
118
+ "3": 1280,
119
+ "7": 1280
120
+ },
121
+ "temporal_patch_size": 2,
122
+ "v_bias": true
123
+ },
124
+
125
+ "transformer_config": {
126
+ "patch_size": 2,
127
+ "in_channels": 16,
128
+ "hidden_size": 2520,
129
+ "num_layers": 32,
130
+ "num_refiner_layers": 2,
131
+ "num_attention_heads": 21,
132
+ "num_kv_heads": 7,
133
+ "multiple_of": 256,
134
+ "norm_eps": 1e-5,
135
+ "axes_dim_rope": [40, 40, 40],
136
+ "axes_lens": [10000, 10000, 10000],
137
+ "text_feat_dim": 2048,
138
+ "timestep_scale": 1000.0
139
+ },
140
+
141
+ "vae_config": {
142
+ "act_fn": "silu",
143
+ "block_out_channels": [128, 256, 512, 512],
144
+ "down_block_types": [
145
+ "DownEncoderBlock2D",
146
+ "DownEncoderBlock2D",
147
+ "DownEncoderBlock2D",
148
+ "DownEncoderBlock2D"
149
+ ],
150
+ "in_channels": 3,
151
+ "latent_channels": 16,
152
+ "layers_per_block": 2,
153
+ "mid_block_add_attention": true,
154
+ "norm_num_groups": 32,
155
+ "out_channels": 3,
156
+ "sample_size": 1024,
157
+ "scaling_factor": 0.3611,
158
+ "shift_factor": 0.1159,
159
+ "up_block_types": [
160
+ "UpDecoderBlock2D",
161
+ "UpDecoderBlock2D",
162
+ "UpDecoderBlock2D",
163
+ "UpDecoderBlock2D"
164
+ ],
165
+ "use_post_quant_conv": false,
166
+ "use_quant_conv": false,
167
+ "force_upcast": true
168
+ },
169
+
170
+ "scheduler_config": {
171
+ "num_train_timesteps": 1000,
172
+ "dynamic_time_shift": true
173
+ },
174
+
175
+ "weight_path": "WEIGHT_PATH_TO_LONGCAT_NEXT/image_decoder/image_decoder.safetensors"
176
+ }
177
+ },
178
+
179
+
180
+ "audio_config": {
181
+ "audio_head_transformer_dims": 3072,
182
+ "audio_head_transformer_ffn_scale": 16,
183
+ "audio_head_transformer_layers": 4,
184
+
185
+ "audio_delim_token_id": 131116,
186
+ "audio_end_token_id": 131104,
187
+ "audio_pad_token_id": 131105,
188
+ "audio_start_token_id": 131103,
189
+ "audiogen_end_token_id": 131124,
190
+ "audiogen_start_token_id": 131123,
191
+ "audiotext_end_token_id": 131121,
192
+ "audiotext_pad_token_id": 131122,
193
+ "audiotext_start_token_id": 131120,
194
+
195
+ "_attn_implementation": "flash_attention_2",
196
+ "d_model": 1280,
197
+ "decoder_attention_heads": 20,
198
+ "decoder_ffn_dim": 5120,
199
+ "decoder_layers": 8,
200
+ "encoder_attention_heads": 20,
201
+ "encoder_ffn_dim": 5120,
202
+ "encoder_layers": 32,
203
+ "num_mel_bins": 128,
204
+
205
+ "avg_pooler": 4,
206
+ "decoder_kernel_size": 3,
207
+ "decoder_stride_size": 2,
208
+ "hop_length": 160,
209
+ "kernel_size": 3,
210
+ "max_audio_seconds": 30,
211
+ "n_fft": 400,
212
+ "num_hidden_layers": 32,
213
+ "sampling_rate": 16000,
214
+ "stride_size": 2,
215
+
216
+ "vq_config": {
217
+ "codebook_sizes": [
218
+ 8192,
219
+ 4096,
220
+ 2048,
221
+ 1024,
222
+ 1024,
223
+ 1024,
224
+ 1024,
225
+ 1024
226
+ ]
227
+ },
228
+
229
+ "vocoder_config": {
230
+ "channels": [
231
+ 256,
232
+ 256,
233
+ 256,
234
+ 256,
235
+ 256
236
+ ],
237
+ "hop_length": 256,
238
+ "num_mel_bins": 80,
239
+ "sampling_rate": 16000
240
+ },
241
+
242
+
243
+ "flow_matching_config": {
244
+ "in_channels": 80,
245
+ "spk_emb_dim": 0,
246
+ "diffusion_steps": 10,
247
+ "cal_mel_mae": true,
248
+
249
+ "prenet_activation_function": "gelu",
250
+ "prenet_attention_heads": 8,
251
+ "prenet_d_model": 512,
252
+ "prenet_ffn_dim": 2048,
253
+ "prenet_in_dim": 1280,
254
+ "prenet_max_source_positions": 5000,
255
+ "prenet_nlayers": 12,
256
+ "prenet_out_dim": 80,
257
+ "prenet_target_mel_length_scale_ratio": 1.0,
258
+
259
+ "channels": [
260
+ 256
261
+ ],
262
+ "dropout": 0.0,
263
+ "attention_head_dim": 64,
264
+ "n_blocks": 4,
265
+ "num_heads": 8,
266
+ "num_mid_blocks": 12,
267
+ "act_fn": "gelu",
268
+
269
+ "cfm_params": {
270
+ "inference_cfg_rate": 0.7,
271
+ "sigma_min": 1e-06,
272
+ "solver": "euler",
273
+ "t_scheduler": "cosine",
274
+ "training_cfg_rate": 0.2
275
+ },
276
+
277
+ "use_hidden_states_before_dconv2": true
278
+ },
279
+
280
+ "cosy24kvocoder_config": {
281
+ "weight_path": "WEIGHT_PATH_TO_LONGCAT_NEXT/cosy24k_vocoder/hift.pt"
282
+ }
283
+
284
+ }
285
+ }
configuration_longcat_next.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLVisionConfig
3
+ from transformers.models.whisper.configuration_whisper import WhisperConfig
4
+
5
+ from .configuration_longcat_ngram import LongcatFlashNgramConfig
6
+
7
+ class LongcatNextConfig(LongcatFlashNgramConfig):
8
+ def __init__(
9
+ self,
10
+ vocab_size=131072,
11
+ hidden_size=6144,
12
+ num_hidden_layers=56,
13
+ num_layers=28,
14
+ num_attention_heads=64,
15
+ num_key_value_heads=None,
16
+ hidden_act="silu",
17
+ max_position_embeddings=131072,
18
+ initializer_range=0.02,
19
+ rms_norm_eps=1e-5,
20
+ use_cache=True,
21
+ pad_token_id=None,
22
+ bos_token_id=1,
23
+ eos_token_id=2,
24
+ tie_word_embeddings=False,
25
+ rope_theta=10000000.0,
26
+ rope_scaling=None,
27
+ attention_bias=False,
28
+ attention_dropout=0.0,
29
+ ffn_hidden_size=12288,
30
+ q_lora_rank=1536,
31
+ kv_lora_rank=512,
32
+ qk_nope_head_dim=128,
33
+ qk_rope_head_dim=64,
34
+ head_dim=64,
35
+ v_head_dim=128,
36
+ qk_head_dim=None,
37
+ moe_topk=12,
38
+ n_routed_experts=512,
39
+ zero_expert_num=256,
40
+ expert_ffn_hidden_size=2048,
41
+ routed_scaling_factor=6.0,
42
+ emb_neighbor_num=None,
43
+ emb_split_num=None,
44
+ ngram_vocab_size_ratio=None,
45
+ oe_ignored_token_ids=[],
46
+ text_vocab_size=131072, # text vocab size (vocab_size = text_vocab_size + audio_token + visual_token + multimodal_special_token_list)
47
+ text_vocab_plus_multimodal_special_token_size=131125,
48
+ visual_embedding_layer_intermediate_size=8192,
49
+ visual_embedding_layer_hidden_act="silu",
50
+ visual_offset=150581,
51
+ audio_offset=131125,
52
+ visual_config={},
53
+ audio_config={},
54
+ **kwargs,
55
+ ):
56
+ self.text_vocab_size = text_vocab_size
57
+ self.text_vocab_plus_multimodal_special_token_size = text_vocab_plus_multimodal_special_token_size
58
+ self.visual_embedding_layer_intermediate_size = visual_embedding_layer_intermediate_size
59
+ self.visual_embedding_layer_hidden_act = visual_embedding_layer_hidden_act
60
+ self.visual_offset = visual_offset
61
+ self.audio_offset = audio_offset
62
+ self.visual_config = LongcatNextVisualConfig(**visual_config)
63
+ self.audio_config = LongcatNextAudioConfig(**audio_config)
64
+ oe_ignored_token_ids = oe_ignored_token_ids or list(range(self.text_vocab_size, self.text_vocab_plus_multimodal_special_token_size))
65
+
66
+ super().__init__(
67
+ vocab_size=vocab_size,
68
+ hidden_size=hidden_size,
69
+ num_hidden_layers=num_hidden_layers,
70
+ num_layers=num_layers,
71
+ num_attention_heads=num_attention_heads,
72
+ num_key_value_heads=num_key_value_heads,
73
+ hidden_act=hidden_act,
74
+ max_position_embeddings=max_position_embeddings,
75
+ initializer_range=initializer_range,
76
+ rms_norm_eps=rms_norm_eps,
77
+ use_cache=use_cache,
78
+ pad_token_id=pad_token_id,
79
+ bos_token_id=bos_token_id,
80
+ eos_token_id=eos_token_id,
81
+ tie_word_embeddings=tie_word_embeddings,
82
+ rope_theta=rope_theta,
83
+ rope_scaling=rope_scaling,
84
+ attention_bias=attention_bias,
85
+ attention_dropout=attention_dropout,
86
+ ffn_hidden_size=ffn_hidden_size,
87
+ q_lora_rank=q_lora_rank,
88
+ kv_lora_rank=kv_lora_rank,
89
+ qk_nope_head_dim=qk_nope_head_dim,
90
+ qk_rope_head_dim=qk_rope_head_dim,
91
+ head_dim=head_dim,
92
+ v_head_dim=v_head_dim,
93
+ qk_head_dim=qk_head_dim,
94
+ moe_topk=moe_topk,
95
+ n_routed_experts=n_routed_experts,
96
+ zero_expert_num=zero_expert_num,
97
+ expert_ffn_hidden_size=expert_ffn_hidden_size,
98
+ routed_scaling_factor=routed_scaling_factor,
99
+ emb_neighbor_num=emb_neighbor_num,
100
+ emb_split_num=emb_split_num,
101
+ ngram_vocab_size_ratio=ngram_vocab_size_ratio,
102
+ oe_ignored_token_ids=oe_ignored_token_ids,
103
+ **kwargs,
104
+ )
105
+
106
+ class LongcatNextVisualConfig(Qwen2_5_VLVisionConfig):
107
+ model_type = "longcat_next_visual"
108
+ base_config_key = ""
109
+
110
+ def __init__(
111
+ self,
112
+ image_start_token_id=131106,
113
+ image_end_token_id=131107,
114
+ image_pad_token_id=131108,
115
+ image_newline_token_id=131109,
116
+ vq_config={},
117
+ visual_decoder_config={},
118
+ **kwargs,
119
+ ):
120
+ self.image_start_token_id = image_start_token_id
121
+ self.image_end_token_id = image_end_token_id
122
+ self.image_pad_token_id = image_pad_token_id
123
+ self.image_newline_token_id = image_newline_token_id
124
+ self.vq_config = PretrainedConfig(**vq_config)
125
+ self.visual_decoder_config = PretrainedConfig(**visual_decoder_config)
126
+ self.visual_decoder_config.image_decoder_config = PretrainedConfig(**getattr(self.visual_decoder_config, "image_decoder_config", {}))
127
+ self.visual_decoder_config.transformer_config = PretrainedConfig(**getattr(self.visual_decoder_config, "transformer_config", {}))
128
+ self.visual_decoder_config.vae_config = PretrainedConfig(**getattr(self.visual_decoder_config, "vae_config", {}))
129
+ self.visual_decoder_config.scheduler_config = PretrainedConfig(**getattr(self.visual_decoder_config, "scheduler_config", {}))
130
+ super().__init__(**kwargs)
131
+
132
+ class LongcatNextAudioConfig(WhisperConfig):
133
+ model_type = "longcat_next_audio"
134
+ base_config_key = ""
135
+
136
+ def __init__(
137
+ self,
138
+ vq_config={},
139
+ vocoder_config={},
140
+ flow_matching_config={},
141
+ cosy24kvocoder_config={},
142
+ **kwargs
143
+ ):
144
+ self.vq_config = PretrainedConfig(**vq_config)
145
+ self.vocoder_config = PretrainedConfig(**vocoder_config)
146
+ self.flow_matching_config = PretrainedConfig(**flow_matching_config)
147
+ self.flow_matching_config.cfm_params = PretrainedConfig(**getattr(self.flow_matching_config, "cfm_params", {}))
148
+ self.cosy24kvocoder_config = PretrainedConfig(**cosy24kvocoder_config)
149
+ super().__init__(**kwargs)
150
+
151
+
152
+ __all__ = ["LongcatNextConfig", "LongcatNextVisualConfig", "LongcatNextAudioConfig"]
configuration_longcat_ngram.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.models.longcat_flash import LongcatFlashConfig
2
+
3
+
4
+ class LongcatFlashNgramConfig(LongcatFlashConfig):
5
+ r"""
6
+ This is the configuration class to store the configuration of a [`LongcatFlashNgramModel`]. It is used to instantiate
7
+ a LongCat Flash model with N-gram enhanced embeddings according to the specified arguments, defining the model architecture.
8
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
9
+ documentation from [`PretrainedConfig`] for more information.
10
+
11
+
12
+ Args:
13
+ vocab_size (`int`, *optional*, defaults to 131072):
14
+ Vocabulary size of the LongCat Flash model. Defines the number of different tokens that can be represented by the
15
+ `input_ids` passed when calling [`LongcatFlashNgramModel`]
16
+ hidden_size (`int`, *optional*, defaults to 6144):
17
+ Dimension of the hidden representations.
18
+ num_hidden_layers (`int`, *optional*, defaults to 56):
19
+ Number of hidden layers in the Transformer decoder.
20
+ num_layers (`int`, *optional*, defaults to 28):
21
+ Number of layers, each with 2 sublayers.
22
+ num_attention_heads (`int`, *optional*, defaults to 64):
23
+ Number of attention heads for each attention layer in the Transformer decoder.
24
+ num_key_value_heads (`int`, *optional*):
25
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
26
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
27
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
28
+ converting from a multi-head checkpoint to a GQA checkpoint, each group key and value head should be
29
+ constructed by meanpooling all the original heads within that group. For more details checkout [this
30
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
31
+ `num_attention_heads`.
32
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
33
+ The non-linear activation function (function or string) in the decoder.
34
+ max_position_embeddings (`int`, *optional*, defaults to 131072):
35
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
36
+ just in case (e.g., 512 or 1024 or 2048).
37
+ initializer_range (`float`, *optional*, defaults to 0.02):
38
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
39
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
40
+ The epsilon value used by the RMS normalization layers.
41
+ use_cache (`bool`, *optional*, defaults to `True`):
42
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
43
+ relevant if `config.is_decoder=True`.
44
+ pad_token_id (`int`, *optional*):
45
+ Padding token id.
46
+ bos_token_id (`int`, *optional*, defaults to 1):
47
+ Beginning of stream token id.
48
+ eos_token_id (`int`, *optional*, defaults to 2):
49
+ End of stream token id.
50
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
51
+ Whether to tie input and output embeddings.
52
+ rope_theta (`float`, *optional*, defaults to 10000000.0):
53
+ The base period of the RoPE embeddings.
54
+ rope_scaling (`Dict`, *optional*):
55
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
56
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
57
+ `{"type": strategy name, "factor": scaling factor}`.
58
+ attention_bias (`bool`, *optional*, defaults to `False`):
59
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
60
+ attention_dropout (`float`, *optional*, defaults to 0.0):
61
+ The dropout ratio for the attention probabilities.
62
+ ffn_hidden_size (`int`, *optional*, defaults to 12288):
63
+ Dimension of the MLP representations.
64
+ q_lora_rank (`int`, *optional*, defaults to 1536):
65
+ The rank of the query LoRA projection in MLA (Multi-head Latent Attention).
66
+ kv_lora_rank (`int`, *optional*, defaults to 512):
67
+ The rank of the key-value LoRA projection in MLA.
68
+ qk_nope_head_dim (`int`, *optional*, defaults to 128):
69
+ The dimension of the non-position encoding part of query/key heads.
70
+ qk_rope_head_dim (`int`, *optional*, defaults to 64):
71
+ The dimension of the RoPE part of query/key heads.
72
+ head_dim (`int`, *optional*, defaults to 64):
73
+ Standard dimension of qk heads, unused except for CI.
74
+ v_head_dim (`int`, *optional*, defaults to 128):
75
+ The dimension of value heads.
76
+ qk_head_dim (`int`, *optional*):
77
+ The total dimension of query/key heads. If not specified, set to `qk_nope_head_dim + qk_rope_head_dim`.
78
+ moe_topk (`int`, *optional*, defaults to 12):
79
+ Number of experts to route to for each token in the MoE layer.
80
+ n_routed_experts (`int`, *optional*, defaults to 512):
81
+ Number of routed experts in the MoE layer.
82
+ zero_expert_num (`int`, *optional*, defaults to 256):
83
+ Number of zero experts (identity function) to add to the expert pool.
84
+ expert_ffn_hidden_size (`int`, *optional*, defaults to 2048):
85
+ Hidden size of individual expert FFN layers.
86
+ routed_scaling_factor (`float`, *optional*, defaults to 6.0):
87
+ Scaling factor applied to the routing weights.
88
+ emb_neighbor_num (`int`, *optional*):
89
+ Maximum N-gram length for N-gram embeddings. This parameter determines the context window size for N-gram computation. Higher values capture
90
+ longer-range lexical patterns but increase memory usage.
91
+ emb_split_num (`int`, *optional*):
92
+ Number of hash functions (or splits) to use for N-gram embeddings. Multiple hash functions help improve the quality of N-gram representations.
93
+ ngram_vocab_size_ratio (`float`, *optional*):
94
+ Ratio multiplier for N-gram vocabulary size relative to the base vocabulary size. The N-gram vocabulary
95
+ size is calculated as `vocab_size * ngram_vocab_size_ratio`.
96
+
97
+ Example:
98
+ ```python
99
+ >>> from transformers import LongcatFlashNgramModel, LongcatFlashNgramConfig
100
+
101
+ >>> # Initializing a LongCat Flash N-gram style configuration
102
+ >>> configuration = LongcatFlashNgramConfig(
103
+ ... emb_neighbor_num=3,
104
+ ... emb_split_num=4,
105
+ ... ngram_vocab_size_ratio=1.5
106
+ ... )
107
+
108
+ >>> # Initializing a model from the configuration
109
+ >>> model = LongcatFlashNgramModel(configuration)
110
+
111
+ >>> # Accessing the model configuration
112
+ >>> configuration = model.config
113
+ ```"""
114
+
115
+ model_type = "longcat_flash_ngram"
116
+ keys_to_ignore_at_inference = ["past_key_values"]
117
+ base_model_tp_plan = {
118
+ "layers.*.self_attn.*.q_b_proj": "colwise",
119
+ "layers.*.self_attn.*.kv_b_proj": "colwise",
120
+ "layers.*.self_attn.*.o_proj": "rowwise",
121
+ "layers.*.mlps.*.gate_proj": "colwise",
122
+ "layers.*.mlps.*.up_proj": "colwise",
123
+ "layers.*.mlps.*.down_proj": "rowwise",
124
+ "layers.*.mlp.experts.*.gate_proj": "colwise",
125
+ "layers.*.mlp.experts.*.up_proj": "colwise",
126
+ "layers.*.mlp.experts.*.down_proj": "rowwise",
127
+ }
128
+
129
+ base_model_pp_plan = {
130
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
131
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
132
+ "norm": (["hidden_states"], ["hidden_states"]),
133
+ }
134
+
135
+ def __init__(
136
+ self,
137
+ vocab_size=131072,
138
+ hidden_size=6144,
139
+ num_hidden_layers=56,
140
+ num_layers=28,
141
+ num_attention_heads=64,
142
+ num_key_value_heads=None,
143
+ hidden_act="silu",
144
+ max_position_embeddings=131072,
145
+ initializer_range=0.02,
146
+ rms_norm_eps=1e-5,
147
+ use_cache=True,
148
+ pad_token_id=None,
149
+ bos_token_id=1,
150
+ eos_token_id=2,
151
+ tie_word_embeddings=False,
152
+ rope_theta=10000000.0,
153
+ rope_scaling=None,
154
+ attention_bias=False,
155
+ attention_dropout=0.0,
156
+ ffn_hidden_size=12288,
157
+ q_lora_rank=1536,
158
+ kv_lora_rank=512,
159
+ qk_nope_head_dim=128,
160
+ qk_rope_head_dim=64,
161
+ head_dim=64,
162
+ v_head_dim=128,
163
+ qk_head_dim=None,
164
+ moe_topk=12,
165
+ n_routed_experts=512,
166
+ zero_expert_num=256,
167
+ expert_ffn_hidden_size=2048,
168
+ routed_scaling_factor=6.0,
169
+ emb_neighbor_num=None,
170
+ emb_split_num=None,
171
+ ngram_vocab_size_ratio=None,
172
+ oe_ignored_token_ids=[],
173
+ **kwargs,
174
+ ):
175
+ # N-gram embedding specific parameters
176
+ self.emb_neighbor_num = emb_neighbor_num
177
+ self.emb_split_num = emb_split_num
178
+ self.ngram_vocab_size_ratio = ngram_vocab_size_ratio
179
+ self.oe_ignored_token_ids = oe_ignored_token_ids
180
+
181
+ super().__init__(
182
+ vocab_size=vocab_size,
183
+ hidden_size=hidden_size,
184
+ num_hidden_layers=num_hidden_layers,
185
+ num_layers=num_layers,
186
+ num_attention_heads=num_attention_heads,
187
+ num_key_value_heads=num_key_value_heads,
188
+ hidden_act=hidden_act,
189
+ max_position_embeddings=max_position_embeddings,
190
+ initializer_range=initializer_range,
191
+ rms_norm_eps=rms_norm_eps,
192
+ use_cache=use_cache,
193
+ pad_token_id=pad_token_id,
194
+ bos_token_id=bos_token_id,
195
+ eos_token_id=eos_token_id,
196
+ tie_word_embeddings=tie_word_embeddings,
197
+ rope_theta=rope_theta,
198
+ rope_scaling=rope_scaling,
199
+ attention_bias=attention_bias,
200
+ attention_dropout=attention_dropout,
201
+ ffn_hidden_size=ffn_hidden_size,
202
+ q_lora_rank=q_lora_rank,
203
+ kv_lora_rank=kv_lora_rank,
204
+ qk_nope_head_dim=qk_nope_head_dim,
205
+ qk_rope_head_dim=qk_rope_head_dim,
206
+ head_dim=head_dim,
207
+ v_head_dim=v_head_dim,
208
+ qk_head_dim=qk_head_dim,
209
+ moe_topk=moe_topk,
210
+ n_routed_experts=n_routed_experts,
211
+ zero_expert_num=zero_expert_num,
212
+ expert_ffn_hidden_size=expert_ffn_hidden_size,
213
+ routed_scaling_factor=routed_scaling_factor,
214
+ **kwargs,
215
+ )
216
+
217
+
218
+ __all__ = ["LongcatFlashNgramConfig"]
cosy24k_vocoder.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """HIFI-GAN"""
16
+
17
+ from typing import Dict, Optional, List
18
+ import numpy as np
19
+ from scipy.signal import get_window
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from torch.nn import Conv1d
24
+ from torch.nn import ConvTranspose1d
25
+ from torch.nn.utils import remove_weight_norm
26
+ from torch.nn.utils import weight_norm
27
+ from torch.distributions.uniform import Uniform
28
+ from torch.nn import Parameter
29
+ from torch import nn, sin, pow
30
+
31
+
32
+ class Snake(nn.Module):
33
+ '''
34
+ Implementation of a sine-based periodic activation function
35
+ Shape:
36
+ - Input: (B, C, T)
37
+ - Output: (B, C, T), same shape as the input
38
+ Parameters:
39
+ - alpha - trainable parameter
40
+ References:
41
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
42
+ https://arxiv.org/abs/2006.08195
43
+ Examples:
44
+ >>> a1 = snake(256)
45
+ >>> x = torch.randn(256)
46
+ >>> x = a1(x)
47
+ '''
48
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
49
+ '''
50
+ Initialization.
51
+ INPUT:
52
+ - in_features: shape of the input
53
+ - alpha: trainable parameter
54
+ alpha is initialized to 1 by default, higher values = higher-frequency.
55
+ alpha will be trained along with the rest of your model.
56
+ '''
57
+ super(Snake, self).__init__()
58
+ self.in_features = in_features
59
+
60
+ # initialize alpha
61
+ self.alpha_logscale = alpha_logscale
62
+ if self.alpha_logscale: # log scale alphas initialized to zeros
63
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
64
+ else: # linear scale alphas initialized to ones
65
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
66
+
67
+ self.alpha.requires_grad = alpha_trainable
68
+
69
+ self.no_div_by_zero = 0.000000001
70
+
71
+ def forward(self, x):
72
+ '''
73
+ Forward pass of the function.
74
+ Applies the function to the input elementwise.
75
+ Snake ∶= x + 1/a * sin^2 (xa)
76
+ '''
77
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
78
+ if self.alpha_logscale:
79
+ alpha = torch.exp(alpha)
80
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
81
+
82
+ return x
83
+
84
+ def get_padding(kernel_size, dilation=1):
85
+ return int((kernel_size * dilation - dilation) / 2)
86
+
87
+ def init_weights(m, mean=0.0, std=0.01):
88
+ classname = m.__class__.__name__
89
+ if classname.find("Conv") != -1:
90
+ m.weight.data.normal_(mean, std)
91
+
92
+ """hifigan based generator implementation.
93
+
94
+ This code is modified from https://github.com/jik876/hifi-gan
95
+ ,https://github.com/kan-bayashi/ParallelWaveGAN and
96
+ https://github.com/NVIDIA/BigVGAN
97
+
98
+ """
99
+
100
+
101
+ class ResBlock(torch.nn.Module):
102
+ """Residual block module in HiFiGAN/BigVGAN."""
103
+ def __init__(
104
+ self,
105
+ channels: int = 512,
106
+ kernel_size: int = 3,
107
+ dilations: List[int] = [1, 3, 5],
108
+ ):
109
+ super(ResBlock, self).__init__()
110
+ self.convs1 = nn.ModuleList()
111
+ self.convs2 = nn.ModuleList()
112
+
113
+ for dilation in dilations:
114
+ self.convs1.append(
115
+ weight_norm(
116
+ Conv1d(
117
+ channels,
118
+ channels,
119
+ kernel_size,
120
+ 1,
121
+ dilation=dilation,
122
+ padding=get_padding(kernel_size, dilation)
123
+ )
124
+ )
125
+ )
126
+ self.convs2.append(
127
+ weight_norm(
128
+ Conv1d(
129
+ channels,
130
+ channels,
131
+ kernel_size,
132
+ 1,
133
+ dilation=1,
134
+ padding=get_padding(kernel_size, 1)
135
+ )
136
+ )
137
+ )
138
+ self.convs1.apply(init_weights)
139
+ self.convs2.apply(init_weights)
140
+ self.activations1 = nn.ModuleList([
141
+ Snake(channels, alpha_logscale=False)
142
+ for _ in range(len(self.convs1))
143
+ ])
144
+ self.activations2 = nn.ModuleList([
145
+ Snake(channels, alpha_logscale=False)
146
+ for _ in range(len(self.convs2))
147
+ ])
148
+
149
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
150
+ for idx in range(len(self.convs1)):
151
+ xt = self.activations1[idx](x)
152
+ xt = self.convs1[idx](xt)
153
+ xt = self.activations2[idx](xt)
154
+ xt = self.convs2[idx](xt)
155
+ x = xt + x
156
+ return x
157
+
158
+ def remove_weight_norm(self):
159
+ for idx in range(len(self.convs1)):
160
+ remove_weight_norm(self.convs1[idx])
161
+ remove_weight_norm(self.convs2[idx])
162
+
163
+
164
+ class SineGen(torch.nn.Module):
165
+ """ Definition of sine generator
166
+ SineGen(samp_rate, harmonic_num = 0,
167
+ sine_amp = 0.1, noise_std = 0.003,
168
+ voiced_threshold = 0,
169
+ flag_for_pulse=False)
170
+ samp_rate: sampling rate in Hz
171
+ harmonic_num: number of harmonic overtones (default 0)
172
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
173
+ noise_std: std of Gaussian noise (default 0.003)
174
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
175
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
176
+ Note: when flag_for_pulse is True, the first time step of a voiced
177
+ segment is always sin(np.pi) or cos(0)
178
+ """
179
+
180
+ def __init__(self, samp_rate, harmonic_num=0,
181
+ sine_amp=0.1, noise_std=0.003,
182
+ voiced_threshold=0):
183
+ super(SineGen, self).__init__()
184
+ self.sine_amp = sine_amp
185
+ self.noise_std = noise_std
186
+ self.harmonic_num = harmonic_num
187
+ self.sampling_rate = samp_rate
188
+ self.voiced_threshold = voiced_threshold
189
+
190
+ def _f02uv(self, f0):
191
+ # generate uv signal
192
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
193
+ return uv
194
+
195
+ @torch.no_grad()
196
+ def forward(self, f0):
197
+ """
198
+ :param f0: [B, 1, sample_len], Hz
199
+ :return: [B, 1, sample_len]
200
+ """
201
+
202
+ F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
203
+ for i in range(self.harmonic_num + 1):
204
+ F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
205
+
206
+ theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
207
+ u_dist = Uniform(low=-np.pi, high=np.pi)
208
+ phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
209
+ phase_vec[:, 0, :] = 0
210
+
211
+ # generate sine waveforms
212
+ sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
213
+
214
+ # generate uv signal
215
+ uv = self._f02uv(f0)
216
+
217
+ # noise: for unvoiced should be similar to sine_amp
218
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
219
+ # . for voiced regions is self.noise_std
220
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
221
+ noise = noise_amp * torch.randn_like(sine_waves)
222
+
223
+ # first: set the unvoiced part to 0 by uv
224
+ # then: additive noise
225
+ sine_waves = sine_waves * uv + noise
226
+ return sine_waves, uv, noise
227
+
228
+
229
+ class SourceModuleHnNSF(torch.nn.Module):
230
+ """ SourceModule for hn-nsf
231
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
232
+ add_noise_std=0.003, voiced_threshod=0)
233
+ sampling_rate: sampling_rate in Hz
234
+ harmonic_num: number of harmonic above F0 (default: 0)
235
+ sine_amp: amplitude of sine source signal (default: 0.1)
236
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
237
+ note that amplitude of noise in unvoiced is decided
238
+ by sine_amp
239
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
240
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
241
+ F0_sampled (batchsize, length, 1)
242
+ Sine_source (batchsize, length, 1)
243
+ noise_source (batchsize, length 1)
244
+ uv (batchsize, length, 1)
245
+ """
246
+
247
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
248
+ add_noise_std=0.003, voiced_threshod=0):
249
+ super(SourceModuleHnNSF, self).__init__()
250
+
251
+ self.sine_amp = sine_amp
252
+ self.noise_std = add_noise_std
253
+
254
+ # to produce sine waveforms
255
+ self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
256
+ sine_amp, add_noise_std, voiced_threshod)
257
+
258
+ # to merge source harmonics into a single excitation
259
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
260
+ self.l_tanh = torch.nn.Tanh()
261
+
262
+ def forward(self, x):
263
+ """
264
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
265
+ F0_sampled (batchsize, length, 1)
266
+ Sine_source (batchsize, length, 1)
267
+ noise_source (batchsize, length 1)
268
+ """
269
+ # source for harmonic branch
270
+ with torch.no_grad():
271
+ sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
272
+ sine_wavs = sine_wavs.transpose(1, 2)
273
+ uv = uv.transpose(1, 2)
274
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
275
+
276
+ # source for noise branch, in the same shape as uv
277
+ noise = torch.randn_like(uv) * self.sine_amp / 3
278
+ return sine_merge, noise, uv
279
+
280
+
281
+ class HiFTGenerator(nn.Module):
282
+ """
283
+ HiFTNet Generator: Neural Source Filter + ISTFTNet
284
+ https://arxiv.org/abs/2309.09493
285
+ """
286
+ def __init__(
287
+ self,
288
+ in_channels: int = 80,
289
+ base_channels: int = 512,
290
+ nb_harmonics: int = 8,
291
+ sampling_rate: int = 22050,
292
+ nsf_alpha: float = 0.1,
293
+ nsf_sigma: float = 0.003,
294
+ nsf_voiced_threshold: float = 10,
295
+ upsample_rates: List[int] = [8, 8],
296
+ upsample_kernel_sizes: List[int] = [16, 16],
297
+ istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
298
+ resblock_kernel_sizes: List[int] = [3, 7, 11],
299
+ resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
300
+ source_resblock_kernel_sizes: List[int] = [7, 11],
301
+ source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
302
+ lrelu_slope: float = 0.1,
303
+ audio_limit: float = 0.99,
304
+ f0_predictor: torch.nn.Module = None,
305
+ ):
306
+ super(HiFTGenerator, self).__init__()
307
+
308
+ self.out_channels = 1
309
+ self.nb_harmonics = nb_harmonics
310
+ self.sampling_rate = sampling_rate
311
+ self.istft_params = istft_params
312
+ self.lrelu_slope = lrelu_slope
313
+ self.audio_limit = audio_limit
314
+
315
+ self.num_kernels = len(resblock_kernel_sizes)
316
+ self.num_upsamples = len(upsample_rates)
317
+ self.m_source = SourceModuleHnNSF(
318
+ sampling_rate=sampling_rate,
319
+ upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
320
+ harmonic_num=nb_harmonics,
321
+ sine_amp=nsf_alpha,
322
+ add_noise_std=nsf_sigma,
323
+ voiced_threshod=nsf_voiced_threshold)
324
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
325
+
326
+ self.conv_pre = weight_norm(
327
+ Conv1d(in_channels, base_channels, 7, 1, padding=3)
328
+ )
329
+
330
+ # Up
331
+ self.ups = nn.ModuleList()
332
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
333
+ self.ups.append(
334
+ weight_norm(
335
+ ConvTranspose1d(
336
+ base_channels // (2**i),
337
+ base_channels // (2**(i + 1)),
338
+ k,
339
+ u,
340
+ padding=(k - u) // 2,
341
+ )
342
+ )
343
+ )
344
+
345
+ # Down
346
+ self.source_downs = nn.ModuleList()
347
+ self.source_resblocks = nn.ModuleList()
348
+ downsample_rates = [1] + upsample_rates[::-1][:-1]
349
+ downsample_cum_rates = np.cumprod(downsample_rates)
350
+ for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
351
+ if u == 1:
352
+ self.source_downs.append(
353
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
354
+ )
355
+ else:
356
+ self.source_downs.append(
357
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
358
+ )
359
+
360
+ self.source_resblocks.append(
361
+ ResBlock(base_channels // (2 ** (i + 1)), k, d)
362
+ )
363
+
364
+ self.resblocks = nn.ModuleList()
365
+ for i in range(len(self.ups)):
366
+ ch = base_channels // (2**(i + 1))
367
+ for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
368
+ self.resblocks.append(ResBlock(ch, k, d))
369
+
370
+ self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
371
+ self.ups.apply(init_weights)
372
+ self.conv_post.apply(init_weights)
373
+ self.reflection_pad = nn.ReflectionPad1d((1, 0))
374
+ self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
375
+ self.f0_predictor = f0_predictor
376
+
377
+ def remove_weight_norm(self):
378
+ print('Removing weight norm...')
379
+ for l in self.ups:
380
+ remove_weight_norm(l)
381
+ for l in self.resblocks:
382
+ l.remove_weight_norm()
383
+ remove_weight_norm(self.conv_pre)
384
+ remove_weight_norm(self.conv_post)
385
+ self.m_source.remove_weight_norm()
386
+ for l in self.source_downs:
387
+ remove_weight_norm(l)
388
+ for l in self.source_resblocks:
389
+ l.remove_weight_norm()
390
+
391
+ def _stft(self, x):
392
+ spec = torch.stft(
393
+ x,
394
+ self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
395
+ return_complex=True)
396
+ spec = torch.view_as_real(spec) # [B, F, TT, 2]
397
+ return spec[..., 0], spec[..., 1]
398
+
399
+ def _istft(self, magnitude, phase):
400
+ magnitude = torch.clip(magnitude, max=1e2)
401
+ real = magnitude * torch.cos(phase)
402
+ img = magnitude * torch.sin(phase)
403
+ inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
404
+ self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
405
+ return inverse_transform
406
+
407
+ def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
408
+ s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
409
+ s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
410
+
411
+ x = self.conv_pre(x)
412
+ for i in range(self.num_upsamples):
413
+ x = F.leaky_relu(x, self.lrelu_slope)
414
+ x = self.ups[i](x)
415
+
416
+ if i == self.num_upsamples - 1:
417
+ x = self.reflection_pad(x)
418
+
419
+ # fusion
420
+ si = self.source_downs[i](s_stft)
421
+ si = self.source_resblocks[i](si)
422
+ x = x + si
423
+
424
+ xs = None
425
+ for j in range(self.num_kernels):
426
+ if xs is None:
427
+ xs = self.resblocks[i * self.num_kernels + j](x)
428
+ else:
429
+ xs += self.resblocks[i * self.num_kernels + j](x)
430
+ x = xs / self.num_kernels
431
+
432
+ x = F.leaky_relu(x)
433
+ x = self.conv_post(x)
434
+ magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
435
+ phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
436
+
437
+ x = self._istft(magnitude, phase)
438
+ x = torch.clamp(x, -self.audio_limit, self.audio_limit)
439
+ return x
440
+
441
+ def forward(
442
+ self,
443
+ batch: dict,
444
+ # device: torch.device,
445
+ ) -> Dict[str, Optional[torch.Tensor]]:
446
+ speech_feat = batch['speech_feat'].transpose(1, 2) # .to(device)
447
+ # mel->f0
448
+ f0 = self.f0_predictor(speech_feat)
449
+ # f0->source
450
+ s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
451
+ s, _, _ = self.m_source(s)
452
+ s = s.transpose(1, 2)
453
+ # mel+source->speech
454
+ generated_speech = self.decode(x=speech_feat, s=s)
455
+ return generated_speech, f0
456
+
457
+ @torch.inference_mode()
458
+ def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
459
+ # mel->f0
460
+ f0 = self.f0_predictor(speech_feat)
461
+ # f0->source
462
+ s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
463
+ s, _, _ = self.m_source(s)
464
+ s = s.transpose(1, 2)
465
+ # use cache_source to avoid glitch
466
+ if cache_source.shape[2] != 0:
467
+ s[:, :, :cache_source.shape[2]] = cache_source
468
+ generated_speech = self.decode(x=speech_feat, s=s)
469
+ return generated_speech, s
470
+
471
+
472
+ class ConvRNNF0Predictor(nn.Module):
473
+ def __init__(self,
474
+ num_class: int = 1,
475
+ in_channels: int = 80,
476
+ cond_channels: int = 512
477
+ ):
478
+ super().__init__()
479
+
480
+ self.num_class = num_class
481
+ self.condnet = nn.Sequential(
482
+ weight_norm(
483
+ nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
484
+ ),
485
+ nn.ELU(),
486
+ weight_norm(
487
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
488
+ ),
489
+ nn.ELU(),
490
+ weight_norm(
491
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
492
+ ),
493
+ nn.ELU(),
494
+ weight_norm(
495
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
496
+ ),
497
+ nn.ELU(),
498
+ weight_norm(
499
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
500
+ ),
501
+ nn.ELU(),
502
+ )
503
+ self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
504
+
505
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
506
+ x = self.condnet(x)
507
+ x = x.transpose(1, 2)
508
+ return torch.abs(self.classifier(x).squeeze(-1))
509
+
510
+
511
+ class Cosy24kVocoder(nn.Module):
512
+ def __init__(self):
513
+ super().__init__()
514
+ self.hifigan_generator = HiFTGenerator(
515
+ in_channels=80,
516
+ base_channels=512,
517
+ nb_harmonics=8,
518
+ sampling_rate=24000,
519
+ nsf_alpha=0.1,
520
+ nsf_sigma=0.003,
521
+ nsf_voiced_threshold=10,
522
+ upsample_rates=[8, 5, 3],
523
+ upsample_kernel_sizes=[16, 11, 7],
524
+ resblock_kernel_sizes=[3, 7, 11],
525
+ resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
526
+ source_resblock_kernel_sizes=[7, 7, 11],
527
+ source_resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
528
+ lrelu_slope=0.1,
529
+ audio_limit=0.99,
530
+ f0_predictor=ConvRNNF0Predictor(
531
+ num_class=1,
532
+ in_channels=80,
533
+ cond_channels=512,
534
+ ),
535
+ )
536
+
537
+ def decode(self, mel, device="cuda"):
538
+ """
539
+ Args: mel: (batch_size, n_frames, n_mel)
540
+ """
541
+ generated_speech, f0 = self.hifigan_generator.forward(
542
+ {"speech_feat": mel.transpose(1, 2)}, # device=device
543
+ )
544
+ return generated_speech
545
+
546
+ @classmethod
547
+ def from_pretrained(cls, model_path: str):
548
+ """Load a pretrained model from a checkpoint."""
549
+ model = cls()
550
+ model.hifigan_generator.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=True)
551
+ model.eval()
552
+ return model
environment.yml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ name: longcat_next
2
+
3
+ dependencies:
4
+ - python=3.10
5
+ - ffmpeg<7
6
+ - pip
7
+ - pip:
8
+ - soundfile==0.13.1
generation_config.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "eos_token_id": 2,
4
+ "pad_token_id": 3,
5
+
6
+ "max_new_tokens": 2048,
7
+ "do_sample": true,
8
+ "temperature": 0.4,
9
+ "top_k": 20,
10
+ "top_p": 0.85,
11
+ "repetition_penalty": 1.1,
12
+
13
+ "visual_generation_config": {
14
+ "do_sample": true,
15
+ "temperature": 0.5,
16
+ "top_p": 0.75,
17
+ "top_k": 1024,
18
+ "custom_params": {
19
+ "cfg_scale": 3.0,
20
+ "token_h": 37,
21
+ "token_w": 37,
22
+ "anyres_prefix": "<longcat_img_token_size>{h} {w}</longcat_img_token_size>"
23
+ }
24
+ },
25
+
26
+ "audio_generation_config": {
27
+ "audio_parallel_decoding": false,
28
+ "do_sample": true,
29
+ "temperature": 0.5,
30
+ "top_k": 5,
31
+ "top_p": 0.85,
32
+ "repetition_penalty": 1.3,
33
+ "custom_params": {
34
+ "sampling_rate": 24000,
35
+ "wave_concat_overlap": 1200
36
+ }
37
+ },
38
+
39
+ "transformers_version": "4.57.6"
40
+ }
image_refiner.py ADDED
@@ -0,0 +1,748 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Image refiner: refiner pipeline, refiner container, and utilities.
2
+
3
+ Contains:
4
+ - RefinerImageProcessor: Image pre/post-processing for the diffusion pipeline
5
+ - RefinerPipeline: DiffusionPipeline for image refinement
6
+ - ImageRefinerContainer: nn.Module container for refiner sub-modules
7
+ - IdentityWithArgs: Placeholder module for cond_proj
8
+ - de_transform / tensor2pil: Tensor-to-PIL conversion utilities
9
+ """
10
+
11
+ import inspect
12
+ import math
13
+ import warnings
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import numpy as np
18
+ from safetensors.torch import load_file
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ from PIL import Image
23
+
24
+ from diffusers import DiffusionPipeline
25
+ from diffusers.configuration_utils import register_to_config
26
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor, is_valid_image_imagelist
27
+ from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
28
+ from .refiner_modules import FlowMatchEulerDiscreteScheduler
29
+
30
+ from .refiner_modules import Transformer2DModel, RotaryPosEmbed
31
+
32
+ # ---------------------------------------------------------------------------
33
+ # Helpers
34
+ # ---------------------------------------------------------------------------
35
+
36
+
37
+ def _clean_config_dict(cfg, cls=None) -> dict:
38
+ """Convert a PretrainedConfig to a clean dict for model construction.
39
+
40
+ If ``cls`` is provided, only keeps keys that match the cls.__init__ params
41
+ (allowlist approach). Otherwise falls back to blocklist filtering.
42
+ """
43
+ if hasattr(cfg, "to_dict"):
44
+ d = cfg.to_dict()
45
+ elif isinstance(cfg, dict):
46
+ d = dict(cfg)
47
+ else:
48
+ d = {k: v for k, v in vars(cfg).items()}
49
+
50
+ if cls is not None:
51
+ import inspect
52
+ sig = inspect.signature(cls.__init__)
53
+ valid_keys = set(sig.parameters.keys()) - {"self"}
54
+ if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()):
55
+ # Has **kwargs — can't filter by allowlist, fall through to blocklist
56
+ pass
57
+ else:
58
+ return {k: v for k, v in d.items() if k in valid_keys}
59
+
60
+ # Blocklist: remove HuggingFace PretrainedConfig metadata
61
+ _PRETRAINED_CONFIG_KEYS = {
62
+ "_name_or_path", "transformers_version", "model_type", "_commit_hash",
63
+ "_attn_implementation", "_attn_implementation_autoset", "return_dict",
64
+ "output_hidden_states", "output_attentions", "use_bfloat16",
65
+ "torchscript", "torch_dtype", "is_encoder_decoder", "is_decoder",
66
+ "add_cross_attention", "tie_encoder_decoder", "tie_word_embeddings",
67
+ "cross_attention_hidden_size", "chunk_size_feed_forward", "decoder_start_token_id",
68
+ "architectures", "finetuning_task", "id2label", "label2id", "prefix",
69
+ "problem_type", "tokenizer_class", "task_specific_params", "pruned_heads",
70
+ "bos_token_id", "eos_token_id", "pad_token_id", "sep_token_id",
71
+ "max_length", "min_length", "do_sample", "early_stopping",
72
+ "num_beams", "num_beam_groups", "diversity_penalty", "temperature",
73
+ "top_k", "top_p", "typical_p", "repetition_penalty", "length_penalty",
74
+ "no_repeat_ngram_size", "encoder_no_repeat_ngram_size", "bad_words_ids",
75
+ "num_return_sequences", "output_scores", "return_dict_in_generate",
76
+ "forced_bos_token_id", "forced_eos_token_id", "remove_invalid_values",
77
+ "exponential_decay_length_penalty", "suppress_tokens", "begin_suppress_tokens",
78
+ "tf_legacy_loss", "dtype",
79
+ }
80
+ return {k: v for k, v in d.items() if not k.startswith("_") and k not in _PRETRAINED_CONFIG_KEYS}
81
+
82
+
83
+ # ---------------------------------------------------------------------------
84
+ # Image Refiner Container (nn.Module for state_dict loading)
85
+ # ---------------------------------------------------------------------------
86
+
87
+
88
+ class ImageRefinerContainer(nn.Module):
89
+ """Container for refiner components.
90
+
91
+ Holds base_transformer, vae, cond_proj as nn.Module children so their
92
+ parameters appear in the parent model's state_dict and are loaded
93
+ automatically via from_pretrained.
94
+ """
95
+
96
+ def __init__(self, visual_decoder_config):
97
+ super().__init__()
98
+
99
+ tc = visual_decoder_config.transformer_config
100
+ vc = visual_decoder_config.vae_config
101
+
102
+ self.base_transformer = Transformer2DModel(**_clean_config_dict(tc))
103
+
104
+ self.vae = AutoencoderKL(**_clean_config_dict(vc))
105
+ self.vae.requires_grad_(False)
106
+
107
+ text_feat_dim = getattr(tc, "text_feat_dim", 3584)
108
+ codebook_dim = getattr(visual_decoder_config, "codebook_dim", text_feat_dim)
109
+ if codebook_dim != text_feat_dim:
110
+ self.cond_proj = nn.Linear(codebook_dim, text_feat_dim)
111
+ else:
112
+ self.cond_proj = IdentityWithArgs()
113
+
114
+ @classmethod
115
+ def from_pretrained(cls, config, model_path: str):
116
+ model = cls(config)
117
+ weight_dict = load_file(model_path, device="cpu")
118
+ model.load_state_dict({k.removeprefix("image_refiner."): v for k, v in weight_dict.items() if k.startswith("image_refiner.")}, strict=True)
119
+ model.eval()
120
+ return model
121
+
122
+ @property
123
+ def device(self):
124
+ return next(self.parameters()).device
125
+
126
+ @property
127
+ def dtype(self):
128
+ return next(self.parameters()).dtype
129
+
130
+
131
+ class RefinerImageProcessor(VaeImageProcessor):
132
+ """Image processor for refiner - extends diffusers' VaeImageProcessor."""
133
+
134
+ @register_to_config
135
+ def __init__(
136
+ self,
137
+ do_resize: bool = True,
138
+ vae_scale_factor: int = 16,
139
+ resample: str = "lanczos",
140
+ max_pixels: Optional[int] = None,
141
+ max_side_length: Optional[int] = None,
142
+ do_normalize: bool = True,
143
+ do_binarize: bool = False,
144
+ do_convert_grayscale: bool = False,
145
+ ):
146
+ super().__init__(
147
+ do_resize=do_resize,
148
+ vae_scale_factor=vae_scale_factor,
149
+ resample=resample,
150
+ do_normalize=do_normalize,
151
+ do_binarize=do_binarize,
152
+ do_convert_grayscale=do_convert_grayscale,
153
+ )
154
+ self.max_pixels = max_pixels
155
+ self.max_side_length = max_side_length
156
+
157
+ def get_new_height_width(
158
+ self,
159
+ image: Union["PIL.Image.Image", np.ndarray, torch.Tensor],
160
+ height: Optional[int] = None,
161
+ width: Optional[int] = None,
162
+ max_pixels: Optional[int] = None,
163
+ max_side_length: Optional[int] = None,
164
+ ) -> Tuple[int, int]:
165
+ import PIL.Image
166
+
167
+ if height is None:
168
+ if isinstance(image, PIL.Image.Image):
169
+ height = image.height
170
+ elif isinstance(image, torch.Tensor):
171
+ height = image.shape[2]
172
+ else:
173
+ height = image.shape[1]
174
+
175
+ if width is None:
176
+ if isinstance(image, PIL.Image.Image):
177
+ width = image.width
178
+ elif isinstance(image, torch.Tensor):
179
+ width = image.shape[3]
180
+ else:
181
+ width = image.shape[2]
182
+
183
+ if max_side_length is None:
184
+ max_side_length = self.max_side_length
185
+ if max_pixels is None:
186
+ max_pixels = self.max_pixels
187
+
188
+ ratio = 1.0
189
+ if max_side_length is not None:
190
+ max_side_length_ratio = max_side_length / max(height, width)
191
+ else:
192
+ max_side_length_ratio = 1.0
193
+
194
+ cur_pixels = height * width
195
+ max_pixels_ratio = (max_pixels / cur_pixels) ** 0.5 if max_pixels is not None else 1.0
196
+ ratio = min(max_pixels_ratio, max_side_length_ratio, 1.0)
197
+
198
+ sf = self.config.vae_scale_factor
199
+ new_height = int(height * ratio) // sf * sf
200
+ new_width = int(width * ratio) // sf * sf
201
+ return new_height, new_width
202
+
203
+ def preprocess(
204
+ self,
205
+ image: PipelineImageInput,
206
+ height: Optional[int] = None,
207
+ width: Optional[int] = None,
208
+ max_pixels: Optional[int] = None,
209
+ max_side_length: Optional[int] = None,
210
+ resize_mode: str = "default",
211
+ crops_coords: Optional[Tuple[int, int, int, int]] = None,
212
+ ) -> torch.Tensor:
213
+ import PIL.Image
214
+
215
+ supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
216
+
217
+ if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3:
218
+ if isinstance(image, torch.Tensor):
219
+ image = image.unsqueeze(1)
220
+ else:
221
+ if image.shape[-1] == 1:
222
+ image = np.expand_dims(image, axis=0)
223
+ else:
224
+ image = np.expand_dims(image, axis=-1)
225
+
226
+ if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4:
227
+ warnings.warn(
228
+ "Passing `image` as a list of 4d np.ndarray is deprecated. "
229
+ "Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray",
230
+ FutureWarning,
231
+ )
232
+ image = np.concatenate(image, axis=0)
233
+ if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
234
+ warnings.warn(
235
+ "Passing `image` as a list of 4d torch.Tensor is deprecated. "
236
+ "Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor",
237
+ FutureWarning,
238
+ )
239
+ image = torch.cat(image, axis=0)
240
+
241
+ if not is_valid_image_imagelist(image):
242
+ raise ValueError(
243
+ f"Input is in incorrect format. Currently, we only support "
244
+ f"{', '.join(str(x) for x in supported_formats)}"
245
+ )
246
+ if not isinstance(image, list):
247
+ image = [image]
248
+
249
+ if isinstance(image[0], PIL.Image.Image):
250
+ if crops_coords is not None:
251
+ image = [i.crop(crops_coords) for i in image]
252
+ if self.config.do_resize:
253
+ height, width = self.get_new_height_width(image[0], height, width, max_pixels, max_side_length)
254
+ image = [self.resize(i, height, width, resize_mode=resize_mode) for i in image]
255
+ if self.config.do_convert_grayscale:
256
+ image = [self.convert_to_grayscale(i) for i in image]
257
+ image = self.pil_to_numpy(image)
258
+ image = self.numpy_to_pt(image)
259
+ elif isinstance(image[0], np.ndarray):
260
+ image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
261
+ image = self.numpy_to_pt(image)
262
+ height, width = self.get_new_height_width(image, height, width, max_pixels, max_side_length)
263
+ if self.config.do_resize:
264
+ image = self.resize(image, height, width)
265
+ elif isinstance(image[0], torch.Tensor):
266
+ image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
267
+ if self.config.do_convert_grayscale and image.ndim == 3:
268
+ image = image.unsqueeze(1)
269
+ channel = image.shape[1]
270
+ if channel == self.config.vae_latent_channels:
271
+ return image
272
+ height, width = self.get_new_height_width(image, height, width, max_pixels, max_side_length)
273
+ if self.config.do_resize:
274
+ image = self.resize(image, height, width)
275
+
276
+ do_normalize = self.config.do_normalize
277
+ if do_normalize and image.min() < 0:
278
+ warnings.warn(
279
+ "Passing `image` as torch tensor with value range in [-1,1] is deprecated. "
280
+ f"The expected value range for image tensor is [0,1] when passing as pytorch tensor or numpy Array. "
281
+ f"You passed `image` with value range [{image.min()},{image.max()}]",
282
+ FutureWarning,
283
+ )
284
+ do_normalize = False
285
+ if do_normalize:
286
+ image = self.normalize(image)
287
+
288
+ if self.config.do_binarize:
289
+ image = self.binarize(image)
290
+
291
+ return image
292
+
293
+
294
+ @dataclass
295
+ class RefinerOutput:
296
+ images: Union[List[Image.Image], torch.Tensor]
297
+
298
+
299
+ class IdentityWithArgs(nn.Module):
300
+ """Placeholder Identity module for cond_proj."""
301
+
302
+ def __init__(self, dtype=torch.float32, device=None):
303
+ super().__init__()
304
+ self.register_buffer("_dummy", torch.zeros((), dtype=dtype, device=device))
305
+
306
+ @property
307
+ def dtype(self):
308
+ return self._dummy.dtype
309
+
310
+ @property
311
+ def device(self):
312
+ return self._dummy.device
313
+
314
+ def forward(self, x, *args, **kwargs):
315
+ return x
316
+
317
+
318
+ def _retrieve_timesteps(
319
+ scheduler: FlowMatchEulerDiscreteScheduler,
320
+ num_inference_steps: Optional[int] = None,
321
+ device: Optional[Union[str, torch.device]] = None,
322
+ timesteps: Optional[List[int]] = None,
323
+ **kwargs,
324
+ ):
325
+ # If scheduler uses dynamic shifting and caller passed num_tokens, compute mu
326
+ # (same as training code refiner pipeline)
327
+ num_tokens = kwargs.pop("num_tokens", None)
328
+ if num_tokens is not None and getattr(scheduler.config, "use_dynamic_shifting", False):
329
+ # Compute mu from num_tokens using scheduler's linear interpolation
330
+ base_shift = getattr(scheduler.config, "base_shift", 0.5)
331
+ max_shift = getattr(scheduler.config, "max_shift", 1.15)
332
+ base_seq_len = getattr(scheduler.config, "base_image_seq_len", 256)
333
+ max_seq_len = getattr(scheduler.config, "max_image_seq_len", 4096)
334
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
335
+ b = base_shift - m * base_seq_len
336
+ mu = num_tokens * m + b
337
+ kwargs["mu"] = mu
338
+
339
+ accepted = set(inspect.signature(scheduler.set_timesteps).parameters.keys())
340
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in accepted}
341
+
342
+ if timesteps is not None:
343
+ if "timesteps" not in accepted:
344
+ raise ValueError(
345
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
346
+ f" timestep schedules. Please check whether you are using the correct scheduler."
347
+ )
348
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **filtered_kwargs)
349
+ timesteps = scheduler.timesteps
350
+ num_inference_steps = len(timesteps)
351
+ else:
352
+ scheduler.set_timesteps(num_inference_steps, device=device, **filtered_kwargs)
353
+ timesteps = scheduler.timesteps
354
+ return timesteps, num_inference_steps
355
+
356
+
357
+ class RefinerPipeline(DiffusionPipeline):
358
+ """
359
+ Image refiner evaluation pipeline.
360
+
361
+ - cond comes from upstream model: encoder_hidden_states (quants / last_latent)
362
+ - grid_thw_list is used to split cond (consistent with training)
363
+ - image as ref image
364
+ - Supports FlowMatchEulerDiscreteScheduler + velocity model
365
+ """
366
+
367
+ def __init__(
368
+ self,
369
+ vae: AutoencoderKL,
370
+ transformer: Transformer2DModel,
371
+ scheduler: FlowMatchEulerDiscreteScheduler,
372
+ cond_proj: Optional[nn.Module] = None,
373
+ ):
374
+ super().__init__()
375
+
376
+ self.register_modules(
377
+ vae=vae,
378
+ transformer=transformer,
379
+ scheduler=scheduler,
380
+ cond_proj=cond_proj if cond_proj is not None else IdentityWithArgs(),
381
+ )
382
+
383
+ self.vae_scale_factor = (
384
+ 2 ** (len(self.vae.config.block_out_channels) - 1)
385
+ if hasattr(self.vae.config, "block_out_channels")
386
+ else 8
387
+ )
388
+ self.image_processor = RefinerImageProcessor(
389
+ vae_scale_factor=self.vae_scale_factor * 2, do_resize=True
390
+ )
391
+ self.patch_size = int(getattr(self.transformer.config, "patch_size", 16))
392
+
393
+ self._num_timesteps: int = 0
394
+ self._current_timestep: Optional[torch.Tensor] = None
395
+ self._interrupt: bool = False
396
+ self._freqs_cis: Optional[torch.Tensor] = None
397
+ self._text_guidance_scale: float = 1.0
398
+ self._image_guidance_scale: float = 1.0
399
+ self._cfg_range: Tuple[float, float] = (0.0, 1.0)
400
+
401
+ @torch.no_grad()
402
+ def _get_freqs_cis(self, device, dtype):
403
+ if self._freqs_cis is None:
404
+ self._freqs_cis = RotaryPosEmbed.get_freqs_cis(
405
+ self.transformer.config.axes_dim_rope,
406
+ self.transformer.config.axes_lens,
407
+ theta=10000,
408
+ )
409
+ return self._freqs_cis
410
+
411
+ @staticmethod
412
+ def _split_tokens(
413
+ encoder_hidden_states: torch.Tensor,
414
+ grid_thw_list: List[Tuple[int, int, int]],
415
+ ) -> List[torch.Tensor]:
416
+ splits = [int(h) * int(w) // 4 for (_, h, w) in grid_thw_list]
417
+ return list(torch.split(encoder_hidden_states, splits, dim=1))
418
+
419
+ @staticmethod
420
+ def _looks_like_latents(x: Union[torch.Tensor, Image.Image], latent_ch_hint: int = 16) -> bool:
421
+ if not isinstance(x, torch.Tensor):
422
+ return False
423
+ if x.ndim not in (3, 4):
424
+ return False
425
+ c = int(x.shape[-3])
426
+ if c == 3:
427
+ return False
428
+ if c == latent_ch_hint:
429
+ return True
430
+ if c > 3 and c <= 32:
431
+ return True
432
+ return False
433
+
434
+ @torch.no_grad()
435
+ def _preprocess_to_vae_range(self, img: torch.Tensor) -> torch.Tensor:
436
+ if img.dtype not in (torch.float32, torch.float16, torch.bfloat16):
437
+ img = img.float()
438
+ if img.max() > 1.5:
439
+ img = img / 255.0
440
+ if img.min() >= 0.0 and img.max() <= 1.0:
441
+ img = img * 2.0 - 1.0
442
+ return img.clamp(-1, 1)
443
+
444
+ @torch.no_grad()
445
+ def _encode_image_to_latents(
446
+ self,
447
+ img_any: Union[Image.Image, torch.Tensor],
448
+ device,
449
+ dtype,
450
+ ) -> Tuple[torch.Tensor, int, int]:
451
+ latent_ch_hint = int(getattr(getattr(self.vae, "config", None), "latent_channels", 16))
452
+
453
+ if self._looks_like_latents(img_any, latent_ch_hint=latent_ch_hint):
454
+ z = img_any
455
+ if z.ndim == 3:
456
+ z = z.unsqueeze(0)
457
+ z = z.to(device=device, dtype=dtype)
458
+ H_lat, W_lat = z.shape[-2], z.shape[-1]
459
+ return z, H_lat, W_lat
460
+
461
+ if isinstance(img_any, Image.Image):
462
+ img = torch.from_numpy(
463
+ np.array(img_any).astype("float32") / 255.0
464
+ ).permute(2, 0, 1).unsqueeze(0)
465
+ elif isinstance(img_any, torch.Tensor):
466
+ img = img_any
467
+ if img.ndim == 3:
468
+ img = img.unsqueeze(0)
469
+ else:
470
+ raise TypeError("Unsupported image type. Use PIL.Image or torch.Tensor or latent Tensor.")
471
+
472
+ img = self._preprocess_to_vae_range(img)
473
+
474
+ H, W = img.shape[-2:]
475
+ base = self.patch_size * self.vae_scale_factor
476
+ target_H = max(base, math.ceil(H / base) * base)
477
+ target_W = max(base, math.ceil(W / base) * base)
478
+ if (H != target_H) or (W != target_W):
479
+ img = F.interpolate(img, size=(target_H, target_W), mode="bilinear", align_corners=False)
480
+
481
+ img = img.to(device=device, dtype=self.vae.dtype)
482
+
483
+ posterior = self.vae.encode(img).latent_dist
484
+ z0 = posterior.sample()
485
+ if getattr(self.vae.config, "shift_factor", None) is not None:
486
+ z0 = z0 - self.vae.config.shift_factor
487
+ if getattr(self.vae.config, "scaling_factor", None) is not None:
488
+ z0 = z0 * self.vae.config.scaling_factor
489
+
490
+ z0 = z0.to(device=device, dtype=dtype)
491
+ H_lat, W_lat = z0.shape[-2], z0.shape[-1]
492
+ return z0, H_lat, W_lat
493
+
494
+ @staticmethod
495
+ def _expand_to_list(x, n):
496
+ if x is None:
497
+ return [None] * n
498
+ if isinstance(x, (Image.Image, torch.Tensor)):
499
+ return [x] * n
500
+ assert isinstance(x, list), "`image` must be PIL / Tensor or list of them."
501
+ assert len(x) == n, "`len(image)` must equal number of image chunks"
502
+ return x
503
+
504
+ @torch.no_grad()
505
+ def _denoise_once(
506
+ self,
507
+ cond_tokens: torch.Tensor,
508
+ ref_img: Optional[Union[Image.Image, torch.Tensor]],
509
+ num_inference_steps: int = 28,
510
+ timesteps: Optional[List[int]] = None,
511
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
512
+ output_type: str = "pil",
513
+ text_guidance_scale: float = 1.0,
514
+ image_guidance_scale: float = 1.0,
515
+ cfg_range: Tuple[float, float] = (0.0, 1.0),
516
+ enable_processor_bar: bool = True,
517
+ ):
518
+ device = cond_tokens.device
519
+ weight_dtype = self.transformer.dtype
520
+
521
+ self._text_guidance_scale = text_guidance_scale
522
+ self._image_guidance_scale = image_guidance_scale
523
+ self._cfg_range = cfg_range
524
+
525
+ cond_tokens = cond_tokens.to(device=device, dtype=weight_dtype)
526
+ text_feats = self.cond_proj(cond_tokens)
527
+ B, L, _ = text_feats.shape
528
+ text_mask = torch.ones(B, L, device=device, dtype=torch.bool)
529
+
530
+ ref_image_hidden_states = None
531
+ H_lat: int
532
+ W_lat: int
533
+
534
+ if ref_img is not None:
535
+ if isinstance(ref_img, torch.Tensor) and ref_img.ndim == 4 and ref_img.shape[0] == B:
536
+ z_ref, H_lat, W_lat = self._encode_image_to_latents(ref_img, device=device, dtype=weight_dtype)
537
+ ref_image_hidden_states = [[z_ref[b]] for b in range(B)]
538
+ else:
539
+ z_ref, H_lat, W_lat = self._encode_image_to_latents(ref_img, device=device, dtype=weight_dtype)
540
+ z_single = z_ref[0]
541
+ ref_image_hidden_states = [[z_single] for _ in range(B)]
542
+ else:
543
+ H_lat = W_lat = 128 // self.vae_scale_factor
544
+
545
+ C_lat = getattr(self.transformer.config, "in_channels", None)
546
+ if C_lat is None:
547
+ if ref_image_hidden_states is not None:
548
+ C_lat = ref_image_hidden_states[0][0].shape[0]
549
+ else:
550
+ raise ValueError("transformer.config.in_channels is None and no ref_img was provided.")
551
+ latents_shape = (B, C_lat, H_lat, W_lat)
552
+
553
+ if isinstance(generator, list):
554
+ if len(generator) != B:
555
+ raise ValueError(
556
+ f"len(generator)={len(generator)} must equal B={B} when passing list of generators."
557
+ )
558
+ latents = torch.stack(
559
+ [
560
+ torch.randn(
561
+ (1, C_lat, H_lat, W_lat),
562
+ generator=generator[i],
563
+ device=device,
564
+ dtype=weight_dtype,
565
+ ).squeeze(0)
566
+ for i in range(B)
567
+ ],
568
+ dim=0,
569
+ )
570
+ else:
571
+ latents = torch.randn(latents_shape, generator=generator, device=device, dtype=weight_dtype)
572
+
573
+ num_tokens = H_lat * W_lat
574
+ timesteps_sched, num_inference_steps = _retrieve_timesteps(
575
+ self.scheduler,
576
+ num_inference_steps=num_inference_steps,
577
+ device=device,
578
+ timesteps=timesteps,
579
+ num_tokens=num_tokens,
580
+ )
581
+ num_warmup_steps = max(len(timesteps_sched) - num_inference_steps * self.scheduler.order, 0)
582
+ self._num_timesteps = len(timesteps_sched)
583
+
584
+ freqs_cis = self._get_freqs_cis(device=device, dtype=weight_dtype)
585
+
586
+ progress_bar = self.progress_bar(total=num_inference_steps) if enable_processor_bar else None
587
+ for i, t in enumerate(timesteps_sched):
588
+ if self._interrupt:
589
+ continue
590
+ self._current_timestep = t
591
+
592
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
593
+
594
+ step_frac = i / max(len(timesteps_sched) - 1, 1)
595
+ use_cfg = (cfg_range[0] <= step_frac <= cfg_range[1]) and (
596
+ text_guidance_scale > 1.0 or image_guidance_scale > 1.0
597
+ )
598
+
599
+ if not use_cfg:
600
+ optional_kwargs: Dict[str, Any] = {}
601
+ if "ref_image_hidden_states" in inspect.signature(self.transformer.forward).parameters:
602
+ optional_kwargs["ref_image_hidden_states"] = ref_image_hidden_states
603
+ model_pred = self.transformer(
604
+ latents, timestep, text_feats, freqs_cis, text_mask, **optional_kwargs
605
+ )
606
+ else:
607
+ text_uncond = torch.zeros_like(text_feats)
608
+
609
+ opt_kwargs_text: Dict[str, Any] = {}
610
+ if "ref_image_hidden_states" in inspect.signature(self.transformer.forward).parameters:
611
+ opt_kwargs_text["ref_image_hidden_states"] = ref_image_hidden_states
612
+
613
+ model_pred_text = self.transformer(
614
+ latents, timestep, text_feats, freqs_cis, text_mask, **opt_kwargs_text
615
+ )
616
+
617
+ opt_kwargs_ref: Dict[str, Any] = {}
618
+ if "ref_image_hidden_states" in inspect.signature(self.transformer.forward).parameters:
619
+ opt_kwargs_ref["ref_image_hidden_states"] = ref_image_hidden_states
620
+
621
+ model_pred_ref = self.transformer(
622
+ latents, timestep, text_uncond, freqs_cis, text_mask, **opt_kwargs_ref
623
+ )
624
+
625
+ opt_kwargs_uncond: Dict[str, Any] = {}
626
+ if "ref_image_hidden_states" in inspect.signature(self.transformer.forward).parameters:
627
+ opt_kwargs_uncond["ref_image_hidden_states"] = None
628
+
629
+ model_pred_uncond = self.transformer(
630
+ latents, timestep, text_uncond, freqs_cis, text_mask, **opt_kwargs_uncond
631
+ )
632
+
633
+ if text_guidance_scale > 1.0 and image_guidance_scale > 1.0:
634
+ model_pred = (
635
+ model_pred_uncond
636
+ + image_guidance_scale * (model_pred_ref - model_pred_uncond)
637
+ + text_guidance_scale * (model_pred_text - model_pred_ref)
638
+ )
639
+ elif text_guidance_scale > 1.0:
640
+ model_pred = model_pred_uncond + text_guidance_scale * (model_pred_text - model_pred_uncond)
641
+ elif image_guidance_scale > 1.0:
642
+ model_pred = model_pred_uncond + image_guidance_scale * (model_pred_ref - model_pred_uncond)
643
+ else:
644
+ model_pred = model_pred_text
645
+
646
+ latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0]
647
+ latents = latents.to(dtype=weight_dtype)
648
+
649
+ if progress_bar is not None:
650
+ if i == len(timesteps_sched) - 1 or (
651
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
652
+ ):
653
+ progress_bar.update()
654
+
655
+ if progress_bar is not None:
656
+ progress_bar.close()
657
+
658
+ self._current_timestep = None
659
+
660
+ latents = latents.to(dtype=self.vae.dtype)
661
+ if getattr(self.vae.config, "scaling_factor", None) is not None:
662
+ latents = latents / self.vae.config.scaling_factor
663
+ if getattr(self.vae.config, "shift_factor", None) is not None:
664
+ latents = latents + self.vae.config.shift_factor
665
+ image = self.vae.decode(latents, return_dict=False)[0]
666
+
667
+ images = self.image_processor.postprocess(image, output_type=output_type)
668
+ return images
669
+
670
+ @torch.no_grad()
671
+ def __call__(
672
+ self,
673
+ *,
674
+ encoder_hidden_states: torch.Tensor,
675
+ grid_thw_list: List[Tuple[int, int, int]],
676
+ image: Union[Image.Image, torch.Tensor, List[Union[Image.Image, torch.Tensor]], None] = None,
677
+ num_inference_steps: int = 28,
678
+ timesteps: Optional[List[int]] = None,
679
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
680
+ output_type: str = "pil",
681
+ return_dict: bool = True,
682
+ text_guidance_scale: float = 1.5,
683
+ image_guidance_scale: float = 1.5,
684
+ cfg_range: Tuple[float, float] = (0.0, 1.0),
685
+ enable_processor_bar: bool = True,
686
+ **kwargs,
687
+ ) -> Union[RefinerOutput, List[Image.Image], torch.Tensor]:
688
+ self._interrupt = False
689
+
690
+ token_chunks = self._split_tokens(encoder_hidden_states, grid_thw_list)
691
+ ref_list = self._expand_to_list(image, len(token_chunks))
692
+
693
+ results_pil: List[Image.Image] = []
694
+ results_pt: Optional[torch.Tensor] = None
695
+
696
+ for tok, _, img_any in zip(token_chunks, grid_thw_list, ref_list):
697
+ imgs = self._denoise_once(
698
+ cond_tokens=tok,
699
+ ref_img=img_any,
700
+ num_inference_steps=num_inference_steps,
701
+ timesteps=timesteps,
702
+ generator=generator,
703
+ output_type=output_type,
704
+ text_guidance_scale=text_guidance_scale,
705
+ image_guidance_scale=image_guidance_scale,
706
+ cfg_range=cfg_range,
707
+ enable_processor_bar=enable_processor_bar,
708
+ )
709
+
710
+ if output_type == "pil":
711
+ results_pil += imgs
712
+ else:
713
+ results_pt = imgs if results_pt is None else torch.cat([results_pt, imgs], dim=0)
714
+
715
+ if not return_dict:
716
+ return results_pil if output_type == "pil" else results_pt
717
+ return RefinerOutput(images=results_pil if output_type == "pil" else results_pt)
718
+
719
+
720
+ def de_transform(
721
+ tensor: torch.Tensor,
722
+ mean=(0.48145466, 0.4578275, 0.40821073),
723
+ std=(0.26862954, 0.26130258, 0.27577711),
724
+ rescale_factor: float = 1 / 255,
725
+ ) -> torch.Tensor:
726
+ """De-normalize and de-rescale, suitable for images processed by Qwen2VLImageProcessor."""
727
+ if tensor.ndim == 3:
728
+ tensor = tensor.unsqueeze(0)
729
+ mean_t = torch.tensor(mean).view(1, -1, 1, 1).to(tensor.device)
730
+ std_t = torch.tensor(std).view(1, -1, 1, 1).to(tensor.device)
731
+ tensor = tensor * std_t + mean_t
732
+ tensor = tensor / rescale_factor
733
+ tensor = torch.clamp(tensor / 255.0, 0, 1)
734
+ return tensor
735
+
736
+
737
+ def tensor2pil(image_t: torch.Tensor, image_mean, image_std) -> Image.Image:
738
+ """Convert a tensor to a PIL Image."""
739
+ image_t = image_t.detach().cpu()
740
+ rescale_factor = 1 / 255
741
+ sample = de_transform(
742
+ image_t,
743
+ mean=image_mean,
744
+ std=image_std,
745
+ rescale_factor=rescale_factor,
746
+ )[0]
747
+ ndarr = sample.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
748
+ return Image.fromarray(ndarr)
modeling_longcat_next.py ADDED
@@ -0,0 +1,824 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2026 Meituan
3
+ # This code is licensed under the MIT License, for details, see the ./LICENSE file.
4
+
5
+ import os
6
+ from dataclasses import dataclass
7
+ from tqdm import tqdm
8
+ from typing import Optional, Union
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torch import nn
13
+
14
+ from transformers.cache_utils import Cache
15
+ from transformers.generation.configuration_utils import GenerationConfig
16
+ from transformers.generation.logits_process import LogitsProcessorList
17
+ from transformers.generation.stopping_criteria import StoppingCriteriaList
18
+ from transformers.generation.utils import GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput, GenerateNonBeamOutput
19
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
20
+ from transformers.models.longcat_flash.modeling_longcat_flash import LongcatFlashForCausalLM
21
+ from transformers.processing_utils import Unpack
22
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
23
+
24
+ from .configuration_longcat_next import LongcatNextConfig
25
+ from .modeling_longcat_ngram import LongcatFlashNgramModel, NgramCache
26
+ from .modular_longcat_next import CasualDepthTransformerHead
27
+ from .modular_longcat_next_audio import LongcatNextAudioTokenizer
28
+ from .modular_longcat_next_visual import LongcatNextVisualTokenizer
29
+
30
+ from .cosy24k_vocoder import Cosy24kVocoder
31
+ from .image_refiner import ImageRefinerContainer
32
+ from .refiner_modules import FlowMatchEulerDiscreteScheduler
33
+
34
+ logger = logging.get_logger(__name__)
35
+
36
+ @dataclass
37
+ class LongcatNextForCausalLMOutputWithPast(CausalLMOutputWithPast):
38
+ visual_loss: Optional[torch.FloatTensor] = None
39
+ visual_logits: Optional[torch.FloatTensor] = None
40
+ visual_ids: Optional[torch.LongTensor] = None
41
+ audio_loss: Optional[torch.FloatTensor] = None
42
+ audio_logits: Optional[torch.FloatTensor] = None
43
+ audio_ids: Optional[torch.LongTensor] = None
44
+
45
+ @dataclass
46
+ class LongcatNextForCausalLMGenerateDecoderOnlyOutput(GenerateDecoderOnlyOutput):
47
+ visual_ids: Optional[torch.LongTensor] = None
48
+ audio_ids: Optional[torch.LongTensor] = None
49
+ audio_text_ids: Optional[torch.LongTensor] = None
50
+
51
+ @dataclass
52
+ class LongcatNextForCausalLMGenerateEncoderDecoderOutput(GenerateEncoderDecoderOutput):
53
+ visual_ids: Optional[torch.LongTensor] = None
54
+ audio_ids: Optional[torch.LongTensor] = None
55
+ audio_text_ids: Optional[torch.LongTensor] = None
56
+
57
+ @dataclass
58
+ class LongcatNextForCausalLMGenerationStatus:
59
+ mode: str = "text"
60
+ current_image_token_num: int = -1
61
+ audio_parallel_decoding: bool = False
62
+ is_audio_text_end: bool = False
63
+ is_audio_start: bool = False
64
+ last_step_mode: str = None
65
+
66
+ def __init__(self, visual_generation_config, audio_generation_config):
67
+ self.visual_generation_config = visual_generation_config
68
+ self.h = self.visual_generation_config.custom_params["token_h"]
69
+ self.w = self.visual_generation_config.custom_params["token_w"]
70
+ self.anyres_prefix = self.visual_generation_config.custom_params["anyres_prefix"].format(h=self.h, w=self.w)
71
+ self.audio_generation_config = audio_generation_config
72
+ self.audio_parallel_decoding = audio_generation_config.audio_parallel_decoding
73
+
74
+ def switch_to(self, modal):
75
+ assert modal in ["text", "visual", "audio"]
76
+ self.mode = modal
77
+ self.current_image_token_num = 0 if modal == "visual" else -1
78
+ self.is_audio_text_end = False
79
+ self.is_audio_start = False
80
+
81
+ @property
82
+ def is_img_newline(self):
83
+ return ((self.current_image_token_num + 1) % (self.w + 1)) == 0 and not self.is_img_end
84
+
85
+ @property
86
+ def is_img_end(self):
87
+ return (self.current_image_token_num + 1) / (self.w + 1) == self.h
88
+
89
+
90
+ class LongcatNextModel(LongcatFlashNgramModel):
91
+ _keys_to_ignore_on_load_unexpected = [r"model\.mtp.*"]
92
+ config_class = LongcatNextConfig
93
+
94
+ def __init__(self, config):
95
+ super().__init__(config)
96
+ self.visual_tokenizer = LongcatNextVisualTokenizer(config)
97
+ self.audio_tokenizer = LongcatNextAudioTokenizer(config)
98
+
99
+ self._init_multimodal_constants(config)
100
+ self.post_init()
101
+
102
+ def _init_multimodal_constants(self, config):
103
+ name2id_dict = {
104
+ "image_newline_token_id": self.config.visual_config.image_newline_token_id,
105
+ "image_end_token_id": self.config.visual_config.image_end_token_id,
106
+ "image_pad_token_id": self.config.visual_config.image_pad_token_id,
107
+ "audiotext_start_token_id": config.audio_config.audiotext_start_token_id,
108
+ "audiotext_pad_token_id": self.config.audio_config.audiotext_pad_token_id,
109
+ "audiogen_end_token_id": config.audio_config.audiogen_end_token_id,
110
+ "audio_pad_token_id": self.config.audio_config.audio_pad_token_id,
111
+ }
112
+ for k, v in name2id_dict.items():
113
+ self.register_buffer(k, torch.tensor([v], dtype=torch.long), persistent=False)
114
+ visual_offset_list = [config.visual_offset] + config.visual_config.vq_config.codebook_sizes[:-1]
115
+ visual_offset_vals = torch.cumsum(torch.tensor(visual_offset_list, dtype=torch.long), dim=0)
116
+ self.register_buffer("visual_offset_vals", visual_offset_vals, persistent=False)
117
+ audio_offset_list = [config.audio_offset] + config.audio_config.vq_config.codebook_sizes[:-1]
118
+ audio_offset_vals = torch.cumsum(torch.tensor(audio_offset_list, dtype=torch.long), dim=0)
119
+ self.register_buffer("audio_offset_vals", audio_offset_vals, persistent=False)
120
+ print(f"{self.visual_offset_vals=}")
121
+ print(f"{self.audio_offset_vals=}")
122
+
123
+ def forward(
124
+ self,
125
+ input_ids: Optional[torch.LongTensor] = None,
126
+ attention_mask: Optional[torch.Tensor] = None,
127
+ position_ids: Optional[torch.LongTensor] = None,
128
+ past_key_values: Optional[Cache] = None,
129
+ inputs_embeds: Optional[torch.FloatTensor] = None,
130
+ cache_position: Optional[torch.LongTensor] = None,
131
+ use_cache: Optional[bool] = None,
132
+ visual_inputs=None,
133
+ visual_ids=None,
134
+ audio_inputs=None,
135
+ audio_ids=None,
136
+ audio_text_ids=None,
137
+ multimodal_generation_status=None,
138
+ **kwargs
139
+ ) -> BaseModelOutputWithPast:
140
+
141
+ if input_ids is None:
142
+ raise ValueError("You must specify input_ids")
143
+
144
+ # Extract N-gram context if available
145
+ ngram_context = None
146
+ if isinstance(past_key_values, NgramCache) and past_key_values.ngram_context is not None:
147
+ ngram_context = past_key_values.ngram_context
148
+
149
+ # assert input_ids.size(0) == 1, "only support bs=1 for now" # but when bs=2, idx=1 is for uncond_image_generation
150
+ special_visual_mask, special_audio_mask, special_audio_text_start_mask, special_audio_text_pad_mask = self.get_placeholder_mask(input_ids[:1]) # seq-dim
151
+
152
+ if inputs_embeds is None:
153
+ input_ids[:, special_visual_mask | special_audio_mask | special_audio_text_pad_mask | special_audio_text_start_mask] = 0
154
+ filled_text_pad_mask = torch.ones_like(special_audio_mask)
155
+ audio_text_position_mask = (special_audio_text_pad_mask | special_audio_text_start_mask | special_audio_mask)
156
+
157
+ if audio_text_ids is not None and audio_text_ids.size(1) > 0 and audio_text_position_mask.sum() > 0:
158
+ filled_text = audio_text_ids[:, -audio_text_position_mask.sum():]
159
+ filled_text_pad_mask = (filled_text==self.config.audio_config.audiotext_pad_token_id)[0]
160
+ input_ids[:, audio_text_position_mask] = filled_text
161
+ input_ids[input_ids == self.config.audio_config.audiotext_pad_token_id] = 0
162
+
163
+ inputs_embeds = self.ngram_embeddings(input_ids, ngram_context=ngram_context)
164
+ inputs_embeds[:, (special_visual_mask | (special_audio_mask & filled_text_pad_mask))] = 0
165
+
166
+ if special_audio_text_start_mask.sum() > 0:
167
+ audio_text_start_embedding = self.embed_tokens(self.audiotext_start_token_id)
168
+ if multimodal_generation_status.last_step_mode is None: # prefill
169
+ inputs_embeds[:1, special_audio_text_start_mask] += audio_text_start_embedding
170
+ else:
171
+ inputs_embeds[:, special_audio_text_start_mask] += audio_text_start_embedding
172
+
173
+ if visual_inputs is not None:
174
+ visual_ids = self.get_visual_ids(**visual_inputs) # [<bs=1>*seq, lev]
175
+
176
+ if visual_ids is not None and special_visual_mask.sum() > 0:
177
+ visual_embeddings = self.get_visual_embeddings(visual_ids[-special_visual_mask.sum():]) # -> [seq, dim]
178
+ if multimodal_generation_status.last_step_mode is None: # prefill
179
+ inputs_embeds[:1, special_visual_mask] = visual_embeddings.to(inputs_embeds.device)
180
+ else:
181
+ inputs_embeds[:, special_visual_mask] = visual_embeddings.to(inputs_embeds.device)
182
+
183
+ if audio_inputs is not None:
184
+ audio_ids = self.get_audio_ids(**audio_inputs) # -> [<bs=1>*seq, lev]
185
+
186
+ if audio_ids is not None and special_audio_mask.sum() > 0:
187
+ audio_embeddings = self.get_audio_embeddings(audio_ids[-special_audio_mask.sum():]) # -> [seq, dim]
188
+ if multimodal_generation_status.last_step_mode is None: # prefill
189
+ inputs_embeds[:1, special_audio_mask] += audio_embeddings.to(inputs_embeds.device)
190
+ else:
191
+ inputs_embeds[:, special_audio_mask] += audio_embeddings.to(inputs_embeds.device)
192
+
193
+ # Initialize NgramCache if needed
194
+ if use_cache and past_key_values is None:
195
+ past_key_values = NgramCache(config=self.config)
196
+
197
+ # Update N-gram context
198
+ if use_cache and isinstance(past_key_values, NgramCache):
199
+ past_key_values.update_ngram_context(input_ids)
200
+
201
+ return super().forward(
202
+ input_ids=None,
203
+ attention_mask=attention_mask,
204
+ position_ids=position_ids,
205
+ past_key_values=past_key_values,
206
+ inputs_embeds=inputs_embeds,
207
+ cache_position=cache_position,
208
+ use_cache=use_cache,
209
+ **kwargs
210
+ )
211
+
212
+ def get_visual_ids(self, pixel_values, visual_grid_thw, offset=True):
213
+ visual_ids = self.visual_tokenizer.encode(pixel_values, visual_grid_thw)
214
+ if offset:
215
+ visual_ids += self.visual_offset_vals.to(visual_ids.device)
216
+ return visual_ids
217
+
218
+ def get_audio_ids(self, audio, encoder_length, bridge_length, offset=True):
219
+ audio_ids = self.audio_tokenizer.encode(audio, encoder_length, bridge_length)
220
+ if offset:
221
+ audio_ids += self.audio_offset_vals.to(audio_ids.device)
222
+ return audio_ids
223
+
224
+ @torch.no_grad()
225
+ def decode_visual_ids_and_save(
226
+ self,
227
+ visual_ids,
228
+ save_prefix,
229
+ token_h,
230
+ token_w,
231
+ **kwargs,
232
+ ):
233
+ visual_ids -= self.visual_offset_vals.to(visual_ids.device)
234
+
235
+ if not (save_prefix.startswith("./") or save_prefix.startswith("/")):
236
+ save_prefix = f"./{save_prefix}"
237
+ os.makedirs(os.path.dirname(save_prefix), exist_ok=True)
238
+ return self.visual_tokenizer.lazy_decode_and_save(visual_ids, token_h, token_w, f"{save_prefix}_{0}.png")
239
+
240
+ @torch.no_grad()
241
+ def decode_audio_ids_and_save(
242
+ self,
243
+ audio_ids,
244
+ save_prefix,
245
+ sampling_rate,
246
+ wave_concat_overlap,
247
+ **kwargs,
248
+ ):
249
+ audio_ids -= self.audio_offset_vals.to(audio_ids.device)
250
+
251
+ if not (save_prefix.startswith("./") or save_prefix.startswith("/")):
252
+ save_prefix = f"./{save_prefix}"
253
+ os.makedirs(os.path.dirname(save_prefix), exist_ok=True)
254
+ save_path = f"{save_prefix}_{0}.wav"
255
+ self.audio_tokenizer.lazy_decode_and_save(audio_ids, sampling_rate, wave_concat_overlap, save_path)
256
+ return [save_path]
257
+
258
+ def get_visual_embeddings(self, visual_ids):
259
+ visual_embeddings = self.embed_tokens(visual_ids).sum(dim=1) # [seq, lev] -> [seq, lev, dim] -> [seq, dim]
260
+ visual_embeddings = self.visual_tokenizer.visual_embedding_layer(visual_embeddings)
261
+ return visual_embeddings
262
+
263
+ def get_audio_embeddings(self, audio_ids):
264
+ audio_embeddings = self.embed_tokens(audio_ids).sum(dim=1)
265
+ return audio_embeddings
266
+
267
+ def get_placeholder_mask(self, input_ids: torch.LongTensor):
268
+ special_image_mask = (input_ids == self.config.visual_config.image_pad_token_id).squeeze(0)
269
+ special_audio_mask = (input_ids == self.config.audio_config.audio_pad_token_id).squeeze(0)
270
+ special_audio_text_start_mask = (input_ids == self.config.audio_config.audiotext_start_token_id).squeeze(0)
271
+ special_audio_text_pad_mask = (input_ids == self.config.audio_config.audiotext_pad_token_id).squeeze(0)
272
+ return special_image_mask, special_audio_mask, special_audio_text_start_mask, special_audio_text_pad_mask
273
+
274
+
275
+ class LongcatNextForCausalLM(LongcatFlashForCausalLM):
276
+ _keys_to_ignore_on_load_unexpected = [r"model\.mtp.*"]
277
+ _no_split_modules = [
278
+ "LongcatFlashDecoderLayer",
279
+ "CasualDepthTransformerHead",
280
+ ]
281
+ config_class = LongcatNextConfig
282
+
283
+ def __init__(self, config):
284
+ super().__init__(config)
285
+ self.config = config
286
+ self.model = LongcatNextModel(config)
287
+ self.lm_head = nn.Linear(config.hidden_size, config.text_vocab_plus_multimodal_special_token_size, bias=False)
288
+
289
+ self.visual_head = CasualDepthTransformerHead(
290
+ hidden_size=config.hidden_size,
291
+ codebook_sizes=config.visual_config.vq_config.codebook_sizes,
292
+ transformer_layer_num=config.visual_config.image_head_transformer_layers,
293
+ transformer_dim=config.visual_config.image_head_transformer_dims,
294
+ transformer_ffn_scale=config.visual_config.image_head_transformer_ffn_scale,
295
+ )
296
+ self.audio_head = CasualDepthTransformerHead(
297
+ hidden_size=config.hidden_size,
298
+ codebook_sizes=config.audio_config.vq_config.codebook_sizes,
299
+ transformer_layer_num=config.audio_config.audio_head_transformer_layers,
300
+ transformer_dim=config.audio_config.audio_head_transformer_dims,
301
+ transformer_ffn_scale=config.audio_config.audio_head_transformer_ffn_scale,
302
+ )
303
+
304
+ self.post_init()
305
+
306
+ @can_return_tuple
307
+ @auto_docstring
308
+ def forward(
309
+ self,
310
+ input_ids: Optional[torch.LongTensor] = None,
311
+ attention_mask: Optional[torch.Tensor] = None,
312
+ position_ids: Optional[torch.LongTensor] = None,
313
+ past_key_values: Optional[Cache] = None,
314
+ inputs_embeds: Optional[torch.FloatTensor] = None,
315
+ labels: Optional[torch.LongTensor] = None,
316
+ use_cache: Optional[bool] = None,
317
+ cache_position: Optional[torch.LongTensor] = None,
318
+ logits_to_keep: Union[int, torch.Tensor] = 0,
319
+ visual_inputs=None,
320
+ visual_ids=None,
321
+ audio_inputs=None,
322
+ audio_ids=None,
323
+ audio_text_ids=None,
324
+ multimodal_generation_status: LongcatNextForCausalLMGenerationStatus = None,
325
+ visual_generation_config: GenerationConfig = None,
326
+ audio_generation_config: GenerationConfig = None,
327
+ **kwargs: Unpack[TransformersKwargs],
328
+ ) -> CausalLMOutputWithPast:
329
+ r"""
330
+ visual_inputs (`BatchFeature`, *optional*):
331
+ Visual inputs returned by the processor, containing pixel values and grid metadata for image encoding.
332
+ visual_ids (`torch.LongTensor` of shape `(num_visual_tokens, num_codebooks)`, *optional*):
333
+ Quantized visual token ids from the visual tokenizer, used to build visual embeddings during generation.
334
+ audio_inputs (`BatchFeature`, *optional*):
335
+ Audio inputs returned by the processor, containing mel-spectrogram features and length metadata.
336
+ audio_ids (`torch.LongTensor` of shape `(num_audio_tokens, num_codebooks)`, *optional*):
337
+ Quantized audio token ids from the audio tokenizer, used to build audio embeddings during generation.
338
+ audio_text_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
339
+ Token ids for the audio text transcript generated alongside audio tokens.
340
+ multimodal_generation_status (`LongcatNextForCausalLMGenerationStatus`, *optional*):
341
+ Stateful object tracking the current multimodal generation mode (text / visual / audio) and
342
+ associated counters used to route logits to the correct head during auto-regressive decoding.
343
+ visual_generation_config (`GenerationConfig`, *optional*):
344
+ Generation configuration for the visual head, controlling sampling parameters such as
345
+ `temperature`, `top_k`, `top_p`, and custom parameters like `cfg_scale` and `anyres_config`.
346
+ audio_generation_config (`GenerationConfig`, *optional*):
347
+ Generation configuration for the audio head, controlling sampling parameters such as
348
+ `temperature`, `top_k`, `top_p`, `repetition_penalty`, and `audio_parallel_decoding`.
349
+ """
350
+
351
+ if multimodal_generation_status.mode == "visual" and visual_generation_config.custom_params["cfg_scale"] != 1.0 and input_ids.size(0) == 1:
352
+ input_ids = input_ids.repeat((2, 1))
353
+
354
+ outputs: BaseModelOutputWithPast = self.model(
355
+ input_ids=input_ids,
356
+ attention_mask=attention_mask,
357
+ position_ids=position_ids,
358
+ past_key_values=past_key_values,
359
+ inputs_embeds=inputs_embeds,
360
+ use_cache=use_cache,
361
+ cache_position=cache_position,
362
+ visual_inputs=visual_inputs,
363
+ visual_ids=visual_ids,
364
+ audio_inputs=audio_inputs,
365
+ audio_ids=audio_ids,
366
+ audio_text_ids=audio_text_ids,
367
+ multimodal_generation_status=multimodal_generation_status,
368
+ **kwargs,
369
+ )
370
+
371
+ hidden_states = outputs.last_hidden_state
372
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
373
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
374
+ slice_hidden_states = hidden_states[:, slice_indices, :]
375
+
376
+ loss, logits = None, None
377
+ if multimodal_generation_status.mode == "visual" and \
378
+ (not multimodal_generation_status.is_img_newline) and (not multimodal_generation_status.is_img_end):
379
+ visual_ids = self.get_multimodal_logits_and_ids(
380
+ self.visual_head,
381
+ visual_ids,
382
+ slice_hidden_states,
383
+ self.model.embed_tokens,
384
+ self.config.visual_config.vq_config.codebook_sizes,
385
+ self.model.visual_offset_vals,
386
+ visual_generation_config,
387
+ )
388
+ else:
389
+ logits = self.lm_head(slice_hidden_states)
390
+
391
+ if multimodal_generation_status.mode == "audio" and multimodal_generation_status.is_audio_start:
392
+ audio_ids = self.get_multimodal_logits_and_ids(
393
+ self.audio_head,
394
+ audio_ids,
395
+ slice_hidden_states,
396
+ self.model.embed_tokens,
397
+ self.config.audio_config.vq_config.codebook_sizes,
398
+ self.model.audio_offset_vals,
399
+ audio_generation_config,
400
+ )
401
+
402
+ return LongcatNextForCausalLMOutputWithPast(
403
+ loss=loss,
404
+ logits=logits,
405
+ past_key_values=outputs.past_key_values,
406
+ hidden_states=outputs.hidden_states,
407
+ attentions=outputs.attentions,
408
+ visual_ids=visual_ids,
409
+ audio_ids=audio_ids,
410
+ )
411
+
412
+ def get_multimodal_logits_and_ids(
413
+ self,
414
+ head_model,
415
+ multimodal_ids,
416
+ hidden_states,
417
+ multimodal_embedding_layer,
418
+ codebook_sizes,
419
+ offset_vals,
420
+ multimodal_generation_config,
421
+ ):
422
+ next_token_ids = torch.zeros(hidden_states.size(0), len(codebook_sizes), dtype=torch.long, device=hidden_states.device)
423
+ multimodal_embedding_layer = multimodal_embedding_layer.to(hidden_states.device)
424
+
425
+ for level, _ in enumerate(codebook_sizes):
426
+ logits = head_model(hidden_states, next_token_ids, multimodal_embedding_layer, level) # -> (bs, 1, dim)
427
+ next_token_id = self.inner_sample(logits, multimodal_ids[None, :, level]-offset_vals[level], multimodal_generation_config) # (bs, 1)
428
+ next_token_id += offset_vals[level]
429
+ next_token_ids[:, level] = next_token_id
430
+
431
+ return next_token_ids[:1]
432
+
433
+ def inner_sample(
434
+ self,
435
+ next_token_logits: torch.Tensor,
436
+ multimodal_ids: torch.LongTensor,
437
+ generation_config: GenerationConfig,
438
+ ) -> torch.Tensor:
439
+ logits_processor = self._get_logits_processor(generation_config)
440
+
441
+ if "cfg_scale" in generation_config.custom_params and generation_config.custom_params["cfg_scale"] != 1.0:
442
+ cond_logits, uncond_logits = next_token_logits.chunk(2, dim=0)
443
+ next_token_logits = generation_config.custom_params["cfg_scale"] * (cond_logits - uncond_logits) + uncond_logits
444
+
445
+ next_token_scores = logits_processor(multimodal_ids, next_token_logits.to(multimodal_ids.device))
446
+ if generation_config.do_sample:
447
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
448
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
449
+ else:
450
+ next_tokens = torch.argmax(next_token_scores, dim=-1)
451
+ return next_tokens
452
+
453
+ @torch.no_grad()
454
+ def generate(self, inputs=None, **kwargs):
455
+ """Override to ensure NgramCache is used."""
456
+
457
+ if "past_key_values" not in kwargs or kwargs["past_key_values"] is None:
458
+ kwargs["past_key_values"] = NgramCache(config=self.config)
459
+
460
+ return super().generate(
461
+ inputs=inputs,
462
+ **kwargs,
463
+ )
464
+
465
+ def prepare_inputs_for_generation(
466
+ self,
467
+ input_ids,
468
+ visual_ids,
469
+ audio_ids,
470
+ audio_text_ids,
471
+ multimodal_generation_status,
472
+ generation_config,
473
+ attention_mask,
474
+ cache_position,
475
+ **kwargs,
476
+ ):
477
+ extra_new_tokens = torch.empty(input_ids.size(0), 0, dtype=torch.long, device=input_ids.device)
478
+ if visual_ids is None:
479
+ visual_ids = torch.empty(0, len(self.config.visual_config.vq_config.codebook_sizes), dtype=torch.long, device=input_ids.device)
480
+ if audio_ids is None:
481
+ audio_ids = torch.empty(0, len(self.config.audio_config.vq_config.codebook_sizes), dtype=torch.long, device=input_ids.device)
482
+ if audio_text_ids is None:
483
+ audio_text_ids = torch.empty(input_ids.size(0), 0, dtype=torch.long, device=input_ids.device)
484
+
485
+ def insert_ids(new_ids, _input_ids, _attention_mask, _cache_position, position=0):
486
+ if position < 0:
487
+ parts = [_input_ids[:, :position], new_ids, _input_ids[:, position:]]
488
+ else:
489
+ parts = [_input_ids, new_ids]
490
+ _input_ids = torch.cat(parts, dim=1)
491
+ insert_len = new_ids.size(1)
492
+ _attention_mask = F.pad(_attention_mask, (0, insert_len), value=1)
493
+ insert_position = _cache_position[-1] + 1 + torch.arange(insert_len, device=_cache_position.device)
494
+ _cache_position = torch.cat([_cache_position, insert_position])
495
+ return _input_ids, _attention_mask, _cache_position
496
+
497
+ # multimodal generation status change
498
+ if cache_position[0] != 0:
499
+ multimodal_generation_status.last_step_mode = multimodal_generation_status.mode
500
+
501
+ if multimodal_generation_status.mode == "visual":
502
+ multimodal_generation_status.current_image_token_num += 1
503
+
504
+ if (input_ids[:, -1] == self.config.visual_config.image_start_token_id).all():
505
+ multimodal_generation_status.switch_to("visual")
506
+ anyres_prefix_ids = self.text_tokenizer.encode(multimodal_generation_status.anyres_prefix, return_tensors="pt")
507
+ anyres_prefix_ids = anyres_prefix_ids.to(input_ids.device)
508
+ extra_new_tokens = torch.cat([extra_new_tokens, anyres_prefix_ids], dim=1)
509
+ input_ids, attention_mask, cache_position = insert_ids(anyres_prefix_ids, input_ids, attention_mask, cache_position, position=-1)
510
+ if input_ids.size(0) == 1: # cfg, change bs=1 -> 2
511
+ input_ids = input_ids.repeat((2, input_ids.size(1)))
512
+ input_ids[1, :-(anyres_prefix_ids.size(-1)+1)] = 0
513
+ print(f"change to cfg, input_ids: {input_ids}")
514
+ attention_mask = attention_mask.repeat((2, attention_mask.size(1)))
515
+
516
+ elif (input_ids[:, -1] == self.config.audio_config.audiogen_start_token_id).all():
517
+ multimodal_generation_status.switch_to("audio")
518
+
519
+ elif (input_ids[:, -1] == self.config.audio_config.audiotext_start_token_id).all():
520
+ multimodal_generation_status.is_audio_start = True
521
+
522
+ elif ((input_ids[:, -1] == self.config.visual_config.image_end_token_id) | (input_ids[:, -1] == self.config.audio_config.audiogen_end_token_id)).all():
523
+ multimodal_generation_status.switch_to("text")
524
+
525
+ model_inputs = super().prepare_inputs_for_generation(
526
+ input_ids=input_ids,
527
+ visual_ids=visual_ids,
528
+ audio_ids=audio_ids,
529
+ audio_text_ids=audio_text_ids,
530
+ attention_mask=attention_mask,
531
+ cache_position=cache_position,
532
+ **kwargs,
533
+ )
534
+
535
+ if model_inputs["cache_position"][0] != 0:
536
+ model_inputs["visual_inputs"] = None
537
+ model_inputs["audio_inputs"] = None
538
+
539
+ return model_inputs, multimodal_generation_status, extra_new_tokens
540
+
541
+ def _sample(
542
+ self,
543
+ input_ids: torch.LongTensor,
544
+ logits_processor: LogitsProcessorList,
545
+ stopping_criteria: StoppingCriteriaList,
546
+ generation_config: GenerationConfig,
547
+ synced_gpus: bool = False,
548
+ streamer: Optional["BaseStreamer"] = None,
549
+ visual_ids=None,
550
+ audio_ids=None,
551
+ audio_text_ids=None,
552
+ **model_kwargs,
553
+ ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
554
+ r"""
555
+ Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
556
+ can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
557
+
558
+ Parameters:
559
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
560
+ The sequence used as a prompt for the generation.
561
+ logits_processor (`LogitsProcessorList`):
562
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
563
+ used to modify the prediction scores of the language modeling head applied at each generation step.
564
+ stopping_criteria (`StoppingCriteriaList`):
565
+ An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
566
+ used to tell if the generation loop should stop.
567
+ generation_config ([`~generation.GenerationConfig`]):
568
+ The generation configuration to be used as parametrization of the decoding method.
569
+ synced_gpus (`bool`):
570
+ Whether to continue running the while loop until max_length (needed to avoid deadlocking with
571
+ `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
572
+ streamer (`BaseStreamer`, *optional*):
573
+ Streamer object that will be used to stream the generated sequences. Generated tokens are passed
574
+ through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
575
+ model_kwargs:
576
+ Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
577
+ an encoder-decoder model the kwargs should include `encoder_outputs`.
578
+
579
+ Return:
580
+ [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`:
581
+ A `torch.LongTensor` containing the generated tokens (default behaviour) or a
582
+ [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
583
+ `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
584
+ `model.config.is_encoder_decoder=True`.
585
+ """
586
+ # init values
587
+ pad_token_id = generation_config._pad_token_tensor
588
+ output_attentions = generation_config.output_attentions
589
+ output_hidden_states = generation_config.output_hidden_states
590
+ output_scores = generation_config.output_scores
591
+ output_logits = generation_config.output_logits
592
+ return_dict_in_generate = generation_config.return_dict_in_generate
593
+ has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
594
+ do_sample = generation_config.do_sample
595
+
596
+ # init attention / hidden states / scores tuples
597
+ scores = () if (return_dict_in_generate and output_scores) else None
598
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
599
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
600
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
601
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
602
+
603
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
604
+ if return_dict_in_generate and self.config.is_encoder_decoder:
605
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
606
+ encoder_hidden_states = (
607
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
608
+ )
609
+
610
+ # keep track of which sequences are already finished
611
+ batch_size, cur_len = input_ids.shape[:2]
612
+ this_peer_finished = False
613
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
614
+ model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
615
+
616
+ model_forward = self.__call__
617
+ compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
618
+ if compile_forward:
619
+ os.environ["TOKENIZERS_PARALLELISM"] = "0"
620
+ # If we use FA2 and a static cache, we cannot compile with fullgraph
621
+ if self.config._attn_implementation == "flash_attention_2":
622
+ # only raise warning if the user passed an explicit compile-config
623
+ if generation_config.compile_config is not None and generation_config.compile_config.fullgraph:
624
+ logger.warning_once(
625
+ "When using Flash Attention 2 and a static cache, you cannot use the option `CompileConfig(fullgraph=True)` as "
626
+ "FA2 introduces graph breaks. We overrode the option with `fullgraph=False`."
627
+ )
628
+ generation_config.compile_config.fullgraph = False
629
+ model_forward = self.get_compiled_call(generation_config.compile_config)
630
+
631
+ if generation_config.prefill_chunk_size is not None:
632
+ model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
633
+ is_prefill = False
634
+ else:
635
+ is_prefill = True
636
+
637
+ visual_generation_config = GenerationConfig(**generation_config.visual_generation_config)
638
+ audio_generation_config = GenerationConfig(**generation_config.audio_generation_config)
639
+ multimodal_generation_status = LongcatNextForCausalLMGenerationStatus(visual_generation_config, audio_generation_config)
640
+
641
+ pbar = tqdm(iter(int, 1), desc="Generating", unit="tok")
642
+ while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
643
+ # prepare model inputs
644
+ model_inputs, multimodal_generation_status, extra_new_tokens = self.prepare_inputs_for_generation(
645
+ input_ids,
646
+ visual_ids,
647
+ audio_ids,
648
+ audio_text_ids,
649
+ multimodal_generation_status,
650
+ generation_config,
651
+ **model_kwargs,
652
+ )
653
+ if extra_new_tokens.size(1) > 0:
654
+ input_ids = torch.cat([input_ids[:, :-1], extra_new_tokens, input_ids[:, -1:]], dim=1)
655
+ model_kwargs["attention_mask"] = model_inputs["attention_mask"]
656
+ model_kwargs["cache_position"] = model_inputs["cache_position"]
657
+
658
+ if multimodal_generation_status.mode == "text" and multimodal_generation_status.last_step_mode == "visual":
659
+ next_tokens = generation_config._eos_token_tensor
660
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
661
+ if streamer is not None:
662
+ streamer.put(next_tokens.cpu())
663
+ break
664
+
665
+ visual_ids = model_inputs["visual_ids"]
666
+ audio_ids = model_inputs["audio_ids"]
667
+ audio_text_ids = model_inputs["audio_text_ids"]
668
+
669
+ if is_prefill:
670
+ outputs = self(**model_inputs, return_dict=True, multimodal_generation_status=multimodal_generation_status, visual_generation_config=visual_generation_config, audio_generation_config=audio_generation_config)
671
+ is_prefill = False
672
+ else:
673
+ outputs = model_forward(**model_inputs, return_dict=True, multimodal_generation_status=multimodal_generation_status, visual_generation_config=visual_generation_config, audio_generation_config=audio_generation_config)
674
+
675
+ # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
676
+ model_kwargs = self._update_model_kwargs_for_generation(
677
+ outputs,
678
+ model_kwargs,
679
+ is_encoder_decoder=self.config.is_encoder_decoder,
680
+ num_new_tokens=1,
681
+ )
682
+ if synced_gpus and this_peer_finished:
683
+ continue
684
+
685
+
686
+ # multimodal generation
687
+ if multimodal_generation_status.mode == "text" or \
688
+ (multimodal_generation_status.mode == "audio" and not multimodal_generation_status.is_audio_text_end):
689
+ # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
690
+ # (the clone itself is always small)
691
+ next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
692
+
693
+ # pre-process distribution
694
+ next_token_scores = logits_processor(input_ids, next_token_logits)
695
+
696
+ # Store scores, attentions and hidden_states when required
697
+ if return_dict_in_generate:
698
+ if output_scores:
699
+ scores += (next_token_scores,)
700
+ if output_logits:
701
+ raw_logits += (next_token_logits,)
702
+ if output_attentions:
703
+ decoder_attentions += (
704
+ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
705
+ )
706
+ if self.config.is_encoder_decoder:
707
+ cross_attentions += (outputs.cross_attentions,)
708
+
709
+ if output_hidden_states:
710
+ decoder_hidden_states += (
711
+ (outputs.decoder_hidden_states,)
712
+ if self.config.is_encoder_decoder
713
+ else (outputs.hidden_states,)
714
+ )
715
+
716
+ # token selection
717
+ if do_sample:
718
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
719
+ # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
720
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
721
+ else:
722
+ next_tokens = torch.argmax(next_token_scores, dim=-1)
723
+
724
+ # audio_text_ids done
725
+ if multimodal_generation_status.mode == "audio" and (next_tokens == self.config.audio_config.audiotext_pad_token_id).all():
726
+ multimodal_generation_status.is_audio_text_end = True
727
+
728
+ elif multimodal_generation_status.mode == "visual":
729
+ if multimodal_generation_status.is_img_end:
730
+ next_tokens = self.model.image_end_token_id.to(input_ids.device)
731
+
732
+ elif multimodal_generation_status.is_img_newline:
733
+ next_tokens = self.model.image_newline_token_id.to(input_ids.device)
734
+
735
+ else:
736
+ visual_ids = torch.cat([visual_ids, outputs.visual_ids], dim=0) # [seq, lev]
737
+ next_tokens = self.model.image_pad_token_id.to(input_ids.device)
738
+
739
+ else: # mode == "audio" and multimodal_generation_status.is_audio_text_end
740
+ next_tokens = self.model.audio_pad_token_id.to(input_ids.device)
741
+
742
+
743
+ if multimodal_generation_status.mode == "audio":
744
+ # audio_text_ids update
745
+ audio_text_next_tokens = self.model.audiotext_pad_token_id.to(input_ids.device)
746
+ if not multimodal_generation_status.is_audio_text_end:
747
+ audio_text_next_tokens, next_tokens = next_tokens, audio_text_next_tokens
748
+ audio_text_ids = torch.cat((audio_text_ids, audio_text_next_tokens[:, None]), dim=1)
749
+
750
+ # audio_ids update
751
+ if multimodal_generation_status.is_audio_start:
752
+ if outputs.audio_ids[-1, 0] == (self.model.audio_offset_vals[1]): # offset + (level_1_len)
753
+ next_tokens = self.model.audiogen_end_token_id.to(input_ids.device)
754
+ else:
755
+ next_tokens = self.model.audio_pad_token_id.to(input_ids.device)
756
+ audio_ids = torch.cat([audio_ids, outputs.audio_ids], dim=0)
757
+
758
+ elif (multimodal_generation_status.audio_parallel_decoding) or \
759
+ (not multimodal_generation_status.audio_parallel_decoding and multimodal_generation_status.is_audio_text_end):
760
+ next_tokens = self.model.audiotext_start_token_id.to(input_ids.device)
761
+
762
+
763
+ # finished sentences should have their next token be a padding token
764
+ if has_eos_stopping_criteria:
765
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
766
+
767
+ # update generated ids, model inputs, and length for next step
768
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
769
+
770
+ # TODO: streaming mm ids
771
+ if streamer is not None:
772
+ streamer.put(next_tokens.cpu())
773
+
774
+ unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
775
+ this_peer_finished = unfinished_sequences.max() == 0
776
+ cur_len += 1
777
+
778
+ # This is needed to properly delete outputs.logits which may be very large for first iteration
779
+ # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
780
+ del outputs
781
+
782
+ pbar.update(1)
783
+ pbar.set_postfix({
784
+ "recent_5toks": f"{input_ids[:, -5:].tolist()}",
785
+ })
786
+
787
+ pbar.close()
788
+
789
+ if streamer is not None:
790
+ streamer.end()
791
+
792
+ if return_dict_in_generate:
793
+ if self.config.is_encoder_decoder:
794
+ return LongcatNextForCausalLMGenerateEncoderDecoderOutput(
795
+ sequences=input_ids,
796
+ scores=scores,
797
+ logits=raw_logits,
798
+ encoder_attentions=encoder_attentions,
799
+ encoder_hidden_states=encoder_hidden_states,
800
+ decoder_attentions=decoder_attentions,
801
+ cross_attentions=cross_attentions,
802
+ decoder_hidden_states=decoder_hidden_states,
803
+ past_key_values=model_kwargs.get("past_key_values"),
804
+ visual_ids=visual_ids,
805
+ audio_ids=audio_ids,
806
+ audio_text_ids=audio_text_ids,
807
+ )
808
+ else:
809
+ return LongcatNextForCausalLMGenerateDecoderOnlyOutput(
810
+ sequences=input_ids,
811
+ scores=scores,
812
+ logits=raw_logits,
813
+ attentions=decoder_attentions,
814
+ hidden_states=decoder_hidden_states,
815
+ past_key_values=model_kwargs.get("past_key_values"),
816
+ visual_ids=visual_ids,
817
+ audio_ids=audio_ids,
818
+ audio_text_ids=audio_text_ids,
819
+ )
820
+ else:
821
+ return input_ids, visual_ids, audio_ids, audio_text_ids
822
+
823
+
824
+ __all__ = ["LongcatNextModel", "LongcatNextForCausalLM"]
modeling_longcat_ngram.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2026 Meituan
3
+ # This code is licensed under the MIT License, for details, see the ./LICENSE file.
4
+
5
+ from typing import Optional, Tuple, Dict, List
6
+
7
+ import torch
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+
11
+ from transformers.cache_utils import Cache, DynamicCache
12
+ from transformers.masking_utils import create_causal_mask
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast
14
+ from transformers.processing_utils import Unpack
15
+ from transformers.utils import auto_docstring, logging
16
+ from transformers.models.longcat_flash.modeling_longcat_flash import (
17
+ LongcatFlashForCausalLM,
18
+ LongcatFlashModel,
19
+ LongcatFlashRMSNorm,
20
+ LongcatFlashRotaryEmbedding,
21
+ LongcatFlashDecoderLayer,
22
+ LongcatFlashPreTrainedModel,
23
+ )
24
+ from .configuration_longcat_ngram import LongcatFlashNgramConfig
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ @auto_docstring
30
+ class LongcatFlashNgramPreTrainedModel(LongcatFlashPreTrainedModel):
31
+ pass
32
+
33
+
34
+ class NgramCache(DynamicCache):
35
+ """
36
+ Extended DynamicCache for storing N-gram context alongside KV cache.
37
+ """
38
+ def __init__(self, config=None):
39
+ super().__init__()
40
+ self.ngram_context = None
41
+ # Keep only n-1 tokens (minimum needed for N-gram computation)
42
+ self.max_context_len = config.emb_neighbor_num - 1
43
+ self.oe_ignored_token_ids = torch.tensor(config.oe_ignored_token_ids, dtype=torch.long)
44
+
45
+
46
+ def update_ngram_context(self, new_tokens: torch.Tensor) -> None:
47
+ """
48
+ Update N-gram context with window management.
49
+
50
+ Args:
51
+ new_tokens: New tokens to append, shape (batch_size, seq_len)
52
+ """
53
+ new_tokens = new_tokens.clone()
54
+ new_tokens[torch.isin(new_tokens, self.oe_ignored_token_ids.to(new_tokens.device))] = 0
55
+
56
+ if self.ngram_context is None:
57
+ self.ngram_context = new_tokens
58
+ else:
59
+ self.ngram_context = torch.cat([self.ngram_context, new_tokens], dim=-1)
60
+
61
+ # Truncate to maintain constant memory footprint
62
+ if self.ngram_context.size(-1) > self.max_context_len:
63
+ self.ngram_context = self.ngram_context[..., -self.max_context_len:]
64
+
65
+ def reorder_cache(self, beam_idx: torch.LongTensor) -> "Cache":
66
+ """Reorder cache for beam search."""
67
+ # Reorder parent's KV cache
68
+ super().reorder_cache(beam_idx)
69
+
70
+ # Reorder N-gram context
71
+ if self.ngram_context is not None:
72
+ self.ngram_context = self.ngram_context.index_select(0, beam_idx.to(self.ngram_context.device))
73
+
74
+ return self
75
+
76
+
77
+ class EmbeddingWithMask(nn.Embedding):
78
+ def forward(self, input: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
79
+ """
80
+ Args:
81
+ x (torch.Tensor): Input indices of shape (batch_size, seq_len)
82
+ mask (torch.Tensor): Boolean mask of shape (batch_size, seq_len).
83
+ True means compute, False means skip and return 0.
84
+ Returns:
85
+ torch.Tensor: Embeddings of shape (batch_size, seq_len, embedding_dim)
86
+ """
87
+ if mask is not None:
88
+ # Ensure mask is boolean
89
+ mask = mask.bool()
90
+ else:
91
+ mask = torch.ones_like(input, dtype=torch.bool)
92
+
93
+ batch_size, seq_len = input.shape
94
+ embedding_dim = self.embedding_dim
95
+
96
+ # 1. Initialize the output tensor with zeros on the correct device
97
+ output = torch.zeros(
98
+ (batch_size, seq_len, embedding_dim),
99
+ device=input.device,
100
+ dtype=self.weight.dtype
101
+ )
102
+
103
+ # 2. Filter out the valid indices using the mask
104
+ # valid_indices is a 1D tensor containing only the elements where mask is True
105
+ valid_indices = input[mask]
106
+
107
+ # 3. Only perform the embedding lookup if there is at least one valid index
108
+ if valid_indices.numel() > 0:
109
+ # Look up only the necessary embeddings (saves compute/memory bandwidth)
110
+ valid_embeddings = F.embedding(
111
+ valid_indices, self.weight, self.padding_idx, self.max_norm,
112
+ self.norm_type, self.scale_grad_by_freq, self.sparse)
113
+
114
+ # 4. Scatter the valid embeddings back to their original positions in the output tensor
115
+ output[mask] = valid_embeddings
116
+
117
+ return output
118
+
119
+
120
+ class NgramEmbedding(nn.Module):
121
+ """
122
+ Computes embeddings enriched with N-gram features without maintaining internal state.
123
+ """
124
+ def __init__(self, config, base_embeddings):
125
+ super().__init__()
126
+ self.config = config
127
+ self.word_embeddings = base_embeddings
128
+
129
+ # self.m = config.ngram_vocab_size_ratio * config.vocab_size
130
+ self.m = config.ngram_vocab_size_ratio * config.text_vocab_size
131
+ self.k = config.emb_split_num
132
+ self.n = config.emb_neighbor_num
133
+ self.oe_ignored_token_ids = torch.tensor(config.oe_ignored_token_ids)
134
+
135
+ self._init_ngram_embeddings()
136
+ self._vocab_mods_cache = None
137
+
138
+ def _init_ngram_embeddings(self) -> None:
139
+ """Initialize N-gram embedding and projection layers."""
140
+ num_embedders = self.k * (self.n - 1)
141
+ emb_dim = self.config.hidden_size // num_embedders
142
+
143
+ embedders = []
144
+ post_projs = []
145
+
146
+ for i in range(num_embedders):
147
+ vocab_size = int(self.m + i * 2 + 1)
148
+ emb = EmbeddingWithMask(vocab_size, emb_dim, padding_idx=self.config.pad_token_id)
149
+ proj = nn.Linear(emb_dim, self.config.hidden_size, bias=False)
150
+ embedders.append(emb)
151
+ post_projs.append(proj)
152
+
153
+ self.embedders = nn.ModuleList(embedders)
154
+ self.post_projs = nn.ModuleList(post_projs)
155
+
156
+ def _shift_right_ignore_eos(self, tensor: torch.Tensor, n: int, eos_token_id: int = 2) -> torch.Tensor:
157
+ p, q = tensor.shape
158
+ # special_token / modal set 0
159
+ special_tokens = 0
160
+
161
+ if n == 0:
162
+ return tensor.clone()
163
+
164
+ if n >= q:
165
+ return torch.zeros_like(tensor)
166
+
167
+ result = torch.zeros_like(tensor)
168
+
169
+ # Find all special_token/modal/EOS locations
170
+ special_mask = (tensor == special_tokens)
171
+ total_mask = (tensor == eos_token_id | special_mask)
172
+
173
+ # Calculate the segment ID to which each position belongs
174
+ eos_cumsum = total_mask.long().cumsum(dim=1)
175
+ # Shift right by 1, so that the first EOS position still belongs to segment 0, and the second EOS position belongs to segment 1
176
+ segment_ids = torch.cat([
177
+ torch.zeros(p, 1, dtype=torch.long, device=tensor.device),
178
+ eos_cumsum[:, :-1]
179
+ ], dim=1)
180
+
181
+ col_indices = torch.arange(q, device=tensor.device).unsqueeze(0).expand(p, q)
182
+ # Number of segments
183
+ max_segments = segment_ids.max().item() + 1
184
+ segment_starts = torch.full((p, max_segments), q, dtype=torch.long, device=tensor.device)
185
+ # Calculate the starting position of each segment
186
+ segment_starts.scatter_reduce_(1, segment_ids, col_indices, reduce='amin', include_self=False)
187
+
188
+ # Get the start position of the segment to which each position belongs
189
+ segment_start_per_pos = torch.gather(segment_starts, 1, segment_ids)
190
+
191
+ # Calculate the offset of each position within the segment
192
+ offset_in_segment = col_indices - segment_start_per_pos
193
+
194
+ # Data for each position should be taken from the position offset -n within the segment
195
+ source_offset = offset_in_segment - n
196
+ valid_mask = source_offset >= 0
197
+
198
+ # Calculate the actual source index
199
+ source_indices = segment_start_per_pos + torch.clamp(source_offset, min=0)
200
+
201
+ # Data is collected by source_indices
202
+ result = torch.gather(tensor, 1, source_indices)
203
+
204
+ # Set invalid position to zero
205
+ result = result * valid_mask * (~special_mask)
206
+
207
+ return result
208
+
209
+ def _precompute_vocab_mods(self) -> Dict[Tuple[int, int], List[int]]:
210
+ """Precompute modular arithmetic values for vocabulary."""
211
+ if self._vocab_mods_cache is not None:
212
+ return self._vocab_mods_cache
213
+
214
+ vocab_mods = {}
215
+ vocab_size = self.config.text_vocab_size
216
+
217
+ for i in range(2, self.n + 1):
218
+ for j in range(self.k):
219
+ index = (i - 2) * self.k + j
220
+ emb_vocab_dim = int(self.m + index * 2 + 1)
221
+
222
+ mods = []
223
+ power_mod = 1
224
+ for _ in range(i - 1):
225
+ power_mod = (power_mod * vocab_size) % emb_vocab_dim
226
+ mods.append(power_mod)
227
+
228
+ vocab_mods[(i, j)] = mods
229
+
230
+ self._vocab_mods_cache = vocab_mods
231
+ return vocab_mods
232
+
233
+ def _get_ngram_ids(
234
+ self,
235
+ input_ids: torch.Tensor,
236
+ shifted_ids: Dict[int, torch.Tensor],
237
+ vocab_mods: List[int],
238
+ ngram: int
239
+ ) -> torch.Tensor:
240
+ """Compute N-gram hash IDs using polynomial rolling hash."""
241
+ ngram_ids = input_ids.clone()
242
+ for k in range(2, ngram + 1):
243
+ ngram_ids = ngram_ids + shifted_ids[k] * vocab_mods[k - 2]
244
+ return ngram_ids
245
+
246
+ def forward(
247
+ self,
248
+ input_ids: torch.Tensor,
249
+ ngram_context: Optional[torch.Tensor] = None
250
+ ) -> torch.Tensor:
251
+ """
252
+ Stateless forward pass.
253
+
254
+ Args:
255
+ input_ids: Current input token IDs of shape (batch_size, seq_len)
256
+ ngram_context: Optional historical context of shape (batch_size, context_len)
257
+
258
+ Returns:
259
+ Embedding tensor of shape (batch_size, seq_len, hidden_size)
260
+ """
261
+ seq_len = input_ids.size(-1)
262
+
263
+ # Determine complete context
264
+ if ngram_context is not None:
265
+ context = torch.cat([ngram_context[..., -(self.n-1):], input_ids], dim=-1)
266
+ else:
267
+ context = input_ids.clone()
268
+
269
+ # Skip N-gram look-up for oe_ignored_token_ids
270
+ oe_ignored_mask = torch.isin(input_ids, self.oe_ignored_token_ids.to(device=input_ids.device))
271
+ context[torch.isin(context, self.oe_ignored_token_ids.to(device=context.device))] = 0
272
+
273
+ # Base word embeddings
274
+ device = self.word_embeddings.weight.device
275
+ x = self.word_embeddings(input_ids.to(device)).clone()
276
+
277
+ # Precompute modular values
278
+ vocab_mods = self._precompute_vocab_mods()
279
+
280
+ # Compute shifted IDs
281
+ shifted_ids = {}
282
+ for i in range(2, self.n + 1):
283
+ shifted_ids[i] = self._shift_right_ignore_eos(
284
+ context, i - 1, eos_token_id=self.config.eos_token_id
285
+ )
286
+
287
+ # Add N-gram embeddings
288
+ for i in range(2, self.n + 1):
289
+ for j in range(self.k):
290
+ index = (i - 2) * self.k + j
291
+ emb_vocab_dim = int(self.m + index * 2 + 1)
292
+
293
+ ngram_ids = self._get_ngram_ids(context, shifted_ids, vocab_mods[(i, j)], ngram=i)
294
+ new_ids = (ngram_ids % emb_vocab_dim)[..., -seq_len:]
295
+ text_mask = new_ids > 0
296
+
297
+ embedder_device = self.embedders[index].weight.device
298
+ x_ngram = self.embedders[index](new_ids.to(embedder_device), text_mask)
299
+
300
+ proj_device = self.post_projs[index].weight.device
301
+ x_proj = self.post_projs[index](x_ngram.to(proj_device))
302
+ x = x + x_proj.to(x.device)
303
+
304
+ # Normalize
305
+ x[~oe_ignored_mask] /= (1 + self.k * (self.n - 1))
306
+
307
+ return x
308
+
309
+
310
+ class LongcatFlashNgramModel(LongcatFlashModel):
311
+ """LongcatFlash model with N-gram enhanced embeddings."""
312
+ _keys_to_ignore_on_load_unexpected = [r"model\.mtp.*"]
313
+ config_class = LongcatFlashNgramConfig
314
+
315
+ def __init__(self, config):
316
+ super().__init__(config)
317
+
318
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
319
+ self.ngram_embeddings = NgramEmbedding(config, self.embed_tokens)
320
+
321
+ self.layers = nn.ModuleList(
322
+ [LongcatFlashDecoderLayer(config, layer_idx) for layer_idx in range(config.num_layers)]
323
+ )
324
+
325
+ self.head_dim = config.head_dim
326
+ self.config.num_hidden_layers = 2 * config.num_layers
327
+ self.norm = LongcatFlashRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
328
+ self.rotary_emb = LongcatFlashRotaryEmbedding(config=config)
329
+ self.gradient_checkpointing = False
330
+
331
+ self.post_init()
332
+
333
+ def forward(
334
+ self,
335
+ input_ids: Optional[torch.LongTensor] = None,
336
+ attention_mask: Optional[torch.Tensor] = None,
337
+ position_ids: Optional[torch.LongTensor] = None,
338
+ past_key_values: Optional[Cache] = None,
339
+ inputs_embeds: Optional[torch.FloatTensor] = None,
340
+ cache_position: Optional[torch.LongTensor] = None,
341
+ use_cache: Optional[bool] = None,
342
+ **kwargs
343
+ ) -> BaseModelOutputWithPast:
344
+ if (input_ids is None) ^ (inputs_embeds is not None):
345
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
346
+
347
+ # Extract N-gram context if available
348
+ ngram_context = None
349
+ if isinstance(past_key_values, NgramCache) and past_key_values.ngram_context is not None:
350
+ ngram_context = past_key_values.ngram_context
351
+
352
+ if inputs_embeds is None:
353
+ inputs_embeds = self.ngram_embeddings(input_ids, ngram_context=ngram_context)
354
+
355
+ # Initialize NgramCache if needed
356
+ if use_cache and past_key_values is None:
357
+ past_key_values = NgramCache(config=self.config)
358
+
359
+ # Update N-gram context
360
+ if use_cache and isinstance(past_key_values, NgramCache) and input_ids is not None:
361
+ past_key_values.update_ngram_context(input_ids)
362
+
363
+ # Prepare cache position
364
+ if cache_position is None:
365
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
366
+ cache_position = torch.arange(
367
+ inputs_embeds.shape[1], device=inputs_embeds.device
368
+ ) + past_seen_tokens
369
+
370
+ if position_ids is None:
371
+ position_ids = cache_position.unsqueeze(0)
372
+
373
+ # Create causal mask
374
+ causal_mask = create_causal_mask(
375
+ config=self.config,
376
+ input_embeds=inputs_embeds,
377
+ attention_mask=attention_mask,
378
+ cache_position=cache_position,
379
+ past_key_values=past_key_values,
380
+ position_ids=position_ids,
381
+ )
382
+
383
+ # Forward through decoder layers
384
+ hidden_states = inputs_embeds
385
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
386
+
387
+ for decoder_layer in self.layers[: self.config.num_layers]:
388
+ hidden_states = decoder_layer(
389
+ hidden_states,
390
+ attention_mask=causal_mask,
391
+ position_ids=position_ids,
392
+ past_key_values=past_key_values,
393
+ cache_position=cache_position,
394
+ position_embeddings=position_embeddings,
395
+ **kwargs,
396
+ )
397
+
398
+ hidden_states = self.norm(hidden_states)
399
+
400
+ return BaseModelOutputWithPast(
401
+ last_hidden_state=hidden_states,
402
+ past_key_values=past_key_values,
403
+ hidden_states=None,
404
+ attentions=None,
405
+ )
406
+
407
+
408
+ class LongcatFlashNgramForCausalLM(LongcatFlashForCausalLM):
409
+ """LongcatFlash model for causal language modeling with N-gram embeddings."""
410
+ _keys_to_ignore_on_load_unexpected = [r"model\.mtp.*"]
411
+ config_class = LongcatFlashNgramConfig
412
+
413
+ def __init__(self, config):
414
+ super().__init__(config)
415
+ self.model = LongcatFlashNgramModel(config)
416
+
417
+ @torch.no_grad()
418
+ def generate(self, inputs=None, generation_config=None, **kwargs):
419
+ """Override to ensure NgramCache is used."""
420
+
421
+ if "past_key_values" not in kwargs or kwargs["past_key_values"] is None:
422
+ kwargs["past_key_values"] = NgramCache(config=self.config)
423
+
424
+ return super().generate(inputs=inputs, generation_config=generation_config, **kwargs)
425
+
426
+ __all__ = ["LongcatFlashNgramPreTrainedModel", "LongcatFlashNgramModel", "LongcatFlashNgramForCausalLM"]
modular_longcat_next.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+ from flash_attn import flash_attn_varlen_func
6
+
7
+ from transformers.models.t5.modeling_t5 import T5LayerNorm as RMSNorm
8
+
9
+
10
+ class FlashVarLenAttention(nn.Module):
11
+ def __init__(self, embed_dim, num_heads, causal=False, window_size=(-1,-1)):
12
+ super().__init__()
13
+ self.embed_dim = embed_dim
14
+ self.num_heads = num_heads
15
+ self.head_dim = embed_dim // num_heads
16
+
17
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
18
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
19
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
20
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
21
+
22
+ self.causal = causal
23
+ self.window_size = window_size
24
+
25
+ def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor):
26
+ bsz, _ = hidden_states.size()
27
+
28
+ query_states = self.q_proj(hidden_states)
29
+ query_states = query_states.view(bsz, self.num_heads, self.head_dim).contiguous()
30
+ key_states = self.k_proj(hidden_states)
31
+ key_states = key_states.view(bsz, self.num_heads, self.head_dim).contiguous()
32
+ value_states = self.v_proj(hidden_states)
33
+ value_states = value_states.view(bsz, self.num_heads, self.head_dim).contiguous()
34
+
35
+ cu_len = F.pad(torch.cumsum(seq_len, dim=0), (1, 0), "constant", 0).to(torch.int32)
36
+ max_seqlen = torch.max(seq_len).to(torch.int32).detach()
37
+
38
+ attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_len, cu_len, max_seqlen,
39
+ max_seqlen, causal=self.causal, window_size=self.window_size) # (bsz * qlen, nheads, headdim)
40
+ attn_output = attn_output.reshape(bsz, self.embed_dim)
41
+ attn_output = self.out_proj(attn_output)
42
+ return attn_output
43
+
44
+
45
+
46
+ class CasualDepthTransformerLayer(nn.Module):
47
+ def __init__(self, depth, transformer_dim, transformer_ffn_scale):
48
+ super().__init__()
49
+ self.depth = depth
50
+ self.transformer_dim = transformer_dim
51
+ self.transformer_ffn_scale = transformer_ffn_scale
52
+ self.num_heads = self.transformer_dim // 128
53
+
54
+ assert self.transformer_dim % 128 == 0
55
+ assert self.transformer_dim % depth == 0
56
+
57
+ self.self_attention = FlashVarLenAttention(embed_dim=self.transformer_dim, num_heads=self.num_heads, causal=True)
58
+
59
+ self.layernorm1 = RMSNorm(self.transformer_dim)
60
+ self.layernorm2 = RMSNorm(self.transformer_dim)
61
+
62
+ self.linear1 = nn.Linear(self.transformer_dim, self.transformer_ffn_scale * self.transformer_dim)
63
+ self.linear2 = nn.Linear(self.transformer_ffn_scale * self.transformer_dim, self.transformer_dim)
64
+
65
+ def forward(self, x):
66
+ bsz = x.shape[0]
67
+ res = x
68
+ x = self.layernorm1(x)
69
+ seqlens = torch.tensor([self.depth] * bsz, dtype=torch.int32, device=x.device)
70
+ _x = self.self_attention(x.view(-1, self.transformer_dim), seqlens)
71
+ _x = _x.view(bsz, self.depth, self.transformer_dim).contiguous()
72
+
73
+ _res = _x + res # (bs, sl, d)
74
+ res = self.layernorm2(_res)
75
+ x = torch.einsum('bld,tld->blt', res, torch.reshape(self.linear1.weight, (self.transformer_ffn_scale * self.transformer_dim // self.depth, self.depth, self.transformer_dim)))
76
+ x = torch.nn.functional.gelu(x)
77
+ x = torch.einsum('blt,dlt->bld',x, torch.reshape(self.linear2.weight, (self.transformer_dim, self.depth, self.transformer_ffn_scale * self.transformer_dim // self.depth)))
78
+ return _res + x
79
+
80
+
81
+ class CasualDepthTransformerHead(nn.Module):
82
+ """
83
+ Depth-wise causal transformer head shared by image/audio heads.
84
+ """
85
+
86
+ def __init__(
87
+ self,
88
+ hidden_size,
89
+ codebook_sizes,
90
+ transformer_layer_num,
91
+ transformer_dim,
92
+ transformer_ffn_scale,
93
+ gradient_checkpointing=False,
94
+ ):
95
+ super().__init__()
96
+ self.hidden_size = hidden_size
97
+ self.codebook_sizes = codebook_sizes
98
+ self.transformer_ffn_scale = transformer_ffn_scale
99
+ self.gradient_checkpointing = gradient_checkpointing
100
+
101
+ if self.transformer_ffn_scale > 0:
102
+ self.hidden_norm = RMSNorm(self.hidden_size)
103
+ self.hidden_proj = nn.Linear(self.hidden_size, transformer_dim, bias=False)
104
+
105
+ self.transformer_layers = nn.ModuleList(
106
+ [
107
+ CasualDepthTransformerLayer(len(codebook_sizes), transformer_dim, transformer_ffn_scale)
108
+ for _ in range(transformer_layer_num)
109
+ ]
110
+ )
111
+ self.headnorm = RMSNorm(transformer_dim)
112
+ self.heads = nn.ModuleList(
113
+ [nn.Linear(transformer_dim, vq_size + 1) for vq_size in codebook_sizes]
114
+ )
115
+
116
+ for param in self.parameters():
117
+ param.requires_grad = False
118
+
119
+ def forward(self, x, visual_tokens, visual_emb_layers, level):
120
+ main_device = "cuda:0"
121
+ visual_tokens = visual_tokens.to(main_device)
122
+ visual_emb_layers = visual_emb_layers.to(main_device)
123
+
124
+ cumsum_visual_embed = torch.stack([
125
+ visual_emb_layers(visual_tokens[..., i])
126
+ for i, vq_size in enumerate(self.codebook_sizes[:-1])
127
+ ], dim=1).to(x.device)
128
+
129
+ cumsum_visual_embed = torch.cumsum(cumsum_visual_embed, dim=1) # (bs, depth-1, d)
130
+
131
+ hidden_states = torch.concat([x.reshape(-1, 1, self.hidden_size), cumsum_visual_embed], dim=1) # (bs, depth, d)
132
+ assert hidden_states.size(1) == len(self.codebook_sizes)
133
+
134
+ if self.transformer_ffn_scale > 0:
135
+ hidden_states = self.hidden_norm(hidden_states)
136
+ hidden_states = self.hidden_proj(hidden_states)
137
+
138
+ for i, tlayer in enumerate(self.transformer_layers):
139
+ if self.gradient_checkpointing and self.training:
140
+
141
+ def create_custom_forward(module):
142
+ def custom_forward(*inputs):
143
+ # None for past_key_value
144
+ return module(*inputs)
145
+
146
+ return custom_forward
147
+
148
+ hidden_states = torch.utils.checkpoint.checkpoint(
149
+ create_custom_forward(tlayer), hidden_states,
150
+ )
151
+ else:
152
+ hidden_states = tlayer(
153
+ hidden_states,
154
+ )
155
+ hidden_states = self.headnorm(hidden_states)
156
+ logits = self.heads[level](hidden_states[:, level])
157
+ return logits
modular_longcat_next_audio.py ADDED
@@ -0,0 +1,2039 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import copy
3
+ from abc import ABC
4
+ from dataclasses import dataclass
5
+ from typing import Any, Dict, Optional
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torchaudio
10
+ from einops import pack, rearrange, repeat
11
+ from flash_attn import flash_attn_varlen_func
12
+ from torch import nn
13
+ from torch.cuda.amp import autocast
14
+ from torch.nn import functional as F
15
+
16
+ from diffusers.models.activations import get_activation
17
+ from diffusers.models.attention import (
18
+ GEGLU,
19
+ GELU,
20
+ AdaLayerNorm,
21
+ AdaLayerNormZero,
22
+ ApproximateGELU,
23
+ )
24
+ from diffusers.models.attention_processor import Attention
25
+ from diffusers.models.lora import LoRACompatibleLinear
26
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
27
+
28
+ from transformers.activations import ACT2FN
29
+ from transformers.modeling_outputs import ModelOutput
30
+ from transformers.utils import logging
31
+
32
+ from .cosy24k_vocoder import Cosy24kVocoder
33
+
34
+ logger = logging.get_logger(__name__)
35
+
36
+
37
+ def sinusoids(length, channels, max_timescale=10000):
38
+ """Returns sinusoids for positional embedding"""
39
+ assert channels % 2 == 0
40
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
41
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
42
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
43
+ return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
44
+
45
+
46
+ def get_sequence_mask(inputs, inputs_length):
47
+ if inputs.dim() == 3:
48
+ bsz, tgt_len, _ = inputs.size()
49
+ else:
50
+ bsz, tgt_len = inputs_length.shape[0], torch.max(inputs_length)
51
+ sequence_mask = torch.arange(0, tgt_len).to(inputs.device)
52
+ sequence_mask = torch.lt(sequence_mask, inputs_length.reshape(bsz, 1)).view(bsz, tgt_len, 1)
53
+ unpacking_index = torch.cumsum(sequence_mask.to(torch.int64).view(-1), dim=0) - 1 # 转成下标
54
+ return sequence_mask, unpacking_index
55
+
56
+
57
+ def unpack_hidden_states(hidden_states, lengths):
58
+ bsz = lengths.shape[0]
59
+ sequence_mask, unpacking_index = get_sequence_mask(hidden_states, lengths)
60
+ hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(
61
+ bsz, torch.max(lengths), hidden_states.shape[-1]
62
+ )
63
+ hidden_states = torch.where(
64
+ sequence_mask, hidden_states, 0
65
+ ) # 3d (bsz, max_input_len, d)
66
+ return hidden_states
67
+
68
+
69
+ def uniform_init(*shape):
70
+ t = torch.zeros(shape)
71
+ nn.init.kaiming_uniform_(t)
72
+ return t
73
+
74
+
75
+ def cdist(x, y):
76
+ x2 = torch.sum(x ** 2, dim=-1, keepdims=True) # (b, 1)
77
+ y2 = torch.sum(y ** 2, dim=-1).reshape(1, -1) # (1, c)
78
+ xy = torch.einsum('bd,cd->bc', x, y) * -2
79
+ return (x2 + y2 + xy).clamp(min=0).sqrt() # (b, c)
80
+
81
+
82
+ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
83
+ assert mask.dtype == torch.bool
84
+ assert dtype in [torch.float32, torch.bfloat16, torch.float16]
85
+ mask = mask.to(dtype)
86
+ # attention mask bias
87
+ # NOTE(Mddct): torch.finfo jit issues
88
+ # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
89
+ mask = (1.0 - mask) * torch.finfo(dtype).min
90
+ return mask
91
+
92
+
93
+ def subsequent_chunk_mask(
94
+ size: int,
95
+ chunk_size: int,
96
+ num_left_chunks: int = -1,
97
+ device: torch.device = torch.device("cpu"),
98
+ ) -> torch.Tensor:
99
+ """Create mask for subsequent steps (size, size) with chunk size,
100
+ this is for streaming encoder
101
+
102
+ Args:
103
+ size (int): size of mask
104
+ chunk_size (int): size of chunk
105
+ num_left_chunks (int): number of left chunks
106
+ <0: use full chunk
107
+ >=0: use num_left_chunks
108
+ device (torch.device): "cpu" or "cuda" or torch.Tensor.device
109
+
110
+ Returns:
111
+ torch.Tensor: mask
112
+
113
+ Examples:
114
+ >>> subsequent_chunk_mask(4, 2)
115
+ [[1, 1, 0, 0],
116
+ [1, 1, 0, 0],
117
+ [1, 1, 1, 1],
118
+ [1, 1, 1, 1]]
119
+ """
120
+ # NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
121
+ # actually this is not needed after we have inference cache implemented, will remove it later
122
+ pos_idx = torch.arange(size, device=device)
123
+ block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
124
+ ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
125
+ return ret
126
+
127
+
128
+ def add_optional_chunk_mask(xs: torch.Tensor,
129
+ masks: torch.Tensor,
130
+ use_dynamic_chunk: bool,
131
+ use_dynamic_left_chunk: bool,
132
+ decoding_chunk_size: int,
133
+ static_chunk_size: int,
134
+ num_decoding_left_chunks: int,
135
+ enable_full_context: bool = True):
136
+ """ Apply optional mask for encoder.
137
+
138
+ Args:
139
+ xs (torch.Tensor): padded input, (B, L, D), L for max length
140
+ mask (torch.Tensor): mask for xs, (B, 1, L)
141
+ use_dynamic_chunk (bool): whether to use dynamic chunk or not
142
+ use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
143
+ training.
144
+ decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
145
+ 0: default for training, use random dynamic chunk.
146
+ <0: for decoding, use full chunk.
147
+ >0: for decoding, use fixed chunk size as set.
148
+ static_chunk_size (int): chunk size for static chunk training/decoding
149
+ if it's greater than 0, if use_dynamic_chunk is true,
150
+ this parameter will be ignored
151
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
152
+ the chunk size is decoding_chunk_size.
153
+ >=0: use num_decoding_left_chunks
154
+ <0: use all left chunks
155
+ enable_full_context (bool):
156
+ True: chunk size is either [1, 25] or full context(max_len)
157
+ False: chunk size ~ U[1, 25]
158
+
159
+ Returns:
160
+ torch.Tensor: chunk mask of the input xs.
161
+ """
162
+ # Whether to use chunk mask or not
163
+ if use_dynamic_chunk:
164
+ max_len = xs.size(1)
165
+ if decoding_chunk_size < 0:
166
+ chunk_size = max_len
167
+ num_left_chunks = -1
168
+ elif decoding_chunk_size > 0:
169
+ chunk_size = decoding_chunk_size
170
+ num_left_chunks = num_decoding_left_chunks
171
+ else:
172
+ # chunk size is either [1, 25] or full context(max_len).
173
+ # Since we use 4 times subsampling and allow up to 1s(100 frames)
174
+ # delay, the maximum frame is 100 / 4 = 25.
175
+ chunk_size = torch.randint(1, max_len, (1, )).item()
176
+ num_left_chunks = -1
177
+ if chunk_size > max_len // 2 and enable_full_context:
178
+ chunk_size = max_len
179
+ else:
180
+ chunk_size = chunk_size % 25 + 1
181
+ if use_dynamic_left_chunk:
182
+ max_left_chunks = (max_len - 1) // chunk_size
183
+ num_left_chunks = torch.randint(0, max_left_chunks,
184
+ (1, )).item()
185
+ chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
186
+ num_left_chunks,
187
+ xs.device) # (L, L)
188
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
189
+ chunk_masks = masks & chunk_masks # (B, L, L)
190
+ elif static_chunk_size > 0:
191
+ num_left_chunks = num_decoding_left_chunks
192
+ chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
193
+ num_left_chunks,
194
+ xs.device) # (L, L)
195
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
196
+ chunk_masks = masks & chunk_masks # (B, L, L)
197
+ else:
198
+ chunk_masks = masks
199
+ return chunk_masks
200
+
201
+
202
+ class EuclideanCodebook(nn.Module):
203
+ def __init__(
204
+ self,
205
+ dim,
206
+ codebook_size,
207
+ init_std=0.02,
208
+ ):
209
+ super().__init__()
210
+ self.init_std = init_std
211
+ self.dim = dim
212
+ self.codebook_size = codebook_size
213
+
214
+ embed = uniform_init(codebook_size, dim).to(torch.float32)
215
+ self.cluster_size = nn.Parameter(torch.ones(codebook_size))
216
+ self.embed_avg = nn.Parameter(embed.clone())
217
+ self.embed = nn.Parameter(embed)
218
+ del embed
219
+
220
+ @autocast(enabled=True, dtype=torch.float32)
221
+ @torch.no_grad()
222
+ def forward(self, x):
223
+ assert(len(x.shape) == 2)
224
+ assert(x.dtype == torch.float32)
225
+ embed = self.embed.detach().to(x.device)
226
+ dist = -cdist(x, embed) # dist((bs*sl, d), (c, d)) --> (bs*sl, c)
227
+ embed_ind = dist.argmax(dim=-1)
228
+ quantize = embed[embed_ind] # (bs*sl, d)
229
+ return quantize, embed_ind, dist
230
+
231
+
232
+ class VectorQuantize(nn.Module):
233
+ def __init__(self, config, *args, **kwargs):
234
+ super().__init__(*args, **kwargs)
235
+ self.config = config
236
+ self.codebook = EuclideanCodebook(dim=config.dim, codebook_size=config.codebook_size)
237
+
238
+ def forward(self, x, input_length):
239
+ batch_size, seq_len, _ = x.shape
240
+ mask, unpacking_index = get_sequence_mask(x, input_length)
241
+ if x.dtype != torch.float32:
242
+ x = x.to(torch.float32)
243
+ x = torch.masked_select(x, mask).reshape(-1, self.config.dim) # (bs*sl?, d)
244
+ quantize, embed_ind, _ = self.codebook(x)
245
+ quantize = torch.index_select(quantize, 0, unpacking_index).view(batch_size, seq_len, self.config.dim)
246
+ quantize = torch.where(mask, quantize, 0)
247
+ embed_ind = torch.index_select(embed_ind.reshape(-1, 1), 0, unpacking_index).view(batch_size, seq_len, 1)
248
+ embed_ind = torch.where(mask, embed_ind, -1).squeeze()
249
+ return quantize, embed_ind
250
+
251
+ def get_output_from_indices(self, indices):
252
+ indices = indices.to(self.codebook.embed.device)
253
+ return self.codebook.embed[indices]
254
+
255
+
256
+ class SnakeBeta(nn.Module):
257
+ """
258
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
259
+ Shape:
260
+ - Input: (B, C, T)
261
+ - Output: (B, C, T), same shape as the input
262
+ Parameters:
263
+ - alpha - trainable parameter that controls frequency
264
+ - beta - trainable parameter that controls magnitude
265
+ References:
266
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
267
+ https://arxiv.org/abs/2006.08195
268
+ Examples:
269
+ >>> a1 = snakebeta(256)
270
+ >>> x = torch.randn(256)
271
+ >>> x = a1(x)
272
+ """
273
+
274
+ def __init__(
275
+ self,
276
+ in_features,
277
+ out_features,
278
+ alpha=1.0,
279
+ alpha_trainable=True,
280
+ alpha_logscale=True,
281
+ ):
282
+ """
283
+ Initialization.
284
+ INPUT:
285
+ - in_features: shape of the input
286
+ - alpha - trainable parameter that controls frequency
287
+ - beta - trainable parameter that controls magnitude
288
+ alpha is initialized to 1 by default, higher values = higher-frequency.
289
+ beta is initialized to 1 by default, higher values = higher-magnitude.
290
+ alpha will be trained along with the rest of your model.
291
+ """
292
+ super().__init__()
293
+ self.in_features = (
294
+ out_features if isinstance(out_features, list) else [out_features]
295
+ )
296
+ self.proj = LoRACompatibleLinear(in_features, out_features)
297
+
298
+ # initialize alpha
299
+ self.alpha_logscale = alpha_logscale
300
+ if self.alpha_logscale: # log scale alphas initialized to zeros
301
+ self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha)
302
+ self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha)
303
+ else: # linear scale alphas initialized to ones
304
+ self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha)
305
+ self.beta = nn.Parameter(torch.ones(self.in_features) * alpha)
306
+
307
+ self.alpha.requires_grad = alpha_trainable
308
+ self.beta.requires_grad = alpha_trainable
309
+
310
+ self.no_div_by_zero = 0.000000001
311
+
312
+ def forward(self, x):
313
+ """
314
+ Forward pass of the function.
315
+ Applies the function to the input elementwise.
316
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
317
+ """
318
+ x = self.proj(x)
319
+ if self.alpha_logscale:
320
+ alpha = torch.exp(self.alpha)
321
+ beta = torch.exp(self.beta)
322
+ else:
323
+ alpha = self.alpha
324
+ beta = self.beta
325
+
326
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(
327
+ torch.sin(x * alpha), 2
328
+ )
329
+
330
+ return x
331
+
332
+
333
+ class FeedForward(nn.Module):
334
+ r"""
335
+ A feed-forward layer.
336
+
337
+ Parameters:
338
+ dim (`int`): The number of channels in the input.
339
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
340
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
341
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
342
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
343
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
344
+ """
345
+
346
+ def __init__(
347
+ self,
348
+ dim: int,
349
+ dim_out: Optional[int] = None,
350
+ mult: int = 4,
351
+ dropout: float = 0.0,
352
+ activation_fn: str = "geglu",
353
+ final_dropout: bool = False,
354
+ ):
355
+ super().__init__()
356
+ inner_dim = int(dim * mult)
357
+ dim_out = dim_out if dim_out is not None else dim
358
+
359
+ if activation_fn == "gelu":
360
+ act_fn = GELU(dim, inner_dim)
361
+ if activation_fn == "gelu-approximate":
362
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
363
+ elif activation_fn == "geglu":
364
+ act_fn = GEGLU(dim, inner_dim)
365
+ elif activation_fn == "geglu-approximate":
366
+ act_fn = ApproximateGELU(dim, inner_dim)
367
+ elif activation_fn == "snakebeta":
368
+ act_fn = SnakeBeta(dim, inner_dim)
369
+
370
+ self.net = nn.ModuleList([])
371
+ # project in
372
+ self.net.append(act_fn)
373
+ # project dropout
374
+ self.net.append(nn.Dropout(dropout))
375
+ # project out
376
+ self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
377
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
378
+ if final_dropout:
379
+ self.net.append(nn.Dropout(dropout))
380
+
381
+ def forward(self, hidden_states):
382
+ for module in self.net:
383
+ hidden_states = module(hidden_states)
384
+ return hidden_states
385
+
386
+
387
+ @maybe_allow_in_graph
388
+ class BasicTransformerBlock(nn.Module):
389
+ r"""
390
+ A basic Transformer block.
391
+
392
+ Parameters:
393
+ dim (`int`): The number of channels in the input and output.
394
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
395
+ attention_head_dim (`int`): The number of channels in each head.
396
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
397
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
398
+ only_cross_attention (`bool`, *optional*):
399
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
400
+ double_self_attention (`bool`, *optional*):
401
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
402
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
403
+ num_embeds_ada_norm (:
404
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
405
+ attention_bias (:
406
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
407
+ """
408
+
409
+ def __init__(
410
+ self,
411
+ dim: int,
412
+ num_attention_heads: int,
413
+ attention_head_dim: int,
414
+ dropout=0.0,
415
+ cross_attention_dim: Optional[int] = None,
416
+ activation_fn: str = "geglu",
417
+ num_embeds_ada_norm: Optional[int] = None,
418
+ attention_bias: bool = False,
419
+ only_cross_attention: bool = False,
420
+ double_self_attention: bool = False,
421
+ upcast_attention: bool = False,
422
+ norm_elementwise_affine: bool = True,
423
+ norm_type: str = "layer_norm",
424
+ final_dropout: bool = False,
425
+ use_omni_attn: bool = False,
426
+ ):
427
+ super().__init__()
428
+
429
+ self.use_omni_attn = use_omni_attn
430
+ self.dim = dim
431
+
432
+ self.only_cross_attention = only_cross_attention
433
+
434
+ self.use_ada_layer_norm_zero = (
435
+ num_embeds_ada_norm is not None
436
+ ) and norm_type == "ada_norm_zero"
437
+ self.use_ada_layer_norm = (
438
+ num_embeds_ada_norm is not None
439
+ ) and norm_type == "ada_norm"
440
+
441
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
442
+ raise ValueError(
443
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
444
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
445
+ )
446
+
447
+ # Define 3 blocks. Each block has its own normalization layer.
448
+ # 1. Self-Attn
449
+ if self.use_ada_layer_norm:
450
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
451
+ elif self.use_ada_layer_norm_zero:
452
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
453
+ else:
454
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
455
+
456
+ if self.use_omni_attn:
457
+ if only_cross_attention:
458
+ raise NotImplementedError
459
+ print(
460
+ "Use OmniWhisperAttention with flash attention. Dropout is ignored."
461
+ )
462
+ self.attn1 = OmniWhisperAttention(
463
+ embed_dim=dim, num_heads=num_attention_heads, causal=False
464
+ )
465
+ else:
466
+ self.attn1 = Attention(
467
+ query_dim=dim,
468
+ heads=num_attention_heads,
469
+ dim_head=attention_head_dim,
470
+ dropout=dropout,
471
+ bias=attention_bias,
472
+ cross_attention_dim=(
473
+ cross_attention_dim if only_cross_attention else None
474
+ ),
475
+ upcast_attention=upcast_attention,
476
+ )
477
+
478
+ # 2. Cross-Attn
479
+ if cross_attention_dim is not None or double_self_attention:
480
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
481
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
482
+ # the second cross attention block.
483
+ self.norm2 = (
484
+ AdaLayerNorm(dim, num_embeds_ada_norm)
485
+ if self.use_ada_layer_norm
486
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
487
+ )
488
+ self.attn2 = Attention(
489
+ query_dim=dim,
490
+ cross_attention_dim=(
491
+ cross_attention_dim if not double_self_attention else None
492
+ ),
493
+ heads=num_attention_heads,
494
+ dim_head=attention_head_dim,
495
+ dropout=dropout,
496
+ bias=attention_bias,
497
+ upcast_attention=upcast_attention,
498
+ # scale_qk=False, # uncomment this to not to use flash attention
499
+ ) # is self-attn if encoder_hidden_states is none
500
+ else:
501
+ self.norm2 = None
502
+ self.attn2 = None
503
+
504
+ # 3. Feed-forward
505
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
506
+ self.ff = FeedForward(
507
+ dim,
508
+ dropout=dropout,
509
+ activation_fn=activation_fn,
510
+ final_dropout=final_dropout,
511
+ )
512
+
513
+ # let chunk size default to None
514
+ self._chunk_size = None
515
+ self._chunk_dim = 0
516
+
517
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
518
+ # Sets chunk feed-forward
519
+ self._chunk_size = chunk_size
520
+ self._chunk_dim = dim
521
+
522
+ def forward(
523
+ self,
524
+ hidden_states: torch.FloatTensor,
525
+ attention_mask: Optional[torch.FloatTensor] = None,
526
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
527
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
528
+ timestep: Optional[torch.LongTensor] = None,
529
+ cross_attention_kwargs: Dict[str, Any] = None,
530
+ class_labels: Optional[torch.LongTensor] = None,
531
+ ):
532
+
533
+ bsz, tgt_len, d_model = hidden_states.shape
534
+
535
+ # Notice that normalization is always applied before the real computation in the following blocks.
536
+ # 1. Self-Attention
537
+ if self.use_ada_layer_norm:
538
+ norm_hidden_states = self.norm1(hidden_states, timestep)
539
+ elif self.use_ada_layer_norm_zero:
540
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
541
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
542
+ )
543
+ else:
544
+ norm_hidden_states = self.norm1(hidden_states)
545
+
546
+ cross_attention_kwargs = (
547
+ cross_attention_kwargs if cross_attention_kwargs is not None else {}
548
+ )
549
+
550
+ if self.use_omni_attn:
551
+ seq_len = attention_mask[:, 0, :].float().long().sum(dim=1)
552
+ var_len_attention_mask, unpacking_index = get_sequence_mask(
553
+ norm_hidden_states, seq_len
554
+ )
555
+ norm_hidden_states = torch.masked_select(
556
+ norm_hidden_states, var_len_attention_mask
557
+ )
558
+ norm_hidden_states = norm_hidden_states.view(torch.sum(seq_len), self.dim)
559
+ attn_output = self.attn1(norm_hidden_states, seq_len)
560
+ # unpacking
561
+ attn_output = torch.index_select(attn_output, 0, unpacking_index).view(
562
+ bsz, tgt_len, d_model
563
+ )
564
+ attn_output = torch.where(var_len_attention_mask, attn_output, 0)
565
+ else:
566
+ attn_output = self.attn1(
567
+ norm_hidden_states,
568
+ encoder_hidden_states=(
569
+ encoder_hidden_states if self.only_cross_attention else None
570
+ ),
571
+ attention_mask=(
572
+ encoder_attention_mask
573
+ if self.only_cross_attention
574
+ else attention_mask
575
+ ),
576
+ **cross_attention_kwargs,
577
+ )
578
+
579
+ if self.use_ada_layer_norm_zero:
580
+ attn_output = gate_msa.unsqueeze(1) * attn_output
581
+ hidden_states = attn_output + hidden_states
582
+
583
+ # 2. Cross-Attention
584
+ if self.attn2 is not None:
585
+ norm_hidden_states = (
586
+ self.norm2(hidden_states, timestep)
587
+ if self.use_ada_layer_norm
588
+ else self.norm2(hidden_states)
589
+ )
590
+
591
+ attn_output = self.attn2(
592
+ norm_hidden_states,
593
+ encoder_hidden_states=encoder_hidden_states,
594
+ attention_mask=encoder_attention_mask,
595
+ **cross_attention_kwargs,
596
+ )
597
+ hidden_states = attn_output + hidden_states
598
+
599
+ # 3. Feed-forward
600
+ norm_hidden_states = self.norm3(hidden_states)
601
+
602
+ if self.use_ada_layer_norm_zero:
603
+ norm_hidden_states = (
604
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
605
+ )
606
+
607
+ if self._chunk_size is not None:
608
+ # "feed_forward_chunk_size" can be used to save memory
609
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
610
+ raise ValueError(
611
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
612
+ )
613
+
614
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
615
+ ff_output = torch.cat(
616
+ [
617
+ self.ff(hid_slice)
618
+ for hid_slice in norm_hidden_states.chunk(
619
+ num_chunks, dim=self._chunk_dim
620
+ )
621
+ ],
622
+ dim=self._chunk_dim,
623
+ )
624
+ else:
625
+ ff_output = self.ff(norm_hidden_states)
626
+
627
+ if self.use_ada_layer_norm_zero:
628
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
629
+
630
+ hidden_states = ff_output + hidden_states
631
+
632
+ return hidden_states
633
+
634
+
635
+ class Transpose(torch.nn.Module):
636
+ def __init__(self, dim0: int, dim1: int):
637
+ super().__init__()
638
+ self.dim0 = dim0
639
+ self.dim1 = dim1
640
+
641
+ def forward(self, x: torch.Tensor):
642
+ x = torch.transpose(x, self.dim0, self.dim1)
643
+ return x
644
+
645
+
646
+ class Block1D(torch.nn.Module):
647
+ def __init__(self, dim, dim_out, groups=8):
648
+ super().__init__()
649
+ self.block = torch.nn.Sequential(
650
+ torch.nn.Conv1d(dim, dim_out, 3, padding=1),
651
+ torch.nn.GroupNorm(groups, dim_out),
652
+ nn.Mish(),
653
+ )
654
+
655
+ def forward(self, x, mask):
656
+ output = self.block(x * mask)
657
+ return output * mask
658
+
659
+
660
+ class ResnetBlock1D(torch.nn.Module):
661
+ def __init__(self, dim, dim_out, time_emb_dim, groups=8):
662
+ super().__init__()
663
+ self.mlp = torch.nn.Sequential(
664
+ nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out)
665
+ )
666
+
667
+ self.block1 = Block1D(dim, dim_out, groups=groups)
668
+ self.block2 = Block1D(dim_out, dim_out, groups=groups)
669
+
670
+ self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
671
+
672
+ def forward(self, x, mask, time_emb):
673
+ h = self.block1(x, mask)
674
+ h += self.mlp(time_emb).unsqueeze(-1)
675
+ h = self.block2(h, mask)
676
+ output = h + self.res_conv(x * mask)
677
+ return output
678
+
679
+
680
+ class CausalBlock1D(Block1D):
681
+ def __init__(self, dim: int, dim_out: int):
682
+ super(CausalBlock1D, self).__init__(dim, dim_out)
683
+ self.block = torch.nn.Sequential(
684
+ CausalConv1d(dim, dim_out, 3),
685
+ Transpose(1, 2),
686
+ nn.LayerNorm(dim_out),
687
+ Transpose(1, 2),
688
+ nn.Mish(),
689
+ )
690
+
691
+ def forward(self, x: torch.Tensor, mask: torch.Tensor):
692
+ output = self.block(x * mask)
693
+ return output * mask
694
+
695
+
696
+ class CausalResnetBlock1D(ResnetBlock1D):
697
+ def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
698
+ super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
699
+ self.block1 = CausalBlock1D(dim, dim_out)
700
+ self.block2 = CausalBlock1D(dim_out, dim_out)
701
+
702
+
703
+ class CausalConv1d(torch.nn.Conv1d):
704
+ def __init__(
705
+ self,
706
+ in_channels: int,
707
+ out_channels: int,
708
+ kernel_size: int,
709
+ stride: int = 1,
710
+ dilation: int = 1,
711
+ groups: int = 1,
712
+ bias: bool = True,
713
+ padding_mode: str = 'zeros',
714
+ device=None,
715
+ dtype=None
716
+ ) -> None:
717
+ super(CausalConv1d, self).__init__(in_channels, out_channels,
718
+ kernel_size, stride,
719
+ padding=0, dilation=dilation,
720
+ groups=groups, bias=bias,
721
+ padding_mode=padding_mode,
722
+ device=device, dtype=dtype)
723
+ assert stride == 1
724
+ self.causal_padding = (kernel_size - 1, 0)
725
+
726
+ def forward(self, x: torch.Tensor):
727
+ x = F.pad(x, self.causal_padding)
728
+ x = super(CausalConv1d, self).forward(x)
729
+ return x
730
+
731
+
732
+ class BASECFM(torch.nn.Module, ABC):
733
+ def __init__(
734
+ self,
735
+ n_feats,
736
+ cfm_params,
737
+ n_spks=1,
738
+ spk_emb_dim=128,
739
+ ):
740
+ super().__init__()
741
+ self.n_feats = n_feats
742
+ self.n_spks = n_spks
743
+ self.spk_emb_dim = spk_emb_dim
744
+ self.solver = cfm_params.solver
745
+ if hasattr(cfm_params, "sigma_min"):
746
+ self.sigma_min = cfm_params.sigma_min
747
+ else:
748
+ self.sigma_min = 1e-4
749
+
750
+ self.estimator = None
751
+
752
+ @torch.inference_mode()
753
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
754
+ """Forward diffusion
755
+
756
+ Args:
757
+ mu (torch.Tensor): output of encoder
758
+ shape: (batch_size, n_feats, mel_timesteps)
759
+ mask (torch.Tensor): output_mask
760
+ shape: (batch_size, 1, mel_timesteps)
761
+ n_timesteps (int): number of diffusion steps
762
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
763
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
764
+ shape: (batch_size, spk_emb_dim)
765
+ cond: Not used but kept for future purposes
766
+
767
+ Returns:
768
+ sample: generated mel-spectrogram
769
+ shape: (batch_size, n_feats, mel_timesteps)
770
+ """
771
+ z = torch.randn_like(mu) * temperature
772
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
773
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
774
+
775
+ def solve_euler(self, x, t_span, mu, mask, spks, cond):
776
+ """
777
+ Fixed euler solver for ODEs.
778
+ Args:
779
+ x (torch.Tensor): random noise
780
+ t_span (torch.Tensor): n_timesteps interpolated
781
+ shape: (n_timesteps + 1,)
782
+ mu (torch.Tensor): output of encoder
783
+ shape: (batch_size, n_feats, mel_timesteps)
784
+ mask (torch.Tensor): output_mask
785
+ shape: (batch_size, 1, mel_timesteps)
786
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
787
+ shape: (batch_size, spk_emb_dim)
788
+ cond: Not used but kept for future purposes
789
+ """
790
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
791
+
792
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
793
+ # Or in future might add like a return_all_steps flag
794
+ sol = []
795
+
796
+ for step in range(1, len(t_span)):
797
+ dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
798
+
799
+ x = x + dt * dphi_dt
800
+ t = t + dt
801
+ sol.append(x)
802
+ if step < len(t_span) - 1:
803
+ dt = t_span[step + 1] - t
804
+
805
+ return sol[-1]
806
+
807
+ def compute_loss(self, x1, mask, mu, spks=None, cond=None):
808
+ """Computes diffusion loss
809
+
810
+ Args:
811
+ x1 (torch.Tensor): Target
812
+ shape: (batch_size, n_feats, mel_timesteps)
813
+ mask (torch.Tensor): target mask
814
+ shape: (batch_size, 1, mel_timesteps)
815
+ mu (torch.Tensor): output of encoder
816
+ shape: (batch_size, n_feats, mel_timesteps)
817
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
818
+ shape: (batch_size, spk_emb_dim)
819
+
820
+ Returns:
821
+ loss: conditional flow matching loss
822
+ y: conditional flow
823
+ shape: (batch_size, n_feats, mel_timesteps)
824
+ """
825
+ b, _, t = mu.shape
826
+
827
+ # random timestep
828
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
829
+ # sample noise p(x_0)
830
+ z = torch.randn_like(x1)
831
+
832
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
833
+ u = x1 - (1 - self.sigma_min) * z
834
+
835
+ loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / (
836
+ torch.sum(mask) * u.shape[1]
837
+ )
838
+ return loss, y
839
+
840
+
841
+ class ConditionalDecoder(nn.Module):
842
+ def __init__(
843
+ self,
844
+ in_channels,
845
+ out_channels,
846
+ causal=False,
847
+ channels=(256, 256),
848
+ dropout=0.05,
849
+ attention_head_dim=64,
850
+ n_blocks=1,
851
+ num_mid_blocks=2,
852
+ num_heads=4,
853
+ act_fn="snake",
854
+ gradient_checkpointing=False,
855
+ ):
856
+ """
857
+ This decoder requires an input with the same shape of the target. So, if your text content
858
+ is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
859
+ """
860
+ super().__init__()
861
+ channels = tuple(channels)
862
+ self.in_channels = in_channels
863
+ self.out_channels = out_channels
864
+ self.causal = causal
865
+ self.static_chunk_size = 2 * 25 * 2 # 2*input_frame_rate*token_mel_ratio
866
+ self.gradient_checkpointing = gradient_checkpointing
867
+
868
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
869
+ time_embed_dim = channels[0] * 4
870
+ self.time_mlp = TimestepEmbedding(
871
+ in_channels=in_channels,
872
+ time_embed_dim=time_embed_dim,
873
+ act_fn="silu",
874
+ )
875
+ self.down_blocks = nn.ModuleList([])
876
+ self.mid_blocks = nn.ModuleList([])
877
+ self.up_blocks = nn.ModuleList([])
878
+
879
+ output_channel = in_channels
880
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
881
+ input_channel = output_channel
882
+ output_channel = channels[i]
883
+ is_last = i == len(channels) - 1
884
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
885
+ ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
886
+ transformer_blocks = nn.ModuleList(
887
+ [
888
+ BasicTransformerBlock(
889
+ dim=output_channel,
890
+ num_attention_heads=num_heads,
891
+ attention_head_dim=attention_head_dim,
892
+ dropout=dropout,
893
+ activation_fn=act_fn,
894
+ )
895
+ for _ in range(n_blocks)
896
+ ]
897
+ )
898
+ downsample = (
899
+ Downsample1D(output_channel) if not is_last else
900
+ CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
901
+ )
902
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
903
+
904
+ for _ in range(num_mid_blocks):
905
+ input_channel = channels[-1]
906
+ out_channels = channels[-1]
907
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
908
+ ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
909
+
910
+ transformer_blocks = nn.ModuleList(
911
+ [
912
+ BasicTransformerBlock(
913
+ dim=output_channel,
914
+ num_attention_heads=num_heads,
915
+ attention_head_dim=attention_head_dim,
916
+ dropout=dropout,
917
+ activation_fn=act_fn,
918
+ )
919
+ for _ in range(n_blocks)
920
+ ]
921
+ )
922
+
923
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
924
+
925
+ channels = channels[::-1] + (channels[0],)
926
+ for i in range(len(channels) - 1):
927
+ input_channel = channels[i] * 2
928
+ output_channel = channels[i + 1]
929
+ is_last = i == len(channels) - 2
930
+ resnet = CausalResnetBlock1D(
931
+ dim=input_channel,
932
+ dim_out=output_channel,
933
+ time_emb_dim=time_embed_dim,
934
+ ) if self.causal else ResnetBlock1D(
935
+ dim=input_channel,
936
+ dim_out=output_channel,
937
+ time_emb_dim=time_embed_dim,
938
+ )
939
+ transformer_blocks = nn.ModuleList(
940
+ [
941
+ BasicTransformerBlock(
942
+ dim=output_channel,
943
+ num_attention_heads=num_heads,
944
+ attention_head_dim=attention_head_dim,
945
+ dropout=dropout,
946
+ activation_fn=act_fn,
947
+ )
948
+ for _ in range(n_blocks)
949
+ ]
950
+ )
951
+ upsample = (
952
+ Upsample1D(output_channel, use_conv_transpose=True)
953
+ if not is_last
954
+ else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
955
+ )
956
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
957
+ self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
958
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
959
+ self.initialize_weights()
960
+
961
+ def initialize_weights(self):
962
+ for m in self.modules():
963
+ if isinstance(m, nn.Conv1d):
964
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
965
+ if m.bias is not None:
966
+ nn.init.constant_(m.bias, 0)
967
+ elif isinstance(m, nn.GroupNorm):
968
+ nn.init.constant_(m.weight, 1)
969
+ nn.init.constant_(m.bias, 0)
970
+ elif isinstance(m, nn.Linear):
971
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
972
+ if m.bias is not None:
973
+ nn.init.constant_(m.bias, 0)
974
+
975
+ def forward(self, x, mask, mu, t, spks=None, cond=None):
976
+ """Forward pass of the UNet1DConditional model.
977
+
978
+ Args:
979
+ x (torch.Tensor): shape (batch_size, in_channels, time)
980
+ mask (_type_): shape (batch_size, 1, time)
981
+ t (_type_): shape (batch_size)
982
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
983
+ cond (_type_, optional): placeholder for future use. Defaults to None.
984
+
985
+ Raises:
986
+ ValueError: _description_
987
+ ValueError: _description_
988
+
989
+ Returns:
990
+ _type_: _description_
991
+ """
992
+ t = self.time_embeddings(t)
993
+ t = t.to(x.dtype)
994
+ t = self.time_mlp(t)
995
+ x = pack([x, mu], "b * t")[0]
996
+ mask = mask.to(x.dtype)
997
+ if spks is not None:
998
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
999
+ x = pack([x, spks], "b * t")[0]
1000
+ if cond is not None:
1001
+ x = pack([x, cond], "b * t")[0]
1002
+
1003
+ hiddens = []
1004
+ masks = [mask]
1005
+ for resnet, transformer_blocks, downsample in self.down_blocks:
1006
+ mask_down = masks[-1]
1007
+ x = resnet(x, mask_down, t)
1008
+ x = rearrange(x, "b c t -> b t c").contiguous()
1009
+ # attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
1010
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
1011
+ attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
1012
+ for transformer_block in transformer_blocks:
1013
+ if self.gradient_checkpointing and self.training:
1014
+ def create_custom_forward(module):
1015
+ def custom_forward(*inputs):
1016
+ return module(*inputs)
1017
+ return custom_forward
1018
+ x = torch.utils.checkpoint.checkpoint(
1019
+ create_custom_forward(transformer_block),
1020
+ x,
1021
+ attn_mask,
1022
+ t,
1023
+ )
1024
+ else:
1025
+ x = transformer_block(
1026
+ hidden_states=x,
1027
+ attention_mask=attn_mask,
1028
+ timestep=t,
1029
+ )
1030
+ x = rearrange(x, "b t c -> b c t").contiguous()
1031
+ hiddens.append(x) # Save hidden states for skip connections
1032
+ x = downsample(x * mask_down)
1033
+ masks.append(mask_down[:, :, ::2])
1034
+ masks = masks[:-1]
1035
+ mask_mid = masks[-1]
1036
+
1037
+ for resnet, transformer_blocks in self.mid_blocks:
1038
+ x = resnet(x, mask_mid, t)
1039
+ x = rearrange(x, "b c t -> b t c").contiguous()
1040
+ # attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
1041
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
1042
+ attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
1043
+ for transformer_block in transformer_blocks:
1044
+ if self.gradient_checkpointing and self.training:
1045
+ def create_custom_forward(module):
1046
+ def custom_forward(*inputs):
1047
+ return module(*inputs)
1048
+ return custom_forward
1049
+ x = torch.utils.checkpoint.checkpoint(
1050
+ create_custom_forward(transformer_block),
1051
+ x,
1052
+ attn_mask,
1053
+ t,
1054
+ )
1055
+ else:
1056
+ x = transformer_block(
1057
+ hidden_states=x,
1058
+ attention_mask=attn_mask,
1059
+ timestep=t,
1060
+ )
1061
+ x = rearrange(x, "b t c -> b c t").contiguous()
1062
+
1063
+ for resnet, transformer_blocks, upsample in self.up_blocks:
1064
+ mask_up = masks.pop()
1065
+ skip = hiddens.pop()
1066
+ x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
1067
+ x = resnet(x, mask_up, t)
1068
+ x = rearrange(x, "b c t -> b t c").contiguous()
1069
+ # attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
1070
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
1071
+ attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
1072
+ for transformer_block in transformer_blocks:
1073
+ if self.gradient_checkpointing and self.training:
1074
+ def create_custom_forward(module):
1075
+ def custom_forward(*inputs):
1076
+ return module(*inputs)
1077
+ return custom_forward
1078
+ x = torch.utils.checkpoint.checkpoint(
1079
+ create_custom_forward(transformer_block),
1080
+ x,
1081
+ attn_mask,
1082
+ t,
1083
+ )
1084
+ else:
1085
+ x = transformer_block(
1086
+ hidden_states=x,
1087
+ attention_mask=attn_mask,
1088
+ timestep=t,
1089
+ )
1090
+ x = rearrange(x, "b t c -> b c t").contiguous()
1091
+ x = upsample(x * mask_up)
1092
+ x = self.final_block(x, mask_up)
1093
+ output = self.final_proj(x * mask_up)
1094
+ return output * mask
1095
+
1096
+
1097
+ class ConditionalCFM(BASECFM):
1098
+ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64):
1099
+ super().__init__(
1100
+ n_feats=in_channels,
1101
+ cfm_params=cfm_params,
1102
+ n_spks=n_spks,
1103
+ spk_emb_dim=spk_emb_dim,
1104
+ )
1105
+ self.t_scheduler = cfm_params.t_scheduler
1106
+ self.training_cfg_rate = cfm_params.training_cfg_rate
1107
+ self.inference_cfg_rate = cfm_params.inference_cfg_rate
1108
+
1109
+ @torch.inference_mode()
1110
+ def forward(self, estimator, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
1111
+ """Forward diffusion
1112
+
1113
+ Args:
1114
+ mu (torch.Tensor): output of encoder
1115
+ shape: (batch_size, n_feats, mel_timesteps)
1116
+ mask (torch.Tensor): output_mask
1117
+ shape: (batch_size, 1, mel_timesteps)
1118
+ n_timesteps (int): number of diffusion steps
1119
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
1120
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
1121
+ shape: (batch_size, spk_emb_dim)
1122
+ cond: Not used but kept for future purposes
1123
+
1124
+ Returns:
1125
+ sample: generated mel-spectrogram
1126
+ shape: (batch_size, n_feats, mel_timesteps)
1127
+ """
1128
+ z = torch.randn_like(mu) * temperature
1129
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
1130
+ if self.t_scheduler == 'cosine':
1131
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
1132
+ return self.solve_euler(estimator, z, t_span=t_span.to(mu.dtype), mu=mu, mask=mask, spks=spks, cond=cond)
1133
+
1134
+ def solve_euler(self, estimator, x, t_span, mu, mask, spks, cond):
1135
+ """
1136
+ Fixed euler solver for ODEs.
1137
+ Args:
1138
+ x (torch.Tensor): random noise
1139
+ t_span (torch.Tensor): n_timesteps interpolated
1140
+ shape: (n_timesteps + 1,)
1141
+ mu (torch.Tensor): output of encoder
1142
+ shape: (batch_size, n_feats, mel_timesteps)
1143
+ mask (torch.Tensor): output_mask
1144
+ shape: (batch_size, 1, mel_timesteps)
1145
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
1146
+ shape: (batch_size, spk_emb_dim)
1147
+ cond: Not used but kept for future purposes
1148
+ """
1149
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
1150
+
1151
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
1152
+ # Or in future might add like a return_all_steps flag
1153
+ sol = []
1154
+
1155
+ for step in range(1, len(t_span)):
1156
+ dphi_dt = estimator(x, mask, mu, t, spks, cond)
1157
+ # Classifier-Free Guidance inference introduced in VoiceBox
1158
+ if self.inference_cfg_rate > 0:
1159
+ cfg_dphi_dt = estimator(
1160
+ x, mask,
1161
+ torch.zeros_like(mu), t,
1162
+ torch.zeros_like(spks) if spks is not None else None,
1163
+ cond=cond
1164
+ )
1165
+ dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt -
1166
+ self.inference_cfg_rate * cfg_dphi_dt)
1167
+ x = x + dt * dphi_dt
1168
+ t = t + dt
1169
+ sol.append(x)
1170
+ if step < len(t_span) - 1:
1171
+ dt = t_span[step + 1] - t
1172
+
1173
+ return sol[-1]
1174
+
1175
+ def compute_loss(self, estimator, x1, mask, mu, spks=None, cond=None):
1176
+ """Computes diffusion loss
1177
+
1178
+ Args:
1179
+ x1 (torch.Tensor): Target
1180
+ shape: (batch_size, n_feats, mel_timesteps)
1181
+ mask (torch.Tensor): target mask
1182
+ shape: (batch_size, 1, mel_timesteps)
1183
+ mu (torch.Tensor): output of encoder
1184
+ shape: (batch_size, n_feats, mel_timesteps)
1185
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
1186
+ shape: (batch_size, spk_emb_dim)
1187
+
1188
+ Returns:
1189
+ loss: conditional flow matching loss
1190
+ y: conditional flow
1191
+ shape: (batch_size, n_feats, mel_timesteps)
1192
+ """
1193
+ org_dtype = x1.dtype
1194
+
1195
+ b, _, t = mu.shape
1196
+ # random timestep
1197
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
1198
+ if self.t_scheduler == 'cosine':
1199
+ t = 1 - torch.cos(t * 0.5 * torch.pi)
1200
+ # sample noise p(x_0)
1201
+ z = torch.randn_like(x1)
1202
+
1203
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
1204
+ u = x1 - (1 - self.sigma_min) * z
1205
+
1206
+ # during training, we randomly drop condition to trade off mode coverage and sample fidelity
1207
+ if self.training_cfg_rate > 0:
1208
+ cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
1209
+ mu = mu * cfg_mask.view(-1, 1, 1)
1210
+ if spks is not None:
1211
+ spks = spks * cfg_mask.view(-1, 1)
1212
+ if cond is not None:
1213
+ cond = cond * cfg_mask.view(-1, 1, 1)
1214
+
1215
+ pred = estimator(y, mask, mu, t.squeeze(), spks, cond)
1216
+ pred = pred.float()
1217
+ u = u.float()
1218
+ loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
1219
+ loss = loss.to(org_dtype)
1220
+ return loss, y
1221
+
1222
+
1223
+ class SinusoidalPosEmb(torch.nn.Module):
1224
+ def __init__(self, dim):
1225
+ super().__init__()
1226
+ self.dim = dim
1227
+ assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
1228
+
1229
+ def forward(self, x, scale=1000):
1230
+ if x.ndim < 1:
1231
+ x = x.unsqueeze(0)
1232
+ device = x.device
1233
+ half_dim = self.dim // 2
1234
+ emb = math.log(10000) / (half_dim - 1)
1235
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
1236
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
1237
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
1238
+ return emb
1239
+
1240
+
1241
+ class Downsample1D(nn.Module):
1242
+ def __init__(self, dim):
1243
+ super().__init__()
1244
+ self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1)
1245
+
1246
+ def forward(self, x):
1247
+ return self.conv(x)
1248
+
1249
+
1250
+ class TimestepEmbedding(nn.Module):
1251
+ def __init__(
1252
+ self,
1253
+ in_channels: int,
1254
+ time_embed_dim: int,
1255
+ act_fn: str = "silu",
1256
+ out_dim: int = None,
1257
+ post_act_fn: Optional[str] = None,
1258
+ cond_proj_dim=None,
1259
+ ):
1260
+ super().__init__()
1261
+
1262
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim)
1263
+
1264
+ if cond_proj_dim is not None:
1265
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
1266
+ else:
1267
+ self.cond_proj = None
1268
+
1269
+ self.act = get_activation(act_fn)
1270
+
1271
+ if out_dim is not None:
1272
+ time_embed_dim_out = out_dim
1273
+ else:
1274
+ time_embed_dim_out = time_embed_dim
1275
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
1276
+
1277
+ if post_act_fn is None:
1278
+ self.post_act = None
1279
+ else:
1280
+ self.post_act = get_activation(post_act_fn)
1281
+
1282
+ def forward(self, sample, condition=None):
1283
+ if condition is not None:
1284
+ sample = sample + self.cond_proj(condition)
1285
+ sample = self.linear_1(sample)
1286
+
1287
+ if self.act is not None:
1288
+ sample = self.act(sample)
1289
+
1290
+ sample = self.linear_2(sample)
1291
+
1292
+ if self.post_act is not None:
1293
+ sample = self.post_act(sample)
1294
+ return sample
1295
+
1296
+
1297
+ class Upsample1D(nn.Module):
1298
+ """A 1D upsampling layer with an optional convolution.
1299
+
1300
+ Parameters:
1301
+ channels (`int`):
1302
+ number of channels in the inputs and outputs.
1303
+ use_conv (`bool`, default `False`):
1304
+ option to use a convolution.
1305
+ use_conv_transpose (`bool`, default `False`):
1306
+ option to use a convolution transpose.
1307
+ out_channels (`int`, optional):
1308
+ number of output channels. Defaults to `channels`.
1309
+ """
1310
+
1311
+ def __init__(
1312
+ self,
1313
+ channels,
1314
+ use_conv=False,
1315
+ use_conv_transpose=True,
1316
+ out_channels=None,
1317
+ name="conv",
1318
+ ):
1319
+ super().__init__()
1320
+ self.channels = channels
1321
+ self.out_channels = out_channels or channels
1322
+ self.use_conv = use_conv
1323
+ self.use_conv_transpose = use_conv_transpose
1324
+ self.name = name
1325
+
1326
+ self.conv = None
1327
+ if use_conv_transpose:
1328
+ self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
1329
+ elif use_conv:
1330
+ self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
1331
+
1332
+ def forward(self, inputs):
1333
+ assert inputs.shape[1] == self.channels
1334
+ if self.use_conv_transpose:
1335
+ return self.conv(inputs)
1336
+
1337
+ outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
1338
+
1339
+ if self.use_conv:
1340
+ outputs = self.conv(outputs)
1341
+
1342
+ return outputs
1343
+
1344
+
1345
+ class RMSNorm(nn.Module):
1346
+ def __init__(self, hidden_size, eps=1e-6):
1347
+ """
1348
+ RMSNorm is equivalent to T5LayerNorm
1349
+ """
1350
+ super().__init__()
1351
+ self.weight = nn.Parameter(torch.ones(hidden_size))
1352
+ self.variance_epsilon = eps
1353
+
1354
+ def forward(self, hidden_states):
1355
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
1356
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
1357
+
1358
+ # convert into half-precision if necessary
1359
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
1360
+ hidden_states = hidden_states.to(self.weight.dtype)
1361
+
1362
+ return self.weight * hidden_states
1363
+
1364
+
1365
+ class OmniWhisperAttention(nn.Module):
1366
+ def __init__(self, embed_dim, num_heads, causal=False):
1367
+ super().__init__()
1368
+ self.embed_dim = embed_dim
1369
+ self.num_heads = num_heads
1370
+ self.head_dim = embed_dim // num_heads
1371
+
1372
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
1373
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
1374
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
1375
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
1376
+
1377
+ self.causal = causal
1378
+
1379
+ def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor):
1380
+ bsz, _ = hidden_states.size()
1381
+
1382
+ query_states = self.q_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
1383
+ key_states = self.k_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
1384
+ value_states = self.v_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
1385
+
1386
+ cu_len = F.pad(torch.cumsum(seq_len, dim=0), (1, 0), "constant", 0).to(torch.int32)
1387
+ max_seqlen = torch.max(seq_len).to(torch.int32).detach()
1388
+ attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_len, cu_len, max_seqlen, max_seqlen, causal=self.causal) # (bsz * qlen, nheads, headdim)
1389
+ attn_output = attn_output.reshape(bsz, self.embed_dim)
1390
+ attn_output = self.out_proj(attn_output)
1391
+ return attn_output
1392
+
1393
+
1394
+ class OmniWhisperTransformerLayer(nn.Module):
1395
+ def __init__(
1396
+ self,
1397
+ act,
1398
+ d_model,
1399
+ encoder_attention_heads,
1400
+ encoder_ffn_dim,
1401
+ causal,
1402
+ ln_type="LayerNorm",
1403
+ ):
1404
+ super().__init__()
1405
+ self.embed_dim = d_model
1406
+ self.self_attn = OmniWhisperAttention(
1407
+ self.embed_dim, encoder_attention_heads, causal
1408
+ )
1409
+
1410
+ if ln_type == "LayerNorm":
1411
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
1412
+ elif ln_type == "RMSNorm":
1413
+ self.self_attn_layer_norm = RMSNorm(self.embed_dim)
1414
+ else:
1415
+ raise ValueError(f"Unknown ln_type: {ln_type}")
1416
+
1417
+ self.activation_fn = act
1418
+ self.fc1 = nn.Linear(self.embed_dim, encoder_ffn_dim)
1419
+ self.fc2 = nn.Linear(encoder_ffn_dim, self.embed_dim)
1420
+
1421
+ if ln_type == "LayerNorm":
1422
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
1423
+ elif ln_type == "RMSNorm":
1424
+ self.final_layer_norm = RMSNorm(self.embed_dim)
1425
+ else:
1426
+ raise ValueError(f"Unknown ln_type: {ln_type}")
1427
+
1428
+ def forward(
1429
+ self, hidden_states: torch.Tensor, seq_len: torch.Tensor
1430
+ ) -> torch.Tensor:
1431
+ residual = hidden_states
1432
+ hidden_states = self.self_attn_layer_norm(hidden_states)
1433
+ hidden_states = self.self_attn(hidden_states, seq_len)
1434
+ hidden_states = residual + hidden_states
1435
+ residual = hidden_states
1436
+ hidden_states = self.final_layer_norm(hidden_states)
1437
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
1438
+ hidden_states = self.fc2(hidden_states)
1439
+ hidden_states = residual + hidden_states
1440
+
1441
+ if (
1442
+ hidden_states.dtype == torch.float16
1443
+ or hidden_states.dtype == torch.bfloat16
1444
+ ) and (torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()):
1445
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
1446
+ hidden_states = torch.clamp(
1447
+ hidden_states, min=-clamp_value, max=clamp_value
1448
+ )
1449
+ return hidden_states
1450
+
1451
+
1452
+
1453
+ class LongcatNextAudioEncoder(nn.Module):
1454
+ def __init__(self, config):
1455
+ super().__init__()
1456
+ self.config = config
1457
+ self.max_source_positions = (config.max_audio_seconds * config.sampling_rate // config.hop_length) // config.stride_size
1458
+ self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
1459
+
1460
+ self.conv1 = nn.Conv1d(config.num_mel_bins, config.d_model, kernel_size=config.kernel_size, padding=1)
1461
+ self.conv2 = nn.Conv1d(config.d_model, config.d_model, kernel_size=config.kernel_size,
1462
+ stride=config.stride_size, padding=1)
1463
+ self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, config.d_model)) # 1500 * d
1464
+
1465
+ self.layers = nn.ModuleList([OmniWhisperTransformerLayer(
1466
+ ACT2FN[config.activation_function],
1467
+ config.d_model,
1468
+ config.encoder_attention_heads,
1469
+ config.encoder_ffn_dim,
1470
+ False) for _ in range(config.encoder_layers)])
1471
+ self.layer_norm = nn.LayerNorm(config.d_model)
1472
+
1473
+ def forward(
1474
+ self,
1475
+ input_features,
1476
+ output_length,
1477
+ ):
1478
+ input_features = input_features.to(self.conv1.weight.dtype)
1479
+ inputs_embeds = nn.functional.gelu(self.conv1(input_features)) # (bs, channels, frames)
1480
+ inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) # (bs, channels, frames // 2)
1481
+ inputs_embeds = inputs_embeds.permute(0, 2, 1) # (bs, frams, channels)
1482
+ bsz, tgt_len, _ = inputs_embeds.size()
1483
+ if tgt_len < self.positional_embedding.shape[0]:
1484
+ current_positional_embedding = self.positional_embedding[:tgt_len]
1485
+ else:
1486
+ current_positional_embedding = self.positional_embedding
1487
+ hidden_states = (inputs_embeds.to(torch.float32) + current_positional_embedding).to(inputs_embeds.dtype)
1488
+
1489
+ # packing hidden states
1490
+ attention_mask, unpacking_index = get_sequence_mask(hidden_states, output_length)
1491
+ hidden_states = torch.masked_select(hidden_states, attention_mask).view(torch.sum(output_length),
1492
+ self.config.d_model)
1493
+
1494
+ for idx, encoder_layer in enumerate(self.layers):
1495
+ hidden_states = encoder_layer(hidden_states, output_length)
1496
+ hidden_states = self.layer_norm(hidden_states)
1497
+ # unpacking
1498
+ hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(bsz, tgt_len, self.config.d_model)
1499
+ hidden_states = torch.where(attention_mask, hidden_states, 0)
1500
+ return hidden_states
1501
+
1502
+
1503
+ class CasualConvTranspose1d(nn.Module):
1504
+ def __init__(self, in_channels, out_channels, kernel_size, stride):
1505
+ super().__init__()
1506
+ self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride)
1507
+ self.norm = nn.GroupNorm(1, out_channels)
1508
+ self.in_channels = in_channels
1509
+ self.out_channels = out_channels
1510
+
1511
+ def forward(self, hidden_states, input_length, output_dim=None):
1512
+ kernel_size = self.conv.kernel_size[0]
1513
+ stride = self.conv.stride[0]
1514
+ bsz = input_length.shape[0]
1515
+
1516
+ if output_dim is None:
1517
+ output_dim = hidden_states.dim()
1518
+ if hidden_states.dim() <= 2: # unpack sequence to 3d
1519
+ sequence_mask, unpacking_index = get_sequence_mask(hidden_states, input_length)
1520
+ hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(bsz, torch.max(input_length),
1521
+ self.in_channels)
1522
+ hidden_states = torch.where(sequence_mask, hidden_states, 0) # 3d (bsz, max_input_len, d)
1523
+
1524
+ hidden_states = hidden_states.transpose(2, 1) # (N, L, C) -> (N, C, L)
1525
+ hidden_states = self.conv(hidden_states)
1526
+ hidden_states = self.norm(hidden_states)
1527
+ hidden_states = hidden_states.transpose(2, 1) # (N, C, L) -> (N, L, C)
1528
+
1529
+ casual_padding_right = max(0, kernel_size - stride)
1530
+ hidden_states = hidden_states[:, :hidden_states.shape[1] - casual_padding_right,
1531
+ :]
1532
+ output_length = (input_length - 1) * stride + kernel_size - casual_padding_right
1533
+ sequence_mask, _ = get_sequence_mask(hidden_states, output_length)
1534
+ if output_dim <= 2:
1535
+ hidden_states = torch.masked_select(hidden_states, sequence_mask).view(-1, self.out_channels)
1536
+ else:
1537
+ hidden_states = torch.where(sequence_mask, hidden_states, 0)
1538
+ hidden_states = hidden_states[:, :torch.max(output_length), :]
1539
+ return hidden_states, output_length
1540
+
1541
+
1542
+ class MelSpecRefineNet(nn.Module):
1543
+ """
1544
+ # post net, coarse to refined mel-spectrogram frames
1545
+ # ref1: Autoregressive Speech Synthesis without Vector Quantization
1546
+ # ref2: CosyVoice length_regulator.py
1547
+ # ref3: Neural Speech Synthesis with Transformer Network https://github.com/soobinseo/Transformer-TTS/blob/master/network.py
1548
+ """
1549
+
1550
+ def __init__(self, encoder_config, vocoder_config):
1551
+ super().__init__()
1552
+ self.encoder_config = encoder_config
1553
+ self.vocoder_config = vocoder_config
1554
+
1555
+ layers = nn.ModuleList([])
1556
+ in_channels = self.vocoder_config.num_mel_bins
1557
+ for i, out_channels in enumerate(self.vocoder_config.channels[:-1]):
1558
+ module = nn.Conv1d(in_channels, out_channels, 5, 1, 2) # cosyvoice kernel=3, stride=1, pad=1
1559
+ in_channels = out_channels
1560
+ norm = nn.GroupNorm(1, out_channels)
1561
+ act = nn.Mish()
1562
+ layers.extend([module, norm, act])
1563
+ layers.append(nn.Conv1d(in_channels, self.vocoder_config.num_mel_bins, 1, 1)) # projector
1564
+ self.layers = nn.Sequential(*layers)
1565
+
1566
+ def compute_output_length(self, input_length):
1567
+ output_length = input_length.to(
1568
+ torch.float32) * self.encoder_config.hop_length / self.encoder_config.sampling_rate
1569
+ output_length = output_length * self.vocoder_config.sampling_rate / self.vocoder_config.hop_length
1570
+ return output_length.to(torch.int64)
1571
+
1572
+ def forward(self, coarse_mel, input_length, output_length=None):
1573
+ bsz, _, d = coarse_mel.shape
1574
+ assert (d == self.vocoder_config.num_mel_bins)
1575
+ if output_length is None or not self.training:
1576
+ output_length = self.compute_output_length(input_length)
1577
+ coarse_mel, default_dtype = coarse_mel[:, :torch.max(input_length), :], coarse_mel.dtype
1578
+ coarse_mel = F.interpolate(coarse_mel.to(torch.float32).transpose(1, 2).contiguous(), size=output_length.max(),
1579
+ mode='nearest').to(default_dtype)
1580
+ refined_mel = self.layers(coarse_mel).transpose(1, 2).contiguous() # (bs, t, d)
1581
+ coarse_mel = coarse_mel.transpose(1, 2) # (bs, max(output_length), d)
1582
+ refined_mel += coarse_mel # residual conntection
1583
+ sequence_mask, _ = get_sequence_mask(refined_mel, output_length)
1584
+ coarse_mel = torch.where(sequence_mask, coarse_mel, 0)
1585
+ refined_mel = torch.where(sequence_mask, refined_mel, 0)
1586
+ return refined_mel, coarse_mel, output_length
1587
+
1588
+
1589
+ @dataclass
1590
+ class OmniAudioDecoderOutput(ModelOutput):
1591
+ refined_mel: Optional[torch.FloatTensor] = None
1592
+ coarse_mel: Optional[torch.FloatTensor] = None
1593
+ mel_length: Optional[torch.Tensor] = None
1594
+ hidden_states_before_dconv2: Optional[torch.FloatTensor] = None
1595
+ output_length_before_dconv2: Optional[torch.Tensor] = None
1596
+
1597
+
1598
+ class LongcatNextAudioDecoder(nn.Module):
1599
+ def __init__(self, config):
1600
+ super().__init__()
1601
+ self.config = config
1602
+ self.vocoder_config = config.vocoder_config
1603
+ self.max_source_positions = self.config.max_audio_seconds * self.config.sampling_rate // self.config.hop_length
1604
+
1605
+ self.dconv1 = CasualConvTranspose1d(
1606
+ self.config.d_model,
1607
+ self.config.d_model,
1608
+ self.config.decoder_kernel_size,
1609
+ self.config.avg_pooler,
1610
+ )
1611
+ self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, self.config.d_model))
1612
+ # causal transformer layers
1613
+ self.layers = nn.ModuleList(
1614
+ [OmniWhisperTransformerLayer(
1615
+ ACT2FN[self.config.activation_function],
1616
+ self.config.d_model,
1617
+ self.config.decoder_attention_heads,
1618
+ self.config.decoder_ffn_dim,
1619
+ True # causal
1620
+ ) for _ in range(self.config.decoder_layers)
1621
+ ])
1622
+ self.layer_norm = nn.LayerNorm(self.config.d_model)
1623
+ self.dconv2 = CasualConvTranspose1d(
1624
+ self.config.d_model,
1625
+ self.vocoder_config.num_mel_bins,
1626
+ self.config.decoder_kernel_size,
1627
+ self.config.decoder_stride_size
1628
+ )
1629
+ self.post_net = MelSpecRefineNet(self.config, self.vocoder_config)
1630
+ self.gradient_checkpointing = False
1631
+
1632
+ def forward(self,
1633
+ audio_embed,
1634
+ input_length,
1635
+ mel_labels=None,
1636
+ mel_labels_length=None,
1637
+ ):
1638
+ assert (audio_embed.shape[-1] == self.config.d_model)
1639
+ audio_embed = audio_embed.to(self.layer_norm.weight) # device and type
1640
+ audio_embed, output_length = self.dconv1(audio_embed, input_length, output_dim=3) # (b, l*2, d_model)
1641
+ _, tgt_len, _ = audio_embed.size()
1642
+ if tgt_len < self.positional_embedding.shape[0]:
1643
+ current_positional_embedding = self.positional_embedding[:tgt_len]
1644
+ else:
1645
+ current_positional_embedding = self.positional_embedding
1646
+ hidden_states = (audio_embed.to(torch.float32) + current_positional_embedding).to(audio_embed.dtype)
1647
+
1648
+ # packing hidden states
1649
+ attention_mask, _ = get_sequence_mask(hidden_states, output_length)
1650
+ hidden_states = torch.masked_select(hidden_states, attention_mask).view(torch.sum(output_length), self.config.d_model)
1651
+
1652
+ for idx, encoder_layer in enumerate(self.layers):
1653
+ hidden_states = encoder_layer(hidden_states, output_length)
1654
+
1655
+ hidden_states = self.layer_norm(hidden_states)
1656
+ hidden_states_before_dconv2 = hidden_states
1657
+ output_length_before_dconv2 = output_length
1658
+
1659
+ coarse_mel, output_length = self.dconv2(hidden_states, output_length, output_dim=3)
1660
+ refined_mel, coarse_mel, mel_labels_length = self.post_net(coarse_mel, output_length, mel_labels_length)
1661
+
1662
+ return OmniAudioDecoderOutput(
1663
+ refined_mel=refined_mel,
1664
+ coarse_mel=coarse_mel,
1665
+ mel_length=mel_labels_length,
1666
+ hidden_states_before_dconv2=hidden_states_before_dconv2,
1667
+ output_length_before_dconv2=output_length_before_dconv2,
1668
+ )
1669
+
1670
+
1671
+ class LongcatNextAudioVQBridger(nn.Module):
1672
+ def __init__(self, config):
1673
+ super().__init__()
1674
+ self.config = config
1675
+ self.gradient_checkpointing = False
1676
+ self.intermediate_dim = self.config.d_model * self.config.avg_pooler
1677
+ self.gate_proj = nn.Conv1d(self.config.d_model, self.intermediate_dim, self.config.avg_pooler, self.config.avg_pooler, bias=False)
1678
+ self.up_proj = nn.Conv1d(self.config.d_model, self.intermediate_dim, self.config.avg_pooler, self.config.avg_pooler, bias=False)
1679
+
1680
+ self.down_proj = nn.Linear(self.intermediate_dim, self.intermediate_dim, bias=False)
1681
+ self.act_fn = ACT2FN['silu']
1682
+ self.layer_norm = nn.LayerNorm(self.intermediate_dim)
1683
+ self.proj_decoder = nn.Linear(self.intermediate_dim, self.config.d_model)
1684
+
1685
+ self.vq_list = nn.ModuleList([])
1686
+ for idx, codebook_size in enumerate(self.config.vq_config.codebook_sizes):
1687
+ vq_config = copy.deepcopy(self.config.vq_config)
1688
+ vq_config.dim = self.intermediate_dim
1689
+ vq_config.codebook_size = codebook_size
1690
+ self.vq_list.append(VectorQuantize(vq_config))
1691
+
1692
+ def rvq_op(self, inputs, output_length):
1693
+ def rvq_layer_op(vq_layer, residual_encoding, output_length):
1694
+ q_v_i, code_ids_i = vq_layer(residual_encoding, output_length)
1695
+ residual_encoding = residual_encoding.float() - q_v_i.float()
1696
+ residual_encoding = residual_encoding.to(inputs.dtype)
1697
+ return residual_encoding, code_ids_i
1698
+
1699
+ cmt_loss, residual_encoding = 0, inputs
1700
+ code_ids_list = []
1701
+ for i, vq_layer in enumerate(self.vq_list):
1702
+ residual_encoding, code_ids_i = rvq_layer_op(vq_layer, residual_encoding, output_length)
1703
+ code_ids_list.append(code_ids_i)
1704
+ return torch.stack(code_ids_list, -1)
1705
+
1706
+ def forward(self, x, output_length):
1707
+ batch_size, _, _ = x.shape
1708
+ output_length = output_length.to(x.device)
1709
+
1710
+ if x.shape[1] % self.config.avg_pooler != 0:
1711
+ x = F.pad(x, (0, 0, 0, self.config.avg_pooler - x.shape[1] % self.config.avg_pooler), "constant", 0)
1712
+ xt = x.permute(0, 2, 1)
1713
+ g = self.gate_proj(xt).permute(0, 2, 1) # (bs, sl//poolersizre+1, d*2)
1714
+ u = self.up_proj(xt).permute(0, 2, 1)
1715
+ x = x.reshape(batch_size, -1, self.intermediate_dim) # (bs, sl//poolersizre+1, d*2)
1716
+
1717
+ c = self.down_proj(self.act_fn(g) * u)
1718
+ res = self.layer_norm(c + x)
1719
+ valid_mask, _ = get_sequence_mask(res, output_length)
1720
+ code_ids = self.rvq_op(res, output_length)
1721
+ code_ids = torch.masked_select(code_ids, valid_mask).reshape(-1, len(self.vq_list)) # (sum(valid_sequence_length), vq_num)
1722
+ return code_ids
1723
+
1724
+ @torch.no_grad()
1725
+ def decode(self, code_ids):
1726
+ vq_num = code_ids.shape[-1]
1727
+ res = sum(self.vq_list[i].get_output_from_indices(code_ids[:, i]).float() for i in range(vq_num-1,-1,-1)).to(self.proj_decoder.weight)
1728
+ decoder_emb = self.proj_decoder(res.to(self.proj_decoder.weight))
1729
+ return decoder_emb
1730
+
1731
+ @torch.no_grad()
1732
+ def recover(self, code_ids):
1733
+ vq_num = code_ids.shape[-1]
1734
+ res = sum(self.vq_list[i].get_output_from_indices(code_ids[:, i]).float() for i in range(vq_num-1,-1,-1)).to(self.proj_decoder.weight)
1735
+ return res
1736
+
1737
+
1738
+ class FlowmatchingPrenet(nn.Module):
1739
+ def __init__(
1740
+ self,
1741
+ input_feat_dim,
1742
+ out_feat_dim,
1743
+ d_model,
1744
+ attention_heads,
1745
+ ffn_dim,
1746
+ nlayers,
1747
+ activation_function,
1748
+ max_source_positions,
1749
+ target_mel_length_scale_ratio,
1750
+ ):
1751
+ super().__init__()
1752
+
1753
+ self.d_model = d_model
1754
+ self.target_mel_length_scale_ratio = target_mel_length_scale_ratio
1755
+ self.gradient_checkpointing = False
1756
+
1757
+ self.register_buffer(
1758
+ "positional_embedding", sinusoids(max_source_positions, d_model)
1759
+ )
1760
+
1761
+ self.in_mlp = nn.Sequential(
1762
+ nn.Linear(input_feat_dim, d_model * 4),
1763
+ nn.SiLU(),
1764
+ nn.Linear(d_model * 4, d_model),
1765
+ )
1766
+
1767
+ self.transformer_layers = nn.ModuleList(
1768
+ [
1769
+ OmniWhisperTransformerLayer(
1770
+ act=ACT2FN[activation_function],
1771
+ d_model=d_model,
1772
+ encoder_attention_heads=attention_heads,
1773
+ encoder_ffn_dim=ffn_dim,
1774
+ causal=True, # causal
1775
+ ln_type="RMSNorm",
1776
+ )
1777
+ for _ in range(nlayers)
1778
+ ]
1779
+ )
1780
+
1781
+ self.final_norm = RMSNorm(self.d_model)
1782
+ self.out_proj = nn.Linear(d_model, out_feat_dim, bias=False)
1783
+
1784
+ def compute_output_length(self, input_length):
1785
+ output_length = input_length.float() * self.target_mel_length_scale_ratio
1786
+ return output_length.to(torch.int64)
1787
+
1788
+ def forward(self, input_feat, input_length, output_length=None):
1789
+ """
1790
+ Args:
1791
+ input_feat: [B, T, input_feat_dim]
1792
+ input_length: [B]
1793
+ output_length: [B]
1794
+
1795
+ """
1796
+ if output_length is None or not self.training:
1797
+ output_length = self.compute_output_length(input_length)
1798
+
1799
+ input_feat = input_feat[:, : input_length.max(), :] # [B, T, D]
1800
+ orig_dtype = input_feat.dtype
1801
+
1802
+ input_feat = F.interpolate(
1803
+ input=input_feat.to(torch.float32).transpose(1, 2).contiguous(),
1804
+ size=output_length.max(),
1805
+ mode="nearest",
1806
+ ).to(orig_dtype)
1807
+ input_feat = input_feat.transpose(1, 2).contiguous() # [B, T, D]
1808
+ hidden_states = self.in_mlp(input_feat)
1809
+
1810
+ # packing hidden states
1811
+ bsz, tgt_len, d_model = hidden_states.shape
1812
+ attention_mask, unpacking_index = get_sequence_mask(
1813
+ hidden_states, output_length
1814
+ )
1815
+ hidden_states = torch.masked_select(hidden_states, attention_mask).view(
1816
+ torch.sum(output_length), self.d_model
1817
+ )
1818
+
1819
+ for idx, encoder_layer in enumerate(self.transformer_layers):
1820
+ hidden_states = encoder_layer(hidden_states, output_length)
1821
+
1822
+ # unpacking
1823
+ hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(
1824
+ bsz, tgt_len, d_model
1825
+ )
1826
+ hidden_states = torch.where(attention_mask, hidden_states, 0)
1827
+
1828
+ hidden_states = self.final_norm(hidden_states)
1829
+ output = self.out_proj(hidden_states)
1830
+ return output, output_length
1831
+
1832
+
1833
+ @dataclass
1834
+ class OmniAudioFlowMatchingDecoderOutput(ModelOutput):
1835
+ flow_matching_mel: Optional[torch.FloatTensor] = None
1836
+ flow_matching_mel_lengths: Optional[torch.FloatTensor] = None
1837
+
1838
+
1839
+ class LongcatNextAudioFlowMatchingDecoder(nn.Module):
1840
+ def __init__(self, config):
1841
+ super().__init__()
1842
+ self.config = config.flow_matching_config
1843
+ self.in_channels = self.config.in_channels
1844
+ self.spk_emb_dim = self.config.spk_emb_dim
1845
+ self.diffusion_steps = self.config.diffusion_steps
1846
+ self.cal_mel_mae = self.config.cal_mel_mae
1847
+ self.forward_step = -1
1848
+
1849
+ self.prenet = FlowmatchingPrenet(
1850
+ input_feat_dim=self.config.prenet_in_dim,
1851
+ out_feat_dim=self.config.prenet_out_dim,
1852
+ d_model=self.config.prenet_d_model,
1853
+ attention_heads=self.config.prenet_attention_heads,
1854
+ ffn_dim=self.config.prenet_ffn_dim,
1855
+ nlayers=self.config.prenet_nlayers,
1856
+ activation_function=self.config.prenet_activation_function,
1857
+ max_source_positions=self.config.prenet_max_source_positions,
1858
+ target_mel_length_scale_ratio=self.config.prenet_target_mel_length_scale_ratio,
1859
+ )
1860
+
1861
+ self.conditional_decoder = ConditionalDecoder(
1862
+ in_channels=self.in_channels * 2 + self.spk_emb_dim,
1863
+ out_channels=self.in_channels,
1864
+ causal=True,
1865
+ channels=self.config.channels,
1866
+ dropout=self.config.dropout,
1867
+ attention_head_dim=self.config.attention_head_dim,
1868
+ n_blocks=self.config.n_blocks,
1869
+ num_mid_blocks=self.config.num_mid_blocks,
1870
+ num_heads=self.config.num_heads,
1871
+ act_fn=self.config.act_fn,
1872
+ )
1873
+
1874
+ self.cfm = ConditionalCFM(
1875
+ in_channels=self.in_channels,
1876
+ cfm_params=self.config.cfm_params,
1877
+ n_spks=0,
1878
+ spk_emb_dim=self.spk_emb_dim,
1879
+ )
1880
+
1881
+
1882
+ def unpack_hidden_states(self, hidden_states, output_length):
1883
+ unpacked = unpack_hidden_states(hidden_states, output_length)
1884
+ return unpacked, output_length
1885
+
1886
+ def forward(
1887
+ self, refined_mel, input_length, mel_labels=None, mel_labels_length=None
1888
+ ):
1889
+ """
1890
+ :param refined_mel: [bs, max_input_len, mel_bin]
1891
+ :param input_length: [batch_size]
1892
+ :param refined_mel: [bs, mel_bin, max_input_len]
1893
+ :return:
1894
+ """
1895
+ self.forward_step += 1
1896
+
1897
+ orig_dtype = refined_mel.dtype
1898
+ prenet_mae_metric = torch.tensor(0.0).to(refined_mel.device)
1899
+ prenet_regression_loss = torch.tensor(0.0).to(refined_mel.device)
1900
+
1901
+ if self.prenet is not None:
1902
+ refined_mel = refined_mel[:, : torch.max(input_length), :]
1903
+ if mel_labels_length is None:
1904
+ mel_labels_length = self.prenet.compute_output_length(input_length)
1905
+ refined_mel, input_length = self.prenet(
1906
+ refined_mel, input_length, mel_labels_length
1907
+ )
1908
+
1909
+ float_dtype = refined_mel.dtype
1910
+ refined_mel = refined_mel.float()
1911
+ input_length = input_length.long()
1912
+
1913
+ refined_mel = refined_mel[:, : torch.max(input_length), :]
1914
+ sequence_mask, unpacking_index = get_sequence_mask(refined_mel, input_length)
1915
+ refined_mel = refined_mel.transpose(1, 2) # (bs, mel_bin, max_input_len)
1916
+ sequence_mask = sequence_mask.transpose(2, 1) # (bs, 1, sl)
1917
+
1918
+ fm_mel = self.cfm.forward(
1919
+ estimator=self.conditional_decoder,
1920
+ mu=refined_mel.to(float_dtype),
1921
+ mask=sequence_mask.float(),
1922
+ n_timesteps=self.diffusion_steps,
1923
+ )
1924
+ return OmniAudioFlowMatchingDecoderOutput(
1925
+ flow_matching_mel=fm_mel.transpose(1, 2),
1926
+ flow_matching_mel_lengths=mel_labels_length,
1927
+ )
1928
+
1929
+
1930
+ @torch.no_grad()
1931
+ def decode_wave_vocoder2(response, vocoder, audio_tokenizer):
1932
+ response_len = (response[:,:,0] == audio_tokenizer.config.audio_config.vq_config.codebook_sizes[0]).long().argmax(dim=1)
1933
+ valid_response_list = [response[i, :response_len[i], :] for i in range(response.shape[0]) if int(response_len[i])>0]
1934
+
1935
+ if len(valid_response_list)==0:
1936
+ return []
1937
+ flatten_response = torch.cat(valid_response_list, dim=0) if len(valid_response_list)>1 else valid_response_list[0]
1938
+ valid_response_len = response_len[response_len>0]
1939
+ ret = audio_tokenizer.decode(flatten_response.view(-1,response.shape[-1]),
1940
+ bridge_length=valid_response_len)
1941
+ batch_size = response.shape[0]
1942
+ valid_start = 0
1943
+ r = []
1944
+ for i in range(batch_size):
1945
+ if response_len[i]==0:
1946
+ r.append(None)
1947
+ continue
1948
+ if isinstance(ret, torch.Tensor):
1949
+ r.append(ret[valid_start:valid_start+1])
1950
+ valid_start+=1
1951
+ continue
1952
+ decode_wave = vocoder.decode(ret.flow_matching_mel[valid_start ][:ret.flow_matching_mel_lengths[valid_start ], :].transpose(0, 1).to(torch.float32).unsqueeze(0))
1953
+ r.append(decode_wave.cpu())
1954
+ valid_start+=1
1955
+ return r
1956
+
1957
+
1958
+ @torch.no_grad()
1959
+ def decode_save_concat2(response_list, vocoder, model, path, sampling_rate=16000, wave_concat_overlap=800):
1960
+ wave_list = []
1961
+ for response in response_list:
1962
+ wave_list.extend([wave_i for wave_i in decode_wave_vocoder2(response, vocoder, model) if wave_i is not None])
1963
+ new_wave_list = [wave_list[0]]
1964
+ for w in wave_list[1:]:
1965
+ if new_wave_list[-1].shape[1] > wave_concat_overlap and w.shape[1] > wave_concat_overlap:
1966
+ new_wave_list.append((new_wave_list[-1][:, -wave_concat_overlap:] * torch.linspace(1.0, 0.0, wave_concat_overlap, device=new_wave_list[-1].device)[None, :]
1967
+ + w[:, :wave_concat_overlap] * torch.linspace(0.0, 1.0, wave_concat_overlap, device=new_wave_list[-1].device)[None, :]))
1968
+ new_wave_list.append(w)
1969
+ full_wave = torch.cat(new_wave_list, dim=1) if len(new_wave_list) > 1 else new_wave_list[0]
1970
+ torchaudio.save(path, full_wave, sampling_rate)
1971
+
1972
+
1973
+ class LongcatNextAudioTokenizer(nn.Module):
1974
+
1975
+ def __init__(self, config):
1976
+ super().__init__()
1977
+ self.config = config
1978
+ self.audio_model = LongcatNextAudioEncoder(config.audio_config)
1979
+ self.audio_bridge_model = LongcatNextAudioVQBridger(config.audio_config)
1980
+ self.audio_decoder = LongcatNextAudioDecoder(config.audio_config)
1981
+ self.audio_flow_matching_decoder = LongcatNextAudioFlowMatchingDecoder(config.audio_config)
1982
+ self.cosy24kvocoder = None
1983
+
1984
+ @torch.no_grad()
1985
+ def encode(self, x, encoder_length: Optional[torch.Tensor] = None, bridge_length: Optional[torch.Tensor] = None):
1986
+ audio_emb = self.audio_model(x, encoder_length)
1987
+ audio_tokens = self.audio_bridge_model(audio_emb, bridge_length)
1988
+ return audio_tokens
1989
+
1990
+ @torch.no_grad()
1991
+ def decode(self, audio_ids, bridge_length: Optional[torch.Tensor] = None):
1992
+ audio_emb = self.audio_bridge_model.decode(audio_ids)
1993
+ audio_dec = self.audio_decoder(
1994
+ audio_emb.to(next(self.audio_decoder.parameters())), bridge_length
1995
+ )
1996
+ if self.config.audio_config.flow_matching_config.use_hidden_states_before_dconv2:
1997
+ hidden_states, hidden_states_length = (
1998
+ self.audio_flow_matching_decoder.unpack_hidden_states(
1999
+ audio_dec.hidden_states_before_dconv2,
2000
+ audio_dec.output_length_before_dconv2,
2001
+ )
2002
+ )
2003
+ audio_flow_matching_decoder_ret = self.audio_flow_matching_decoder(
2004
+ hidden_states, hidden_states_length
2005
+ )
2006
+ else:
2007
+ audio_flow_matching_decoder_ret = self.audio_flow_matching_decoder(
2008
+ audio_dec.refined_mel, audio_dec.mel_length
2009
+ )
2010
+ return audio_flow_matching_decoder_ret
2011
+
2012
+ @torch.no_grad()
2013
+ def lazy_decode_and_save(self, audio_ids, sampling_rate, wave_concat_overlap, save_path):
2014
+ if self.cosy24kvocoder is None:
2015
+ print("lazy load cosy24kvocoder ...")
2016
+ device = next(self.parameters()).device
2017
+ self.cosy24kvocoder = Cosy24kVocoder.from_pretrained(self.config.audio_config.cosy24kvocoder_config.weight_path).to(device)
2018
+
2019
+ if audio_ids[-1, 0] != self.config.audio_config.vq_config.codebook_sizes[0]: # exceed max_new_tokens
2020
+ audio_ids = F.pad(audio_ids, (0, 0, 0, 1), value=self.config.audio_config.vq_config.codebook_sizes[0])
2021
+
2022
+ audio_end_pos = [-1] + (audio_ids[:, 0] == self.config.audio_config.vq_config.codebook_sizes[0]).nonzero().view(-1).tolist()
2023
+
2024
+ audio_ids_chunk = []
2025
+ for i in range(len(audio_end_pos) - 1):
2026
+ start = audio_end_pos[i] + 1
2027
+ end = audio_end_pos[i+1] + 1
2028
+ audio_ids_chunk.append(audio_ids[start:end].unsqueeze(0))
2029
+
2030
+ audio_ids = audio_ids_chunk
2031
+
2032
+ decode_save_concat2(
2033
+ response_list=audio_ids,
2034
+ vocoder=self.cosy24kvocoder,
2035
+ model=self,
2036
+ path=save_path,
2037
+ sampling_rate=sampling_rate,
2038
+ wave_concat_overlap=wave_concat_overlap,
2039
+ )
modular_longcat_next_visual.py ADDED
@@ -0,0 +1,1077 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Iterable, Optional, Tuple
2
+
3
+ import numpy as np
4
+ from safetensors.torch import load_file
5
+ import torch
6
+ import torch.utils.checkpoint
7
+ from torch import nn
8
+ from torch.amp import autocast
9
+ from torch.nn import functional as F
10
+
11
+ from einops import rearrange
12
+ from flash_attn import flash_attn_varlen_func
13
+
14
+ from transformers.activations import ACT2FN
15
+ from transformers.modeling_outputs import BaseModelOutput
16
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
17
+ Qwen2RMSNorm,
18
+ Qwen2_5_VisionTransformerPretrainedModel,
19
+ )
20
+ from transformers.utils import logging
21
+
22
+ from .image_refiner import (
23
+ ImageRefinerContainer,
24
+ RefinerImageProcessor,
25
+ RefinerPipeline,
26
+ de_transform,
27
+ tensor2pil,
28
+ )
29
+ from .refiner_modules import FlowMatchEulerDiscreteScheduler
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+
34
+ def uniform_init(*shape):
35
+ t = torch.zeros(shape)
36
+ nn.init.kaiming_uniform_(t)
37
+ return t
38
+
39
+ class VQEmbedding(nn.Module):
40
+ """VQ embedding module with ema update."""
41
+
42
+ def __init__(self, n_embed, embed_dim, ema=True, decay=0.99, restart_unused_codes=True, eps=1e-5, init_std=0.02):
43
+ super().__init__()
44
+
45
+ self.ema = ema
46
+ self.decay = decay
47
+ self.eps = eps
48
+ self.restart_unused_codes = restart_unused_codes
49
+ self.n_embed = n_embed
50
+ self.init_std = init_std
51
+
52
+ assert self.ema
53
+ embed = uniform_init(n_embed + 1, embed_dim).to(torch.float32)
54
+ self.embed = nn.Parameter(embed)
55
+ self.embed_ema = nn.Parameter(embed[:-1, :].clone())
56
+ self.cluster_size_ema = nn.Parameter(torch.ones(n_embed))
57
+ del embed
58
+ _ = [p.requires_grad_(False) for p in self.parameters()]
59
+
60
+ @torch.no_grad()
61
+ def compute_distances(self, inputs):
62
+ codebook_t = self.embed[:-1, :].t()
63
+
64
+ (embed_dim, _) = codebook_t.shape
65
+ inputs_shape = inputs.shape
66
+ assert inputs_shape[-1] == embed_dim
67
+
68
+ inputs_flat = inputs.reshape(-1, embed_dim)
69
+
70
+ inputs_norm_sq = inputs_flat.pow(2.).sum(dim=1, keepdim=True)
71
+ codebook_t_norm_sq = codebook_t.pow(2.).sum(dim=0, keepdim=True)
72
+ distances = torch.addmm(
73
+ inputs_norm_sq + codebook_t_norm_sq,
74
+ inputs_flat,
75
+ codebook_t,
76
+ alpha=-2.0,
77
+ )
78
+ distances = distances.reshape(*inputs_shape[:-1], -1) # [B, h, w, n_embed or n_embed+1]
79
+ return distances
80
+
81
+ @torch.no_grad()
82
+ def find_nearest_embedding(self, inputs):
83
+ distances = self.compute_distances(inputs) # [B, h, w, n_embed or n_embed+1]
84
+ embed_idxs = distances.argmin(dim=-1) # use padding index or not
85
+
86
+ return embed_idxs
87
+
88
+ @autocast('cuda', enabled=True, dtype=torch.float32)
89
+ @torch.no_grad()
90
+ def forward(self, inputs):
91
+ if inputs.dtype != torch.float32:
92
+ inputs = inputs.to(torch.float32)
93
+ embed_idxs = self.find_nearest_embedding(inputs)
94
+ embeds = self.embed[embed_idxs]
95
+ return embeds, embed_idxs
96
+
97
+
98
+ class RQBottleneck(nn.Module):
99
+ """
100
+ Quantization bottleneck via Residual Quantization.
101
+
102
+ Arguments:
103
+ latent_shape (Tuple[int, int, int]): the shape of latents, denoted (H, W, D)
104
+ code_shape (Tuple[int, int, int]): the shape of codes, denoted (h, w, d)
105
+ n_embed (int, List, or Tuple): the number of embeddings (i.e., the size of codebook)
106
+ If isinstance(n_embed, int), the sizes of all codebooks are same.
107
+ shared_codebook (bool): If True, codebooks are shared in all location. If False,
108
+ uses separate codebooks along the ``depth'' dimension. (default: False)
109
+ restart_unused_codes (bool): If True, it randomly assigns a feature vector in the curruent batch
110
+ as the new embedding of unused codes in training. (default: True)
111
+ """
112
+
113
+ def __init__(self,
114
+ latent_shape,
115
+ code_shape,
116
+ n_embed,
117
+ decay=0.99,
118
+ shared_codebook=False,
119
+ restart_unused_codes=True,
120
+ commitment_loss='cumsum'
121
+ ):
122
+ super().__init__()
123
+
124
+ if not len(code_shape) == len(latent_shape) == 3:
125
+ raise ValueError("incompatible code shape or latent shape")
126
+ if any([y % x != 0 for x, y in zip(code_shape[:2], latent_shape[:2])]):
127
+ raise ValueError("incompatible code shape or latent shape")
128
+
129
+ #residual quantization does not divide feature dims for quantization.
130
+ embed_dim = np.prod(latent_shape[:2]) // np.prod(code_shape[:2]) * latent_shape[2]
131
+
132
+ self.latent_shape = torch.Size(latent_shape)
133
+ self.code_shape = torch.Size(code_shape)
134
+ self.shape_divisor = torch.Size([latent_shape[i] // code_shape[i] for i in range(len(latent_shape))])
135
+
136
+ self.shared_codebook = shared_codebook
137
+ if self.shared_codebook:
138
+ if isinstance(n_embed, Iterable) or isinstance(decay, Iterable):
139
+ raise ValueError("Shared codebooks are incompatible \
140
+ with list types of momentums or sizes: Change it into int")
141
+
142
+ self.restart_unused_codes = restart_unused_codes
143
+ self.n_embed = n_embed if isinstance(n_embed, Iterable) else [n_embed for _ in range(self.code_shape[-1])]
144
+ self.decay = decay if isinstance(decay, Iterable) else [decay for _ in range(self.code_shape[-1])]
145
+ assert len(self.n_embed) == self.code_shape[-1]
146
+ assert len(self.decay) == self.code_shape[-1]
147
+
148
+ if self.shared_codebook:
149
+ codebook0 = VQEmbedding(self.n_embed[0],
150
+ embed_dim,
151
+ decay=self.decay[0],
152
+ restart_unused_codes=restart_unused_codes,
153
+ ).to(torch.float32)
154
+ self.codebooks = nn.ModuleList([codebook0 for _ in range(self.code_shape[-1])])
155
+ else:
156
+ codebooks = [VQEmbedding(self.n_embed[idx],
157
+ embed_dim,
158
+ decay=self.decay[idx],
159
+ restart_unused_codes=restart_unused_codes,
160
+ ).to(torch.float32) for idx in range(self.code_shape[-1])]
161
+ self.codebooks = nn.ModuleList(codebooks)
162
+
163
+ self.commitment_loss = commitment_loss
164
+
165
+ def to_code_shape(self, x):
166
+ (B, H, W, D) = x.shape
167
+ (rH, rW, _) = self.shape_divisor
168
+
169
+ x = x.reshape(B, H//rH, rH, W//rW, rW, D)
170
+ x = x.permute(0, 1, 3, 2, 4, 5)
171
+ x = x.reshape(B, H//rH, W//rW, -1)
172
+
173
+ return x
174
+
175
+ def to_latent_shape(self, x):
176
+ (B, h, w, _) = x.shape
177
+ (_, _, D) = self.latent_shape
178
+ (rH, rW, _) = self.shape_divisor
179
+
180
+ x = x.reshape(B, h, w, rH, rW, D)
181
+ x = x.permute(0, 1, 3, 2, 4, 5)
182
+ x = x.reshape(B, h*rH, w*rW, D)
183
+
184
+ return x
185
+
186
+ def quantize(self, x):
187
+ r"""
188
+ Return list of quantized features and the selected codewords by the residual quantization.
189
+ The code is selected by the residuals between x and quantized features by the previous codebooks.
190
+
191
+ Arguments:
192
+ x (Tensor): bottleneck feature maps to quantize.
193
+
194
+ Returns:
195
+ quant_list (list): list of sequentially aggregated and quantized feature maps by codebooks.
196
+ codes (LongTensor): codewords index, corresponding to quants.
197
+
198
+ Shape:
199
+ - x: (B, h, w, embed_dim)
200
+ - quant_list[i]: (B, h, w, embed_dim)
201
+ - codes: (B, h, w, d)
202
+ """
203
+ B, h, w, embed_dim = x.shape
204
+ ori_dtype = x.dtype
205
+ x = x.to(torch.float32)
206
+ self.codebooks = self.codebooks.to(torch.float32)
207
+
208
+ residual_feature = x.detach().clone()
209
+
210
+ quant_list = []
211
+ code_list = []
212
+ aggregated_quants = torch.zeros_like(x)
213
+ for i in range(self.code_shape[-1]):
214
+ quant, code = self.codebooks[i](residual_feature)
215
+ residual_feature.sub_(quant)
216
+ aggregated_quants.add_(quant)
217
+ quant_list.append(aggregated_quants.clone().to(dtype=ori_dtype))
218
+ code_list.append(code.unsqueeze(-1))
219
+
220
+ codes = torch.cat(code_list, dim=-1)
221
+ return quant_list, codes
222
+
223
+ def forward(self, x):
224
+ x_reshaped = self.to_code_shape(x)
225
+ # 强制使用float32精度来执行
226
+ quant_list, codes = self.quantize(x_reshaped)
227
+ # quant_list, codes = self.quantize(x_reshaped)
228
+
229
+ commitment_loss = self.compute_commitment_loss(x_reshaped, quant_list)
230
+ quants_trunc = self.to_latent_shape(quant_list[-1])
231
+ quants_trunc = x + (quants_trunc - x).detach()
232
+
233
+ '''
234
+ if self.shared_codebook:
235
+ cur_len = codes.view(-1).shape[0]
236
+ self.codebook_used[:-cur_len] = self.codebook_used[cur_len:].clone()
237
+ self.codebook_used[-cur_len:] = codes.view(-1)
238
+ codebook_usage = len(torch.unique(self.codebook_used)) / self.n_embed[0]
239
+ else:
240
+ # info|code: torch.Size([10, 16, 16, 4])
241
+ codebook_usage = 0
242
+ for idx in range(self.code_shape[-1]):
243
+ cur_len = codes[..., idx].view(-1).shape[0]
244
+ self.codebook_used[idx, :-cur_len] = self.codebook_used[idx, cur_len:].clone()
245
+ self.codebook_used[idx, -cur_len:] = codes[..., idx].view(-1)
246
+ codebook_usage += len(torch.unique(self.codebook_used[idx]))
247
+ codebook_usage /= (self.n_embed[0] * self.code_shape[-1])
248
+ '''
249
+ codebook_usage = 0
250
+ # (vq_loss, commit_loss, entropy_loss, codebook_usage) # 格式对齐
251
+ codebook_loss = [0, commitment_loss, 0, codebook_usage]
252
+
253
+ return quants_trunc, codebook_loss, codes
254
+
255
+ def compute_commitment_loss(self, x, quant_list):
256
+ r"""
257
+ Compute the commitment loss for the residual quantization.
258
+ The loss is iteratively computed by aggregating quantized features.
259
+ """
260
+ loss_list = []
261
+
262
+ for idx, quant in enumerate(quant_list):
263
+ partial_loss = (x-quant.detach()).pow(2.0).mean()
264
+ loss_list.append(partial_loss)
265
+
266
+ commitment_loss = torch.mean(torch.stack(loss_list))
267
+ return commitment_loss
268
+
269
+
270
+
271
+ class Qwen2_5_VisionRotaryEmbedding_Modified(nn.Module):
272
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
273
+ super().__init__()
274
+ self.inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
275
+ # self.register_buffer("inv_freq", inv_freq, persistent=False)
276
+
277
+ def forward(self, seqlen: int, device: torch.device) -> torch.Tensor:
278
+ self.inv_freq = self.inv_freq.to(device)
279
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
280
+ freqs = torch.outer(seq, self.inv_freq)
281
+ return freqs
282
+
283
+ class VisualEncoder(Qwen2_5_VisionTransformerPretrainedModel):
284
+
285
+ def __init__(self, config):
286
+ config._attn_implementation = 'flash_attention_2'
287
+ super().__init__(config)
288
+ self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding_Modified(config.hidden_size // config.num_heads // 2)
289
+ self.gradient_checkpointing = False
290
+ self._gradient_checkpointing_func = torch.utils.checkpoint.checkpoint
291
+ self.merge_size = config.merge_size if hasattr(config, 'merge_size') else 2
292
+ del self.merger # register visual.merger in visual_bridge_model
293
+
294
+ def get_dtype(self) -> torch.dtype:
295
+ return self.blocks[0].mlp.down_proj.weight.dtype
296
+
297
+ def get_device(self) -> torch.device:
298
+ return self.blocks[0].mlp.down_proj.weight.device
299
+
300
+ def rot_pos_emb(self, grid_thw):
301
+ pos_ids = []
302
+ for t, h, w in grid_thw:
303
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
304
+ hpos_ids = hpos_ids.reshape(
305
+ h // self.spatial_merge_size,
306
+ self.spatial_merge_size,
307
+ w // self.spatial_merge_size,
308
+ self.spatial_merge_size,
309
+ )
310
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
311
+ hpos_ids = hpos_ids.flatten()
312
+
313
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
314
+ wpos_ids = wpos_ids.reshape(
315
+ h // self.spatial_merge_size,
316
+ self.spatial_merge_size,
317
+ w // self.spatial_merge_size,
318
+ self.spatial_merge_size,
319
+ )
320
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
321
+ wpos_ids = wpos_ids.flatten()
322
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
323
+ pos_ids = torch.cat(pos_ids, dim=0)
324
+ max_grid_size = grid_thw[:, 1:].max()
325
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size, device=grid_thw.device)
326
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
327
+ return rotary_pos_emb
328
+
329
+ def forward(
330
+ self,
331
+ pixel_values: torch.Tensor,
332
+ grid_thw: torch.Tensor,
333
+ require_window_index: bool = False,
334
+ ):
335
+ '''
336
+ pixel_values.shape=[NumOfPatches, 1176]
337
+ grid_thw.shape=[NumOfSamples, 3]. [grid_t,grid_h,grid_w]
338
+ '''
339
+ hidden_states = pixel_values.to(torch.bfloat16)
340
+ grid_thw = grid_thw.to(pixel_values.device)
341
+
342
+ hidden_states = self.patch_embed(hidden_states)
343
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
344
+ window_index, cu_window_seqlens = self.get_window_index(grid_thw)
345
+ cu_window_seqlens = torch.tensor(
346
+ cu_window_seqlens,
347
+ device=hidden_states.device,
348
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
349
+ )
350
+ cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
351
+
352
+ seq_len, _ = hidden_states.size()
353
+ hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
354
+ hidden_states = hidden_states[window_index, :, :]
355
+ hidden_states = hidden_states.reshape(seq_len, -1)
356
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
357
+ rotary_pos_emb = rotary_pos_emb[window_index, :, :]
358
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
359
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
360
+ position_embeddings = (emb.cos(), emb.sin())
361
+
362
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
363
+ dim=0,
364
+ # Select dtype based on the following factors:
365
+ # - FA2 requires that cu_seqlens_q must have dtype int32
366
+ # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
367
+ # See https://github.com/huggingface/transformers/pull/34852 for more information
368
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
369
+ )
370
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
371
+
372
+ for layer_num, blk in enumerate(self.blocks):
373
+ if layer_num in self.fullatt_block_indexes:
374
+ cu_seqlens_now = cu_seqlens
375
+ else:
376
+ cu_seqlens_now = cu_window_seqlens
377
+ if self.gradient_checkpointing and self.training:
378
+ hidden_states = self._gradient_checkpointing_func(blk.__call__, hidden_states, cu_seqlens_now, None, position_embeddings)
379
+ else:
380
+ hidden_states = blk(
381
+ hidden_states,
382
+ cu_seqlens=cu_seqlens_now,
383
+ position_embeddings=position_embeddings,
384
+ )
385
+
386
+ if require_window_index:
387
+ return hidden_states, window_index
388
+ return hidden_states
389
+
390
+
391
+ class OmniVisualBridge(nn.Module):
392
+ def __init__(self, config):
393
+ super().__init__()
394
+ self.config = config
395
+ self.merge_size = self.config.merge_size if hasattr(self.config, 'merge_size') else 2
396
+ self.hidden_size = self.config.hidden_size * (self.merge_size**2)
397
+ self.window_index = self.config.window_size
398
+ self.ln_q = Qwen2RMSNorm(self.config.hidden_size, eps=1e-6)
399
+ self.mlp = nn.Sequential(
400
+ nn.Linear(self.hidden_size, self.hidden_size),
401
+ nn.GELU(),
402
+ nn.Linear(self.hidden_size, self.config.out_hidden_size),
403
+ )
404
+
405
+ def forward(self, x: torch.Tensor, window_index) -> torch.Tensor:
406
+ x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
407
+ reverse_indices = torch.argsort(window_index)
408
+ x = x[reverse_indices, :]
409
+
410
+ return x
411
+
412
+
413
+ class VisualQuantizer(nn.Module):
414
+ def __init__(self, quantizer_config):
415
+ super().__init__()
416
+
417
+ self.config = quantizer_config
418
+ self.depth = self.config.depth
419
+ self.decay = self.config.decay
420
+ self.codebook_size = self.config.codebook_size
421
+ self.codebook_dim = self.config.codebook_dim
422
+ self.shared_codebook = self.config.shared_codebook
423
+ self.restart_unused_codes = self.config.restart_unused_codes
424
+ self.in_channels = self.config.in_channels
425
+
426
+ self.vq_loss_ratio = self.config.vq_loss_ratio
427
+ self.entropy_loss_ratio = self.config.entropy_loss_ratio
428
+ self.commit_loss_ratio = self.config.commit_loss_ratio
429
+
430
+ code_h_w = int(448 / 14)
431
+ latent_shape = [code_h_w, code_h_w, self.codebook_dim]
432
+ code_shape = [code_h_w, code_h_w, self.depth]
433
+
434
+ self.quantize = RQBottleneck(
435
+ latent_shape=latent_shape,
436
+ code_shape=code_shape,
437
+ n_embed=self.codebook_size,
438
+ decay=self.decay,
439
+ shared_codebook=self.shared_codebook,
440
+ restart_unused_codes=self.restart_unused_codes,
441
+ )
442
+
443
+ if self.config.quant_conv:
444
+ self.quant_conv = nn.Sequential(
445
+ nn.LayerNorm(self.in_channels),
446
+ nn.Linear(self.in_channels, self.in_channels),
447
+ nn.GELU(),
448
+ nn.Linear(self.in_channels, self.codebook_dim)
449
+ )
450
+ else:
451
+ self.quant_conv = None
452
+
453
+ def encode(self, x):
454
+ L, D = x.shape
455
+ to_qnt_feat = x.clone()
456
+ to_qnt_feat = to_qnt_feat.unsqueeze(0) # [L, D] -> [1, L, D]
457
+ N = 1
458
+
459
+ if self.quant_conv is not None:
460
+ to_qnt_feat = self.quant_conv(to_qnt_feat)
461
+
462
+ # quantizer needs nchw format. N,L,d -> N,1,L,d -> N,d,1,L
463
+ to_qnt_feat = to_qnt_feat.reshape(N, 1, L, self.codebook_dim).permute(0,3,1,2)
464
+ if self.config.quantizer_type == "rq":
465
+ to_qnt_feat = to_qnt_feat.permute(0, 2, 3, 1).contiguous() # N,d,1,L -> N,1,L,d
466
+ quant, emb_loss, info = self.quantize(to_qnt_feat)
467
+ info = info.reshape(-1, info.shape[-1]) # n,h,w,lv -> n*h*w,lv
468
+ info = [None, None, info]
469
+ quant = quant.permute(0, 3, 1, 2).contiguous() # N,1,L,d -> N,d,1,L
470
+ else:
471
+ quant, emb_loss, info = self.quantize(to_qnt_feat)
472
+ return quant, emb_loss, info, x.detach()
473
+
474
+ def forward(self, x):
475
+ quant, (vq_loss, commit_loss, entropy_loss, codebook_usage), (perplexity, min_encodings, min_encoding_indices), align_feature = \
476
+ self.encode(x)
477
+ return min_encoding_indices
478
+
479
+
480
+ class MLP(nn.Module):
481
+ def __init__(
482
+ self,
483
+ hidden_size: int,
484
+ intermediate_size: int,
485
+ hidden_act: str,
486
+ ):
487
+ super().__init__()
488
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
489
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
490
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
491
+ self.act_fn = ACT2FN[hidden_act]
492
+
493
+ def forward(self, x):
494
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
495
+
496
+ class DecoderLayer(nn.Module):
497
+ def __init__(self, config):
498
+ super().__init__()
499
+ self.hidden_size = config.hidden_size
500
+ self.mlp = MLP(
501
+ hidden_size=self.hidden_size,
502
+ intermediate_size=config.visual_embedding_layer_intermediate_size,
503
+ hidden_act=config.visual_embedding_layer_hidden_act,
504
+ )
505
+ self.pre_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
506
+
507
+ def forward(
508
+ self,
509
+ hidden_states: torch.Tensor,
510
+ ):
511
+ residual = hidden_states
512
+ hidden_states = self.pre_layernorm(hidden_states)
513
+ hidden_states = self.mlp(hidden_states)
514
+ hidden_states = residual + hidden_states
515
+
516
+ return hidden_states
517
+
518
+
519
+ class VisualEmbeddingBridge(nn.Module):
520
+ def __init__(self, config):
521
+ super().__init__()
522
+ self.pre_buffer = DecoderLayer(config)
523
+
524
+ def forward(self, embeding):
525
+ return self.pre_buffer(embeding)
526
+
527
+
528
+ class VisualVQBridge(nn.Module):
529
+ def __init__(self, visual_config):
530
+ super().__init__()
531
+ self.bridge = OmniVisualBridge(visual_config)
532
+ self.quantizer = VisualQuantizer(visual_config.vq_config)
533
+
534
+ def forward(
535
+ self,
536
+ visual_embed: torch.Tensor,
537
+ window_index: torch.Tensor,
538
+ ):
539
+ visual_embed = self.bridge(visual_embed, window_index)
540
+ indices = self.quantizer(visual_embed)
541
+ return indices
542
+
543
+
544
+ class LongcatNextVisualTokenizer(nn.Module):
545
+
546
+ def __init__(self, config):
547
+ super().__init__()
548
+ self.config = config
549
+ self.visual_model = VisualEncoder(config.visual_config)
550
+ self.visual_bridge_model = VisualVQBridge(config.visual_config)
551
+ self.visual_embedding_layer = VisualEmbeddingBridge(config)
552
+ self.image_decoder = None
553
+ self._refiner_pipeline = None
554
+
555
+ @torch.no_grad()
556
+ def encode(self, pixel_values: torch.Tensor, visual_grid_thw: torch.Tensor):
557
+ visual_embed, window_index = self.visual_model(pixel_values, grid_thw=visual_grid_thw, require_window_index=True)
558
+ indices = self.visual_bridge_model(visual_embed, window_index)
559
+ return indices
560
+
561
+ @torch.no_grad()
562
+ def lazy_decode_and_save(self, visual_ids, tokens_h, tokens_w, save_path):
563
+ device = next(self.parameters()).device
564
+ if self.image_decoder is None:
565
+ print("lazy load image_decoder / image_refiner / _refiner_pipeline ...")
566
+ vdc = self.config.visual_config.visual_decoder_config
567
+ self.image_decoder = VisionTransformerDecoder.from_pretrained(
568
+ vdc.image_decoder_config,
569
+ vdc.weight_path,
570
+ ).to(device=device, dtype=torch.bfloat16)
571
+ image_refiner = ImageRefinerContainer.from_pretrained(vdc, vdc.weight_path).to(device=device, dtype=torch.bfloat16)
572
+
573
+ sc = vdc.scheduler_config
574
+ scheduler = FlowMatchEulerDiscreteScheduler(
575
+ num_train_timesteps=sc.num_train_timesteps,
576
+ dynamic_time_shift=sc.dynamic_time_shift)
577
+ self._refiner_pipeline = RefinerPipeline(
578
+ vae=image_refiner.vae,
579
+ transformer=image_refiner.base_transformer,
580
+ scheduler=scheduler,
581
+ cond_proj=image_refiner.cond_proj,
582
+ )
583
+ self._refiner_pipeline.set_progress_bar_config(disable=False)
584
+
585
+ data = torch.as_tensor(visual_ids, dtype=torch.long)
586
+ if data.ndim == 1:
587
+ data = data.view(-1, len(self.config.visual_config.vq_config.codebook_sizes))
588
+ if data.ndim == 2:
589
+ data = data.unsqueeze(0)
590
+ batch_size = data.shape[0]
591
+
592
+ quant_features = None
593
+ for idx in range(len(self.config.visual_config.vq_config.codebook_sizes)):
594
+ embed = self.visual_bridge_model.quantizer.quantize.codebooks[idx].embed
595
+ feat = embed[data[..., idx].to(embed.device)]
596
+ quant_features = feat if quant_features is None else quant_features + feat
597
+ quant_features = quant_features.to(device)
598
+
599
+ # tokens_h/tokens_w are the merged grid; expand to the full (unmerged) grid
600
+ s = self.image_decoder.spatial_merge_size
601
+ grid_thw_list = [(1, tokens_h * s, tokens_w * s)]
602
+ grid_thw_batch = list(grid_thw_list) * batch_size
603
+
604
+ image_mean = [0.48145466, 0.4578275, 0.40821073]
605
+ image_std = [0.26862954, 0.26130258, 0.27577711]
606
+
607
+ emb_2d = quant_features.reshape(-1, quant_features.shape[-1]).contiguous()
608
+ device_type = "cuda" if str(device).startswith("cuda") else str(device)
609
+ with torch.amp.autocast(device_type=device_type, enabled=True, dtype=torch.float32):
610
+ decoder_out = self.image_decoder(emb_2d, grid_thw_batch, return_pixel_features=False)
611
+
612
+ decoded_tensors = decoder_out.get("images") or []
613
+ decoded_images = [tensor2pil(t, image_mean, image_std) for t in decoded_tensors]
614
+ decoded_path = save_path.replace(".png", "_decoded.png")
615
+ # decoded_images[0].save(decoded_path)
616
+
617
+
618
+ ref_input = []
619
+ for t in decoded_tensors:
620
+ img_01 = de_transform(t, mean=image_mean, std=image_std, rescale_factor=1 / 255)
621
+ img_norm = RefinerImageProcessor.normalize(img_01)
622
+ ref_input.append(img_norm.squeeze(0).to(device))
623
+
624
+ generators = [torch.Generator(device=device).manual_seed(42 + b) for b in range(batch_size)]
625
+ out = self._refiner_pipeline(
626
+ encoder_hidden_states=quant_features,
627
+ grid_thw_list=grid_thw_list,
628
+ image=ref_input,
629
+ generator=generators[0] if batch_size == 1 else generators,
630
+ output_type="pil",
631
+ return_dict=True,
632
+ )
633
+ refined_images = out.images
634
+ refined_path = save_path.replace(".png", "_refined.png")
635
+ refined_images[0].save(refined_path)
636
+
637
+ return [refined_path]
638
+
639
+
640
+ # ---------------------------------------------------------------------------
641
+ # Vision Transformer Decoder
642
+ # ---------------------------------------------------------------------------
643
+
644
+ def _rotate_half(x):
645
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
646
+ x1, x2 = x.unbind(dim=-1)
647
+ x = torch.stack((-x2, x1), dim=-1)
648
+ return rearrange(x, "... d r -> ... (d r)")
649
+
650
+
651
+ class VisionRoPE2D(nn.Module):
652
+ """2D Rotary Position Embedding for Q/K in vision decoder attention."""
653
+
654
+ def __init__(self, theta: float = 10000.0):
655
+ super().__init__()
656
+ self.theta = theta
657
+
658
+ def _rope_half(self, x_half, pos_1d, theta):
659
+ BH, T, d_half = x_half.shape
660
+ idx = torch.arange(0, d_half, 2, device=x_half.device, dtype=torch.float32)
661
+ inv_freq = (1.0 / (theta ** (idx / d_half))).to(x_half.dtype)
662
+ angles = pos_1d.to(x_half.dtype)[:, None] * inv_freq[None, :]
663
+ cos = torch.repeat_interleave(torch.cos(angles), 2, dim=-1).unsqueeze(0)
664
+ sin = torch.repeat_interleave(torch.sin(angles), 2, dim=-1).unsqueeze(0)
665
+ return x_half * cos + _rotate_half(x_half) * sin
666
+
667
+ def forward(self, x, positions_2d):
668
+ d_half = x.shape[-1] // 2
669
+ x_y = self._rope_half(x[:, :, :d_half], positions_2d[:, 0], self.theta)
670
+ x_x = self._rope_half(x[:, :, d_half:], positions_2d[:, 1], self.theta)
671
+ return torch.cat([x_y, x_x], dim=-1)
672
+
673
+
674
+ class VisionAttention(nn.Module):
675
+ """Multi-headed attention with 2D RoPE + FlashAttention varlen."""
676
+
677
+ def __init__(self, config, rope=None, rope_shift=0):
678
+ super().__init__()
679
+ self.config = config
680
+ self.embed_dim = config.hidden_size
681
+ self.num_heads = config.num_attention_heads
682
+ self.head_dim = self.embed_dim // self.num_heads
683
+ if self.head_dim * self.num_heads != self.embed_dim:
684
+ raise ValueError(
685
+ f"embed_dim must be divisible by num_heads (got embed_dim={self.embed_dim}, num_heads={self.num_heads})"
686
+ )
687
+ self.scale = self.head_dim ** -0.5
688
+ self.dropout = config.attention_dropout
689
+ self.subln = config.subln
690
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=getattr(config, "k_bias", True))
691
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=getattr(config, "v_bias", True))
692
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=getattr(config, "q_bias", True))
693
+ self.inner_attn_ln = Qwen2RMSNorm(self.embed_dim, eps=config.layer_norm_eps) if config.subln else nn.Identity()
694
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
695
+ self.rope = rope
696
+ self.rope_shift = int(rope_shift)
697
+
698
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
699
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
700
+
701
+ def _maybe_flash_attention(self, query_states, key_states, value_states, seq_lens, training):
702
+ if not (query_states.is_cuda and (query_states.dtype in (torch.float16, torch.bfloat16, torch.float32))):
703
+ return None
704
+ if seq_lens is None:
705
+ return None
706
+ try:
707
+ BxH, T, hd = query_states.shape
708
+ H = self.num_heads
709
+ assert BxH % H == 0
710
+ B = BxH // H
711
+ if int(seq_lens.sum().item()) != T:
712
+ return None
713
+ q = query_states.view(B, H, T, hd).transpose(1, 2).reshape(-1, H, hd).contiguous()
714
+ k = key_states.view(B, H, T, hd).transpose(1, 2).reshape(-1, H, hd).contiguous()
715
+ v = value_states.view(B, H, T, hd).transpose(1, 2).reshape(-1, H, hd).contiguous()
716
+ cu_q = torch.zeros(seq_lens.numel() + 1, dtype=torch.int32, device=seq_lens.device)
717
+ cu_q[1:] = torch.cumsum(seq_lens.to(torch.int32), dim=0)
718
+ cu_k = cu_q
719
+ max_seqlen = int(seq_lens.max().item())
720
+ orig_dtype = q.dtype
721
+ use_dtype = q.dtype if q.dtype in (torch.float16, torch.bfloat16) else torch.float16
722
+ if q.dtype != use_dtype:
723
+ q = q.to(use_dtype)
724
+ k = k.to(use_dtype)
725
+ v = v.to(use_dtype)
726
+ out = flash_attn_varlen_func(
727
+ q, k, v, cu_q, cu_k, max_seqlen, max_seqlen,
728
+ dropout_p=self.dropout if training else 0.0,
729
+ softmax_scale=None, causal=False, return_attn_probs=False
730
+ )
731
+ if out.dtype != orig_dtype:
732
+ out = out.to(orig_dtype)
733
+ return out.view(B, -1, H, hd).transpose(1, 2).contiguous().view(B * H, T, hd)
734
+ except Exception:
735
+ return None
736
+
737
+ def forward(
738
+ self,
739
+ hidden_states: torch.Tensor,
740
+ attention_mask: Optional[torch.Tensor] = None,
741
+ causal_attention_mask: Optional[torch.Tensor] = None,
742
+ output_attentions: Optional[bool] = False,
743
+ positions_2d: Optional[torch.Tensor] = None,
744
+ seq_lens: Optional[torch.Tensor] = None,
745
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
746
+ bsz, tgt_len, embed_dim = hidden_states.size()
747
+ query_states = self.q_proj(hidden_states) * self.scale
748
+ key_states = self.k_proj(hidden_states)
749
+ value_states = self.v_proj(hidden_states)
750
+ query_states = self._shape(query_states, tgt_len, bsz).view(bsz * self.num_heads, tgt_len, self.head_dim)
751
+ key_states = self._shape(key_states, tgt_len, bsz).view(bsz * self.num_heads, tgt_len, self.head_dim)
752
+ value_states = self._shape(value_states, tgt_len, bsz).view(bsz * self.num_heads, tgt_len, self.head_dim)
753
+ if self.rope is not None and positions_2d is not None:
754
+ if self.rope_shift > 0:
755
+ q_pref = query_states[:, :self.rope_shift, :]
756
+ k_pref = key_states[:, :self.rope_shift, :]
757
+ q_rot = self.rope(query_states[:, self.rope_shift:, :], positions_2d[self.rope_shift:])
758
+ k_rot = self.rope(key_states[:, self.rope_shift:, :], positions_2d[self.rope_shift:])
759
+ query_states = torch.cat([q_pref, q_rot], dim=1).type_as(value_states)
760
+ key_states = torch.cat([k_pref, k_rot], dim=1).type_as(value_states)
761
+ else:
762
+ query_states = self.rope(query_states, positions_2d).type_as(value_states)
763
+ key_states = self.rope(key_states, positions_2d).type_as(value_states)
764
+ attn_output = self._maybe_flash_attention(
765
+ query_states, key_states, value_states, seq_lens=seq_lens, training=self.training
766
+ )
767
+ if attn_output is not None:
768
+ attn_weights_reshaped = None
769
+ else:
770
+ src_len = key_states.size(1)
771
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
772
+ if causal_attention_mask is not None:
773
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
774
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
775
+ if attention_mask is not None:
776
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
777
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
778
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
779
+ if output_attentions:
780
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
781
+ else:
782
+ attn_weights_reshaped = None
783
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
784
+ attn_output = torch.bmm(attn_probs, value_states)
785
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
786
+ attn_output = attn_output.transpose(1, 2).reshape(bsz, tgt_len, embed_dim)
787
+ attn_output = self.inner_attn_ln(attn_output)
788
+ attn_output = self.out_proj(attn_output)
789
+ return attn_output, attn_weights_reshaped
790
+
791
+
792
+ class VisionSwiGLU(nn.Module):
793
+ def __init__(self, config):
794
+ super().__init__()
795
+ self.config = config
796
+ self.hidden_size = config.hidden_size
797
+ self.intermediate_size = config.intermediate_size
798
+ self.w1 = nn.Linear(self.hidden_size, self.intermediate_size)
799
+ self.w2 = nn.Linear(self.hidden_size, self.intermediate_size)
800
+ self.w3 = nn.Linear(self.intermediate_size, self.hidden_size)
801
+ self.act_fn = nn.SiLU()
802
+ self.ffn_ln = Qwen2RMSNorm(self.intermediate_size, eps=config.layer_norm_eps) if config.subln else nn.Identity()
803
+
804
+ def forward(self, x):
805
+ x1 = self.w1(x)
806
+ x2 = self.w2(x)
807
+ hidden = self.act_fn(x1) * x2
808
+ x = self.ffn_ln(hidden)
809
+ x = self.w3(x)
810
+ return x
811
+
812
+
813
+ class VisionMLP(nn.Module):
814
+ def __init__(self, config):
815
+ super().__init__()
816
+ self.config = config
817
+ self.activation_fn = ACT2FN[config.hidden_act]
818
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
819
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
820
+ self.ffn_ln = Qwen2RMSNorm(config.intermediate_size, eps=config.layer_norm_eps) if config.subln else nn.Identity()
821
+
822
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
823
+ hidden_states = self.fc1(hidden_states)
824
+ hidden_states = self.activation_fn(hidden_states)
825
+ hidden_states = self.ffn_ln(hidden_states)
826
+ hidden_states = self.fc2(hidden_states)
827
+ return hidden_states
828
+
829
+
830
+ class VisionEncoderLayer(nn.Module):
831
+ def __init__(self, config, rope=None, rope_shift=0):
832
+ super().__init__()
833
+ self.embed_dim = config.hidden_size
834
+ self.self_attn = VisionAttention(config, rope=rope, rope_shift=rope_shift)
835
+ self.layer_norm1 = Qwen2RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
836
+ self.mlp = VisionSwiGLU(config) if config.swiglu else VisionMLP(config)
837
+ self.layer_norm2 = Qwen2RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
838
+
839
+ def forward(
840
+ self,
841
+ hidden_states: torch.Tensor,
842
+ attention_mask: Optional[torch.Tensor],
843
+ causal_attention_mask: Optional[torch.Tensor],
844
+ output_attentions: Optional[bool] = False,
845
+ positions_2d: Optional[torch.Tensor] = None,
846
+ seq_lens: Optional[torch.Tensor] = None,
847
+ ) -> Tuple[torch.FloatTensor, Optional[torch.Tensor]]:
848
+ residual = hidden_states
849
+ hidden_states = self.layer_norm1(hidden_states)
850
+ hidden_states, attn_weights = self.self_attn(
851
+ hidden_states=hidden_states,
852
+ attention_mask=attention_mask,
853
+ causal_attention_mask=causal_attention_mask,
854
+ output_attentions=output_attentions,
855
+ positions_2d=positions_2d,
856
+ seq_lens=seq_lens,
857
+ )
858
+ hidden_states = residual + hidden_states
859
+ residual = hidden_states
860
+ hidden_states = self.layer_norm2(hidden_states)
861
+ hidden_states = self.mlp(hidden_states)
862
+ hidden_states = residual + hidden_states
863
+ outputs = (hidden_states,)
864
+ if output_attentions:
865
+ outputs += (attn_weights,)
866
+ return outputs
867
+
868
+
869
+ class VisionEncoder(nn.Module):
870
+ def __init__(self, config, rope=None, rope_shift=0):
871
+ super().__init__()
872
+ self.config = config
873
+ self.layers = nn.ModuleList(
874
+ [VisionEncoderLayer(config, rope=rope, rope_shift=rope_shift) for _ in range(config.num_hidden_layers)]
875
+ )
876
+ self.gradient_checkpointing = False
877
+ self._gradient_checkpointing_func = torch.utils.checkpoint.checkpoint
878
+
879
+ def forward(
880
+ self,
881
+ inputs_embeds: torch.Tensor,
882
+ attention_mask: Optional[torch.Tensor] = None,
883
+ causal_attention_mask: Optional[torch.Tensor] = None,
884
+ output_attentions: Optional[bool] = None,
885
+ output_hidden_states: Optional[bool] = None,
886
+ return_dict: Optional[bool] = None,
887
+ positions_2d: Optional[torch.Tensor] = None,
888
+ seq_lens: Optional[torch.Tensor] = None,
889
+ ):
890
+ output_attentions = output_attentions if output_attentions is not None else False
891
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else False
892
+ return_dict = True if return_dict is None else return_dict
893
+
894
+ encoder_states = () if output_hidden_states else None
895
+ all_attentions = () if output_attentions else None
896
+ hidden_states = inputs_embeds
897
+
898
+ for layer in self.layers:
899
+ if output_hidden_states:
900
+ encoder_states = encoder_states + (hidden_states,)
901
+ if self.gradient_checkpointing and self.training:
902
+ def custom_forward(hs, attn, causal, pos2d, seqlens):
903
+ return layer(
904
+ hs,
905
+ attention_mask=attn,
906
+ causal_attention_mask=causal,
907
+ output_attentions=False,
908
+ positions_2d=pos2d,
909
+ seq_lens=seqlens,
910
+ )[0]
911
+ hidden_states = self._gradient_checkpointing_func(
912
+ custom_forward,
913
+ hidden_states,
914
+ attention_mask if attention_mask is not None else torch.tensor(0., device=hidden_states.device),
915
+ causal_attention_mask if causal_attention_mask is not None else torch.tensor(0., device=hidden_states.device),
916
+ positions_2d,
917
+ seq_lens if seq_lens is not None else torch.tensor([], device=hidden_states.device),
918
+ use_reentrant=False,
919
+ )
920
+ else:
921
+ layer_outputs = layer(
922
+ hidden_states,
923
+ attention_mask,
924
+ causal_attention_mask,
925
+ output_attentions=output_attentions,
926
+ positions_2d=positions_2d,
927
+ seq_lens=seq_lens,
928
+ )
929
+ hidden_states = layer_outputs[0]
930
+ if output_attentions:
931
+ all_attentions = all_attentions + (layer_outputs[1],)
932
+
933
+ if output_hidden_states:
934
+ encoder_states = encoder_states + (hidden_states,)
935
+
936
+ if not return_dict:
937
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
938
+
939
+ return BaseModelOutput(
940
+ last_hidden_state=hidden_states,
941
+ hidden_states=encoder_states,
942
+ attentions=all_attentions,
943
+ )
944
+
945
+
946
+ class PatchUnMerger(nn.Module):
947
+ """Learnable inverse of Qwen2_5_VLPatchMerger."""
948
+ def __init__(self, dim, context_dim, spatial_merge_size=2):
949
+ super().__init__()
950
+ self.spatial_merge_size = spatial_merge_size
951
+ self.context_dim = context_dim
952
+ hidden = context_dim * (spatial_merge_size ** 2)
953
+ self.ln_q = Qwen2RMSNorm(dim, eps=1e-6)
954
+ self.mlp = nn.Sequential(nn.Linear(dim, hidden), nn.GELU(), nn.Linear(hidden, hidden))
955
+
956
+ def forward(self, x):
957
+ x = self.mlp(self.ln_q(x))
958
+ return x.view(x.shape[0] * (self.spatial_merge_size ** 2), self.context_dim)
959
+
960
+
961
+ def restore_spatial_structure_and_convert_to_images(patches, grid_thw_list, patch_size,
962
+ channel_dim=3, temporal_patch_size=2, merge_size=2):
963
+ """Convert decoder pixel features back to image tensors [3, H, W]."""
964
+ if isinstance(patches, tuple):
965
+ patches = patches[0]
966
+ image_tensors = []
967
+ ptr = 0
968
+ for grid in grid_thw_list:
969
+ gt, gh, gw = (int(x) for x in (grid if not isinstance(grid, torch.Tensor) else grid.tolist()))
970
+ n = gt * gh * gw
971
+ chunk = patches[ptr:ptr + n]
972
+ ptr += n
973
+ r = chunk.reshape(gt, gh // merge_size, gw // merge_size, merge_size, merge_size,
974
+ channel_dim, temporal_patch_size, patch_size, patch_size)
975
+ r = r.permute(0, 6, 5, 1, 3, 7, 2, 4, 8)
976
+ image_tensors.append(r.reshape(gt * temporal_patch_size, channel_dim, gh * patch_size, gw * patch_size)[0])
977
+ return image_tensors
978
+
979
+
980
+ class VisionTransformerDecoder(nn.Module):
981
+ def __init__(self, config):
982
+ super().__init__()
983
+ self.config = config
984
+ self.embed_dim = config.hidden_size
985
+ self.patch_size = config.patch_size
986
+ self.spatial_merge_size = config.spatial_merge_size
987
+ self.codebook_dim = config.codebook_dim
988
+ self.temporal_patch_size = config.temporal_patch_size
989
+
990
+ self.rope2d = VisionRoPE2D(theta=10000.0)
991
+ self.post_quant_conv = nn.Linear(self.codebook_dim, self.embed_dim)
992
+ self.post_quant_norm = Qwen2RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
993
+ self.patch_unmerger = PatchUnMerger(self.embed_dim, self.embed_dim, self.spatial_merge_size)
994
+ self.norm_in = Qwen2RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
995
+ self.encoder = VisionEncoder(config, rope=self.rope2d, rope_shift=0)
996
+ self.norm_out = Qwen2RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
997
+ self.decoder_head = nn.Sequential(
998
+ nn.Linear(self.embed_dim, config.intermediate_size), nn.GELU(),
999
+ nn.Linear(config.intermediate_size, 3 * self.patch_size * self.patch_size * self.temporal_patch_size),
1000
+ )
1001
+
1002
+ @classmethod
1003
+ def from_pretrained(cls, config, model_path: str):
1004
+ """Load a pretrained model from a checkpoint."""
1005
+ model = cls(config)
1006
+ weight_dict = load_file(model_path, device="cpu")
1007
+ model.load_state_dict({k.removeprefix("image_decoder."): v for k, v in weight_dict.items() if k.startswith("image_decoder.")}, strict=True)
1008
+ model.eval()
1009
+ return model
1010
+
1011
+ def _build_2d_positions(self, grid_thw_list):
1012
+ pos_list = []
1013
+ for (t, gh, gw) in grid_thw_list:
1014
+ for _ in range(int(t)):
1015
+ for y in range(int(gh)):
1016
+ for x in range(int(gw)):
1017
+ pos_list.append([y, x])
1018
+ return torch.tensor(pos_list, dtype=torch.long)
1019
+
1020
+ def _build_attention_mask(self, grid_thw_list, device, dtype, B, num_heads):
1021
+ counts = [int(t) * int(h) * int(w) for (t, h, w) in grid_thw_list]
1022
+ L = sum(counts)
1023
+ mask = torch.zeros((B, num_heads, L, L), device=device, dtype=dtype)
1024
+ s = 0
1025
+ for c in counts:
1026
+ e = s + c
1027
+ if s > 0:
1028
+ mask[:, :, s:e, :s] = float("-inf")
1029
+ if e < L:
1030
+ mask[:, :, s:e, e:] = float("-inf")
1031
+ s = e
1032
+ return mask
1033
+
1034
+ def forward(self, embeddings, grid_thw, return_pixel_features=False, return_last_latent=False):
1035
+ device = embeddings.device
1036
+ grid_thw_list = ([(int(t), int(h), int(w)) for t, h, w in grid_thw.detach().cpu().numpy()]
1037
+ if isinstance(grid_thw, torch.Tensor) else list(grid_thw))
1038
+
1039
+ if embeddings.shape[-1] == self.codebook_dim:
1040
+ embeddings = self.post_quant_conv(embeddings)
1041
+ embeddings = self.post_quant_norm(embeddings)
1042
+
1043
+ unmerged = self.patch_unmerger(embeddings)
1044
+ if unmerged.dim() == 2:
1045
+ unmerged = unmerged.unsqueeze(0)
1046
+ B, L, D = unmerged.shape
1047
+ hidden_states = self.norm_in(unmerged)
1048
+
1049
+ positions_2d = self._build_2d_positions(grid_thw_list).to(device)
1050
+ seq_lens = torch.tensor([int(t) * int(h) * int(w) for (t, h, w) in grid_thw_list],
1051
+ device=device, dtype=torch.int32)
1052
+ assert positions_2d.shape[0] == L, f"positions_2d {positions_2d.shape[0]} != L {L}"
1053
+
1054
+ last_latent = hidden_states.detach().squeeze(0) if return_last_latent else None
1055
+ enc_out = self.encoder(
1056
+ inputs_embeds=hidden_states,
1057
+ attention_mask=None,
1058
+ causal_attention_mask=None,
1059
+ output_attentions=False,
1060
+ output_hidden_states=False,
1061
+ return_dict=True,
1062
+ positions_2d=positions_2d,
1063
+ seq_lens=seq_lens,
1064
+ )
1065
+ hidden_states = enc_out.last_hidden_state
1066
+
1067
+ hidden_states = self.norm_out(hidden_states)
1068
+ pixel_features = self.decoder_head(hidden_states).squeeze(0)
1069
+
1070
+ out_imgs = (None if return_pixel_features else
1071
+ restore_spatial_structure_and_convert_to_images(
1072
+ pixel_features, grid_thw_list, self.patch_size,
1073
+ temporal_patch_size=self.temporal_patch_size, merge_size=self.spatial_merge_size))
1074
+ ret = {"images": out_imgs, "pixel_features": pixel_features}
1075
+ if last_latent is not None:
1076
+ ret["last_latent"] = last_latent
1077
+ return ret
parse_model_response.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+ import uuid
4
+
5
+ def parse_arguments(json_value):
6
+ """
7
+ Attempt to parse a string as JSON
8
+
9
+ Args:
10
+ json_value: String to parse
11
+
12
+ Returns:
13
+ tuple: (parsed_value, is_valid_json)
14
+ """
15
+ try:
16
+ parsed_value = json.loads(json_value)
17
+ return parsed_value, True
18
+ except:
19
+ return json_value, False
20
+
21
+ def get_argument_type(func_name: str, arg_key: str, defined_tools: list):
22
+ """
23
+ Get the type definition of a tool parameter
24
+
25
+ Args:
26
+ func_name: Name of the function/tool
27
+ arg_key: Parameter key name
28
+ defined_tools: List of tool definitions
29
+
30
+ Returns:
31
+ str or None: Type of the parameter ('string', 'object', 'array', 'integer', 'number', 'boolean')
32
+ """
33
+ name2tool = {tool["name"]: tool for tool in defined_tools}
34
+ if func_name not in name2tool:
35
+ return None
36
+ tool = name2tool[func_name]
37
+ if "parameters" not in tool or "properties" not in tool["parameters"]:
38
+ return None
39
+ if arg_key not in tool["parameters"]["properties"]:
40
+ return None
41
+ return tool["parameters"]["properties"][arg_key].get("type")
42
+
43
+ def parse_model_response(response: str, defined_tools: list=[]):
44
+ """
45
+ Parse model response to extract reasoning_content, content, and tool_calls
46
+
47
+ Args:
48
+ response: Raw response text from the model
49
+ defined_tools: List of tool definitions
50
+
51
+ Returns:
52
+ dict: Message containing role, reasoning_content (optional), content (optional),
53
+ and tool_calls (optional)
54
+ """
55
+ text = response
56
+ reasoning_content = None
57
+ content = None
58
+ tool_calls = []
59
+
60
+ formatted_tools = []
61
+ for tool in defined_tools:
62
+ if "function" in tool:
63
+ formatted_tools.append(tool['function'])
64
+ else:
65
+ formatted_tools.append(tool)
66
+
67
+ if '</longcat_think>' in text:
68
+ text = text.replace('<longcat_think>', '')
69
+ thinking_end = text.find('</longcat_think>')
70
+ reasoning_content = text[: thinking_end].strip()
71
+ text = text[thinking_end + len('</longcat_think>'):].lstrip()
72
+
73
+ assert '<longcat_think>' not in text, "Unclosed <longcat_think> tag found in remaining text"
74
+ assert '</longcat_think>' not in text, "Unexpected </longcat_think> tag found without opening tag"
75
+
76
+ if '<longcat_tool_call>' in text:
77
+ index = text.find('<longcat_tool_call>')
78
+ content = text[:index]
79
+ text = text[index:].strip()
80
+ else:
81
+ content = text
82
+ text = ""
83
+
84
+ open_tags = text.count('<longcat_tool_call>')
85
+ close_tags = text.count('</longcat_tool_call>')
86
+ assert open_tags == close_tags, \
87
+ f"Mismatched tool_call tags: {open_tags} opening tags, {close_tags} closing tags"
88
+
89
+ tool_call_strs = re.findall(
90
+ r'<longcat_tool_call>(.*?)</longcat_tool_call>',
91
+ text,
92
+ re.DOTALL
93
+ )
94
+
95
+ for call in tool_call_strs:
96
+ func_name_match = re.match(r'([^\n<]+)', call.strip())
97
+ assert func_name_match, f"Missing function name in tool call: {call[:100]}"
98
+
99
+ func_name = func_name_match.group(1).strip()
100
+ assert func_name, "Empty function name in tool call"
101
+
102
+ # Verify argument tags are properly paired
103
+ arg_key_count = call.count('<longcat_arg_key>')
104
+ arg_key_close_count = call.count('</longcat_arg_key>')
105
+ arg_value_count = call.count('<longcat_arg_value>')
106
+ arg_value_close_count = call.count('</longcat_arg_value>')
107
+
108
+ assert arg_key_count == arg_key_close_count, \
109
+ f"Mismatched arg_key tags in function {func_name}: {arg_key_count} opening, {arg_key_close_count} closing"
110
+ assert arg_value_count == arg_value_close_count, \
111
+ f"Mismatched arg_value tags in function {func_name}: {arg_value_count} opening, {arg_value_close_count} closing"
112
+ assert arg_key_count == arg_value_count, \
113
+ f"Mismatched arg_key and arg_value count in function {func_name}: {arg_key_count} keys, {arg_value_count} values"
114
+
115
+ pairs = re.findall(
116
+ r'<longcat_arg_key>(.*?)</longcat_arg_key>\s*<longcat_arg_value>(.*?)</longcat_arg_value>',
117
+ call,
118
+ re.DOTALL
119
+ )
120
+
121
+ assert len(pairs) == arg_key_count, \
122
+ f"Failed to parse all arguments in function {func_name}: expected {arg_key_count}, got {len(pairs)}"
123
+
124
+ arguments = {}
125
+ for arg_key, arg_value in pairs:
126
+ arg_key = arg_key.strip()
127
+ arg_value = arg_value.strip()
128
+
129
+ assert arg_key, f"Empty argument key in function {func_name}"
130
+ assert arg_key not in arguments, \
131
+ f"Duplicate argument key '{arg_key}' in function {func_name}"
132
+
133
+ arg_type = get_argument_type(func_name, arg_key, formatted_tools)
134
+
135
+ if arg_type and arg_type != 'string':
136
+ parsed_value, is_good_json = parse_arguments(arg_value)
137
+ arg_value = parsed_value
138
+
139
+ arguments[arg_key] = arg_value
140
+
141
+ tool_calls.append({
142
+ 'id': "tool-call-" + str(uuid.uuid4()),
143
+ 'type': "function",
144
+ 'function': {
145
+ 'name': func_name,
146
+ 'arguments': arguments
147
+ }
148
+ })
149
+
150
+ message = {'role': 'assistant'}
151
+
152
+ if reasoning_content:
153
+ message['reasoning_content'] = reasoning_content
154
+ message['content'] = content
155
+ if tool_calls:
156
+ message['tool_calls'] = tool_calls
157
+
158
+ return message
preprocessor_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "processor_class": "LongcatNextProcessor",
3
+ "auto_map": {
4
+ "AutoProcessor": "processing_longcat_next.LongcatNextProcessor"
5
+ },
6
+ "spatial_merge_size": 2,
7
+ "max_pixels": 3211264,
8
+ "min_pixels": 50176,
9
+
10
+ "n_fft": 400,
11
+ "num_mel_bins": 128,
12
+ "sampling_rate": 16000,
13
+ "max_audio_seconds": 30,
14
+ "hop_length": 160,
15
+ "kernel_size": 3,
16
+ "stride_size": 2,
17
+ "split_overlap": 0.0,
18
+ "avg_pooler": 4
19
+ }
processing_longcat_next.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Union, List
3
+ from types import SimpleNamespace
4
+
5
+ import torch
6
+ import librosa
7
+ import soundfile as sf
8
+ import numpy as np
9
+ from transformers import AutoFeatureExtractor
10
+ from transformers.audio_utils import mel_filter_bank
11
+ from transformers.configuration_utils import PretrainedConfig
12
+ from transformers.feature_extraction_utils import BatchFeature, FeatureExtractionMixin
13
+ from transformers.processing_utils import (
14
+ AudioKwargs,
15
+ ImagesKwargs,
16
+ ProcessingKwargs,
17
+ ProcessorMixin,
18
+ VideosKwargs,
19
+ )
20
+ from transformers.utils import logging
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class LongcatNextProcessorKwargs(ProcessingKwargs, total=False):
26
+ images_kwargs: ImagesKwargs
27
+ videos_kwargs: VideosKwargs
28
+ audio_kwargs: AudioKwargs
29
+ _defaults = {
30
+ "text_kwargs": {
31
+ "padding": False,
32
+ "padding_side": "left",
33
+ "return_attention_mask": False,
34
+ }
35
+ }
36
+
37
+
38
+ class LongcatNextAudioProcessor(FeatureExtractionMixin):
39
+
40
+ def __init__(self, **kwargs):
41
+ super().__init__(**kwargs)
42
+ self.mel_filters = mel_filter_bank(
43
+ num_frequency_bins=1 + self.n_fft // 2,
44
+ num_mel_filters=self.num_mel_bins,
45
+ min_frequency=0.0,
46
+ max_frequency=self.sampling_rate / 2.0,
47
+ sampling_rate=self.sampling_rate,
48
+ norm="slaney",
49
+ mel_scale="slaney",
50
+ )
51
+ self.window = torch.hann_window(self.n_fft)
52
+
53
+ @staticmethod
54
+ def zero_mean_unit_var_norm(x):
55
+ return (x - x.mean()) / torch.sqrt(x.var() + 1e-8)
56
+
57
+ def load_audio_waveform(self, uri, metadata=None, waveform_tensor=None, return_tensors=True, do_normalize=False):
58
+ if metadata is None or waveform_tensor is None:
59
+ # 使用 librosa 统一处理所有音频格式(包括 mp3, wav, flac 等)
60
+ # librosa.load 返回的已经是归一化的 float32 数据
61
+ waveform_np, sample_rate = librosa.load(uri, sr=None, mono=False)
62
+
63
+ # 转换为 tensor,确保维度为 (channels, samples)
64
+ if waveform_np.ndim == 1:
65
+ waveform_tensor = torch.from_numpy(waveform_np).unsqueeze(0)
66
+ else:
67
+ waveform_tensor = torch.from_numpy(waveform_np)
68
+
69
+ # 获取音频元信息
70
+ try:
71
+ sf_info = sf.info(uri)
72
+ metadata = SimpleNamespace(
73
+ sample_rate=sample_rate,
74
+ num_frames=waveform_tensor.shape[1],
75
+ num_channels=waveform_tensor.shape[0],
76
+ bits_per_sample=getattr(sf_info, 'bits_per_sample', 16),
77
+ encoding=getattr(sf_info, 'subtype', 'PCM_F')
78
+ )
79
+ except Exception:
80
+ # 如果 soundfile.info 失败,使用 librosa 提供的信息
81
+ metadata = SimpleNamespace(
82
+ sample_rate=sample_rate,
83
+ num_frames=waveform_tensor.shape[1],
84
+ num_channels=waveform_tensor.shape[0],
85
+ bits_per_sample=16,
86
+ encoding='PCM_F'
87
+ )
88
+
89
+ assert(metadata.num_channels <= 2), "acoustic file with {} channels.".format(metadata.num_channels) # whisper only accept mono channel audio
90
+
91
+ if self.sampling_rate != metadata.sample_rate:
92
+ # 使用 torch.functional 进行重采样
93
+ waveform_tensor = torch.nn.functional.interpolate(
94
+ waveform_tensor.unsqueeze(0),
95
+ size=int(waveform_tensor.shape[1] * self.sampling_rate / metadata.sample_rate),
96
+ mode='linear',
97
+ align_corners=False
98
+ ).squeeze(0)
99
+
100
+ # downmix to mono channel https://trac.ffmpeg.org/wiki/AudioChannelManipulation
101
+ if metadata.num_channels > 1:
102
+ waveform_tensor = torch.mean(waveform_tensor, dim=0, keepdim=True)
103
+
104
+ # normalized to zero mean (Qwen Audio没有处理 但Whisper官方实现)
105
+ if do_normalize:
106
+ waveform_tensor = self.zero_mean_unit_var_norm(waveform_tensor)
107
+
108
+ if return_tensors: # (channels, samples)
109
+ return waveform_tensor
110
+ else:
111
+ return waveform_tensor.numpy()
112
+
113
+ def split_with_overlap(self, waveform): # 如果长度超过最大长度限制 分割为带overlap的多段
114
+ channels, wave_samples = waveform.shape
115
+ max_audio_samples = self.max_audio_seconds * self.sampling_rate
116
+ if wave_samples <= max_audio_samples or self.split_overlap < 0:
117
+ return [waveform] # 没有超出最大长度or截断逻辑 统一返回list
118
+
119
+ split_waveform, start = [], 0
120
+ while start < wave_samples: # 统一按秒数对齐overlap
121
+ if start > int(self.sampling_rate * self.split_overlap):
122
+ start -= int(self.sampling_rate * self.split_overlap) # 0表示没有overlap,>0 overlap对应秒数
123
+ end = min(start + max_audio_samples, wave_samples)
124
+ if end - start>= self.n_fft: # 保证至少有一帧数据
125
+ split_waveform.append(waveform[:, start:end]) # 注意这里可能会切割出特别短的片段 需要在预处理判断并丢弃
126
+ start = end
127
+ return split_waveform
128
+
129
+ @classmethod
130
+ def inference_output_length(self, input_length, kernel_size, stride_size, avg_pooler):
131
+ # for whisper + bridge
132
+ encoder_length = (input_length + 2 * (kernel_size // 2) - kernel_size) // 1 + 1 # conv layer1 with pad=1
133
+ encoder_length = (encoder_length + 2 * (kernel_size // 2) - kernel_size) // stride_size + 1 # conv layer2 with pad=1
134
+ if avg_pooler > 1:
135
+ bridge_length = encoder_length // avg_pooler
136
+ return encoder_length, bridge_length
137
+
138
+ def extract_fbank_features(self, waveform):
139
+ # ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py
140
+ channels, wave_samples = waveform.shape
141
+ assert(wave_samples >= self.n_fft)
142
+ valid_frame_nums = min(self.max_audio_seconds * self.sampling_rate // self.hop_length, wave_samples // self.hop_length + 1)
143
+ if wave_samples < self.max_audio_seconds * self.sampling_rate:
144
+ waveform = torch.nn.functional.pad(waveform, (0, self.max_audio_seconds * self.sampling_rate - wave_samples), "constant", 0)
145
+ else:
146
+ waveform = waveform[:, :self.max_audio_seconds * self.sampling_rate]
147
+
148
+ # window = torch.hann_window(self.n_fft)
149
+ stft = torch.stft(waveform, self.n_fft, self.hop_length, window=self.window, return_complex=True) # fft, len(wave) // n_fft // 2 + 1
150
+ magnitudes = stft[..., :-1].abs() ** 2
151
+
152
+ mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
153
+ mel_spec = mel_filters.T @ magnitudes
154
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
155
+ if waveform.dim() == 2:
156
+ max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0]
157
+ log_spec = torch.maximum(log_spec, max_val - 8.0)
158
+ else:
159
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
160
+ log_spec = (log_spec + 4.0) / 4.0
161
+
162
+ log_spec = log_spec[0].numpy() # (channel, filters, samples) -> (filters, samples)
163
+ log_spec[:, valid_frame_nums:] = 0.0 # pad0
164
+
165
+ return log_spec, valid_frame_nums
166
+
167
+ def process(self, audio_path, **kwargs):
168
+ metadata, waveform_tensors = None, None
169
+ waveforms = self.load_audio_waveform(audio_path, metadata, waveform_tensors, True)
170
+ waveforms = self.split_with_overlap(waveforms)
171
+
172
+ ret_audio, ret_encoder_length, ret_bridge_length = [], [], []
173
+ for i, waveform in enumerate(waveforms):
174
+ audio, input_length = self.extract_fbank_features(waveform)
175
+ encoder_length, bridge_length = self.inference_output_length(input_length, self.kernel_size, self.stride_size, self.avg_pooler)
176
+ if bridge_length <= 0:
177
+ continue
178
+
179
+ ret_audio.append(audio)
180
+ ret_encoder_length.append(encoder_length)
181
+ ret_bridge_length.append(bridge_length)
182
+ return ret_audio, ret_encoder_length, ret_bridge_length
183
+
184
+ def __call__(self, audio: Union[str, List[str]], **kwargs):
185
+ if isinstance(audio, str):
186
+ audio = [audio]
187
+ results = {
188
+ "audio": [],
189
+ "encoder_length": [],
190
+ "bridge_length": [],
191
+ }
192
+ for audio_path in audio:
193
+ audio, encoder_length, bridge_length = self.process(audio_path, **kwargs)
194
+ results["audio"].append(audio)
195
+ results["encoder_length"].append(encoder_length)
196
+ results["bridge_length"].append(bridge_length)
197
+ return results
198
+
199
+
200
+ class LongcatNextProcessor(ProcessorMixin):
201
+
202
+ attributes = ["image_processor", "video_processor", "audio_processor", "tokenizer"]
203
+
204
+ image_processor_class = "Qwen2VLImageProcessor"
205
+ video_processor_class = "Qwen2VLImageProcessor"
206
+ audio_processor_class = "LongcatNextAudioProcessor"
207
+ tokenizer_class = "AutoTokenizer"
208
+
209
+ def __init__(self, image_processor=None, video_processor=None, audio_processor=None, tokenizer=None, chat_template=None, **kwargs):
210
+ super().__init__(image_processor, video_processor, audio_processor, tokenizer, chat_template=chat_template)
211
+ init_token_list = [
212
+ "image_start_token", "image_end_token", "image_pad_token", "image_newline_token",
213
+ "audio_start_token", "audio_end_token", "audio_pad_token",
214
+ ]
215
+ for attr in init_token_list:
216
+ token_str = self.tokenizer.init_kwargs.get(attr)
217
+ token_ids = self.tokenizer.encode(token_str, add_special_tokens=False)
218
+ assert len(token_ids) == 1, (f"{attr}='{token_str}' encode to get {len(token_ids)} id(s) {token_ids}, expect 1 id")
219
+ setattr(self, f"{attr}", token_str)
220
+ setattr(self, f"{attr}_id", token_ids[0])
221
+
222
+ def __call__(
223
+ self,
224
+ text: str,
225
+ **kwargs,
226
+ ) -> List["LongcatNextProcessorOutput"]:
227
+
228
+ if text is None:
229
+ raise ValueError("You need to specify either a `text` input to process.")
230
+
231
+ output_kwargs = self._merge_kwargs(
232
+ LongcatNextProcessorKwargs,
233
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
234
+ **kwargs,
235
+ )
236
+
237
+ assert isinstance(text, str)
238
+
239
+ image_path_list = re.findall(rf"{self.image_start_token}(.*?){self.image_end_token}", text)
240
+ audio_path_list = re.findall(rf"{self.audio_start_token}(.*?){self.audio_end_token}", text)
241
+
242
+ if len(image_path_list) > 0:
243
+ images_inputs = self.image_processor(images=image_path_list, **output_kwargs["images_kwargs"])
244
+ image_grid_thw = images_inputs["image_grid_thw"]
245
+ for i, image_path in enumerate(image_path_list):
246
+ image_token_num = image_grid_thw[i][0] * (image_grid_thw[i][1]//self.image_processor.spatial_merge_size) * (image_grid_thw[i][2]//self.image_processor.spatial_merge_size)
247
+ text = text.replace(f"{self.image_start_token}{image_path}{self.image_end_token}", f"{self.image_start_token}{self.image_pad_token * image_token_num}{self.image_end_token}")
248
+ else:
249
+ images_inputs = {}
250
+
251
+ if len(audio_path_list) > 0:
252
+ audio_inputs = self.audio_processor(audio=audio_path_list, **output_kwargs["audio_kwargs"])
253
+ for i, audio_path in enumerate(audio_path_list):
254
+ audio_token_num = np.sum(audio_inputs["bridge_length"][i])
255
+ text = text.replace(f"{self.audio_start_token}{audio_path}{self.audio_end_token}", f"{self.audio_start_token}{self.audio_pad_token * audio_token_num}{self.audio_end_token}")
256
+ for key in audio_inputs:
257
+ audio_inputs[key] = [val for b_val in audio_inputs[key] for val in b_val]
258
+ else:
259
+ audio_inputs = {}
260
+
261
+ texts_inputs = self.tokenizer([text], **output_kwargs["text_kwargs"])
262
+
263
+ batch_feature_func = lambda x: BatchFeature(
264
+ data={**x},
265
+ tensor_type=kwargs.get("return_tensors"),
266
+ )
267
+ return (
268
+ batch_feature_func(texts_inputs),
269
+ batch_feature_func({k.replace("image", "visual"): v for k, v in images_inputs.items()}) if len(images_inputs) > 0 else None,
270
+ batch_feature_func(audio_inputs) if len(audio_inputs) > 0 else None,
271
+ )
272
+
273
+
274
+ class LongcatNextAudioProcessorConfig(PretrainedConfig):
275
+ pass
276
+ AutoFeatureExtractor.register(LongcatNextAudioProcessorConfig, LongcatNextAudioProcessor)
277
+
278
+
279
+ __all__ = ["LongcatNextAudioProcessor", "LongcatNextProcessor"]
refiner_modules.py ADDED
@@ -0,0 +1,1330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---------------------------------------------------------------------------
2
+ # Standard / third-party imports shared by all sections
3
+ # ---------------------------------------------------------------------------
4
+
5
+ import itertools
6
+ import math
7
+ from dataclasses import dataclass
8
+ from typing import Any, Dict, List, Optional, Tuple, Union
9
+
10
+ from flash_attn import flash_attn_varlen_func # type: ignore
11
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # type: ignore
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from torch.nn import RMSNorm
17
+
18
+ from einops import rearrange, repeat
19
+
20
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
21
+ from diffusers.loaders import PeftAdapterMixin
22
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
23
+ from diffusers.models.activations import get_activation
24
+ from diffusers.models.attention_processor import Attention
25
+ from diffusers.models.embeddings import Timesteps, get_1d_rotary_pos_embed
26
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
27
+ from diffusers.models.modeling_utils import ModelMixin
28
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
29
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+
34
+ def swiglu(x, y):
35
+ return F.silu(x.float(), inplace=False).to(x.dtype) * y
36
+
37
+
38
+ class TimestepEmbedding(nn.Module):
39
+ def __init__(
40
+ self,
41
+ in_channels: int,
42
+ time_embed_dim: int,
43
+ act_fn: str = "silu",
44
+ out_dim: int = None,
45
+ post_act_fn: Optional[str] = None,
46
+ cond_proj_dim=None,
47
+ sample_proj_bias=True,
48
+ ):
49
+ super().__init__()
50
+
51
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
52
+
53
+ if cond_proj_dim is not None:
54
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
55
+ else:
56
+ self.cond_proj = None
57
+
58
+ self.act = get_activation(act_fn)
59
+
60
+ if out_dim is not None:
61
+ time_embed_dim_out = out_dim
62
+ else:
63
+ time_embed_dim_out = time_embed_dim
64
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
65
+
66
+ if post_act_fn is None:
67
+ self.post_act = None
68
+ else:
69
+ self.post_act = get_activation(post_act_fn)
70
+
71
+ self.initialize_weights()
72
+
73
+ def initialize_weights(self):
74
+ nn.init.normal_(self.linear_1.weight, std=0.02)
75
+ nn.init.zeros_(self.linear_1.bias)
76
+ nn.init.normal_(self.linear_2.weight, std=0.02)
77
+ nn.init.zeros_(self.linear_2.bias)
78
+
79
+ def forward(self, sample, condition=None):
80
+ if condition is not None:
81
+ sample = sample + self.cond_proj(condition)
82
+ sample = self.linear_1(sample)
83
+ if self.act is not None:
84
+ sample = self.act(sample)
85
+ sample = self.linear_2(sample)
86
+ if self.post_act is not None:
87
+ sample = self.post_act(sample)
88
+ return sample
89
+
90
+
91
+ def apply_rotary_emb(
92
+ x: torch.Tensor,
93
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
94
+ use_real: bool = True,
95
+ use_real_unbind_dim: int = -1,
96
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
97
+ """
98
+ Apply rotary embeddings to input tensors using the given frequency tensor.
99
+ """
100
+ if use_real:
101
+ cos, sin = freqs_cis # [S, D]
102
+ cos = cos[None, None]
103
+ sin = sin[None, None]
104
+ cos, sin = cos.to(x.device), sin.to(x.device)
105
+
106
+ if use_real_unbind_dim == -1:
107
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
108
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
109
+ elif use_real_unbind_dim == -2:
110
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2)
111
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
112
+ else:
113
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
114
+
115
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
116
+ return out
117
+ else:
118
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], x.shape[-1] // 2, 2))
119
+ freqs_cis = freqs_cis.unsqueeze(2)
120
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
121
+ return x_out.type_as(x)
122
+
123
+
124
+ @dataclass
125
+ class TeaCacheParams:
126
+ """
127
+ TeaCache parameters for Transformer2DModel.
128
+ See https://github.com/ali-vilab/TeaCache/ for a more comprehensive understanding.
129
+ """
130
+ previous_residual: Optional[torch.Tensor] = None
131
+ previous_modulated_inp: Optional[torch.Tensor] = None
132
+ accumulated_rel_l1_distance: float = 0
133
+ is_first_or_last_step: bool = False
134
+
135
+
136
+ def derivative_approximation(*args, **kwargs):
137
+ pass
138
+
139
+
140
+ def taylor_formula(*args, **kwargs):
141
+ pass
142
+
143
+
144
+ def taylor_cache_init(*args, **kwargs):
145
+ pass
146
+
147
+
148
+ def cache_init(*args, **kwargs):
149
+ pass
150
+
151
+
152
+ def cal_type(*args, **kwargs):
153
+ pass
154
+
155
+
156
+ class LuminaRMSNormZero(nn.Module):
157
+ """
158
+ Norm layer adaptive RMS normalization zero.
159
+ """
160
+
161
+ def __init__(
162
+ self,
163
+ embedding_dim: int,
164
+ norm_eps: float,
165
+ norm_elementwise_affine: bool,
166
+ ):
167
+ super().__init__()
168
+ self.silu = nn.SiLU()
169
+ self.linear = nn.Linear(
170
+ min(embedding_dim, 1024),
171
+ 4 * embedding_dim,
172
+ bias=True,
173
+ )
174
+ self.norm = RMSNorm(embedding_dim, eps=norm_eps)
175
+
176
+ def forward(
177
+ self,
178
+ x: torch.Tensor,
179
+ emb: Optional[torch.Tensor] = None,
180
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
181
+ emb = self.linear(self.silu(emb))
182
+ scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
183
+ x = self.norm(x) * (1 + scale_msa[:, None])
184
+ return x, gate_msa, scale_mlp, gate_mlp
185
+
186
+
187
+ class LuminaLayerNormContinuous(nn.Module):
188
+ def __init__(
189
+ self,
190
+ embedding_dim: int,
191
+ conditioning_embedding_dim: int,
192
+ elementwise_affine=True,
193
+ eps=1e-5,
194
+ bias=True,
195
+ norm_type="layer_norm",
196
+ out_dim: Optional[int] = None,
197
+ ):
198
+ super().__init__()
199
+
200
+ self.silu = nn.SiLU()
201
+ self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
202
+
203
+ if norm_type == "layer_norm":
204
+ self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)
205
+ elif norm_type == "rms_norm":
206
+ self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
207
+ else:
208
+ raise ValueError(f"unknown norm_type {norm_type}")
209
+
210
+ self.linear_2 = None
211
+ if out_dim is not None:
212
+ self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias)
213
+
214
+ def forward(
215
+ self,
216
+ x: torch.Tensor,
217
+ conditioning_embedding: torch.Tensor,
218
+ ) -> torch.Tensor:
219
+ emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
220
+ scale = emb
221
+ x = self.norm(x) * (1 + scale)[:, None, :]
222
+ if self.linear_2 is not None:
223
+ x = self.linear_2(x)
224
+ return x
225
+
226
+
227
+ class LuminaFeedForward(nn.Module):
228
+ def __init__(
229
+ self,
230
+ dim: int,
231
+ inner_dim: int,
232
+ multiple_of: Optional[int] = 256,
233
+ ffn_dim_multiplier: Optional[float] = None,
234
+ ):
235
+ super().__init__()
236
+
237
+ if ffn_dim_multiplier is not None:
238
+ inner_dim = int(ffn_dim_multiplier * inner_dim)
239
+ inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
240
+
241
+ self.linear_1 = nn.Linear(dim, inner_dim, bias=False)
242
+ self.linear_2 = nn.Linear(inner_dim, dim, bias=False)
243
+ self.linear_3 = nn.Linear(dim, inner_dim, bias=False)
244
+
245
+ def forward(self, x):
246
+ h1, h2 = self.linear_1(x), self.linear_3(x)
247
+ return self.linear_2(swiglu(h1, h2))
248
+
249
+
250
+ class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
251
+ def __init__(
252
+ self,
253
+ hidden_size: int = 4096,
254
+ text_feat_dim: int = 2048,
255
+ frequency_embedding_size: int = 256,
256
+ norm_eps: float = 1e-5,
257
+ timestep_scale: float = 1.0,
258
+ ) -> None:
259
+ super().__init__()
260
+
261
+ self.time_proj = Timesteps(
262
+ num_channels=frequency_embedding_size,
263
+ flip_sin_to_cos=True,
264
+ downscale_freq_shift=0.0,
265
+ scale=timestep_scale,
266
+ )
267
+ self.timestep_embedder = TimestepEmbedding(
268
+ in_channels=frequency_embedding_size,
269
+ time_embed_dim=min(hidden_size, 1024),
270
+ )
271
+ self.caption_embedder = nn.Sequential(
272
+ RMSNorm(text_feat_dim, eps=norm_eps),
273
+ nn.Linear(text_feat_dim, hidden_size, bias=True),
274
+ )
275
+ self._initialize_weights()
276
+
277
+ def _initialize_weights(self):
278
+ nn.init.trunc_normal_(self.caption_embedder[1].weight, std=0.02)
279
+ nn.init.zeros_(self.caption_embedder[1].bias)
280
+
281
+ def forward(
282
+ self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype
283
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
284
+ timestep_proj = self.time_proj(timestep).to(dtype=dtype)
285
+ time_embed = self.timestep_embedder(timestep_proj)
286
+ caption_embed = self.caption_embedder(text_hidden_states)
287
+ return time_embed, caption_embed
288
+
289
+
290
+ class AttnProcessorFlash2Varlen:
291
+ """
292
+ Processor for implementing scaled dot-product attention with flash attention
293
+ and variable length sequences.
294
+ """
295
+
296
+ def __init__(self) -> None:
297
+ pass
298
+ # if not is_flash_attn_available():
299
+ # raise ImportError(
300
+ # "AttnProcessorFlash2Varlen requires flash_attn. "
301
+ # "Please install flash_attn."
302
+ # )
303
+
304
+ def _upad_input(
305
+ self,
306
+ query_layer: torch.Tensor,
307
+ key_layer: torch.Tensor,
308
+ value_layer: torch.Tensor,
309
+ attention_mask: torch.Tensor,
310
+ query_length: int,
311
+ num_heads: int,
312
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
313
+ def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
314
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
315
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
316
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
317
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
318
+ return indices, cu_seqlens, max_seqlen_in_batch
319
+
320
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
321
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
322
+
323
+ key_layer = index_first_axis(
324
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k,
325
+ )
326
+ value_layer = index_first_axis(
327
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k,
328
+ )
329
+
330
+ if query_length == kv_seq_len:
331
+ query_layer = index_first_axis(
332
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k,
333
+ )
334
+ cu_seqlens_q = cu_seqlens_k
335
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
336
+ indices_q = indices_k
337
+ elif query_length == 1:
338
+ max_seqlen_in_batch_q = 1
339
+ cu_seqlens_q = torch.arange(
340
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
341
+ )
342
+ indices_q = cu_seqlens_q[:-1]
343
+ query_layer = query_layer.squeeze(1)
344
+ else:
345
+ attention_mask = attention_mask[:, -query_length:]
346
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
347
+ query_layer, attention_mask
348
+ )
349
+
350
+ return (
351
+ query_layer, key_layer, value_layer, indices_q,
352
+ (cu_seqlens_q, cu_seqlens_k),
353
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
354
+ )
355
+
356
+ def __call__(
357
+ self,
358
+ attn: Attention,
359
+ hidden_states: torch.Tensor,
360
+ encoder_hidden_states: torch.Tensor,
361
+ attention_mask: Optional[torch.Tensor] = None,
362
+ image_rotary_emb: Optional[torch.Tensor] = None,
363
+ base_sequence_length: Optional[int] = None,
364
+ ) -> torch.Tensor:
365
+ batch_size, sequence_length, _ = hidden_states.shape
366
+
367
+ query = attn.to_q(hidden_states)
368
+ key = attn.to_k(encoder_hidden_states)
369
+ value = attn.to_v(encoder_hidden_states)
370
+
371
+ query_dim = query.shape[-1]
372
+ inner_dim = key.shape[-1]
373
+ head_dim = query_dim // attn.heads
374
+ dtype = query.dtype
375
+ kv_heads = inner_dim // head_dim
376
+
377
+ query = query.view(batch_size, -1, attn.heads, head_dim)
378
+ key = key.view(batch_size, -1, kv_heads, head_dim)
379
+ value = value.view(batch_size, -1, kv_heads, head_dim)
380
+
381
+ if attn.norm_q is not None:
382
+ query = attn.norm_q(query)
383
+ if attn.norm_k is not None:
384
+ key = attn.norm_k(key)
385
+
386
+ if image_rotary_emb is not None:
387
+ query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
388
+ key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
389
+
390
+ query, key = query.to(dtype), key.to(dtype)
391
+
392
+ if base_sequence_length is not None:
393
+ softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
394
+ else:
395
+ softmax_scale = attn.scale
396
+
397
+ (
398
+ query_states, key_states, value_states, indices_q,
399
+ cu_seq_lens, max_seq_lens,
400
+ ) = self._upad_input(query, key, value, attention_mask, sequence_length, attn.heads)
401
+
402
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
403
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
404
+
405
+ if kv_heads < attn.heads:
406
+ key_states = repeat(key_states, "l h c -> l (h k) c", k=attn.heads // kv_heads)
407
+ value_states = repeat(value_states, "l h c -> l (h k) c", k=attn.heads // kv_heads)
408
+
409
+ attn_output_unpad = flash_attn_varlen_func(
410
+ query_states, key_states, value_states,
411
+ cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
412
+ max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
413
+ dropout_p=0.0, causal=False, softmax_scale=softmax_scale,
414
+ )
415
+
416
+ hidden_states = pad_input(attn_output_unpad, indices_q, batch_size, sequence_length)
417
+ hidden_states = hidden_states.flatten(-2)
418
+ hidden_states = hidden_states.type_as(query)
419
+
420
+ hidden_states = attn.to_out[0](hidden_states)
421
+ hidden_states = attn.to_out[1](hidden_states)
422
+ return hidden_states
423
+
424
+
425
+ class AttnProcessor:
426
+ """
427
+ Processor for implementing scaled dot-product attention (PyTorch 2.0+).
428
+ """
429
+
430
+ def __init__(self) -> None:
431
+ if not hasattr(F, "scaled_dot_product_attention"):
432
+ raise ImportError(
433
+ "AttnProcessor requires PyTorch 2.0. "
434
+ "Please upgrade PyTorch to version 2.0 or later."
435
+ )
436
+
437
+ def __call__(
438
+ self,
439
+ attn: Attention,
440
+ hidden_states: torch.Tensor,
441
+ encoder_hidden_states: torch.Tensor,
442
+ attention_mask: Optional[torch.Tensor] = None,
443
+ image_rotary_emb: Optional[torch.Tensor] = None,
444
+ base_sequence_length: Optional[int] = None,
445
+ ) -> torch.Tensor:
446
+ batch_size, sequence_length, _ = hidden_states.shape
447
+
448
+ query = attn.to_q(hidden_states)
449
+ key = attn.to_k(encoder_hidden_states)
450
+ value = attn.to_v(encoder_hidden_states)
451
+
452
+ query_dim = query.shape[-1]
453
+ inner_dim = key.shape[-1]
454
+ head_dim = query_dim // attn.heads
455
+ dtype = query.dtype
456
+ kv_heads = inner_dim // head_dim
457
+
458
+ query = query.view(batch_size, -1, attn.heads, head_dim)
459
+ key = key.view(batch_size, -1, kv_heads, head_dim)
460
+ value = value.view(batch_size, -1, kv_heads, head_dim)
461
+
462
+ if attn.norm_q is not None:
463
+ query = attn.norm_q(query)
464
+ if attn.norm_k is not None:
465
+ key = attn.norm_k(key)
466
+
467
+ if image_rotary_emb is not None:
468
+ query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
469
+ key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
470
+
471
+ query, key = query.to(dtype), key.to(dtype)
472
+
473
+ if base_sequence_length is not None:
474
+ softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
475
+ else:
476
+ softmax_scale = attn.scale
477
+
478
+ if attention_mask is not None:
479
+ attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
480
+
481
+ query = query.transpose(1, 2)
482
+ key = key.transpose(1, 2)
483
+ value = value.transpose(1, 2)
484
+
485
+ key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
486
+ value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)
487
+
488
+ hidden_states = F.scaled_dot_product_attention(
489
+ query, key, value, attn_mask=attention_mask, scale=softmax_scale
490
+ )
491
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
492
+ hidden_states = hidden_states.type_as(query)
493
+
494
+ hidden_states = attn.to_out[0](hidden_states)
495
+ hidden_states = attn.to_out[1](hidden_states)
496
+ return hidden_states
497
+
498
+
499
+
500
+ class RotaryPosEmbed(nn.Module):
501
+ def __init__(
502
+ self,
503
+ theta: int,
504
+ axes_dim: Tuple[int, int, int],
505
+ axes_lens: Tuple[int, int, int] = (300, 512, 512),
506
+ patch_size: int = 2,
507
+ ):
508
+ super().__init__()
509
+ self.theta = theta
510
+ self.axes_dim = axes_dim
511
+ self.axes_lens = axes_lens
512
+ self.patch_size = patch_size
513
+
514
+ @staticmethod
515
+ def get_freqs_cis(
516
+ axes_dim: Tuple[int, int, int],
517
+ axes_lens: Tuple[int, int, int],
518
+ theta: int,
519
+ ) -> List[torch.Tensor]:
520
+ freqs_cis = []
521
+ freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
522
+ for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
523
+ emb = get_1d_rotary_pos_embed(d, e, theta=theta, freqs_dtype=freqs_dtype)
524
+ freqs_cis.append(emb)
525
+ return freqs_cis
526
+
527
+ def _get_freqs_cis(self, freqs_cis, ids: torch.Tensor) -> torch.Tensor:
528
+ device = ids.device
529
+ if ids.device.type == "mps":
530
+ ids = ids.to("cpu")
531
+
532
+ result = []
533
+ for i in range(len(self.axes_dim)):
534
+ freqs = freqs_cis[i].to(ids.device)
535
+ index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
536
+ result.append(
537
+ torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index)
538
+ )
539
+ return torch.cat(result, dim=-1).to(device)
540
+
541
+ def forward(
542
+ self,
543
+ freqs_cis,
544
+ attention_mask,
545
+ l_effective_ref_img_len,
546
+ l_effective_img_len,
547
+ ref_img_sizes,
548
+ img_sizes,
549
+ device,
550
+ ):
551
+ batch_size = len(attention_mask)
552
+ p = self.patch_size
553
+
554
+ encoder_seq_len = attention_mask.shape[1]
555
+ l_effective_cap_len = attention_mask.sum(dim=1).tolist()
556
+
557
+ seq_lengths = [
558
+ cap_len + sum(ref_img_len) + img_len
559
+ for cap_len, ref_img_len, img_len in zip(
560
+ l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len
561
+ )
562
+ ]
563
+
564
+ max_seq_len = max(seq_lengths)
565
+ max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
566
+ max_img_len = max(l_effective_img_len)
567
+
568
+ position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
569
+
570
+ for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
571
+ position_ids[i, :cap_seq_len] = repeat(
572
+ torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3"
573
+ )
574
+
575
+ pe_shift = cap_seq_len
576
+ pe_shift_len = cap_seq_len
577
+
578
+ if ref_img_sizes[i] is not None:
579
+ for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]):
580
+ H, W = ref_img_size
581
+ ref_H_tokens, ref_W_tokens = H // p, W // p
582
+ assert ref_H_tokens * ref_W_tokens == ref_img_len
583
+
584
+ row_ids = repeat(
585
+ torch.arange(ref_H_tokens, dtype=torch.int32, device=device),
586
+ "h -> h w", w=ref_W_tokens,
587
+ ).flatten()
588
+ col_ids = repeat(
589
+ torch.arange(ref_W_tokens, dtype=torch.int32, device=device),
590
+ "w -> h w", h=ref_H_tokens,
591
+ ).flatten()
592
+ position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift
593
+ position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids
594
+ position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids
595
+
596
+ pe_shift += max(ref_H_tokens, ref_W_tokens)
597
+ pe_shift_len += ref_img_len
598
+
599
+ H, W = img_sizes[i]
600
+ H_tokens, W_tokens = H // p, W // p
601
+ assert H_tokens * W_tokens == l_effective_img_len[i]
602
+
603
+ row_ids = repeat(
604
+ torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens
605
+ ).flatten()
606
+ col_ids = repeat(
607
+ torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens
608
+ ).flatten()
609
+
610
+ assert pe_shift_len + l_effective_img_len[i] == seq_len
611
+ position_ids[i, pe_shift_len: seq_len, 0] = pe_shift
612
+ position_ids[i, pe_shift_len: seq_len, 1] = row_ids
613
+ position_ids[i, pe_shift_len: seq_len, 2] = col_ids
614
+
615
+ freqs_cis = self._get_freqs_cis(freqs_cis, position_ids)
616
+
617
+ cap_freqs_cis = torch.zeros(
618
+ batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
619
+ )
620
+ ref_img_freqs_cis = torch.zeros(
621
+ batch_size, max_ref_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
622
+ )
623
+ img_freqs_cis = torch.zeros(
624
+ batch_size, max_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
625
+ )
626
+
627
+ for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(
628
+ zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)
629
+ ):
630
+ cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
631
+ ref_img_freqs_cis[i, :sum(ref_img_len)] = freqs_cis[
632
+ i, cap_seq_len:cap_seq_len + sum(ref_img_len)
633
+ ]
634
+ img_freqs_cis[i, :img_len] = freqs_cis[
635
+ i,
636
+ cap_seq_len + sum(ref_img_len):cap_seq_len + sum(ref_img_len) + img_len,
637
+ ]
638
+
639
+ return (
640
+ cap_freqs_cis,
641
+ ref_img_freqs_cis,
642
+ img_freqs_cis,
643
+ freqs_cis,
644
+ l_effective_cap_len,
645
+ seq_lengths,
646
+ )
647
+
648
+
649
+ class TransformerBlock(nn.Module):
650
+ """
651
+ Transformer block for refiner model.
652
+ """
653
+
654
+ def __init__(
655
+ self,
656
+ dim: int,
657
+ num_attention_heads: int,
658
+ num_kv_heads: int,
659
+ multiple_of: int,
660
+ ffn_dim_multiplier: float,
661
+ norm_eps: float,
662
+ modulation: bool = True,
663
+ ) -> None:
664
+ super().__init__()
665
+ self.head_dim = dim // num_attention_heads
666
+ self.modulation = modulation
667
+
668
+ try:
669
+ processor = AttnProcessorFlash2Varlen()
670
+ except ImportError:
671
+ processor = AttnProcessor()
672
+
673
+ self.attn = Attention(
674
+ query_dim=dim,
675
+ cross_attention_dim=None,
676
+ dim_head=dim // num_attention_heads,
677
+ qk_norm="rms_norm",
678
+ heads=num_attention_heads,
679
+ kv_heads=num_kv_heads,
680
+ eps=1e-5,
681
+ bias=False,
682
+ out_bias=False,
683
+ processor=processor,
684
+ )
685
+
686
+ self.feed_forward = LuminaFeedForward(
687
+ dim=dim,
688
+ inner_dim=4 * dim,
689
+ multiple_of=multiple_of,
690
+ ffn_dim_multiplier=ffn_dim_multiplier,
691
+ )
692
+
693
+ if modulation:
694
+ self.norm1 = LuminaRMSNormZero(
695
+ embedding_dim=dim,
696
+ norm_eps=norm_eps,
697
+ norm_elementwise_affine=True,
698
+ )
699
+ else:
700
+ self.norm1 = RMSNorm(dim, eps=norm_eps)
701
+
702
+ self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
703
+ self.norm2 = RMSNorm(dim, eps=norm_eps)
704
+ self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
705
+
706
+ self.initialize_weights()
707
+
708
+ def initialize_weights(self) -> None:
709
+ nn.init.xavier_uniform_(self.attn.to_q.weight)
710
+ nn.init.xavier_uniform_(self.attn.to_k.weight)
711
+ nn.init.xavier_uniform_(self.attn.to_v.weight)
712
+ nn.init.xavier_uniform_(self.attn.to_out[0].weight)
713
+
714
+ nn.init.xavier_uniform_(self.feed_forward.linear_1.weight)
715
+ nn.init.xavier_uniform_(self.feed_forward.linear_2.weight)
716
+ nn.init.xavier_uniform_(self.feed_forward.linear_3.weight)
717
+
718
+ if self.modulation:
719
+ nn.init.zeros_(self.norm1.linear.weight)
720
+ nn.init.zeros_(self.norm1.linear.bias)
721
+
722
+ def forward(
723
+ self,
724
+ hidden_states: torch.Tensor,
725
+ attention_mask: torch.Tensor,
726
+ image_rotary_emb: torch.Tensor,
727
+ temb: Optional[torch.Tensor] = None,
728
+ ) -> torch.Tensor:
729
+ enable_taylorseer = getattr(self, 'enable_taylorseer', False)
730
+ if enable_taylorseer:
731
+ if self.modulation:
732
+ if temb is None:
733
+ raise ValueError("temb must be provided when modulation is enabled")
734
+
735
+ if self.current['type'] == 'full':
736
+ self.current['module'] = 'total'
737
+ taylor_cache_init(cache_dic=self.cache_dic, current=self.current)
738
+
739
+ norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
740
+ attn_output = self.attn(
741
+ hidden_states=norm_hidden_states,
742
+ encoder_hidden_states=norm_hidden_states,
743
+ attention_mask=attention_mask,
744
+ image_rotary_emb=image_rotary_emb,
745
+ )
746
+ hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
747
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
748
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
749
+
750
+ derivative_approximation(cache_dic=self.cache_dic, current=self.current, feature=hidden_states)
751
+
752
+ elif self.current['type'] == 'Taylor':
753
+ self.current['module'] = 'total'
754
+ hidden_states = taylor_formula(cache_dic=self.cache_dic, current=self.current)
755
+ else:
756
+ norm_hidden_states = self.norm1(hidden_states)
757
+ attn_output = self.attn(
758
+ hidden_states=norm_hidden_states,
759
+ encoder_hidden_states=norm_hidden_states,
760
+ attention_mask=attention_mask,
761
+ image_rotary_emb=image_rotary_emb,
762
+ )
763
+ hidden_states = hidden_states + self.norm2(attn_output)
764
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
765
+ hidden_states = hidden_states + self.ffn_norm2(mlp_output)
766
+ else:
767
+ if self.modulation:
768
+ if temb is None:
769
+ raise ValueError("temb must be provided when modulation is enabled")
770
+
771
+ norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
772
+ attn_output = self.attn(
773
+ hidden_states=norm_hidden_states,
774
+ encoder_hidden_states=norm_hidden_states,
775
+ attention_mask=attention_mask,
776
+ image_rotary_emb=image_rotary_emb,
777
+ )
778
+ hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
779
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
780
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
781
+ else:
782
+ norm_hidden_states = self.norm1(hidden_states)
783
+ attn_output = self.attn(
784
+ hidden_states=norm_hidden_states,
785
+ encoder_hidden_states=norm_hidden_states,
786
+ attention_mask=attention_mask,
787
+ image_rotary_emb=image_rotary_emb,
788
+ )
789
+ hidden_states = hidden_states + self.norm2(attn_output)
790
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
791
+ hidden_states = hidden_states + self.ffn_norm2(mlp_output)
792
+
793
+ return hidden_states
794
+
795
+
796
+ class Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
797
+ """
798
+ Transformer 2D Model.
799
+ """
800
+
801
+ _supports_gradient_checkpointing = True
802
+ _no_split_modules = ["TransformerBlock"]
803
+ _skip_layerwise_casting_patterns = ["x_embedder", "norm"]
804
+
805
+ @register_to_config
806
+ def __init__(
807
+ self,
808
+ patch_size: int = 2,
809
+ in_channels: int = 16,
810
+ out_channels: Optional[int] = None,
811
+ hidden_size: int = 2304,
812
+ num_layers: int = 26,
813
+ num_refiner_layers: int = 2,
814
+ num_attention_heads: int = 24,
815
+ num_kv_heads: int = 8,
816
+ multiple_of: int = 256,
817
+ ffn_dim_multiplier: Optional[float] = None,
818
+ norm_eps: float = 1e-5,
819
+ axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
820
+ axes_lens: Tuple[int, int, int] = (300, 512, 512),
821
+ text_feat_dim: int = 1024,
822
+ timestep_scale: float = 1.0,
823
+ ) -> None:
824
+ super().__init__()
825
+
826
+ if (hidden_size // num_attention_heads) != sum(axes_dim_rope):
827
+ raise ValueError(
828
+ f"hidden_size // num_attention_heads ({hidden_size // num_attention_heads}) "
829
+ f"must equal sum(axes_dim_rope) ({sum(axes_dim_rope)})"
830
+ )
831
+
832
+ self.out_channels = out_channels or in_channels
833
+
834
+ self.rope_embedder = RotaryPosEmbed(
835
+ theta=10000,
836
+ axes_dim=axes_dim_rope,
837
+ axes_lens=axes_lens,
838
+ patch_size=patch_size,
839
+ )
840
+
841
+ self.x_embedder = nn.Linear(
842
+ in_features=patch_size * patch_size * in_channels,
843
+ out_features=hidden_size,
844
+ )
845
+
846
+ self.ref_image_patch_embedder = nn.Linear(
847
+ in_features=patch_size * patch_size * in_channels,
848
+ out_features=hidden_size,
849
+ )
850
+
851
+ self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
852
+ hidden_size=hidden_size,
853
+ text_feat_dim=text_feat_dim,
854
+ norm_eps=norm_eps,
855
+ timestep_scale=timestep_scale,
856
+ )
857
+
858
+ self.noise_refiner = nn.ModuleList([
859
+ TransformerBlock(
860
+ hidden_size, num_attention_heads, num_kv_heads,
861
+ multiple_of, ffn_dim_multiplier, norm_eps, modulation=True,
862
+ )
863
+ for _ in range(num_refiner_layers)
864
+ ])
865
+
866
+ self.ref_image_refiner = nn.ModuleList([
867
+ TransformerBlock(
868
+ hidden_size, num_attention_heads, num_kv_heads,
869
+ multiple_of, ffn_dim_multiplier, norm_eps, modulation=True,
870
+ )
871
+ for _ in range(num_refiner_layers)
872
+ ])
873
+
874
+ self.context_refiner = nn.ModuleList([
875
+ TransformerBlock(
876
+ hidden_size, num_attention_heads, num_kv_heads,
877
+ multiple_of, ffn_dim_multiplier, norm_eps, modulation=False,
878
+ )
879
+ for _ in range(num_refiner_layers)
880
+ ])
881
+
882
+ self.layers = nn.ModuleList([
883
+ TransformerBlock(
884
+ hidden_size, num_attention_heads, num_kv_heads,
885
+ multiple_of, ffn_dim_multiplier, norm_eps, modulation=True,
886
+ )
887
+ for _ in range(num_layers)
888
+ ])
889
+
890
+ self.norm_out = LuminaLayerNormContinuous(
891
+ embedding_dim=hidden_size,
892
+ conditioning_embedding_dim=min(hidden_size, 1024),
893
+ elementwise_affine=False,
894
+ eps=1e-6,
895
+ bias=True,
896
+ out_dim=patch_size * patch_size * self.out_channels,
897
+ )
898
+
899
+ self.image_index_embedding = nn.Parameter(torch.randn(5, hidden_size))
900
+
901
+ self.gradient_checkpointing = False
902
+
903
+ self.initialize_weights()
904
+
905
+ self.enable_teacache = False
906
+ self.teacache_rel_l1_thresh = 0.05
907
+ self.teacache_params = TeaCacheParams()
908
+
909
+ coefficients = [-5.48259225, 11.48772289, -4.47407401, 2.47730926, -0.03316487]
910
+ self.rescale_func = np.poly1d(coefficients)
911
+
912
+ def initialize_weights(self) -> None:
913
+ nn.init.xavier_uniform_(self.x_embedder.weight)
914
+ nn.init.constant_(self.x_embedder.bias, 0.0)
915
+
916
+ nn.init.xavier_uniform_(self.ref_image_patch_embedder.weight)
917
+ nn.init.constant_(self.ref_image_patch_embedder.bias, 0.0)
918
+
919
+ nn.init.zeros_(self.norm_out.linear_1.weight)
920
+ nn.init.zeros_(self.norm_out.linear_1.bias)
921
+ nn.init.zeros_(self.norm_out.linear_2.weight)
922
+ nn.init.zeros_(self.norm_out.linear_2.bias)
923
+
924
+ nn.init.normal_(self.image_index_embedding, std=0.02)
925
+
926
+ def img_patch_embed_and_refine(
927
+ self,
928
+ hidden_states,
929
+ ref_image_hidden_states,
930
+ padded_img_mask,
931
+ padded_ref_img_mask,
932
+ noise_rotary_emb,
933
+ ref_img_rotary_emb,
934
+ l_effective_ref_img_len,
935
+ l_effective_img_len,
936
+ temb,
937
+ ):
938
+ batch_size = len(hidden_states)
939
+ max_combined_img_len = max([
940
+ img_len + sum(ref_img_len)
941
+ for img_len, ref_img_len in zip(l_effective_img_len, l_effective_ref_img_len)
942
+ ])
943
+
944
+ hidden_states = self.x_embedder(hidden_states)
945
+ ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states)
946
+
947
+ for i in range(batch_size):
948
+ shift = 0
949
+ for j, ref_img_len in enumerate(l_effective_ref_img_len[i]):
950
+ ref_image_hidden_states[i, shift:shift + ref_img_len, :] = (
951
+ ref_image_hidden_states[i, shift:shift + ref_img_len, :]
952
+ + self.image_index_embedding[j]
953
+ )
954
+ shift += ref_img_len
955
+
956
+ for layer in self.noise_refiner:
957
+ hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
958
+
959
+ flat_l_effective_ref_img_len = list(itertools.chain(*l_effective_ref_img_len))
960
+ num_ref_images = len(flat_l_effective_ref_img_len)
961
+ max_ref_img_len = max(flat_l_effective_ref_img_len)
962
+
963
+ batch_ref_img_mask = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, dtype=torch.bool)
964
+ batch_ref_image_hidden_states = ref_image_hidden_states.new_zeros(
965
+ num_ref_images, max_ref_img_len, self.config.hidden_size
966
+ )
967
+ batch_ref_img_rotary_emb = hidden_states.new_zeros(
968
+ num_ref_images, max_ref_img_len, ref_img_rotary_emb.shape[-1], dtype=ref_img_rotary_emb.dtype
969
+ )
970
+ batch_temb = temb.new_zeros(num_ref_images, *temb.shape[1:], dtype=temb.dtype)
971
+
972
+ idx = 0
973
+ for i in range(batch_size):
974
+ shift = 0
975
+ for ref_img_len in l_effective_ref_img_len[i]:
976
+ batch_ref_img_mask[idx, :ref_img_len] = True
977
+ batch_ref_image_hidden_states[idx, :ref_img_len] = ref_image_hidden_states[i, shift:shift + ref_img_len]
978
+ batch_ref_img_rotary_emb[idx, :ref_img_len] = ref_img_rotary_emb[i, shift:shift + ref_img_len]
979
+ batch_temb[idx] = temb[i]
980
+ shift += ref_img_len
981
+ idx += 1
982
+
983
+ for layer in self.ref_image_refiner:
984
+ batch_ref_image_hidden_states = layer(
985
+ batch_ref_image_hidden_states, batch_ref_img_mask, batch_ref_img_rotary_emb, batch_temb
986
+ )
987
+
988
+ idx = 0
989
+ for i in range(batch_size):
990
+ shift = 0
991
+ for ref_img_len in l_effective_ref_img_len[i]:
992
+ ref_image_hidden_states[i, shift:shift + ref_img_len] = batch_ref_image_hidden_states[idx, :ref_img_len]
993
+ shift += ref_img_len
994
+ idx += 1
995
+
996
+ combined_img_hidden_states = hidden_states.new_zeros(
997
+ batch_size, max_combined_img_len, self.config.hidden_size
998
+ )
999
+ for i, (ref_img_len, img_len) in enumerate(zip(l_effective_ref_img_len, l_effective_img_len)):
1000
+ combined_img_hidden_states[i, :sum(ref_img_len)] = ref_image_hidden_states[i, :sum(ref_img_len)]
1001
+ combined_img_hidden_states[i, sum(ref_img_len):sum(ref_img_len) + img_len] = hidden_states[i, :img_len]
1002
+
1003
+ return combined_img_hidden_states
1004
+
1005
+ def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states):
1006
+ batch_size = len(hidden_states)
1007
+ p = self.config.patch_size
1008
+ device = hidden_states[0].device
1009
+
1010
+ img_sizes = [(img.size(1), img.size(2)) for img in hidden_states]
1011
+ l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes]
1012
+
1013
+ if ref_image_hidden_states is not None and len(ref_image_hidden_states) > 0:
1014
+ ref_img_sizes = [
1015
+ [(img.size(1), img.size(2)) for img in imgs] if imgs is not None else None
1016
+ for imgs in ref_image_hidden_states
1017
+ ]
1018
+ l_effective_ref_img_len = [
1019
+ [(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes]
1020
+ if _ref_img_sizes is not None else [0]
1021
+ for _ref_img_sizes in ref_img_sizes
1022
+ ]
1023
+ else:
1024
+ ref_img_sizes = [None for _ in range(batch_size)]
1025
+ l_effective_ref_img_len = [[0] for _ in range(batch_size)]
1026
+
1027
+ max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
1028
+ max_img_len = max(l_effective_img_len)
1029
+
1030
+ flat_ref_img_hidden_states = []
1031
+ for i in range(batch_size):
1032
+ if ref_img_sizes[i] is not None:
1033
+ imgs = []
1034
+ for ref_img in ref_image_hidden_states[i]:
1035
+ C, H, W = ref_img.size()
1036
+ ref_img = rearrange(ref_img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p)
1037
+ imgs.append(ref_img)
1038
+ flat_ref_img_hidden_states.append(torch.cat(imgs, dim=0))
1039
+ else:
1040
+ flat_ref_img_hidden_states.append(None)
1041
+
1042
+ flat_hidden_states = []
1043
+ for i in range(batch_size):
1044
+ img = hidden_states[i]
1045
+ C, H, W = img.size()
1046
+ img = rearrange(img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p)
1047
+ flat_hidden_states.append(img)
1048
+
1049
+ padded_ref_img_hidden_states = torch.zeros(
1050
+ batch_size, max_ref_img_len, flat_hidden_states[0].shape[-1],
1051
+ device=device, dtype=flat_hidden_states[0].dtype,
1052
+ )
1053
+ padded_ref_img_mask = torch.zeros(batch_size, max_ref_img_len, dtype=torch.bool, device=device)
1054
+ for i in range(batch_size):
1055
+ if ref_img_sizes[i] is not None:
1056
+ padded_ref_img_hidden_states[i, :sum(l_effective_ref_img_len[i])] = flat_ref_img_hidden_states[i]
1057
+ padded_ref_img_mask[i, :sum(l_effective_ref_img_len[i])] = True
1058
+
1059
+ padded_hidden_states = torch.zeros(
1060
+ batch_size, max_img_len, flat_hidden_states[0].shape[-1],
1061
+ device=device, dtype=flat_hidden_states[0].dtype,
1062
+ )
1063
+ padded_img_mask = torch.zeros(batch_size, max_img_len, dtype=torch.bool, device=device)
1064
+ for i in range(batch_size):
1065
+ padded_hidden_states[i, :l_effective_img_len[i]] = flat_hidden_states[i]
1066
+ padded_img_mask[i, :l_effective_img_len[i]] = True
1067
+
1068
+ return (
1069
+ padded_hidden_states,
1070
+ padded_ref_img_hidden_states,
1071
+ padded_img_mask,
1072
+ padded_ref_img_mask,
1073
+ l_effective_ref_img_len,
1074
+ l_effective_img_len,
1075
+ ref_img_sizes,
1076
+ img_sizes,
1077
+ )
1078
+
1079
+ def forward(
1080
+ self,
1081
+ hidden_states: Union[torch.Tensor, List[torch.Tensor]],
1082
+ timestep: torch.Tensor,
1083
+ text_hidden_states: torch.Tensor,
1084
+ freqs_cis: torch.Tensor,
1085
+ text_attention_mask: torch.Tensor,
1086
+ ref_image_hidden_states: Optional[List[List[torch.Tensor]]] = None,
1087
+ attention_kwargs: Optional[Dict[str, Any]] = None,
1088
+ return_dict: bool = False,
1089
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
1090
+ enable_taylorseer = getattr(self, 'enable_taylorseer', False)
1091
+ if enable_taylorseer:
1092
+ cal_type(self.cache_dic, self.current)
1093
+
1094
+ if attention_kwargs is not None:
1095
+ attention_kwargs = attention_kwargs.copy()
1096
+ lora_scale = attention_kwargs.pop("scale", 1.0)
1097
+ else:
1098
+ lora_scale = 1.0
1099
+
1100
+ if USE_PEFT_BACKEND:
1101
+ scale_lora_layers(self, lora_scale)
1102
+ else:
1103
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
1104
+ logger.warning(
1105
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
1106
+ )
1107
+
1108
+ batch_size = len(hidden_states)
1109
+ is_hidden_states_tensor = isinstance(hidden_states, torch.Tensor)
1110
+
1111
+ if is_hidden_states_tensor:
1112
+ assert hidden_states.ndim == 4
1113
+ hidden_states = [_hidden_states for _hidden_states in hidden_states]
1114
+
1115
+ device = hidden_states[0].device
1116
+
1117
+ assert isinstance(text_hidden_states, torch.Tensor), \
1118
+ f"text_hidden_states must be Tensor, got {type(text_hidden_states)}. " \
1119
+ f"Check if freqs_cis and text_hidden_states are swapped in the caller."
1120
+
1121
+ temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype)
1122
+
1123
+ (
1124
+ hidden_states,
1125
+ ref_image_hidden_states,
1126
+ img_mask,
1127
+ ref_img_mask,
1128
+ l_effective_ref_img_len,
1129
+ l_effective_img_len,
1130
+ ref_img_sizes,
1131
+ img_sizes,
1132
+ ) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states)
1133
+
1134
+ (
1135
+ context_rotary_emb,
1136
+ ref_img_rotary_emb,
1137
+ noise_rotary_emb,
1138
+ rotary_emb,
1139
+ encoder_seq_lengths,
1140
+ seq_lengths,
1141
+ ) = self.rope_embedder(
1142
+ freqs_cis,
1143
+ text_attention_mask,
1144
+ l_effective_ref_img_len,
1145
+ l_effective_img_len,
1146
+ ref_img_sizes,
1147
+ img_sizes,
1148
+ device,
1149
+ )
1150
+
1151
+ # 2. Context refinement
1152
+ for layer in self.context_refiner:
1153
+ text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb)
1154
+
1155
+ combined_img_hidden_states = self.img_patch_embed_and_refine(
1156
+ hidden_states,
1157
+ ref_image_hidden_states,
1158
+ img_mask,
1159
+ ref_img_mask,
1160
+ noise_rotary_emb,
1161
+ ref_img_rotary_emb,
1162
+ l_effective_ref_img_len,
1163
+ l_effective_img_len,
1164
+ temb,
1165
+ )
1166
+
1167
+ # 3. Joint Transformer blocks
1168
+ max_seq_len = max(seq_lengths)
1169
+
1170
+ attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
1171
+ joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
1172
+ for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
1173
+ attention_mask[i, :seq_len] = True
1174
+ joint_hidden_states[i, :encoder_seq_len] = text_hidden_states[i, :encoder_seq_len]
1175
+ joint_hidden_states[i, encoder_seq_len:seq_len] = combined_img_hidden_states[i, :seq_len - encoder_seq_len]
1176
+
1177
+ hidden_states = joint_hidden_states
1178
+
1179
+ if self.enable_teacache:
1180
+ teacache_hidden_states = hidden_states.clone()
1181
+ teacache_temb = temb.clone()
1182
+ modulated_inp, _, _, _ = self.layers[0].norm1(teacache_hidden_states, teacache_temb)
1183
+ if self.teacache_params.is_first_or_last_step:
1184
+ should_calc = True
1185
+ self.teacache_params.accumulated_rel_l1_distance = 0
1186
+ else:
1187
+ self.teacache_params.accumulated_rel_l1_distance += self.rescale_func(
1188
+ ((modulated_inp - self.teacache_params.previous_modulated_inp).abs().mean()
1189
+ / self.teacache_params.previous_modulated_inp.abs().mean()).cpu().item()
1190
+ )
1191
+ if self.teacache_params.accumulated_rel_l1_distance < self.teacache_rel_l1_thresh:
1192
+ should_calc = False
1193
+ else:
1194
+ should_calc = True
1195
+ self.teacache_params.accumulated_rel_l1_distance = 0
1196
+ self.teacache_params.previous_modulated_inp = modulated_inp
1197
+
1198
+ if self.enable_teacache:
1199
+ if not should_calc:
1200
+ hidden_states += self.teacache_params.previous_residual
1201
+ else:
1202
+ ori_hidden_states = hidden_states.clone()
1203
+ for layer_idx, layer in enumerate(self.layers):
1204
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1205
+ hidden_states = self._gradient_checkpointing_func(
1206
+ layer, hidden_states, attention_mask, rotary_emb, temb
1207
+ )
1208
+ else:
1209
+ hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
1210
+ self.teacache_params.previous_residual = hidden_states - ori_hidden_states
1211
+ else:
1212
+ if enable_taylorseer:
1213
+ self.current['stream'] = 'layers_stream'
1214
+
1215
+ for layer_idx, layer in enumerate(self.layers):
1216
+ if enable_taylorseer:
1217
+ layer.current = self.current
1218
+ layer.cache_dic = self.cache_dic
1219
+ layer.enable_taylorseer = True
1220
+ self.current['layer'] = layer_idx
1221
+
1222
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1223
+ hidden_states = self._gradient_checkpointing_func(
1224
+ layer, hidden_states, attention_mask, rotary_emb, temb
1225
+ )
1226
+ else:
1227
+ hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
1228
+
1229
+ hidden_states = self.norm_out(hidden_states, temb)
1230
+
1231
+ p = self.config.patch_size
1232
+ output = []
1233
+ for i, (img_size, img_len, seq_len) in enumerate(zip(img_sizes, l_effective_img_len, seq_lengths)):
1234
+ height, width = img_size
1235
+ output.append(rearrange(
1236
+ hidden_states[i][seq_len - img_len:seq_len],
1237
+ '(h w) (p1 p2 c) -> c (h p1) (w p2)',
1238
+ h=height // p, w=width // p, p1=p, p2=p,
1239
+ ))
1240
+ if is_hidden_states_tensor:
1241
+ output = torch.stack(output, dim=0)
1242
+
1243
+ if USE_PEFT_BACKEND:
1244
+ unscale_lora_layers(self, lora_scale)
1245
+
1246
+ if enable_taylorseer:
1247
+ self.current['step'] += 1
1248
+
1249
+ if not return_dict:
1250
+ return output
1251
+ return Transformer2DModelOutput(sample=output)
1252
+
1253
+
1254
+ # ---------------------------------------------------------------------------
1255
+ # FlowMatch Euler Discrete Scheduler (merged from scheduling_flow_match_euler_discrete.py)
1256
+ # ---------------------------------------------------------------------------
1257
+
1258
+ @dataclass
1259
+ class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
1260
+ prev_sample: torch.FloatTensor
1261
+
1262
+
1263
+ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
1264
+ _compatibles = []
1265
+ order = 1
1266
+
1267
+ @register_to_config
1268
+ def __init__(self, num_train_timesteps: int = 1000, dynamic_time_shift: bool = False):
1269
+ timesteps = torch.linspace(0, 1, num_train_timesteps + 1, dtype=torch.float32)[:-1]
1270
+ self.timesteps = timesteps
1271
+ self._step_index = None
1272
+ self._begin_index = None
1273
+
1274
+ @property
1275
+ def step_index(self):
1276
+ return self._step_index
1277
+
1278
+ @property
1279
+ def begin_index(self):
1280
+ return self._begin_index
1281
+
1282
+ def set_begin_index(self, begin_index: int = 0):
1283
+ self._begin_index = begin_index
1284
+
1285
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
1286
+ if schedule_timesteps is None:
1287
+ schedule_timesteps = self._timesteps
1288
+ indices = (schedule_timesteps == timestep).nonzero()
1289
+ pos = 1 if len(indices) > 1 else 0
1290
+ return indices[pos].item()
1291
+
1292
+ def set_timesteps(self, num_inference_steps=None, device=None, timesteps=None, num_tokens=None):
1293
+ if timesteps is None:
1294
+ self.num_inference_steps = num_inference_steps
1295
+ timesteps = np.linspace(0, 1, num_inference_steps + 1, dtype=np.float32)[:-1]
1296
+ if self.config.dynamic_time_shift and num_tokens is not None:
1297
+ m = np.sqrt(num_tokens) / 40
1298
+ timesteps = timesteps / (m - m * timesteps + timesteps)
1299
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device)
1300
+ _timesteps = torch.cat([timesteps, torch.ones(1, device=timesteps.device)])
1301
+ self.timesteps = timesteps
1302
+ self._timesteps = _timesteps
1303
+ self._step_index = None
1304
+ self._begin_index = None
1305
+
1306
+ def _init_step_index(self, timestep):
1307
+ if self.begin_index is None:
1308
+ if isinstance(timestep, torch.Tensor):
1309
+ timestep = timestep.to(self.timesteps.device)
1310
+ self._step_index = self.index_for_timestep(timestep)
1311
+ else:
1312
+ self._step_index = self._begin_index
1313
+
1314
+ def step(self, model_output, timestep, sample, generator=None, return_dict=True):
1315
+ if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
1316
+ raise ValueError("Pass scheduler.timesteps values, not integer indices.")
1317
+ if self.step_index is None:
1318
+ self._init_step_index(timestep)
1319
+ sample = sample.to(torch.float32)
1320
+ t = self._timesteps[self.step_index]
1321
+ t_next = self._timesteps[self.step_index + 1]
1322
+ prev_sample = sample + (t_next - t) * model_output
1323
+ prev_sample = prev_sample.to(model_output.dtype)
1324
+ self._step_index += 1
1325
+ if not return_dict:
1326
+ return (prev_sample,)
1327
+ return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
1328
+
1329
+ def __len__(self):
1330
+ return self.config.num_train_timesteps
requirements-post.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ flash-attn==2.7.4.post1
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch==2.6.0
2
+ torchvision==0.21.0
3
+ torchaudio==2.6.0
4
+ accelerate==1.10.0
5
+ transformers==4.57.6
6
+ librosa==0.11.0
7
+ diffusers==0.34.0
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,2294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": true,
4
+ "add_prefix_space": false,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<longcat_unk>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<longcat_s>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "2": {
23
+ "content": "</longcat_s>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ },
30
+ "3": {
31
+ "content": "<longcat_pad>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": true
37
+ },
38
+ "4": {
39
+ "content": "<shift_unk>",
40
+ "lstrip": false,
41
+ "normalized": false,
42
+ "rstrip": false,
43
+ "single_word": false,
44
+ "special": true
45
+ },
46
+ "5": {
47
+ "content": "<shift_s>",
48
+ "lstrip": false,
49
+ "normalized": false,
50
+ "rstrip": false,
51
+ "single_word": false,
52
+ "special": true
53
+ },
54
+ "6": {
55
+ "content": "</shift_s>",
56
+ "lstrip": false,
57
+ "normalized": false,
58
+ "rstrip": false,
59
+ "single_word": false,
60
+ "special": true
61
+ },
62
+ "7": {
63
+ "content": "<shift_pad>",
64
+ "lstrip": false,
65
+ "normalized": false,
66
+ "rstrip": false,
67
+ "single_word": false,
68
+ "special": true
69
+ },
70
+ "8": {
71
+ "content": "<mask_0>",
72
+ "lstrip": false,
73
+ "normalized": false,
74
+ "rstrip": false,
75
+ "single_word": false,
76
+ "special": true
77
+ },
78
+ "9": {
79
+ "content": "<reponame>",
80
+ "lstrip": false,
81
+ "normalized": false,
82
+ "rstrip": false,
83
+ "single_word": false,
84
+ "special": true
85
+ },
86
+ "10": {
87
+ "content": "<filename>",
88
+ "lstrip": false,
89
+ "normalized": false,
90
+ "rstrip": false,
91
+ "single_word": false,
92
+ "special": true
93
+ },
94
+ "11": {
95
+ "content": "<gh_stars>",
96
+ "lstrip": false,
97
+ "normalized": false,
98
+ "rstrip": false,
99
+ "single_word": false,
100
+ "special": true
101
+ },
102
+ "12": {
103
+ "content": "<issue_start>",
104
+ "lstrip": false,
105
+ "normalized": false,
106
+ "rstrip": false,
107
+ "single_word": false,
108
+ "special": true
109
+ },
110
+ "13": {
111
+ "content": "<issue_comment>",
112
+ "lstrip": false,
113
+ "normalized": false,
114
+ "rstrip": false,
115
+ "single_word": false,
116
+ "special": true
117
+ },
118
+ "14": {
119
+ "content": "<issue_closed>",
120
+ "lstrip": false,
121
+ "normalized": false,
122
+ "rstrip": false,
123
+ "single_word": false,
124
+ "special": true
125
+ },
126
+ "15": {
127
+ "content": "<jupyter_start>",
128
+ "lstrip": false,
129
+ "normalized": false,
130
+ "rstrip": false,
131
+ "single_word": false,
132
+ "special": true
133
+ },
134
+ "16": {
135
+ "content": "<jupyter_text>",
136
+ "lstrip": false,
137
+ "normalized": false,
138
+ "rstrip": false,
139
+ "single_word": false,
140
+ "special": true
141
+ },
142
+ "17": {
143
+ "content": "<jupyter_code>",
144
+ "lstrip": false,
145
+ "normalized": false,
146
+ "rstrip": false,
147
+ "single_word": false,
148
+ "special": true
149
+ },
150
+ "18": {
151
+ "content": "<jupyter_output>",
152
+ "lstrip": false,
153
+ "normalized": false,
154
+ "rstrip": false,
155
+ "single_word": false,
156
+ "special": true
157
+ },
158
+ "19": {
159
+ "content": "<empty_output>",
160
+ "lstrip": false,
161
+ "normalized": false,
162
+ "rstrip": false,
163
+ "single_word": false,
164
+ "special": true
165
+ },
166
+ "20": {
167
+ "content": "<commit_before>",
168
+ "lstrip": false,
169
+ "normalized": false,
170
+ "rstrip": false,
171
+ "single_word": false,
172
+ "special": true
173
+ },
174
+ "21": {
175
+ "content": "<commit_msg>",
176
+ "lstrip": false,
177
+ "normalized": false,
178
+ "rstrip": false,
179
+ "single_word": false,
180
+ "special": true
181
+ },
182
+ "22": {
183
+ "content": "<commit_after>",
184
+ "lstrip": false,
185
+ "normalized": false,
186
+ "rstrip": false,
187
+ "single_word": false,
188
+ "special": true
189
+ },
190
+ "23": {
191
+ "content": "<program_lang>",
192
+ "lstrip": false,
193
+ "normalized": false,
194
+ "rstrip": false,
195
+ "single_word": false,
196
+ "special": true
197
+ },
198
+ "24": {
199
+ "content": "<|image_placeholder|>",
200
+ "lstrip": false,
201
+ "normalized": false,
202
+ "rstrip": false,
203
+ "single_word": false,
204
+ "special": true
205
+ },
206
+ "25": {
207
+ "content": "<|url_placeholder|>",
208
+ "lstrip": false,
209
+ "normalized": false,
210
+ "rstrip": false,
211
+ "single_word": false,
212
+ "special": true
213
+ },
214
+ "26": {
215
+ "content": "<|hyperlink_placeholder|>",
216
+ "lstrip": false,
217
+ "normalized": false,
218
+ "rstrip": false,
219
+ "single_word": false,
220
+ "special": true
221
+ },
222
+ "27": {
223
+ "content": "<|table_placeholder|>",
224
+ "lstrip": false,
225
+ "normalized": false,
226
+ "rstrip": false,
227
+ "single_word": false,
228
+ "special": true
229
+ },
230
+ "28": {
231
+ "content": "<|equation_placeholder|>",
232
+ "lstrip": false,
233
+ "normalized": false,
234
+ "rstrip": false,
235
+ "single_word": false,
236
+ "special": true
237
+ },
238
+ "29": {
239
+ "content": "<|code_placeholder|>",
240
+ "lstrip": false,
241
+ "normalized": false,
242
+ "rstrip": false,
243
+ "single_word": false,
244
+ "special": true
245
+ },
246
+ "30": {
247
+ "content": "<|reference_placeholder|>",
248
+ "lstrip": false,
249
+ "normalized": false,
250
+ "rstrip": false,
251
+ "single_word": false,
252
+ "special": true
253
+ },
254
+ "31": {
255
+ "content": "<|endoftext|>",
256
+ "lstrip": false,
257
+ "normalized": false,
258
+ "rstrip": false,
259
+ "single_word": false,
260
+ "special": true
261
+ },
262
+ "32": {
263
+ "content": "<fim_prefix>",
264
+ "lstrip": false,
265
+ "normalized": false,
266
+ "rstrip": false,
267
+ "single_word": false,
268
+ "special": true
269
+ },
270
+ "33": {
271
+ "content": "<fim_middle>",
272
+ "lstrip": false,
273
+ "normalized": false,
274
+ "rstrip": false,
275
+ "single_word": false,
276
+ "special": true
277
+ },
278
+ "34": {
279
+ "content": "<fim_suffix>",
280
+ "lstrip": false,
281
+ "normalized": false,
282
+ "rstrip": false,
283
+ "single_word": false,
284
+ "special": true
285
+ },
286
+ "35": {
287
+ "content": "<fim_pad>",
288
+ "lstrip": false,
289
+ "normalized": false,
290
+ "rstrip": false,
291
+ "single_word": false,
292
+ "special": true
293
+ },
294
+ "36": {
295
+ "content": "<longcat_think>",
296
+ "lstrip": false,
297
+ "normalized": false,
298
+ "rstrip": false,
299
+ "single_word": false,
300
+ "special": false
301
+ },
302
+ "37": {
303
+ "content": "</longcat_think>",
304
+ "lstrip": false,
305
+ "normalized": false,
306
+ "rstrip": false,
307
+ "single_word": false,
308
+ "special": false
309
+ },
310
+ "38": {
311
+ "content": "<longcat_answer>",
312
+ "lstrip": false,
313
+ "normalized": false,
314
+ "rstrip": false,
315
+ "single_word": false,
316
+ "special": false
317
+ },
318
+ "39": {
319
+ "content": "</longcat_answer>",
320
+ "lstrip": false,
321
+ "normalized": false,
322
+ "rstrip": false,
323
+ "single_word": false,
324
+ "special": false
325
+ },
326
+ "40": {
327
+ "content": "<longcat_files>",
328
+ "lstrip": false,
329
+ "normalized": false,
330
+ "rstrip": false,
331
+ "single_word": false,
332
+ "special": false
333
+ },
334
+ "41": {
335
+ "content": "</longcat_files>",
336
+ "lstrip": false,
337
+ "normalized": false,
338
+ "rstrip": false,
339
+ "single_word": false,
340
+ "special": false
341
+ },
342
+ "42": {
343
+ "content": "<longcat_tool_call>",
344
+ "lstrip": false,
345
+ "normalized": false,
346
+ "rstrip": false,
347
+ "single_word": false,
348
+ "special": false
349
+ },
350
+ "43": {
351
+ "content": "</longcat_tool_call>",
352
+ "lstrip": false,
353
+ "normalized": false,
354
+ "rstrip": false,
355
+ "single_word": false,
356
+ "special": false
357
+ },
358
+ "44": {
359
+ "content": "<longcat_tool_declare>",
360
+ "lstrip": false,
361
+ "normalized": false,
362
+ "rstrip": false,
363
+ "single_word": false,
364
+ "special": true
365
+ },
366
+ "45": {
367
+ "content": "</longcat_tool_declare>",
368
+ "lstrip": false,
369
+ "normalized": false,
370
+ "rstrip": false,
371
+ "single_word": false,
372
+ "special": true
373
+ },
374
+ "46": {
375
+ "content": "<longcat_system>",
376
+ "lstrip": false,
377
+ "normalized": false,
378
+ "rstrip": false,
379
+ "single_word": false,
380
+ "special": true
381
+ },
382
+ "47": {
383
+ "content": "<longcat_user>",
384
+ "lstrip": false,
385
+ "normalized": false,
386
+ "rstrip": false,
387
+ "single_word": false,
388
+ "special": true
389
+ },
390
+ "48": {
391
+ "content": "<longcat_assistant>",
392
+ "lstrip": false,
393
+ "normalized": false,
394
+ "rstrip": false,
395
+ "single_word": false,
396
+ "special": true
397
+ },
398
+ "49": {
399
+ "content": "<longcat_tool_response>",
400
+ "lstrip": false,
401
+ "normalized": false,
402
+ "rstrip": false,
403
+ "single_word": false,
404
+ "special": false
405
+ },
406
+ "50": {
407
+ "content": "</longcat_tool_response>",
408
+ "lstrip": false,
409
+ "normalized": false,
410
+ "rstrip": false,
411
+ "single_word": false,
412
+ "special": false
413
+ },
414
+ "51": {
415
+ "content": "<longcat_arg_key>",
416
+ "lstrip": false,
417
+ "normalized": false,
418
+ "rstrip": false,
419
+ "single_word": false,
420
+ "special": false
421
+ },
422
+ "52": {
423
+ "content": "</longcat_arg_key>",
424
+ "lstrip": false,
425
+ "normalized": false,
426
+ "rstrip": false,
427
+ "single_word": false,
428
+ "special": false
429
+ },
430
+ "53": {
431
+ "content": "<longcat_arg_value>",
432
+ "lstrip": false,
433
+ "normalized": false,
434
+ "rstrip": false,
435
+ "single_word": false,
436
+ "special": false
437
+ },
438
+ "54": {
439
+ "content": "</longcat_arg_value>",
440
+ "lstrip": false,
441
+ "normalized": false,
442
+ "rstrip": false,
443
+ "single_word": false,
444
+ "special": false
445
+ },
446
+ "55": {
447
+ "content": "<mask_31>",
448
+ "lstrip": false,
449
+ "normalized": false,
450
+ "rstrip": false,
451
+ "single_word": false,
452
+ "special": true
453
+ },
454
+ "56": {
455
+ "content": "<mask_32>",
456
+ "lstrip": false,
457
+ "normalized": false,
458
+ "rstrip": false,
459
+ "single_word": false,
460
+ "special": true
461
+ },
462
+ "57": {
463
+ "content": "<mask_33>",
464
+ "lstrip": false,
465
+ "normalized": false,
466
+ "rstrip": false,
467
+ "single_word": false,
468
+ "special": true
469
+ },
470
+ "58": {
471
+ "content": "<mask_34>",
472
+ "lstrip": false,
473
+ "normalized": false,
474
+ "rstrip": false,
475
+ "single_word": false,
476
+ "special": true
477
+ },
478
+ "59": {
479
+ "content": "<mask_35>",
480
+ "lstrip": false,
481
+ "normalized": false,
482
+ "rstrip": false,
483
+ "single_word": false,
484
+ "special": true
485
+ },
486
+ "60": {
487
+ "content": "<mask_36>",
488
+ "lstrip": false,
489
+ "normalized": false,
490
+ "rstrip": false,
491
+ "single_word": false,
492
+ "special": true
493
+ },
494
+ "61": {
495
+ "content": "<mask_37>",
496
+ "lstrip": false,
497
+ "normalized": false,
498
+ "rstrip": false,
499
+ "single_word": false,
500
+ "special": true
501
+ },
502
+ "62": {
503
+ "content": "<mask_38>",
504
+ "lstrip": false,
505
+ "normalized": false,
506
+ "rstrip": false,
507
+ "single_word": false,
508
+ "special": true
509
+ },
510
+ "63": {
511
+ "content": "<mask_39>",
512
+ "lstrip": false,
513
+ "normalized": false,
514
+ "rstrip": false,
515
+ "single_word": false,
516
+ "special": true
517
+ },
518
+ "64": {
519
+ "content": "<mask_40>",
520
+ "lstrip": false,
521
+ "normalized": false,
522
+ "rstrip": false,
523
+ "single_word": false,
524
+ "special": true
525
+ },
526
+ "65": {
527
+ "content": "<mask_41>",
528
+ "lstrip": false,
529
+ "normalized": false,
530
+ "rstrip": false,
531
+ "single_word": false,
532
+ "special": true
533
+ },
534
+ "66": {
535
+ "content": "<mask_42>",
536
+ "lstrip": false,
537
+ "normalized": false,
538
+ "rstrip": false,
539
+ "single_word": false,
540
+ "special": true
541
+ },
542
+ "67": {
543
+ "content": "<mask_43>",
544
+ "lstrip": false,
545
+ "normalized": false,
546
+ "rstrip": false,
547
+ "single_word": false,
548
+ "special": true
549
+ },
550
+ "68": {
551
+ "content": "<mask_44>",
552
+ "lstrip": false,
553
+ "normalized": false,
554
+ "rstrip": false,
555
+ "single_word": false,
556
+ "special": true
557
+ },
558
+ "69": {
559
+ "content": "<mask_45>",
560
+ "lstrip": false,
561
+ "normalized": false,
562
+ "rstrip": false,
563
+ "single_word": false,
564
+ "special": true
565
+ },
566
+ "70": {
567
+ "content": "<mask_46>",
568
+ "lstrip": false,
569
+ "normalized": false,
570
+ "rstrip": false,
571
+ "single_word": false,
572
+ "special": true
573
+ },
574
+ "71": {
575
+ "content": "<mask_47>",
576
+ "lstrip": false,
577
+ "normalized": false,
578
+ "rstrip": false,
579
+ "single_word": false,
580
+ "special": true
581
+ },
582
+ "72": {
583
+ "content": "<mask_48>",
584
+ "lstrip": false,
585
+ "normalized": false,
586
+ "rstrip": false,
587
+ "single_word": false,
588
+ "special": true
589
+ },
590
+ "73": {
591
+ "content": "<mask_49>",
592
+ "lstrip": false,
593
+ "normalized": false,
594
+ "rstrip": false,
595
+ "single_word": false,
596
+ "special": true
597
+ },
598
+ "74": {
599
+ "content": "<mask_50>",
600
+ "lstrip": false,
601
+ "normalized": false,
602
+ "rstrip": false,
603
+ "single_word": false,
604
+ "special": true
605
+ },
606
+ "75": {
607
+ "content": "<mask_51>",
608
+ "lstrip": false,
609
+ "normalized": false,
610
+ "rstrip": false,
611
+ "single_word": false,
612
+ "special": true
613
+ },
614
+ "76": {
615
+ "content": "<mask_52>",
616
+ "lstrip": false,
617
+ "normalized": false,
618
+ "rstrip": false,
619
+ "single_word": false,
620
+ "special": true
621
+ },
622
+ "77": {
623
+ "content": "<mask_53>",
624
+ "lstrip": false,
625
+ "normalized": false,
626
+ "rstrip": false,
627
+ "single_word": false,
628
+ "special": true
629
+ },
630
+ "78": {
631
+ "content": "<mask_54>",
632
+ "lstrip": false,
633
+ "normalized": false,
634
+ "rstrip": false,
635
+ "single_word": false,
636
+ "special": true
637
+ },
638
+ "79": {
639
+ "content": "<mask_55>",
640
+ "lstrip": false,
641
+ "normalized": false,
642
+ "rstrip": false,
643
+ "single_word": false,
644
+ "special": true
645
+ },
646
+ "80": {
647
+ "content": "<mask_56>",
648
+ "lstrip": false,
649
+ "normalized": false,
650
+ "rstrip": false,
651
+ "single_word": false,
652
+ "special": true
653
+ },
654
+ "81": {
655
+ "content": "<mask_57>",
656
+ "lstrip": false,
657
+ "normalized": false,
658
+ "rstrip": false,
659
+ "single_word": false,
660
+ "special": true
661
+ },
662
+ "82": {
663
+ "content": "<mask_58>",
664
+ "lstrip": false,
665
+ "normalized": false,
666
+ "rstrip": false,
667
+ "single_word": false,
668
+ "special": true
669
+ },
670
+ "83": {
671
+ "content": "<mask_59>",
672
+ "lstrip": false,
673
+ "normalized": false,
674
+ "rstrip": false,
675
+ "single_word": false,
676
+ "special": true
677
+ },
678
+ "84": {
679
+ "content": "<mask_60>",
680
+ "lstrip": false,
681
+ "normalized": false,
682
+ "rstrip": false,
683
+ "single_word": false,
684
+ "special": true
685
+ },
686
+ "85": {
687
+ "content": "<mask_61>",
688
+ "lstrip": false,
689
+ "normalized": false,
690
+ "rstrip": false,
691
+ "single_word": false,
692
+ "special": true
693
+ },
694
+ "86": {
695
+ "content": "<mask_62>",
696
+ "lstrip": false,
697
+ "normalized": false,
698
+ "rstrip": false,
699
+ "single_word": false,
700
+ "special": true
701
+ },
702
+ "87": {
703
+ "content": "<mask_63>",
704
+ "lstrip": false,
705
+ "normalized": false,
706
+ "rstrip": false,
707
+ "single_word": false,
708
+ "special": true
709
+ },
710
+ "88": {
711
+ "content": "<mask_64>",
712
+ "lstrip": false,
713
+ "normalized": false,
714
+ "rstrip": false,
715
+ "single_word": false,
716
+ "special": true
717
+ },
718
+ "89": {
719
+ "content": "<mask_65>",
720
+ "lstrip": false,
721
+ "normalized": false,
722
+ "rstrip": false,
723
+ "single_word": false,
724
+ "special": true
725
+ },
726
+ "90": {
727
+ "content": "<mask_66>",
728
+ "lstrip": false,
729
+ "normalized": false,
730
+ "rstrip": false,
731
+ "single_word": false,
732
+ "special": true
733
+ },
734
+ "91": {
735
+ "content": "<mask_67>",
736
+ "lstrip": false,
737
+ "normalized": false,
738
+ "rstrip": false,
739
+ "single_word": false,
740
+ "special": true
741
+ },
742
+ "92": {
743
+ "content": "<mask_68>",
744
+ "lstrip": false,
745
+ "normalized": false,
746
+ "rstrip": false,
747
+ "single_word": false,
748
+ "special": true
749
+ },
750
+ "93": {
751
+ "content": "<mask_69>",
752
+ "lstrip": false,
753
+ "normalized": false,
754
+ "rstrip": false,
755
+ "single_word": false,
756
+ "special": true
757
+ },
758
+ "94": {
759
+ "content": "<mask_70>",
760
+ "lstrip": false,
761
+ "normalized": false,
762
+ "rstrip": false,
763
+ "single_word": false,
764
+ "special": true
765
+ },
766
+ "95": {
767
+ "content": "<mask_71>",
768
+ "lstrip": false,
769
+ "normalized": false,
770
+ "rstrip": false,
771
+ "single_word": false,
772
+ "special": true
773
+ },
774
+ "96": {
775
+ "content": "<mask_72>",
776
+ "lstrip": false,
777
+ "normalized": false,
778
+ "rstrip": false,
779
+ "single_word": false,
780
+ "special": true
781
+ },
782
+ "97": {
783
+ "content": "<mask_73>",
784
+ "lstrip": false,
785
+ "normalized": false,
786
+ "rstrip": false,
787
+ "single_word": false,
788
+ "special": true
789
+ },
790
+ "98": {
791
+ "content": "<mask_74>",
792
+ "lstrip": false,
793
+ "normalized": false,
794
+ "rstrip": false,
795
+ "single_word": false,
796
+ "special": true
797
+ },
798
+ "99": {
799
+ "content": "<mask_75>",
800
+ "lstrip": false,
801
+ "normalized": false,
802
+ "rstrip": false,
803
+ "single_word": false,
804
+ "special": true
805
+ },
806
+ "100": {
807
+ "content": "<mask_76>",
808
+ "lstrip": false,
809
+ "normalized": false,
810
+ "rstrip": false,
811
+ "single_word": false,
812
+ "special": true
813
+ },
814
+ "101": {
815
+ "content": "<mask_77>",
816
+ "lstrip": false,
817
+ "normalized": false,
818
+ "rstrip": false,
819
+ "single_word": false,
820
+ "special": true
821
+ },
822
+ "102": {
823
+ "content": "<mask_78>",
824
+ "lstrip": false,
825
+ "normalized": false,
826
+ "rstrip": false,
827
+ "single_word": false,
828
+ "special": true
829
+ },
830
+ "103": {
831
+ "content": "<mask_79>",
832
+ "lstrip": false,
833
+ "normalized": false,
834
+ "rstrip": false,
835
+ "single_word": false,
836
+ "special": true
837
+ },
838
+ "104": {
839
+ "content": "<mask_80>",
840
+ "lstrip": false,
841
+ "normalized": false,
842
+ "rstrip": false,
843
+ "single_word": false,
844
+ "special": true
845
+ },
846
+ "105": {
847
+ "content": "<mask_81>",
848
+ "lstrip": false,
849
+ "normalized": false,
850
+ "rstrip": false,
851
+ "single_word": false,
852
+ "special": true
853
+ },
854
+ "106": {
855
+ "content": "<mask_82>",
856
+ "lstrip": false,
857
+ "normalized": false,
858
+ "rstrip": false,
859
+ "single_word": false,
860
+ "special": true
861
+ },
862
+ "107": {
863
+ "content": "<mask_83>",
864
+ "lstrip": false,
865
+ "normalized": false,
866
+ "rstrip": false,
867
+ "single_word": false,
868
+ "special": true
869
+ },
870
+ "108": {
871
+ "content": "<mask_84>",
872
+ "lstrip": false,
873
+ "normalized": false,
874
+ "rstrip": false,
875
+ "single_word": false,
876
+ "special": true
877
+ },
878
+ "109": {
879
+ "content": "<mask_85>",
880
+ "lstrip": false,
881
+ "normalized": false,
882
+ "rstrip": false,
883
+ "single_word": false,
884
+ "special": true
885
+ },
886
+ "110": {
887
+ "content": "<mask_86>",
888
+ "lstrip": false,
889
+ "normalized": false,
890
+ "rstrip": false,
891
+ "single_word": false,
892
+ "special": true
893
+ },
894
+ "111": {
895
+ "content": "<mask_87>",
896
+ "lstrip": false,
897
+ "normalized": false,
898
+ "rstrip": false,
899
+ "single_word": false,
900
+ "special": true
901
+ },
902
+ "112": {
903
+ "content": "<mask_88>",
904
+ "lstrip": false,
905
+ "normalized": false,
906
+ "rstrip": false,
907
+ "single_word": false,
908
+ "special": true
909
+ },
910
+ "113": {
911
+ "content": "<mask_89>",
912
+ "lstrip": false,
913
+ "normalized": false,
914
+ "rstrip": false,
915
+ "single_word": false,
916
+ "special": true
917
+ },
918
+ "114": {
919
+ "content": "<mask_90>",
920
+ "lstrip": false,
921
+ "normalized": false,
922
+ "rstrip": false,
923
+ "single_word": false,
924
+ "special": true
925
+ },
926
+ "115": {
927
+ "content": "<mask_91>",
928
+ "lstrip": false,
929
+ "normalized": false,
930
+ "rstrip": false,
931
+ "single_word": false,
932
+ "special": true
933
+ },
934
+ "116": {
935
+ "content": "<mask_92>",
936
+ "lstrip": false,
937
+ "normalized": false,
938
+ "rstrip": false,
939
+ "single_word": false,
940
+ "special": true
941
+ },
942
+ "117": {
943
+ "content": "<mask_93>",
944
+ "lstrip": false,
945
+ "normalized": false,
946
+ "rstrip": false,
947
+ "single_word": false,
948
+ "special": true
949
+ },
950
+ "118": {
951
+ "content": "<mask_94>",
952
+ "lstrip": false,
953
+ "normalized": false,
954
+ "rstrip": false,
955
+ "single_word": false,
956
+ "special": true
957
+ },
958
+ "119": {
959
+ "content": "<mask_95>",
960
+ "lstrip": false,
961
+ "normalized": false,
962
+ "rstrip": false,
963
+ "single_word": false,
964
+ "special": true
965
+ },
966
+ "120": {
967
+ "content": "<mask_96>",
968
+ "lstrip": false,
969
+ "normalized": false,
970
+ "rstrip": false,
971
+ "single_word": false,
972
+ "special": true
973
+ },
974
+ "121": {
975
+ "content": "<mask_97>",
976
+ "lstrip": false,
977
+ "normalized": false,
978
+ "rstrip": false,
979
+ "single_word": false,
980
+ "special": true
981
+ },
982
+ "122": {
983
+ "content": "<mask_98>",
984
+ "lstrip": false,
985
+ "normalized": false,
986
+ "rstrip": false,
987
+ "single_word": false,
988
+ "special": true
989
+ },
990
+ "123": {
991
+ "content": "<mask_99>",
992
+ "lstrip": false,
993
+ "normalized": false,
994
+ "rstrip": false,
995
+ "single_word": false,
996
+ "special": true
997
+ },
998
+ "124": {
999
+ "content": "<mask_100>",
1000
+ "lstrip": false,
1001
+ "normalized": false,
1002
+ "rstrip": false,
1003
+ "single_word": false,
1004
+ "special": true
1005
+ },
1006
+ "125": {
1007
+ "content": "<mask_101>",
1008
+ "lstrip": false,
1009
+ "normalized": false,
1010
+ "rstrip": false,
1011
+ "single_word": false,
1012
+ "special": true
1013
+ },
1014
+ "126": {
1015
+ "content": "<mask_102>",
1016
+ "lstrip": false,
1017
+ "normalized": false,
1018
+ "rstrip": false,
1019
+ "single_word": false,
1020
+ "special": true
1021
+ },
1022
+ "127": {
1023
+ "content": "<mask_103>",
1024
+ "lstrip": false,
1025
+ "normalized": false,
1026
+ "rstrip": false,
1027
+ "single_word": false,
1028
+ "special": true
1029
+ },
1030
+ "128": {
1031
+ "content": "<mask_104>",
1032
+ "lstrip": false,
1033
+ "normalized": false,
1034
+ "rstrip": false,
1035
+ "single_word": false,
1036
+ "special": true
1037
+ },
1038
+ "129": {
1039
+ "content": "<mask_105>",
1040
+ "lstrip": false,
1041
+ "normalized": false,
1042
+ "rstrip": false,
1043
+ "single_word": false,
1044
+ "special": true
1045
+ },
1046
+ "130": {
1047
+ "content": "<mask_106>",
1048
+ "lstrip": false,
1049
+ "normalized": false,
1050
+ "rstrip": false,
1051
+ "single_word": false,
1052
+ "special": true
1053
+ },
1054
+ "131": {
1055
+ "content": "<mask_107>",
1056
+ "lstrip": false,
1057
+ "normalized": false,
1058
+ "rstrip": false,
1059
+ "single_word": false,
1060
+ "special": true
1061
+ },
1062
+ "132": {
1063
+ "content": "<mask_108>",
1064
+ "lstrip": false,
1065
+ "normalized": false,
1066
+ "rstrip": false,
1067
+ "single_word": false,
1068
+ "special": true
1069
+ },
1070
+ "133": {
1071
+ "content": "<mask_109>",
1072
+ "lstrip": false,
1073
+ "normalized": false,
1074
+ "rstrip": false,
1075
+ "single_word": false,
1076
+ "special": true
1077
+ },
1078
+ "134": {
1079
+ "content": "<mask_110>",
1080
+ "lstrip": false,
1081
+ "normalized": false,
1082
+ "rstrip": false,
1083
+ "single_word": false,
1084
+ "special": true
1085
+ },
1086
+ "135": {
1087
+ "content": "<mask_111>",
1088
+ "lstrip": false,
1089
+ "normalized": false,
1090
+ "rstrip": false,
1091
+ "single_word": false,
1092
+ "special": true
1093
+ },
1094
+ "136": {
1095
+ "content": "<mask_112>",
1096
+ "lstrip": false,
1097
+ "normalized": false,
1098
+ "rstrip": false,
1099
+ "single_word": false,
1100
+ "special": true
1101
+ },
1102
+ "137": {
1103
+ "content": "<mask_113>",
1104
+ "lstrip": false,
1105
+ "normalized": false,
1106
+ "rstrip": false,
1107
+ "single_word": false,
1108
+ "special": true
1109
+ },
1110
+ "138": {
1111
+ "content": "<mask_114>",
1112
+ "lstrip": false,
1113
+ "normalized": false,
1114
+ "rstrip": false,
1115
+ "single_word": false,
1116
+ "special": true
1117
+ },
1118
+ "139": {
1119
+ "content": "<mask_115>",
1120
+ "lstrip": false,
1121
+ "normalized": false,
1122
+ "rstrip": false,
1123
+ "single_word": false,
1124
+ "special": true
1125
+ },
1126
+ "140": {
1127
+ "content": "<mask_116>",
1128
+ "lstrip": false,
1129
+ "normalized": false,
1130
+ "rstrip": false,
1131
+ "single_word": false,
1132
+ "special": true
1133
+ },
1134
+ "141": {
1135
+ "content": "<mask_117>",
1136
+ "lstrip": false,
1137
+ "normalized": false,
1138
+ "rstrip": false,
1139
+ "single_word": false,
1140
+ "special": true
1141
+ },
1142
+ "142": {
1143
+ "content": "<mask_118>",
1144
+ "lstrip": false,
1145
+ "normalized": false,
1146
+ "rstrip": false,
1147
+ "single_word": false,
1148
+ "special": true
1149
+ },
1150
+ "143": {
1151
+ "content": "<mask_119>",
1152
+ "lstrip": false,
1153
+ "normalized": false,
1154
+ "rstrip": false,
1155
+ "single_word": false,
1156
+ "special": true
1157
+ },
1158
+ "144": {
1159
+ "content": "<mask_120>",
1160
+ "lstrip": false,
1161
+ "normalized": false,
1162
+ "rstrip": false,
1163
+ "single_word": false,
1164
+ "special": true
1165
+ },
1166
+ "145": {
1167
+ "content": "<mask_121>",
1168
+ "lstrip": false,
1169
+ "normalized": false,
1170
+ "rstrip": false,
1171
+ "single_word": false,
1172
+ "special": true
1173
+ },
1174
+ "146": {
1175
+ "content": "<mask_122>",
1176
+ "lstrip": false,
1177
+ "normalized": false,
1178
+ "rstrip": false,
1179
+ "single_word": false,
1180
+ "special": true
1181
+ },
1182
+ "147": {
1183
+ "content": "<mask_123>",
1184
+ "lstrip": false,
1185
+ "normalized": false,
1186
+ "rstrip": false,
1187
+ "single_word": false,
1188
+ "special": true
1189
+ },
1190
+ "148": {
1191
+ "content": "<mask_124>",
1192
+ "lstrip": false,
1193
+ "normalized": false,
1194
+ "rstrip": false,
1195
+ "single_word": false,
1196
+ "special": true
1197
+ },
1198
+ "149": {
1199
+ "content": "<mask_125>",
1200
+ "lstrip": false,
1201
+ "normalized": false,
1202
+ "rstrip": false,
1203
+ "single_word": false,
1204
+ "special": true
1205
+ },
1206
+ "150": {
1207
+ "content": "<mask_126>",
1208
+ "lstrip": false,
1209
+ "normalized": false,
1210
+ "rstrip": false,
1211
+ "single_word": false,
1212
+ "special": true
1213
+ },
1214
+ "151": {
1215
+ "content": "<mask_127>",
1216
+ "lstrip": false,
1217
+ "normalized": false,
1218
+ "rstrip": false,
1219
+ "single_word": false,
1220
+ "special": true
1221
+ },
1222
+ "152": {
1223
+ "content": "<mask_128>",
1224
+ "lstrip": false,
1225
+ "normalized": false,
1226
+ "rstrip": false,
1227
+ "single_word": false,
1228
+ "special": true
1229
+ },
1230
+ "153": {
1231
+ "content": "<mask_129>",
1232
+ "lstrip": false,
1233
+ "normalized": false,
1234
+ "rstrip": false,
1235
+ "single_word": false,
1236
+ "special": true
1237
+ },
1238
+ "154": {
1239
+ "content": "<mask_130>",
1240
+ "lstrip": false,
1241
+ "normalized": false,
1242
+ "rstrip": false,
1243
+ "single_word": false,
1244
+ "special": true
1245
+ },
1246
+ "155": {
1247
+ "content": "<mask_131>",
1248
+ "lstrip": false,
1249
+ "normalized": false,
1250
+ "rstrip": false,
1251
+ "single_word": false,
1252
+ "special": true
1253
+ },
1254
+ "156": {
1255
+ "content": "<mask_132>",
1256
+ "lstrip": false,
1257
+ "normalized": false,
1258
+ "rstrip": false,
1259
+ "single_word": false,
1260
+ "special": true
1261
+ },
1262
+ "157": {
1263
+ "content": "<mask_133>",
1264
+ "lstrip": false,
1265
+ "normalized": false,
1266
+ "rstrip": false,
1267
+ "single_word": false,
1268
+ "special": true
1269
+ },
1270
+ "158": {
1271
+ "content": "<mask_134>",
1272
+ "lstrip": false,
1273
+ "normalized": false,
1274
+ "rstrip": false,
1275
+ "single_word": false,
1276
+ "special": true
1277
+ },
1278
+ "159": {
1279
+ "content": "<mask_135>",
1280
+ "lstrip": false,
1281
+ "normalized": false,
1282
+ "rstrip": false,
1283
+ "single_word": false,
1284
+ "special": true
1285
+ },
1286
+ "160": {
1287
+ "content": "<mask_136>",
1288
+ "lstrip": false,
1289
+ "normalized": false,
1290
+ "rstrip": false,
1291
+ "single_word": false,
1292
+ "special": true
1293
+ },
1294
+ "161": {
1295
+ "content": "<mask_137>",
1296
+ "lstrip": false,
1297
+ "normalized": false,
1298
+ "rstrip": false,
1299
+ "single_word": false,
1300
+ "special": true
1301
+ },
1302
+ "162": {
1303
+ "content": "<mask_138>",
1304
+ "lstrip": false,
1305
+ "normalized": false,
1306
+ "rstrip": false,
1307
+ "single_word": false,
1308
+ "special": true
1309
+ },
1310
+ "163": {
1311
+ "content": "<mask_139>",
1312
+ "lstrip": false,
1313
+ "normalized": false,
1314
+ "rstrip": false,
1315
+ "single_word": false,
1316
+ "special": true
1317
+ },
1318
+ "164": {
1319
+ "content": "<mask_140>",
1320
+ "lstrip": false,
1321
+ "normalized": false,
1322
+ "rstrip": false,
1323
+ "single_word": false,
1324
+ "special": true
1325
+ },
1326
+ "165": {
1327
+ "content": "<mask_141>",
1328
+ "lstrip": false,
1329
+ "normalized": false,
1330
+ "rstrip": false,
1331
+ "single_word": false,
1332
+ "special": true
1333
+ },
1334
+ "166": {
1335
+ "content": "<mask_142>",
1336
+ "lstrip": false,
1337
+ "normalized": false,
1338
+ "rstrip": false,
1339
+ "single_word": false,
1340
+ "special": true
1341
+ },
1342
+ "167": {
1343
+ "content": "<mask_143>",
1344
+ "lstrip": false,
1345
+ "normalized": false,
1346
+ "rstrip": false,
1347
+ "single_word": false,
1348
+ "special": true
1349
+ },
1350
+ "168": {
1351
+ "content": "<mask_144>",
1352
+ "lstrip": false,
1353
+ "normalized": false,
1354
+ "rstrip": false,
1355
+ "single_word": false,
1356
+ "special": true
1357
+ },
1358
+ "169": {
1359
+ "content": "<mask_145>",
1360
+ "lstrip": false,
1361
+ "normalized": false,
1362
+ "rstrip": false,
1363
+ "single_word": false,
1364
+ "special": true
1365
+ },
1366
+ "170": {
1367
+ "content": "<mask_146>",
1368
+ "lstrip": false,
1369
+ "normalized": false,
1370
+ "rstrip": false,
1371
+ "single_word": false,
1372
+ "special": true
1373
+ },
1374
+ "171": {
1375
+ "content": "<mask_147>",
1376
+ "lstrip": false,
1377
+ "normalized": false,
1378
+ "rstrip": false,
1379
+ "single_word": false,
1380
+ "special": true
1381
+ },
1382
+ "172": {
1383
+ "content": "<mask_148>",
1384
+ "lstrip": false,
1385
+ "normalized": false,
1386
+ "rstrip": false,
1387
+ "single_word": false,
1388
+ "special": true
1389
+ },
1390
+ "173": {
1391
+ "content": "<mask_149>",
1392
+ "lstrip": false,
1393
+ "normalized": false,
1394
+ "rstrip": false,
1395
+ "single_word": false,
1396
+ "special": true
1397
+ },
1398
+ "174": {
1399
+ "content": "<mask_150>",
1400
+ "lstrip": false,
1401
+ "normalized": false,
1402
+ "rstrip": false,
1403
+ "single_word": false,
1404
+ "special": true
1405
+ },
1406
+ "175": {
1407
+ "content": "<mask_151>",
1408
+ "lstrip": false,
1409
+ "normalized": false,
1410
+ "rstrip": false,
1411
+ "single_word": false,
1412
+ "special": true
1413
+ },
1414
+ "176": {
1415
+ "content": "<mask_152>",
1416
+ "lstrip": false,
1417
+ "normalized": false,
1418
+ "rstrip": false,
1419
+ "single_word": false,
1420
+ "special": true
1421
+ },
1422
+ "177": {
1423
+ "content": "<mask_153>",
1424
+ "lstrip": false,
1425
+ "normalized": false,
1426
+ "rstrip": false,
1427
+ "single_word": false,
1428
+ "special": true
1429
+ },
1430
+ "178": {
1431
+ "content": "<mask_154>",
1432
+ "lstrip": false,
1433
+ "normalized": false,
1434
+ "rstrip": false,
1435
+ "single_word": false,
1436
+ "special": true
1437
+ },
1438
+ "179": {
1439
+ "content": "<mask_155>",
1440
+ "lstrip": false,
1441
+ "normalized": false,
1442
+ "rstrip": false,
1443
+ "single_word": false,
1444
+ "special": true
1445
+ },
1446
+ "180": {
1447
+ "content": "<mask_156>",
1448
+ "lstrip": false,
1449
+ "normalized": false,
1450
+ "rstrip": false,
1451
+ "single_word": false,
1452
+ "special": true
1453
+ },
1454
+ "181": {
1455
+ "content": "<mask_157>",
1456
+ "lstrip": false,
1457
+ "normalized": false,
1458
+ "rstrip": false,
1459
+ "single_word": false,
1460
+ "special": true
1461
+ },
1462
+ "182": {
1463
+ "content": "<mask_158>",
1464
+ "lstrip": false,
1465
+ "normalized": false,
1466
+ "rstrip": false,
1467
+ "single_word": false,
1468
+ "special": true
1469
+ },
1470
+ "183": {
1471
+ "content": "<mask_159>",
1472
+ "lstrip": false,
1473
+ "normalized": false,
1474
+ "rstrip": false,
1475
+ "single_word": false,
1476
+ "special": true
1477
+ },
1478
+ "184": {
1479
+ "content": "<mask_160>",
1480
+ "lstrip": false,
1481
+ "normalized": false,
1482
+ "rstrip": false,
1483
+ "single_word": false,
1484
+ "special": true
1485
+ },
1486
+ "185": {
1487
+ "content": "<mask_161>",
1488
+ "lstrip": false,
1489
+ "normalized": false,
1490
+ "rstrip": false,
1491
+ "single_word": false,
1492
+ "special": true
1493
+ },
1494
+ "186": {
1495
+ "content": "<mask_162>",
1496
+ "lstrip": false,
1497
+ "normalized": false,
1498
+ "rstrip": false,
1499
+ "single_word": false,
1500
+ "special": true
1501
+ },
1502
+ "187": {
1503
+ "content": "<mask_163>",
1504
+ "lstrip": false,
1505
+ "normalized": false,
1506
+ "rstrip": false,
1507
+ "single_word": false,
1508
+ "special": true
1509
+ },
1510
+ "188": {
1511
+ "content": "<mask_164>",
1512
+ "lstrip": false,
1513
+ "normalized": false,
1514
+ "rstrip": false,
1515
+ "single_word": false,
1516
+ "special": true
1517
+ },
1518
+ "189": {
1519
+ "content": "<mask_165>",
1520
+ "lstrip": false,
1521
+ "normalized": false,
1522
+ "rstrip": false,
1523
+ "single_word": false,
1524
+ "special": true
1525
+ },
1526
+ "190": {
1527
+ "content": "<mask_166>",
1528
+ "lstrip": false,
1529
+ "normalized": false,
1530
+ "rstrip": false,
1531
+ "single_word": false,
1532
+ "special": true
1533
+ },
1534
+ "191": {
1535
+ "content": "<mask_167>",
1536
+ "lstrip": false,
1537
+ "normalized": false,
1538
+ "rstrip": false,
1539
+ "single_word": false,
1540
+ "special": true
1541
+ },
1542
+ "192": {
1543
+ "content": "<mask_168>",
1544
+ "lstrip": false,
1545
+ "normalized": false,
1546
+ "rstrip": false,
1547
+ "single_word": false,
1548
+ "special": true
1549
+ },
1550
+ "193": {
1551
+ "content": "<mask_169>",
1552
+ "lstrip": false,
1553
+ "normalized": false,
1554
+ "rstrip": false,
1555
+ "single_word": false,
1556
+ "special": true
1557
+ },
1558
+ "194": {
1559
+ "content": "<mask_170>",
1560
+ "lstrip": false,
1561
+ "normalized": false,
1562
+ "rstrip": false,
1563
+ "single_word": false,
1564
+ "special": true
1565
+ },
1566
+ "195": {
1567
+ "content": "<mask_171>",
1568
+ "lstrip": false,
1569
+ "normalized": false,
1570
+ "rstrip": false,
1571
+ "single_word": false,
1572
+ "special": true
1573
+ },
1574
+ "196": {
1575
+ "content": "<mask_172>",
1576
+ "lstrip": false,
1577
+ "normalized": false,
1578
+ "rstrip": false,
1579
+ "single_word": false,
1580
+ "special": true
1581
+ },
1582
+ "197": {
1583
+ "content": "<mask_173>",
1584
+ "lstrip": false,
1585
+ "normalized": false,
1586
+ "rstrip": false,
1587
+ "single_word": false,
1588
+ "special": true
1589
+ },
1590
+ "198": {
1591
+ "content": "<mask_174>",
1592
+ "lstrip": false,
1593
+ "normalized": false,
1594
+ "rstrip": false,
1595
+ "single_word": false,
1596
+ "special": true
1597
+ },
1598
+ "199": {
1599
+ "content": "<mask_175>",
1600
+ "lstrip": false,
1601
+ "normalized": false,
1602
+ "rstrip": false,
1603
+ "single_word": false,
1604
+ "special": true
1605
+ },
1606
+ "200": {
1607
+ "content": "<mask_176>",
1608
+ "lstrip": false,
1609
+ "normalized": false,
1610
+ "rstrip": false,
1611
+ "single_word": false,
1612
+ "special": true
1613
+ },
1614
+ "201": {
1615
+ "content": "<mask_177>",
1616
+ "lstrip": false,
1617
+ "normalized": false,
1618
+ "rstrip": false,
1619
+ "single_word": false,
1620
+ "special": true
1621
+ },
1622
+ "202": {
1623
+ "content": "<mask_178>",
1624
+ "lstrip": false,
1625
+ "normalized": false,
1626
+ "rstrip": false,
1627
+ "single_word": false,
1628
+ "special": true
1629
+ },
1630
+ "203": {
1631
+ "content": "<mask_179>",
1632
+ "lstrip": false,
1633
+ "normalized": false,
1634
+ "rstrip": false,
1635
+ "single_word": false,
1636
+ "special": true
1637
+ },
1638
+ "204": {
1639
+ "content": "<mask_180>",
1640
+ "lstrip": false,
1641
+ "normalized": false,
1642
+ "rstrip": false,
1643
+ "single_word": false,
1644
+ "special": true
1645
+ },
1646
+ "205": {
1647
+ "content": "<mask_181>",
1648
+ "lstrip": false,
1649
+ "normalized": false,
1650
+ "rstrip": false,
1651
+ "single_word": false,
1652
+ "special": true
1653
+ },
1654
+ "206": {
1655
+ "content": "<mask_182>",
1656
+ "lstrip": false,
1657
+ "normalized": false,
1658
+ "rstrip": false,
1659
+ "single_word": false,
1660
+ "special": true
1661
+ },
1662
+ "207": {
1663
+ "content": "<mask_183>",
1664
+ "lstrip": false,
1665
+ "normalized": false,
1666
+ "rstrip": false,
1667
+ "single_word": false,
1668
+ "special": true
1669
+ },
1670
+ "208": {
1671
+ "content": "<mask_184>",
1672
+ "lstrip": false,
1673
+ "normalized": false,
1674
+ "rstrip": false,
1675
+ "single_word": false,
1676
+ "special": true
1677
+ },
1678
+ "209": {
1679
+ "content": "<mask_185>",
1680
+ "lstrip": false,
1681
+ "normalized": false,
1682
+ "rstrip": false,
1683
+ "single_word": false,
1684
+ "special": true
1685
+ },
1686
+ "210": {
1687
+ "content": "<mask_186>",
1688
+ "lstrip": false,
1689
+ "normalized": false,
1690
+ "rstrip": false,
1691
+ "single_word": false,
1692
+ "special": true
1693
+ },
1694
+ "211": {
1695
+ "content": "<mask_187>",
1696
+ "lstrip": false,
1697
+ "normalized": false,
1698
+ "rstrip": false,
1699
+ "single_word": false,
1700
+ "special": true
1701
+ },
1702
+ "212": {
1703
+ "content": "<mask_188>",
1704
+ "lstrip": false,
1705
+ "normalized": false,
1706
+ "rstrip": false,
1707
+ "single_word": false,
1708
+ "special": true
1709
+ },
1710
+ "213": {
1711
+ "content": "<mask_189>",
1712
+ "lstrip": false,
1713
+ "normalized": false,
1714
+ "rstrip": false,
1715
+ "single_word": false,
1716
+ "special": true
1717
+ },
1718
+ "214": {
1719
+ "content": "<mask_190>",
1720
+ "lstrip": false,
1721
+ "normalized": false,
1722
+ "rstrip": false,
1723
+ "single_word": false,
1724
+ "special": true
1725
+ },
1726
+ "215": {
1727
+ "content": "<mask_191>",
1728
+ "lstrip": false,
1729
+ "normalized": false,
1730
+ "rstrip": false,
1731
+ "single_word": false,
1732
+ "special": true
1733
+ },
1734
+ "216": {
1735
+ "content": "<mask_192>",
1736
+ "lstrip": false,
1737
+ "normalized": false,
1738
+ "rstrip": false,
1739
+ "single_word": false,
1740
+ "special": true
1741
+ },
1742
+ "217": {
1743
+ "content": "<mask_193>",
1744
+ "lstrip": false,
1745
+ "normalized": false,
1746
+ "rstrip": false,
1747
+ "single_word": false,
1748
+ "special": true
1749
+ },
1750
+ "218": {
1751
+ "content": "<mask_194>",
1752
+ "lstrip": false,
1753
+ "normalized": false,
1754
+ "rstrip": false,
1755
+ "single_word": false,
1756
+ "special": true
1757
+ },
1758
+ "219": {
1759
+ "content": "<mask_195>",
1760
+ "lstrip": false,
1761
+ "normalized": false,
1762
+ "rstrip": false,
1763
+ "single_word": false,
1764
+ "special": true
1765
+ },
1766
+ "220": {
1767
+ "content": "<mask_196>",
1768
+ "lstrip": false,
1769
+ "normalized": false,
1770
+ "rstrip": false,
1771
+ "single_word": false,
1772
+ "special": true
1773
+ },
1774
+ "221": {
1775
+ "content": "<mask_197>",
1776
+ "lstrip": false,
1777
+ "normalized": false,
1778
+ "rstrip": false,
1779
+ "single_word": false,
1780
+ "special": true
1781
+ },
1782
+ "222": {
1783
+ "content": "<mask_198>",
1784
+ "lstrip": false,
1785
+ "normalized": false,
1786
+ "rstrip": false,
1787
+ "single_word": false,
1788
+ "special": true
1789
+ },
1790
+ "223": {
1791
+ "content": "<mask_199>",
1792
+ "lstrip": false,
1793
+ "normalized": false,
1794
+ "rstrip": false,
1795
+ "single_word": false,
1796
+ "special": true
1797
+ },
1798
+ "131072": {
1799
+ "content": "<mask_131048>",
1800
+ "lstrip": false,
1801
+ "normalized": false,
1802
+ "rstrip": false,
1803
+ "single_word": false,
1804
+ "special": true
1805
+ },
1806
+ "131073": {
1807
+ "content": "<mask_131049>",
1808
+ "lstrip": false,
1809
+ "normalized": false,
1810
+ "rstrip": false,
1811
+ "single_word": false,
1812
+ "special": true
1813
+ },
1814
+ "131074": {
1815
+ "content": "<mask_131050>",
1816
+ "lstrip": false,
1817
+ "normalized": false,
1818
+ "rstrip": false,
1819
+ "single_word": false,
1820
+ "special": true
1821
+ },
1822
+ "131075": {
1823
+ "content": "<mask_131051>",
1824
+ "lstrip": false,
1825
+ "normalized": false,
1826
+ "rstrip": false,
1827
+ "single_word": false,
1828
+ "special": true
1829
+ },
1830
+ "131076": {
1831
+ "content": "<mask_131052>",
1832
+ "lstrip": false,
1833
+ "normalized": false,
1834
+ "rstrip": false,
1835
+ "single_word": false,
1836
+ "special": true
1837
+ },
1838
+ "131077": {
1839
+ "content": "<mask_131053>",
1840
+ "lstrip": false,
1841
+ "normalized": false,
1842
+ "rstrip": false,
1843
+ "single_word": false,
1844
+ "special": true
1845
+ },
1846
+ "131078": {
1847
+ "content": "<mask_131054>",
1848
+ "lstrip": false,
1849
+ "normalized": false,
1850
+ "rstrip": false,
1851
+ "single_word": false,
1852
+ "special": true
1853
+ },
1854
+ "131079": {
1855
+ "content": "<mask_131055>",
1856
+ "lstrip": false,
1857
+ "normalized": false,
1858
+ "rstrip": false,
1859
+ "single_word": false,
1860
+ "special": true
1861
+ },
1862
+ "131080": {
1863
+ "content": "<mask_131056>",
1864
+ "lstrip": false,
1865
+ "normalized": false,
1866
+ "rstrip": false,
1867
+ "single_word": false,
1868
+ "special": true
1869
+ },
1870
+ "131081": {
1871
+ "content": "<mask_131057>",
1872
+ "lstrip": false,
1873
+ "normalized": false,
1874
+ "rstrip": false,
1875
+ "single_word": false,
1876
+ "special": true
1877
+ },
1878
+ "131082": {
1879
+ "content": "<mask_131058>",
1880
+ "lstrip": false,
1881
+ "normalized": false,
1882
+ "rstrip": false,
1883
+ "single_word": false,
1884
+ "special": true
1885
+ },
1886
+ "131083": {
1887
+ "content": "<mask_131059>",
1888
+ "lstrip": false,
1889
+ "normalized": false,
1890
+ "rstrip": false,
1891
+ "single_word": false,
1892
+ "special": true
1893
+ },
1894
+ "131084": {
1895
+ "content": "<mask_131060>",
1896
+ "lstrip": false,
1897
+ "normalized": false,
1898
+ "rstrip": false,
1899
+ "single_word": false,
1900
+ "special": true
1901
+ },
1902
+ "131085": {
1903
+ "content": "<mask_131061>",
1904
+ "lstrip": false,
1905
+ "normalized": false,
1906
+ "rstrip": false,
1907
+ "single_word": false,
1908
+ "special": true
1909
+ },
1910
+ "131086": {
1911
+ "content": "<mask_131062>",
1912
+ "lstrip": false,
1913
+ "normalized": false,
1914
+ "rstrip": false,
1915
+ "single_word": false,
1916
+ "special": true
1917
+ },
1918
+ "131087": {
1919
+ "content": "<mask_131063>",
1920
+ "lstrip": false,
1921
+ "normalized": false,
1922
+ "rstrip": false,
1923
+ "single_word": false,
1924
+ "special": true
1925
+ },
1926
+ "131088": {
1927
+ "content": "<mask_131064>",
1928
+ "lstrip": false,
1929
+ "normalized": false,
1930
+ "rstrip": false,
1931
+ "single_word": false,
1932
+ "special": true
1933
+ },
1934
+ "131089": {
1935
+ "content": "<mask_131065>",
1936
+ "lstrip": false,
1937
+ "normalized": false,
1938
+ "rstrip": false,
1939
+ "single_word": false,
1940
+ "special": true
1941
+ },
1942
+ "131090": {
1943
+ "content": "<longcat_img_token_size>",
1944
+ "lstrip": false,
1945
+ "normalized": false,
1946
+ "rstrip": false,
1947
+ "single_word": false,
1948
+ "special": true
1949
+ },
1950
+ "131091": {
1951
+ "content": "</longcat_img_token_size>",
1952
+ "lstrip": false,
1953
+ "normalized": false,
1954
+ "rstrip": false,
1955
+ "single_word": false,
1956
+ "special": true
1957
+ },
1958
+ "131092": {
1959
+ "content": "<mask_131068>",
1960
+ "lstrip": false,
1961
+ "normalized": false,
1962
+ "rstrip": false,
1963
+ "single_word": false,
1964
+ "special": true
1965
+ },
1966
+ "131093": {
1967
+ "content": "<mask_131069>",
1968
+ "lstrip": false,
1969
+ "normalized": false,
1970
+ "rstrip": false,
1971
+ "single_word": false,
1972
+ "special": true
1973
+ },
1974
+ "131094": {
1975
+ "content": "<mask_131070>",
1976
+ "lstrip": false,
1977
+ "normalized": false,
1978
+ "rstrip": false,
1979
+ "single_word": false,
1980
+ "special": true
1981
+ },
1982
+ "131095": {
1983
+ "content": "<mask_131071>",
1984
+ "lstrip": false,
1985
+ "normalized": false,
1986
+ "rstrip": false,
1987
+ "single_word": false,
1988
+ "special": true
1989
+ },
1990
+ "131096": {
1991
+ "content": "<longcat_point_start>",
1992
+ "lstrip": false,
1993
+ "normalized": false,
1994
+ "rstrip": false,
1995
+ "single_word": false,
1996
+ "special": true
1997
+ },
1998
+ "131097": {
1999
+ "content": "<longcat_point_end>",
2000
+ "lstrip": false,
2001
+ "normalized": false,
2002
+ "rstrip": false,
2003
+ "single_word": false,
2004
+ "special": true
2005
+ },
2006
+ "131098": {
2007
+ "content": "<longcat_point_delim>",
2008
+ "lstrip": false,
2009
+ "normalized": false,
2010
+ "rstrip": false,
2011
+ "single_word": false,
2012
+ "special": true
2013
+ },
2014
+ "131099": {
2015
+ "content": "<longcat_polygon_start>",
2016
+ "lstrip": false,
2017
+ "normalized": false,
2018
+ "rstrip": false,
2019
+ "single_word": false,
2020
+ "special": true
2021
+ },
2022
+ "131100": {
2023
+ "content": "<longcat_polygon_end>",
2024
+ "lstrip": false,
2025
+ "normalized": false,
2026
+ "rstrip": false,
2027
+ "single_word": false,
2028
+ "special": true
2029
+ },
2030
+ "131101": {
2031
+ "content": "<mask_131077>",
2032
+ "lstrip": false,
2033
+ "normalized": false,
2034
+ "rstrip": false,
2035
+ "single_word": false,
2036
+ "special": true
2037
+ },
2038
+ "131102": {
2039
+ "content": "<mask_131078>",
2040
+ "lstrip": false,
2041
+ "normalized": false,
2042
+ "rstrip": false,
2043
+ "single_word": false,
2044
+ "special": true
2045
+ },
2046
+ "131103": {
2047
+ "content": "<longcat_audio_start>",
2048
+ "lstrip": false,
2049
+ "normalized": false,
2050
+ "rstrip": false,
2051
+ "single_word": false,
2052
+ "special": true
2053
+ },
2054
+ "131104": {
2055
+ "content": "<longcat_audio_end>",
2056
+ "lstrip": false,
2057
+ "normalized": false,
2058
+ "rstrip": false,
2059
+ "single_word": false,
2060
+ "special": true
2061
+ },
2062
+ "131105": {
2063
+ "content": "<longcat_audio_pad>",
2064
+ "lstrip": false,
2065
+ "normalized": false,
2066
+ "rstrip": false,
2067
+ "single_word": false,
2068
+ "special": true
2069
+ },
2070
+ "131106": {
2071
+ "content": "<longcat_img_start>",
2072
+ "lstrip": false,
2073
+ "normalized": false,
2074
+ "rstrip": false,
2075
+ "single_word": false,
2076
+ "special": true
2077
+ },
2078
+ "131107": {
2079
+ "content": "<longcat_img_end>",
2080
+ "lstrip": false,
2081
+ "normalized": false,
2082
+ "rstrip": false,
2083
+ "single_word": false,
2084
+ "special": true
2085
+ },
2086
+ "131108": {
2087
+ "content": "<longcat_img_pad>",
2088
+ "lstrip": false,
2089
+ "normalized": false,
2090
+ "rstrip": false,
2091
+ "single_word": false,
2092
+ "special": true
2093
+ },
2094
+ "131109": {
2095
+ "content": "<longcat_img_newline>",
2096
+ "lstrip": false,
2097
+ "normalized": false,
2098
+ "rstrip": false,
2099
+ "single_word": false,
2100
+ "special": true
2101
+ },
2102
+ "131110": {
2103
+ "content": "<longcat_box_start>",
2104
+ "lstrip": false,
2105
+ "normalized": false,
2106
+ "rstrip": false,
2107
+ "single_word": false,
2108
+ "special": true
2109
+ },
2110
+ "131111": {
2111
+ "content": "<longcat_box_end>",
2112
+ "lstrip": false,
2113
+ "normalized": false,
2114
+ "rstrip": false,
2115
+ "single_word": false,
2116
+ "special": true
2117
+ },
2118
+ "131112": {
2119
+ "content": "<longcat_box_delim>",
2120
+ "lstrip": false,
2121
+ "normalized": false,
2122
+ "rstrip": false,
2123
+ "single_word": false,
2124
+ "special": true
2125
+ },
2126
+ "131113": {
2127
+ "content": "<longcat_ref_start>",
2128
+ "lstrip": false,
2129
+ "normalized": false,
2130
+ "rstrip": false,
2131
+ "single_word": false,
2132
+ "special": true
2133
+ },
2134
+ "131114": {
2135
+ "content": "<longcat_ref_end>",
2136
+ "lstrip": false,
2137
+ "normalized": false,
2138
+ "rstrip": false,
2139
+ "single_word": false,
2140
+ "special": true
2141
+ },
2142
+ "131115": {
2143
+ "content": "<longcat_img_delim>",
2144
+ "lstrip": false,
2145
+ "normalized": false,
2146
+ "rstrip": false,
2147
+ "single_word": false,
2148
+ "special": true
2149
+ },
2150
+ "131116": {
2151
+ "content": "<longcat_audio_delim>",
2152
+ "lstrip": false,
2153
+ "normalized": false,
2154
+ "rstrip": false,
2155
+ "single_word": false,
2156
+ "special": true
2157
+ },
2158
+ "131117": {
2159
+ "content": "<longcat_video_palce>",
2160
+ "lstrip": false,
2161
+ "normalized": false,
2162
+ "rstrip": false,
2163
+ "single_word": false,
2164
+ "special": true
2165
+ },
2166
+ "131118": {
2167
+ "content": "<longcat_video_start>",
2168
+ "lstrip": false,
2169
+ "normalized": false,
2170
+ "rstrip": false,
2171
+ "single_word": false,
2172
+ "special": true
2173
+ },
2174
+ "131119": {
2175
+ "content": "<longcat_video_end>",
2176
+ "lstrip": false,
2177
+ "normalized": false,
2178
+ "rstrip": false,
2179
+ "single_word": false,
2180
+ "special": true
2181
+ },
2182
+ "131120": {
2183
+ "content": "<longcat_audiotext_start>",
2184
+ "lstrip": false,
2185
+ "normalized": false,
2186
+ "rstrip": false,
2187
+ "single_word": false,
2188
+ "special": true
2189
+ },
2190
+ "131121": {
2191
+ "content": "<longcat_audiotext_end>",
2192
+ "lstrip": false,
2193
+ "normalized": false,
2194
+ "rstrip": false,
2195
+ "single_word": false,
2196
+ "special": true
2197
+ },
2198
+ "131122": {
2199
+ "content": "<longcat_audiotext_pad>",
2200
+ "lstrip": false,
2201
+ "normalized": false,
2202
+ "rstrip": false,
2203
+ "single_word": false,
2204
+ "special": true
2205
+ },
2206
+ "131123": {
2207
+ "content": "<longcat_audiogen_start>",
2208
+ "lstrip": false,
2209
+ "normalized": false,
2210
+ "rstrip": false,
2211
+ "single_word": false,
2212
+ "special": true
2213
+ },
2214
+ "131124": {
2215
+ "content": "<longcat_audiogen_end>",
2216
+ "lstrip": false,
2217
+ "normalized": false,
2218
+ "rstrip": false,
2219
+ "single_word": false,
2220
+ "special": true
2221
+ }
2222
+ },
2223
+ "additional_special_tokens": [
2224
+ "<mask_131048>",
2225
+ "<mask_131049>",
2226
+ "<mask_131050>",
2227
+ "<mask_131051>",
2228
+ "<mask_131052>",
2229
+ "<mask_131053>",
2230
+ "<mask_131054>",
2231
+ "<mask_131055>",
2232
+ "<mask_131056>",
2233
+ "<mask_131057>",
2234
+ "<mask_131058>",
2235
+ "<mask_131059>",
2236
+ "<mask_131060>",
2237
+ "<mask_131061>",
2238
+ "<mask_131062>",
2239
+ "<mask_131063>",
2240
+ "<mask_131064>",
2241
+ "<mask_131065>",
2242
+ "<longcat_img_token_size>",
2243
+ "</longcat_img_token_size>",
2244
+ "<mask_131068>",
2245
+ "<mask_131069>",
2246
+ "<mask_131070>",
2247
+ "<mask_131071>",
2248
+ "<longcat_point_start>",
2249
+ "<longcat_point_end>",
2250
+ "<longcat_point_delim>",
2251
+ "<longcat_polygon_start>",
2252
+ "<longcat_polygon_end>",
2253
+ "<mask_131077>",
2254
+ "<mask_131078>",
2255
+ "<longcat_audio_start>",
2256
+ "<longcat_audio_end>",
2257
+ "<longcat_audio_pad>",
2258
+ "<longcat_img_start>",
2259
+ "<longcat_img_end>",
2260
+ "<longcat_img_pad>",
2261
+ "<longcat_img_newline>",
2262
+ "<longcat_box_start>",
2263
+ "<longcat_box_end>",
2264
+ "<longcat_box_delim>",
2265
+ "<longcat_ref_start>",
2266
+ "<longcat_ref_end>",
2267
+ "<longcat_img_delim>",
2268
+ "<longcat_audio_delim>",
2269
+ "<longcat_video_palce>",
2270
+ "<longcat_video_start>",
2271
+ "<longcat_video_end>",
2272
+ "<longcat_audiotext_start>",
2273
+ "<longcat_audiotext_end>",
2274
+ "<longcat_audiotext_pad>",
2275
+ "<longcat_audiogen_start>",
2276
+ "<longcat_audiogen_end>"
2277
+ ],
2278
+ "bos_token": "<longcat_s>",
2279
+ "chat_template": "{%- set tool_choice = tool_choice | default('auto') %}\n{%- set ns = namespace(tool_types = [], last_query_index = -1, suffix_to_move = '') %}\n\n{%- if tools and tool_choice != 'none' %}\n {{- \"<longcat_tool_declare>\\n\"-}}\n {{- \"# Tools\\n\" }}\n {{- \"You have access to the following tools:\\n\\n\" }}\n {%- for tool in tools %}\n {%- if tool.type not in ns.tool_types %}\n {%- set ns.tool_types = ns.tool_types + [tool.type] %}\n {{- \"## Tool namespace: \" ~ tool.type ~ \"\\n\\n\" }}\n {%- endif %}\n {%- if tool.type == 'code_interpreter' %}\n {%- set tool = {\"type\":\"code_interpreter\",\"function\":{\"name\":\"code_interpreter_preview\",\"description\":\"The code will be executed in a stateful Jupyter notebook sandbox environment, only supports local computation, data processing, and file operations.\\nCode sandbox environment (network isolated) Any external network requests or online API calls are prohibited.\\nIf online functionality is needed, please use other permitted tools.\\nCode will respond with the output of the execution or time out after 60.0 seconds. \",\"parameters\":{\"type\":\"object\",\"properties\":{\"language\":{\"type\":\"string\",\"description\":\"The programming language of the code to be executed. Available values: python (Default), java, go, js, ts, c, c++.\"},\"code\":{\"type\":\"string\",\"description\":\"Python code to be executed must not include the following:\\n- Importing network libraries such as requests, httplib, etc.\\n- Any form of HTTP requests.\\n- External API calls.\\n- Network port operations. Example: ```python\\nimport pandas as pd\\npd.DataFrame({'A':[1,2]})\\n```\"},\"timeout\":{\"type\":\"number\",\"description\":\"The maximum execution time of the code, in seconds. Default is 60.0.\"}}},\"required\":[\"code\"]}} %}\n {%- endif %}\n {{- \"### Tool name: \" + tool.function.name + \"\\n\" }}\n {{- \"Description: \" + tool.function.description + \"\\n\\n\" }}\n {{- \"InputSchema: \" + tool.function.parameters | tojson(ensure_ascii=False) + \"\\n\\n\" }}\n {%- endfor %}\n {{- '**Note**: For each function call, output the function name and arguments within the following XML format:\\n<longcat_tool_call>{function-name}\\n<longcat_arg_key>{arg-key-1}</longcat_arg_key>\\n<longcat_arg_value>{arg-value-1}</longcat_arg_value>\\n<longcat_arg_key>{arg-key-2}</longcat_arg_key>\\n<longcat_arg_value>{arg-value-2}</longcat_arg_value>\\n...\\n</longcat_tool_call>\\n' }}\n {{- \"</longcat_tool_declare>\"-}}\n {%- for idx in range(messages|length - 1) %}\n {%- set msg = messages[idx] %}\n {%- if msg.role == 'assistant' and not msg.tool_calls %}\n {%- set ns.last_query_index = idx %}\n {%- endif %}\n {%- endfor%}\n{%- endif %}\n\n{%- for msg in messages %}\n {%- if msg.role == \"system\" %}\n {{- \"<longcat_system>\" + msg.content }}\n {%- elif msg.role == \"user\" %}\n {{- \"<longcat_user>\" }}\n {%- if msg[\"files\"] %}\n {{- '<longcat_files>\\n' ~ msg.files | tojson(indent=2) ~ '\\n</longcat_files>' }}\n {%- endif %}\n\n {%- if add_generation_prompt and loop.last and msg.content is string and msg.content.endswith(\"<longcat_img_start>\") %}\n {%- set ns.suffix_to_move = \"<longcat_img_start>\" %}\n {{- msg.content[:-19] }}\n {%- elif add_generation_prompt and loop.last and msg.content is string and msg.content.endswith(\"<longcat_audiogen_start>\") %}\n {%- set ns.suffix_to_move = \"<longcat_audiogen_start>\" %}\n {{- msg.content[:-24] }}\n {%- else %}\n {{- msg.content }}\n {%- endif %}\n\n {%- elif msg.role == \"assistant\" %}\n {{- \"<longcat_assistant>\" }}\n {%- if enable_thinking == true and msg.reasoning_content and ns.tool_types != [] and loop.index0 > ns.last_query_index %}\n {{- \"\\n<longcat_think>\\n\" ~ msg.reasoning_content ~ \"\\n</longcat_think>\\n\" }}\n {%- endif %}\n {%- if msg.content%}\n {{- msg.content }}\n {%- endif %}\n {%- if msg.tool_calls %}\n {%- for tool_call in msg.tool_calls -%}\n {{- \"<longcat_tool_call>\" ~ tool_call.function.name ~ \"\\n\" -}}\n {% set _args = tool_call.function.arguments %}\n {% for k, v in _args.items() %}\n {{- \"<longcat_arg_key>\" ~ k ~ \"</longcat_arg_key>\\n\" -}}\n {{- \"<longcat_arg_value>\" ~ (v if v is string else v | tojson(ensure_ascii=False)) ~ \"</longcat_arg_value>\\n\" -}}\n {% endfor %}\n {{- \"</longcat_tool_call>\\n\" }}\n {%- endfor %}\n {%- endif %}\n {{- \"</longcat_s>\" -}}\n {%- elif msg.role == \"tool\" %}\n {%- if messages[loop.index0 - 1].role != \"tool\"%}\n {{- \"<longcat_user>\" -}}\n {%- endif %}\n {{- \"<longcat_tool_response>\" ~ msg.content ~ \"</longcat_tool_response>\"-}}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {%- if enable_thinking == true %}\n {{- \" /think_on\" }}\n {%- if thinking_budget %}\n {%- if thinking_budget < 1024 %}\n {%- set thinking_budget = 1024 %}\n {%- endif%}\n {{- \"\\nthinking_budget: < \" ~ thinking_budget ~ \".\"}}\n {%- endif %}\n {{- \" <longcat_assistant><longcat_think>\\n\"}}\n {%- elif enable_thinking == false %}\n {{- \" /think_off <longcat_assistant><longcat_think>\\n\\n</longcat_think>\\n\" }}\n {%- else %}\n {{- \"<longcat_assistant>\" ~ ns.suffix_to_move }}\n {%- endif %}\n{%- endif %}",
2280
+ "clean_up_tokenization_spaces": false,
2281
+ "eos_token": "</longcat_s>",
2282
+ "model_max_length": 131072,
2283
+ "pad_token": "<longcat_pad>",
2284
+ "sp_model_kwargs": {},
2285
+ "tokenizer_class": "BloomTokenizer",
2286
+ "unk_token": "<longcat_unk>",
2287
+ "image_start_token": "<longcat_img_start>",
2288
+ "image_end_token": "<longcat_img_end>",
2289
+ "image_pad_token": "<longcat_img_pad>",
2290
+ "image_newline_token": "<longcat_img_newline>",
2291
+ "audio_start_token": "<longcat_audio_start>",
2292
+ "audio_end_token": "<longcat_audio_end>",
2293
+ "audio_pad_token": "<longcat_audio_pad>"
2294
+ }