code
Browse files- .gitattributes +4 -0
- LICENSE +21 -0
- README.md +388 -0
- assets/book.png +3 -0
- assets/evaluation.png +3 -0
- assets/longcat_logo.svg +1 -0
- assets/math1.wav +3 -0
- assets/overview.png +3 -0
- assets/system_audio.wav +3 -0
- assets/vc_zh3.wav +3 -0
- config.json +285 -0
- configuration_longcat_next.py +152 -0
- configuration_longcat_ngram.py +218 -0
- cosy24k_vocoder.py +552 -0
- environment.yml +8 -0
- generation_config.json +40 -0
- image_refiner.py +748 -0
- modeling_longcat_next.py +824 -0
- modeling_longcat_ngram.py +426 -0
- modular_longcat_next.py +157 -0
- modular_longcat_next_audio.py +2039 -0
- modular_longcat_next_visual.py +1077 -0
- parse_model_response.py +158 -0
- preprocessor_config.json +19 -0
- processing_longcat_next.py +279 -0
- refiner_modules.py +1330 -0
- requirements-post.txt +1 -0
- requirements.txt +7 -0
- tokenizer.json +0 -0
- tokenizer_config.json +2294 -0
.gitattributes
CHANGED
|
@@ -1,3 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.wav filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 5 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 6 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 7 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026 Meituan
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in
|
| 13 |
+
all copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
library_name: LongCat-Next
|
| 4 |
+
pipeline_tag: any-to-any
|
| 5 |
+
tags:
|
| 6 |
+
- transformers
|
| 7 |
+
- multimodal
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# LongCat-Next
|
| 11 |
+
|
| 12 |
+
<div align="center">
|
| 13 |
+
<img src="https://raw.githubusercontent.com/meituan-longcat/LongCat-Flash-Chat/main/figures/longcat_logo.svg"
|
| 14 |
+
width="300"
|
| 15 |
+
alt="LongCat Logo"/>
|
| 16 |
+
</div>
|
| 17 |
+
|
| 18 |
+
<hr>
|
| 19 |
+
|
| 20 |
+
<div align="center" style="line-height: 1;">
|
| 21 |
+
<a href="https://longcat.chat/longcat-next/intro" target="_blank" style="margin: 2px;">
|
| 22 |
+
<img alt="Blog" src="https://img.shields.io/badge/Blog-LongCatNext-white?logo=safari&logoColor=white&color=purple" style="display: inline-block; vertical-align: middle;"/>
|
| 23 |
+
</a>
|
| 24 |
+
<a href="https://huggingface.co/meituan-longcat" target="_blank" style="margin: 2px;">
|
| 25 |
+
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-LongCatNext-ffc107?color=ffc107&logoColor=white" style="display: inline-block; vertical-align: middle;"/>
|
| 26 |
+
</a>
|
| 27 |
+
<a href="https://github.com/meituan-longcat/LongCat-Next" target="_blank" style="margin: 2px;">
|
| 28 |
+
<img alt="GitHub" src="https://img.shields.io/badge/GitHub-LongCatNext-white?logo=github&logoColor=white&color=a4b5d5" style="display: inline-block; vertical-align: middle;"/>
|
| 29 |
+
</a>
|
| 30 |
+
<a href="https://longcat.chat/longcat-next" target="_blank" style="margin: 2px;">
|
| 31 |
+
<img alt="Demo" src="https://img.shields.io/badge/Demo-LongCatNext-white?logo=googleplay&logoColor=white&color=eabcdd" style="display: inline-block; vertical-align: middle;"/>
|
| 32 |
+
</a>
|
| 33 |
+
</div>
|
| 34 |
+
|
| 35 |
+
<div align="center" style="line-height: 1;">
|
| 36 |
+
<a href="https://github.com/meituan-longcat/LongCat-Flash-Chat/blob/main/figures/wechat_official_accounts.png" target="_blank" style="margin: 2px;">
|
| 37 |
+
<img alt="Wechat" src="https://img.shields.io/badge/WeChat-LongCat-brightgreen?logo=wechat&logoColor=white" style="display: inline-block; vertical-align: middle;"/>
|
| 38 |
+
</a>
|
| 39 |
+
<a href="https://x.com/Meituan_LongCat" target="_blank" style="margin: 2px;">
|
| 40 |
+
<img alt="Twitter Follow" src="https://img.shields.io/badge/Twitter-LongCat-white?logo=x&logoColor=white" style="display: inline-block; vertical-align: middle;"/>
|
| 41 |
+
</a>
|
| 42 |
+
</div>
|
| 43 |
+
|
| 44 |
+
<div align="center" style="line-height: 1;">
|
| 45 |
+
<a href="https://huggingface.co/meituan-longcat/LongCat-Next/blob/main/LICENSE" style="margin: 2px;">
|
| 46 |
+
<img alt="License" src="https://img.shields.io/badge/License-MIT-f5de53?&color=f5de53" style="display: inline-block; vertical-align: middle;"/>
|
| 47 |
+
</a>
|
| 48 |
+
</div>
|
| 49 |
+
|
| 50 |
+
<p align="center">
|
| 51 |
+
<a href="https://github.com/meituan-longcat/LongCat-Next/blob/main/tech_report.pdf">
|
| 52 |
+
<b>Tech Report</b> 📄
|
| 53 |
+
</a>
|
| 54 |
+
</p>
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
## Model Introduction
|
| 61 |
+
|
| 62 |
+

|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
We develop **LongCat-Next**, a native multimodal model that processes text, vision, and audio under a single autoregressive objective with minimal inductive bias beyond the language paradigm. As an industrial-strength foundation model with A3B model size, it excels at seeing, creating, and talking, achieving strong performance across a wide range of multimodal benchmarks. In particular, leveraging semantically complete discrete representations, it surpasses the long-standing performance ceiling of discrete vision modeling on understanding tasks, and provides a unified solution for visual understanding and generation. This success demonstrates that discrete tokens can universally represent multimodal signals and be deeply internalized within a single discrete embedding space. We further provide extensive experiments to analyze this unified discrete training paradigm and uncover several interesting findings.
|
| 66 |
+
|
| 67 |
+
As a meaningful attempt toward native multimodality, we open-source the **LongCat-Next** and its tokenizers, hoping to foster further research and development in the community.
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
### Key Features
|
| 71 |
+
|
| 72 |
+
This work primarily addresses the fundamental barrier to native multimodality through a design philosophy that prioritizes simplicity, treating vision and audio as intrinsic extensions of language. As a step toward this goal, we present LongCat-Next, a discrete native multimodal model that achieves industrial-strength performance within discrete frameworks while remaining highly competitive across a wide range of specialized domains. Built upon the LongCat-Flash-Lite MoE backbone (A3B) as a _multi-task_ learner, the model unifies language, vision, and audio within a single discrete framework. In this paper, we make the following principal contributions:
|
| 73 |
+
|
| 74 |
+
#### 🌟 Discrete Native Autoregression Paradigm (DiNA).
|
| 75 |
+
We introduce DiNA, a unified paradigm that extends next-token prediction from language to native multimodality, which internalizes diverse modalities into a shared token space. It simplifies multimodal modeling by creating modality-aware tokenizer-detokenizer pairs and leveraging the established training infrastructure of large language models.
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
#### 🌟 Semantic Completeness for Discrete Visual Representation.
|
| 79 |
+
We improve discrete visual modeling by combining Semantic-and-Aligned Encoders (SAE) with Residual Vector Quantization (RVQ). This integration creates hierarchical discrete tokens that preserve both semantic abstraction and fine-grained visual details, surpassing traditional representation limitations.
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
#### 🌟 Discrete Native-Resolution Vision Transformer (dNaViT).
|
| 83 |
+
Analogous to linguistic tokenizers, we propose dNaViT as a highly flexible, unified discrete interface for vision that extracts semantic features as "visual words", constructing a hierarchical representation space supporting dynamic tokenization and detokenization. dNaViT integrates seamlessly with large language models, ensuring high performance without degradation.
|
| 84 |
+
|
| 85 |
+
#### 🌟 Exceling in Seeing, Creating, and Talking in a Unified Model.
|
| 86 |
+
Within the framework of DiNA, visual understanding and generation are elegantly reformulated as two manifestations of the same predictive process without performance compromise. This formulation bridges the long-standing architectural divide while introducing minimal interference between these traditionally competing objectives and preserving core language capabilities. Remarkably, LongCat-Next achieves competitive performance with specialized understanding models, while maintaining strong generative quality even under a 28× compression ratio, particularly in text rendering, while also excelling in advanced speech comprehension, low-latency voice conversation, and customizable voice cloning.
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
Please refer to our [technical report](./tech_report.pdf) for details!
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
## Evaluation Results
|
| 94 |
+
|
| 95 |
+

|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
## Quick Start
|
| 101 |
+
To use LongCat-Next with transformers, we need at least 3 GPUs (80GB VRAM each, e.g., H100/A100 80GB), and we recommend the following environment:
|
| 102 |
+
* `python` >= 3.10
|
| 103 |
+
* `torch` >= 2.6
|
| 104 |
+
* `transformers` >= 4.57.6
|
| 105 |
+
* `accelerate` >= 1.10.0
|
| 106 |
+
|
| 107 |
+
```shell
|
| 108 |
+
# (Install python=3.10, ffffmpeg<7, soundfile==0.13.1)
|
| 109 |
+
conda env create -f environment.yml -v
|
| 110 |
+
|
| 111 |
+
# (Install torch and other pip dependencies)
|
| 112 |
+
pip install -r requirements.txt && pip install -r requirements-post.txt --no-build-isolation
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
Basic Usage Example:
|
| 116 |
+
- Remember to modify `WEIGHT_PATH_TO_LONGCAT_NEXT` in `./config.json`, because decoders use lazy loading.
|
| 117 |
+
|
| 118 |
+
```python
|
| 119 |
+
import torch
|
| 120 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor
|
| 121 |
+
|
| 122 |
+
# Load model
|
| 123 |
+
model_name = "meituan-longcat/LongCat-Next"
|
| 124 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 125 |
+
model_name,
|
| 126 |
+
torch_dtype=torch.bfloat16,
|
| 127 |
+
device_map="auto",
|
| 128 |
+
trust_remote_code=True,
|
| 129 |
+
)
|
| 130 |
+
model.eval()
|
| 131 |
+
|
| 132 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, fix_mistral_regex=True)
|
| 133 |
+
model.text_tokenizer = tokenizer # Dynamic binding
|
| 134 |
+
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
| 135 |
+
|
| 136 |
+
# Set messages
|
| 137 |
+
messages = [
|
| 138 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 139 |
+
{"role": "user", "content": "What book is this?<longcat_img_start>./assets/book.png<longcat_img_end>"}
|
| 140 |
+
]
|
| 141 |
+
|
| 142 |
+
# Apply chat-template
|
| 143 |
+
text_input = tokenizer.apply_chat_template(
|
| 144 |
+
messages,
|
| 145 |
+
tokenize=False,
|
| 146 |
+
add_generation_prompt=True,
|
| 147 |
+
)
|
| 148 |
+
print(f"{text_input=}")
|
| 149 |
+
|
| 150 |
+
# Preprocessing
|
| 151 |
+
text_inputs, visual_inputs, audio_inputs = processor(text=text_input, return_tensors="pt")
|
| 152 |
+
text_inputs = text_inputs.to(model.device)
|
| 153 |
+
if visual_inputs is not None:
|
| 154 |
+
visual_inputs = visual_inputs.to(model.device)
|
| 155 |
+
if audio_inputs is not None:
|
| 156 |
+
audio_inputs = audio_inputs.to(model.device)
|
| 157 |
+
|
| 158 |
+
# AR
|
| 159 |
+
with torch.no_grad():
|
| 160 |
+
outputs = model.generate(
|
| 161 |
+
input_ids=text_inputs["input_ids"],
|
| 162 |
+
visual_inputs=visual_inputs,
|
| 163 |
+
audio_inputs=audio_inputs,
|
| 164 |
+
return_dict_in_generate=True,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# Text decoding
|
| 168 |
+
output_input_ids = outputs.sequences
|
| 169 |
+
text_output = tokenizer.decode(output_input_ids[0][len(text_inputs["input_ids"][0]):], skip_special_tokens=True)
|
| 170 |
+
print(f"{text_output=}")
|
| 171 |
+
|
| 172 |
+
# Images decoding
|
| 173 |
+
output_visual_ids = outputs.visual_ids
|
| 174 |
+
if output_visual_ids.size(0) > 0:
|
| 175 |
+
image_path_list = model.model.decode_visual_ids_and_save(
|
| 176 |
+
output_visual_ids,
|
| 177 |
+
save_prefix="./output_image",
|
| 178 |
+
**model.generation_config.visual_generation_config["custom_params"],
|
| 179 |
+
)
|
| 180 |
+
print(f"{image_path_list=}")
|
| 181 |
+
|
| 182 |
+
# Audio decoding
|
| 183 |
+
output_audio_text_ids = outputs.audio_text_ids
|
| 184 |
+
output_audio_ids = outputs.audio_ids
|
| 185 |
+
if output_audio_text_ids.size(-1) > 0:
|
| 186 |
+
audio_text = tokenizer.decode(output_audio_text_ids[0], skip_special_tokens=True)
|
| 187 |
+
print(f"{audio_text=}")
|
| 188 |
+
if output_audio_ids.size(0) > 0:
|
| 189 |
+
audio_path_list = model.model.decode_audio_ids_and_save(
|
| 190 |
+
output_audio_ids,
|
| 191 |
+
save_prefix="./output_audio",
|
| 192 |
+
**model.generation_config.audio_generation_config["custom_params"],
|
| 193 |
+
)
|
| 194 |
+
print(f"{audio_path_list=}")
|
| 195 |
+
```
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
<details>
|
| 199 |
+
<summary>Text - Tool Calling Example</summary>
|
| 200 |
+
|
| 201 |
+
```python
|
| 202 |
+
from parse_model_response import parse_model_response
|
| 203 |
+
|
| 204 |
+
tools = [
|
| 205 |
+
{
|
| 206 |
+
"type": "function",
|
| 207 |
+
"function": {
|
| 208 |
+
"name": "func_add",
|
| 209 |
+
"description": "Calculate the sum of two numbers",
|
| 210 |
+
"parameters": {
|
| 211 |
+
"type": "object",
|
| 212 |
+
"properties": {
|
| 213 |
+
"x1": {"type": "number", "description": "The first addend"},
|
| 214 |
+
"x2": {"type": "number", "description": "The second addend"}
|
| 215 |
+
},
|
| 216 |
+
"required": ["x1", "x2"]
|
| 217 |
+
}
|
| 218 |
+
}
|
| 219 |
+
}
|
| 220 |
+
]
|
| 221 |
+
messages = [
|
| 222 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 223 |
+
{"role": "user", "content": "Please tell me what is $$125679 + 234519$$?"},
|
| 224 |
+
{
|
| 225 |
+
"role": "assistant",
|
| 226 |
+
"content": "I'll calculate the sum of 125679 and 234519 for you.",
|
| 227 |
+
"tool_calls": [{"type": "function", "function": {"name": "func_add", "arguments": {"x1": 125679, "x2": 234519}}}]
|
| 228 |
+
},
|
| 229 |
+
{"role": "tool", "name": "func_add", "content": '{"ans": 360198}'}
|
| 230 |
+
]
|
| 231 |
+
|
| 232 |
+
text_input = tokenizer.apply_chat_template(
|
| 233 |
+
messages,
|
| 234 |
+
tools=tools, # add tools here
|
| 235 |
+
tokenize=False,
|
| 236 |
+
add_generation_prompt=True,
|
| 237 |
+
)
|
| 238 |
+
print(f"{text_input=}")
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
# Preprocessing - AR - Text decoding
|
| 242 |
+
...
|
| 243 |
+
|
| 244 |
+
# Results parsing
|
| 245 |
+
parsed_message = parse_model_response(text_output.strip("\n"), tools)
|
| 246 |
+
print(f"{parsed_message=}")
|
| 247 |
+
```
|
| 248 |
+
See [`parse_model_response.py`](./parse_model_response.py) for detailed implementation and examples.
|
| 249 |
+
|
| 250 |
+
</details>
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
<details>
|
| 254 |
+
<summary>Image - Understanding Example</summary>
|
| 255 |
+
|
| 256 |
+
```python
|
| 257 |
+
# Simply replace the messages in the main example with the messages below.
|
| 258 |
+
messages = [
|
| 259 |
+
{"role": "user", "content": "What book is this?<longcat_img_start>./assets/book.png<longcat_img_end>"}
|
| 260 |
+
]
|
| 261 |
+
```
|
| 262 |
+
|
| 263 |
+
</details>
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
<details>
|
| 267 |
+
<summary>Image - Generation Example</summary>
|
| 268 |
+
|
| 269 |
+
```python
|
| 270 |
+
# Simply replace the messages in the main example with the messages below.
|
| 271 |
+
messages = [
|
| 272 |
+
{"role": "system", "content": ""},
|
| 273 |
+
{"role": "user", "content": "A small kitten sitting naturally on a moss-covered forest floor, centered in the frame, holding a rectangular wooden sign gently with its front paws resting over the top edge. The kitten has soft, fluffy fur, a natural relaxed posture, and a calm, curious expression with a slightly open mouth (not exaggerated), looking directly at the camera.\n\nThe sign is positioned firmly in front of the kitten\'s chest, supported by its paws, with realistic contact and no floating effect. The board reads \"LongCat-Next: When Modalities Internalize as Multilingual Tokens\" in clean, sharp black text, perfectly legible.\n\nThe environment is a lush forest with tall trees, ferns, and soft green foliage. The ground is covered with moss and small plants. Background softly blurred with natural depth of field. Lighting is soft, diffused sunlight filtering through the trees, creating gentle highlights and shadows. Realistic photography style, natural colors, high detail, no cartoonish exaggeration.<longcat_img_start>"}
|
| 274 |
+
]
|
| 275 |
+
```
|
| 276 |
+
|
| 277 |
+
</details>
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
<details>
|
| 281 |
+
<summary>Audio - Audio-to-Text Example</summary>
|
| 282 |
+
|
| 283 |
+
```python
|
| 284 |
+
# Simply replace the messages in the main example with the messages below.
|
| 285 |
+
messages = [
|
| 286 |
+
{"role": "user", "content": "<longcat_audio_start>./assets/math1.wav<longcat_audio_end>"}
|
| 287 |
+
]
|
| 288 |
+
|
| 289 |
+
```
|
| 290 |
+
|
| 291 |
+
</details>
|
| 292 |
+
|
| 293 |
+
<details>
|
| 294 |
+
<summary>Audio - Audio-to-Audio Example</summary>
|
| 295 |
+
|
| 296 |
+
```python
|
| 297 |
+
# Simply replace the messages in the main example with the messages below.
|
| 298 |
+
messages = [
|
| 299 |
+
{"role": "system", "content": "Replicate the voice in the audio clip to formulate an answer:<longcat_audio_start>./assets/system_audio.wav<longcat_audio_end>"},
|
| 300 |
+
{"role": "user", "content": "<longcat_audio_start>./assets/math1.wav<longcat_audio_end><longcat_audiogen_start>"}
|
| 301 |
+
]
|
| 302 |
+
```
|
| 303 |
+
|
| 304 |
+
</details>
|
| 305 |
+
|
| 306 |
+
<details>
|
| 307 |
+
<summary>Audio - Speech Synthesis Example</summary>
|
| 308 |
+
|
| 309 |
+
```python
|
| 310 |
+
# Simply replace the messages in the main example with the messages below.
|
| 311 |
+
messages = [
|
| 312 |
+
{"role": "system", "content": "Replicate the voice in the audio clip to formulate an answer:<longcat_audio_start>./assets/vc_zh3.wav<longcat_audio_end>"},
|
| 313 |
+
{"role": "user", "content": "用这个声音合成以下内容:明天的meeting在三楼的Conference Room举行。<longcat_audiogen_start>"}
|
| 314 |
+
]
|
| 315 |
+
```
|
| 316 |
+
|
| 317 |
+
</details>
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
<!-- > [!Tip] -->
|
| 321 |
+
|
| 322 |
+
> We recommend using the following set of sampling parameters for generation:
|
| 323 |
+
>
|
| 324 |
+
> - Text: `{"max_new_tokens":2048,"do_sample":false}`
|
| 325 |
+
> - Image - Understanding: `{"max_new_tokens":1024,"do_sample":true,"temperature":0.4,"top_k":40,"top_p":0.85,"repetition_penalty":1.1}`
|
| 326 |
+
> - Image - Generation: `{"max_new_tokens":2048,"do_sample":false,"visual_generation_config":{"do_sample":true,"temperature":0.5,"top_p":0.75,"top_k":1024,"custom_params":{"cfg_scale":3,"token_h":37,"token_w":37,"anyres_prefix":"<longcat_img_token_size>{h} {w}</longcat_img_token_size>"}}}`
|
| 327 |
+
> - Audio - Audio-to-Text: `{"max_new_tokens":1024,"do_sample":true,"temperature":0.2,"top_k":20,"top_p":0.85,"repetition_penalty":1.1}`
|
| 328 |
+
> - Audio - Audio-to-Audio/Speech Synthesis: `{"max_new_tokens":2048,"do_sample":true,"temperature":0.2,"top_k":20,"top_p":0.85,"repetition_penalty":1.1,"audio_generation_config":{"audio_parallel_decoding":false,"do_sample":true,"temperature":0.5,"top_k":5,"top_p":0.85,"repetition_penalty":1.3,"custom_params":{"sampling_rate":24000,"wave_concat_overlap":1200}}}`
|
| 329 |
+
>
|
| 330 |
+
> Please note that the support for sampling parameters varies according to inference frameworks(For transformers, the inference parameter configuration is located in `./generation_config.json`).
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
## Deployment
|
| 335 |
+
|
| 336 |
+
We have implemented basic adaptations in SGLang(Code is being uploaded) to support the deployment of LongCat-Next.
|
| 337 |
+
|
| 338 |
+
```shell
|
| 339 |
+
git clone [TBU]
|
| 340 |
+
cd nmm_infer
|
| 341 |
+
git checkout master
|
| 342 |
+
sh setup.sh
|
| 343 |
+
```
|
| 344 |
+
|
| 345 |
+
```shell
|
| 346 |
+
# Require CUDA >= 12.9
|
| 347 |
+
|
| 348 |
+
# Setup environment
|
| 349 |
+
source create_env.sh
|
| 350 |
+
source set_env.sh
|
| 351 |
+
|
| 352 |
+
# Run tests
|
| 353 |
+
python3 demo.py \
|
| 354 |
+
--model-path meituan-longcat/LongCat-Next \
|
| 355 |
+
--sequential \
|
| 356 |
+
--output-dir output \
|
| 357 |
+
--tasks vis_gen vis_und aud_qa spk_syn
|
| 358 |
+
|
| 359 |
+
```
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
## License Agreement
|
| 363 |
+
This repository, including both the model weights and the source code, is released under the **MIT License**.
|
| 364 |
+
|
| 365 |
+
Any contributions to this repository are licensed under the MIT License, unless otherwise stated. This license does not grant any rights to use Meituan trademarks or patents.
|
| 366 |
+
|
| 367 |
+
For details, see the [LICENSE](./LICENSE) file.
|
| 368 |
+
|
| 369 |
+
## Usage Considerations
|
| 370 |
+
This model has not been specifically designed or comprehensively evaluated for every possible downstream application.
|
| 371 |
+
|
| 372 |
+
Developers should take into account the known limitations of large language models, including performance variations across different languages, and carefully assess accuracy, safety, and fairness before deploying the model in sensitive or high-risk scenarios.
|
| 373 |
+
It is the responsibility of developers and downstream users to understand and comply with all applicable laws and regulations relevant to their use case, including but not limited to data protection, privacy, and content safety requirements.
|
| 374 |
+
|
| 375 |
+
Nothing in this Model Card should be interpreted as altering or restricting the terms of the MIT License under which the model is released.
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
<!-- ## Citation
|
| 379 |
+
|
| 380 |
+
We kindly encourage citation of our work if you find it useful.
|
| 381 |
+
|
| 382 |
+
```
|
| 383 |
+
|
| 384 |
+
``` -->
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
## Contact
|
| 388 |
+
Please contact us at <a href="mailto:longcat-team@meituan.com">longcat-team@meituan.com</a> or open an issue if you have any questions.
|
assets/book.png
ADDED
|
Git LFS Details
|
assets/evaluation.png
ADDED
|
Git LFS Details
|
assets/longcat_logo.svg
ADDED
|
|
assets/math1.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e88c12d17ba1b6d8a28fa6688311222673db0f958a3679347f03ba4afd4b78c2
|
| 3 |
+
size 1140560
|
assets/overview.png
ADDED
|
Git LFS Details
|
assets/system_audio.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bbb21a5cd57013406e1c18e8f267d05197bbbc3fdb8a65038d9c5a7799b9357a
|
| 3 |
+
size 254478
|
assets/vc_zh3.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c8313c738deac97e9c36cb861a85a896c9bbdaa22fe9f9f432feace766a75c65
|
| 3 |
+
size 1282618
|
config.json
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"LongcatNextForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"attention_bias": false,
|
| 6 |
+
"attention_dropout": 0.0,
|
| 7 |
+
"auto_map": {
|
| 8 |
+
"AutoConfig": "configuration_longcat_next.LongcatNextConfig",
|
| 9 |
+
"AutoModel": "modeling_longcat_next.LongcatNextModel",
|
| 10 |
+
"AutoModelForCausalLM": "modeling_longcat_next.LongcatNextForCausalLM"
|
| 11 |
+
},
|
| 12 |
+
"vocab_size": 282624,
|
| 13 |
+
"hidden_size": 3072,
|
| 14 |
+
"ffn_hidden_size": 6144,
|
| 15 |
+
"expert_ffn_hidden_size": 1024,
|
| 16 |
+
"num_layers": 14,
|
| 17 |
+
"num_attention_heads": 32,
|
| 18 |
+
"kv_lora_rank": 512,
|
| 19 |
+
"q_lora_rank": 1536,
|
| 20 |
+
"qk_rope_head_dim": 64,
|
| 21 |
+
"v_head_dim": 128,
|
| 22 |
+
"qk_nope_head_dim": 128,
|
| 23 |
+
"mla_scale_q_lora": true,
|
| 24 |
+
"mla_scale_kv_lora": true,
|
| 25 |
+
"routed_scaling_factor": 6.0,
|
| 26 |
+
"n_routed_experts": 256,
|
| 27 |
+
"rms_norm_eps": 1e-5,
|
| 28 |
+
"use_cache": true,
|
| 29 |
+
"bos_token_id": 1,
|
| 30 |
+
"eos_token_id": 2,
|
| 31 |
+
"rope_theta": 10000000,
|
| 32 |
+
"max_position_embeddings": 131072,
|
| 33 |
+
"zero_expert_num": 128,
|
| 34 |
+
"zero_expert_type": "identity",
|
| 35 |
+
"moe_topk": 12,
|
| 36 |
+
"ngram_vocab_size_ratio": 78,
|
| 37 |
+
"emb_neighbor_num": 4,
|
| 38 |
+
"emb_split_num": 4,
|
| 39 |
+
"torch_dtype": "bfloat16",
|
| 40 |
+
"transformers_version": "4.57.6",
|
| 41 |
+
|
| 42 |
+
"text_vocab_size": 131072,
|
| 43 |
+
"text_vocab_plus_multimodal_special_token_size": 131125,
|
| 44 |
+
"visual_embedding_layer_intermediate_size": 8192,
|
| 45 |
+
"visual_embedding_layer_hidden_act": "silu",
|
| 46 |
+
"visual_offset": 150581,
|
| 47 |
+
"audio_offset": 131125,
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
"visual_config": {
|
| 51 |
+
"image_start_token_id": 131106,
|
| 52 |
+
"image_end_token_id": 131107,
|
| 53 |
+
"image_pad_token_id": 131108,
|
| 54 |
+
"image_newline_token_id": 131109,
|
| 55 |
+
|
| 56 |
+
"_attn_implementation": "flash_attention_2",
|
| 57 |
+
"hidden_size": 1280,
|
| 58 |
+
|
| 59 |
+
"image_head_transformer_dims": 2048,
|
| 60 |
+
"image_head_transformer_ffn_scale": 16,
|
| 61 |
+
"image_head_transformer_layers": 4,
|
| 62 |
+
|
| 63 |
+
"vq_config": {
|
| 64 |
+
"codebook_dim": 3584,
|
| 65 |
+
"codebook_size": 16384,
|
| 66 |
+
"codebook_sizes": [
|
| 67 |
+
16384,
|
| 68 |
+
16384,
|
| 69 |
+
16384,
|
| 70 |
+
16384,
|
| 71 |
+
16384,
|
| 72 |
+
16384,
|
| 73 |
+
16384,
|
| 74 |
+
16384
|
| 75 |
+
],
|
| 76 |
+
"decay": 0.99,
|
| 77 |
+
"depth": 8,
|
| 78 |
+
|
| 79 |
+
"commit_loss_ratio": 0.25,
|
| 80 |
+
"entropy_loss_ratio": 0,
|
| 81 |
+
|
| 82 |
+
"in_channels": 3584,
|
| 83 |
+
"quant_conv": true,
|
| 84 |
+
"quantizer_type": "rq",
|
| 85 |
+
"restart_unused_codes": true,
|
| 86 |
+
"shared_codebook": true,
|
| 87 |
+
|
| 88 |
+
"vq_loss_ratio": 0
|
| 89 |
+
},
|
| 90 |
+
|
| 91 |
+
"visual_decoder_config": {
|
| 92 |
+
"codebook_dim": 3584,
|
| 93 |
+
|
| 94 |
+
"image_decoder_config": {
|
| 95 |
+
"attention_dropout": 0.0,
|
| 96 |
+
"codebook_dim": 3584,
|
| 97 |
+
"distill_taps": [
|
| 98 |
+
3,
|
| 99 |
+
7,
|
| 100 |
+
15,
|
| 101 |
+
23
|
| 102 |
+
],
|
| 103 |
+
"hidden_act": "gelu",
|
| 104 |
+
"hidden_size": 1024,
|
| 105 |
+
"intermediate_size": 2730,
|
| 106 |
+
"k_bias": false,
|
| 107 |
+
"layer_norm_eps": 1e-06,
|
| 108 |
+
"num_attention_heads": 16,
|
| 109 |
+
"num_hidden_layers": 32,
|
| 110 |
+
"patch_size": 14,
|
| 111 |
+
"q_bias": true,
|
| 112 |
+
"spatial_merge_size": 2,
|
| 113 |
+
"subln": true,
|
| 114 |
+
"swiglu": true,
|
| 115 |
+
"teacher_dims": {
|
| 116 |
+
"15": 1280,
|
| 117 |
+
"23": 1280,
|
| 118 |
+
"3": 1280,
|
| 119 |
+
"7": 1280
|
| 120 |
+
},
|
| 121 |
+
"temporal_patch_size": 2,
|
| 122 |
+
"v_bias": true
|
| 123 |
+
},
|
| 124 |
+
|
| 125 |
+
"transformer_config": {
|
| 126 |
+
"patch_size": 2,
|
| 127 |
+
"in_channels": 16,
|
| 128 |
+
"hidden_size": 2520,
|
| 129 |
+
"num_layers": 32,
|
| 130 |
+
"num_refiner_layers": 2,
|
| 131 |
+
"num_attention_heads": 21,
|
| 132 |
+
"num_kv_heads": 7,
|
| 133 |
+
"multiple_of": 256,
|
| 134 |
+
"norm_eps": 1e-5,
|
| 135 |
+
"axes_dim_rope": [40, 40, 40],
|
| 136 |
+
"axes_lens": [10000, 10000, 10000],
|
| 137 |
+
"text_feat_dim": 2048,
|
| 138 |
+
"timestep_scale": 1000.0
|
| 139 |
+
},
|
| 140 |
+
|
| 141 |
+
"vae_config": {
|
| 142 |
+
"act_fn": "silu",
|
| 143 |
+
"block_out_channels": [128, 256, 512, 512],
|
| 144 |
+
"down_block_types": [
|
| 145 |
+
"DownEncoderBlock2D",
|
| 146 |
+
"DownEncoderBlock2D",
|
| 147 |
+
"DownEncoderBlock2D",
|
| 148 |
+
"DownEncoderBlock2D"
|
| 149 |
+
],
|
| 150 |
+
"in_channels": 3,
|
| 151 |
+
"latent_channels": 16,
|
| 152 |
+
"layers_per_block": 2,
|
| 153 |
+
"mid_block_add_attention": true,
|
| 154 |
+
"norm_num_groups": 32,
|
| 155 |
+
"out_channels": 3,
|
| 156 |
+
"sample_size": 1024,
|
| 157 |
+
"scaling_factor": 0.3611,
|
| 158 |
+
"shift_factor": 0.1159,
|
| 159 |
+
"up_block_types": [
|
| 160 |
+
"UpDecoderBlock2D",
|
| 161 |
+
"UpDecoderBlock2D",
|
| 162 |
+
"UpDecoderBlock2D",
|
| 163 |
+
"UpDecoderBlock2D"
|
| 164 |
+
],
|
| 165 |
+
"use_post_quant_conv": false,
|
| 166 |
+
"use_quant_conv": false,
|
| 167 |
+
"force_upcast": true
|
| 168 |
+
},
|
| 169 |
+
|
| 170 |
+
"scheduler_config": {
|
| 171 |
+
"num_train_timesteps": 1000,
|
| 172 |
+
"dynamic_time_shift": true
|
| 173 |
+
},
|
| 174 |
+
|
| 175 |
+
"weight_path": "WEIGHT_PATH_TO_LONGCAT_NEXT/image_decoder/image_decoder.safetensors"
|
| 176 |
+
}
|
| 177 |
+
},
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
"audio_config": {
|
| 181 |
+
"audio_head_transformer_dims": 3072,
|
| 182 |
+
"audio_head_transformer_ffn_scale": 16,
|
| 183 |
+
"audio_head_transformer_layers": 4,
|
| 184 |
+
|
| 185 |
+
"audio_delim_token_id": 131116,
|
| 186 |
+
"audio_end_token_id": 131104,
|
| 187 |
+
"audio_pad_token_id": 131105,
|
| 188 |
+
"audio_start_token_id": 131103,
|
| 189 |
+
"audiogen_end_token_id": 131124,
|
| 190 |
+
"audiogen_start_token_id": 131123,
|
| 191 |
+
"audiotext_end_token_id": 131121,
|
| 192 |
+
"audiotext_pad_token_id": 131122,
|
| 193 |
+
"audiotext_start_token_id": 131120,
|
| 194 |
+
|
| 195 |
+
"_attn_implementation": "flash_attention_2",
|
| 196 |
+
"d_model": 1280,
|
| 197 |
+
"decoder_attention_heads": 20,
|
| 198 |
+
"decoder_ffn_dim": 5120,
|
| 199 |
+
"decoder_layers": 8,
|
| 200 |
+
"encoder_attention_heads": 20,
|
| 201 |
+
"encoder_ffn_dim": 5120,
|
| 202 |
+
"encoder_layers": 32,
|
| 203 |
+
"num_mel_bins": 128,
|
| 204 |
+
|
| 205 |
+
"avg_pooler": 4,
|
| 206 |
+
"decoder_kernel_size": 3,
|
| 207 |
+
"decoder_stride_size": 2,
|
| 208 |
+
"hop_length": 160,
|
| 209 |
+
"kernel_size": 3,
|
| 210 |
+
"max_audio_seconds": 30,
|
| 211 |
+
"n_fft": 400,
|
| 212 |
+
"num_hidden_layers": 32,
|
| 213 |
+
"sampling_rate": 16000,
|
| 214 |
+
"stride_size": 2,
|
| 215 |
+
|
| 216 |
+
"vq_config": {
|
| 217 |
+
"codebook_sizes": [
|
| 218 |
+
8192,
|
| 219 |
+
4096,
|
| 220 |
+
2048,
|
| 221 |
+
1024,
|
| 222 |
+
1024,
|
| 223 |
+
1024,
|
| 224 |
+
1024,
|
| 225 |
+
1024
|
| 226 |
+
]
|
| 227 |
+
},
|
| 228 |
+
|
| 229 |
+
"vocoder_config": {
|
| 230 |
+
"channels": [
|
| 231 |
+
256,
|
| 232 |
+
256,
|
| 233 |
+
256,
|
| 234 |
+
256,
|
| 235 |
+
256
|
| 236 |
+
],
|
| 237 |
+
"hop_length": 256,
|
| 238 |
+
"num_mel_bins": 80,
|
| 239 |
+
"sampling_rate": 16000
|
| 240 |
+
},
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
"flow_matching_config": {
|
| 244 |
+
"in_channels": 80,
|
| 245 |
+
"spk_emb_dim": 0,
|
| 246 |
+
"diffusion_steps": 10,
|
| 247 |
+
"cal_mel_mae": true,
|
| 248 |
+
|
| 249 |
+
"prenet_activation_function": "gelu",
|
| 250 |
+
"prenet_attention_heads": 8,
|
| 251 |
+
"prenet_d_model": 512,
|
| 252 |
+
"prenet_ffn_dim": 2048,
|
| 253 |
+
"prenet_in_dim": 1280,
|
| 254 |
+
"prenet_max_source_positions": 5000,
|
| 255 |
+
"prenet_nlayers": 12,
|
| 256 |
+
"prenet_out_dim": 80,
|
| 257 |
+
"prenet_target_mel_length_scale_ratio": 1.0,
|
| 258 |
+
|
| 259 |
+
"channels": [
|
| 260 |
+
256
|
| 261 |
+
],
|
| 262 |
+
"dropout": 0.0,
|
| 263 |
+
"attention_head_dim": 64,
|
| 264 |
+
"n_blocks": 4,
|
| 265 |
+
"num_heads": 8,
|
| 266 |
+
"num_mid_blocks": 12,
|
| 267 |
+
"act_fn": "gelu",
|
| 268 |
+
|
| 269 |
+
"cfm_params": {
|
| 270 |
+
"inference_cfg_rate": 0.7,
|
| 271 |
+
"sigma_min": 1e-06,
|
| 272 |
+
"solver": "euler",
|
| 273 |
+
"t_scheduler": "cosine",
|
| 274 |
+
"training_cfg_rate": 0.2
|
| 275 |
+
},
|
| 276 |
+
|
| 277 |
+
"use_hidden_states_before_dconv2": true
|
| 278 |
+
},
|
| 279 |
+
|
| 280 |
+
"cosy24kvocoder_config": {
|
| 281 |
+
"weight_path": "WEIGHT_PATH_TO_LONGCAT_NEXT/cosy24k_vocoder/hift.pt"
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
}
|
| 285 |
+
}
|
configuration_longcat_next.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 2 |
+
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLVisionConfig
|
| 3 |
+
from transformers.models.whisper.configuration_whisper import WhisperConfig
|
| 4 |
+
|
| 5 |
+
from .configuration_longcat_ngram import LongcatFlashNgramConfig
|
| 6 |
+
|
| 7 |
+
class LongcatNextConfig(LongcatFlashNgramConfig):
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
vocab_size=131072,
|
| 11 |
+
hidden_size=6144,
|
| 12 |
+
num_hidden_layers=56,
|
| 13 |
+
num_layers=28,
|
| 14 |
+
num_attention_heads=64,
|
| 15 |
+
num_key_value_heads=None,
|
| 16 |
+
hidden_act="silu",
|
| 17 |
+
max_position_embeddings=131072,
|
| 18 |
+
initializer_range=0.02,
|
| 19 |
+
rms_norm_eps=1e-5,
|
| 20 |
+
use_cache=True,
|
| 21 |
+
pad_token_id=None,
|
| 22 |
+
bos_token_id=1,
|
| 23 |
+
eos_token_id=2,
|
| 24 |
+
tie_word_embeddings=False,
|
| 25 |
+
rope_theta=10000000.0,
|
| 26 |
+
rope_scaling=None,
|
| 27 |
+
attention_bias=False,
|
| 28 |
+
attention_dropout=0.0,
|
| 29 |
+
ffn_hidden_size=12288,
|
| 30 |
+
q_lora_rank=1536,
|
| 31 |
+
kv_lora_rank=512,
|
| 32 |
+
qk_nope_head_dim=128,
|
| 33 |
+
qk_rope_head_dim=64,
|
| 34 |
+
head_dim=64,
|
| 35 |
+
v_head_dim=128,
|
| 36 |
+
qk_head_dim=None,
|
| 37 |
+
moe_topk=12,
|
| 38 |
+
n_routed_experts=512,
|
| 39 |
+
zero_expert_num=256,
|
| 40 |
+
expert_ffn_hidden_size=2048,
|
| 41 |
+
routed_scaling_factor=6.0,
|
| 42 |
+
emb_neighbor_num=None,
|
| 43 |
+
emb_split_num=None,
|
| 44 |
+
ngram_vocab_size_ratio=None,
|
| 45 |
+
oe_ignored_token_ids=[],
|
| 46 |
+
text_vocab_size=131072, # text vocab size (vocab_size = text_vocab_size + audio_token + visual_token + multimodal_special_token_list)
|
| 47 |
+
text_vocab_plus_multimodal_special_token_size=131125,
|
| 48 |
+
visual_embedding_layer_intermediate_size=8192,
|
| 49 |
+
visual_embedding_layer_hidden_act="silu",
|
| 50 |
+
visual_offset=150581,
|
| 51 |
+
audio_offset=131125,
|
| 52 |
+
visual_config={},
|
| 53 |
+
audio_config={},
|
| 54 |
+
**kwargs,
|
| 55 |
+
):
|
| 56 |
+
self.text_vocab_size = text_vocab_size
|
| 57 |
+
self.text_vocab_plus_multimodal_special_token_size = text_vocab_plus_multimodal_special_token_size
|
| 58 |
+
self.visual_embedding_layer_intermediate_size = visual_embedding_layer_intermediate_size
|
| 59 |
+
self.visual_embedding_layer_hidden_act = visual_embedding_layer_hidden_act
|
| 60 |
+
self.visual_offset = visual_offset
|
| 61 |
+
self.audio_offset = audio_offset
|
| 62 |
+
self.visual_config = LongcatNextVisualConfig(**visual_config)
|
| 63 |
+
self.audio_config = LongcatNextAudioConfig(**audio_config)
|
| 64 |
+
oe_ignored_token_ids = oe_ignored_token_ids or list(range(self.text_vocab_size, self.text_vocab_plus_multimodal_special_token_size))
|
| 65 |
+
|
| 66 |
+
super().__init__(
|
| 67 |
+
vocab_size=vocab_size,
|
| 68 |
+
hidden_size=hidden_size,
|
| 69 |
+
num_hidden_layers=num_hidden_layers,
|
| 70 |
+
num_layers=num_layers,
|
| 71 |
+
num_attention_heads=num_attention_heads,
|
| 72 |
+
num_key_value_heads=num_key_value_heads,
|
| 73 |
+
hidden_act=hidden_act,
|
| 74 |
+
max_position_embeddings=max_position_embeddings,
|
| 75 |
+
initializer_range=initializer_range,
|
| 76 |
+
rms_norm_eps=rms_norm_eps,
|
| 77 |
+
use_cache=use_cache,
|
| 78 |
+
pad_token_id=pad_token_id,
|
| 79 |
+
bos_token_id=bos_token_id,
|
| 80 |
+
eos_token_id=eos_token_id,
|
| 81 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 82 |
+
rope_theta=rope_theta,
|
| 83 |
+
rope_scaling=rope_scaling,
|
| 84 |
+
attention_bias=attention_bias,
|
| 85 |
+
attention_dropout=attention_dropout,
|
| 86 |
+
ffn_hidden_size=ffn_hidden_size,
|
| 87 |
+
q_lora_rank=q_lora_rank,
|
| 88 |
+
kv_lora_rank=kv_lora_rank,
|
| 89 |
+
qk_nope_head_dim=qk_nope_head_dim,
|
| 90 |
+
qk_rope_head_dim=qk_rope_head_dim,
|
| 91 |
+
head_dim=head_dim,
|
| 92 |
+
v_head_dim=v_head_dim,
|
| 93 |
+
qk_head_dim=qk_head_dim,
|
| 94 |
+
moe_topk=moe_topk,
|
| 95 |
+
n_routed_experts=n_routed_experts,
|
| 96 |
+
zero_expert_num=zero_expert_num,
|
| 97 |
+
expert_ffn_hidden_size=expert_ffn_hidden_size,
|
| 98 |
+
routed_scaling_factor=routed_scaling_factor,
|
| 99 |
+
emb_neighbor_num=emb_neighbor_num,
|
| 100 |
+
emb_split_num=emb_split_num,
|
| 101 |
+
ngram_vocab_size_ratio=ngram_vocab_size_ratio,
|
| 102 |
+
oe_ignored_token_ids=oe_ignored_token_ids,
|
| 103 |
+
**kwargs,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
class LongcatNextVisualConfig(Qwen2_5_VLVisionConfig):
|
| 107 |
+
model_type = "longcat_next_visual"
|
| 108 |
+
base_config_key = ""
|
| 109 |
+
|
| 110 |
+
def __init__(
|
| 111 |
+
self,
|
| 112 |
+
image_start_token_id=131106,
|
| 113 |
+
image_end_token_id=131107,
|
| 114 |
+
image_pad_token_id=131108,
|
| 115 |
+
image_newline_token_id=131109,
|
| 116 |
+
vq_config={},
|
| 117 |
+
visual_decoder_config={},
|
| 118 |
+
**kwargs,
|
| 119 |
+
):
|
| 120 |
+
self.image_start_token_id = image_start_token_id
|
| 121 |
+
self.image_end_token_id = image_end_token_id
|
| 122 |
+
self.image_pad_token_id = image_pad_token_id
|
| 123 |
+
self.image_newline_token_id = image_newline_token_id
|
| 124 |
+
self.vq_config = PretrainedConfig(**vq_config)
|
| 125 |
+
self.visual_decoder_config = PretrainedConfig(**visual_decoder_config)
|
| 126 |
+
self.visual_decoder_config.image_decoder_config = PretrainedConfig(**getattr(self.visual_decoder_config, "image_decoder_config", {}))
|
| 127 |
+
self.visual_decoder_config.transformer_config = PretrainedConfig(**getattr(self.visual_decoder_config, "transformer_config", {}))
|
| 128 |
+
self.visual_decoder_config.vae_config = PretrainedConfig(**getattr(self.visual_decoder_config, "vae_config", {}))
|
| 129 |
+
self.visual_decoder_config.scheduler_config = PretrainedConfig(**getattr(self.visual_decoder_config, "scheduler_config", {}))
|
| 130 |
+
super().__init__(**kwargs)
|
| 131 |
+
|
| 132 |
+
class LongcatNextAudioConfig(WhisperConfig):
|
| 133 |
+
model_type = "longcat_next_audio"
|
| 134 |
+
base_config_key = ""
|
| 135 |
+
|
| 136 |
+
def __init__(
|
| 137 |
+
self,
|
| 138 |
+
vq_config={},
|
| 139 |
+
vocoder_config={},
|
| 140 |
+
flow_matching_config={},
|
| 141 |
+
cosy24kvocoder_config={},
|
| 142 |
+
**kwargs
|
| 143 |
+
):
|
| 144 |
+
self.vq_config = PretrainedConfig(**vq_config)
|
| 145 |
+
self.vocoder_config = PretrainedConfig(**vocoder_config)
|
| 146 |
+
self.flow_matching_config = PretrainedConfig(**flow_matching_config)
|
| 147 |
+
self.flow_matching_config.cfm_params = PretrainedConfig(**getattr(self.flow_matching_config, "cfm_params", {}))
|
| 148 |
+
self.cosy24kvocoder_config = PretrainedConfig(**cosy24kvocoder_config)
|
| 149 |
+
super().__init__(**kwargs)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
__all__ = ["LongcatNextConfig", "LongcatNextVisualConfig", "LongcatNextAudioConfig"]
|
configuration_longcat_ngram.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers.models.longcat_flash import LongcatFlashConfig
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class LongcatFlashNgramConfig(LongcatFlashConfig):
|
| 5 |
+
r"""
|
| 6 |
+
This is the configuration class to store the configuration of a [`LongcatFlashNgramModel`]. It is used to instantiate
|
| 7 |
+
a LongCat Flash model with N-gram enhanced embeddings according to the specified arguments, defining the model architecture.
|
| 8 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 9 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
vocab_size (`int`, *optional*, defaults to 131072):
|
| 14 |
+
Vocabulary size of the LongCat Flash model. Defines the number of different tokens that can be represented by the
|
| 15 |
+
`input_ids` passed when calling [`LongcatFlashNgramModel`]
|
| 16 |
+
hidden_size (`int`, *optional*, defaults to 6144):
|
| 17 |
+
Dimension of the hidden representations.
|
| 18 |
+
num_hidden_layers (`int`, *optional*, defaults to 56):
|
| 19 |
+
Number of hidden layers in the Transformer decoder.
|
| 20 |
+
num_layers (`int`, *optional*, defaults to 28):
|
| 21 |
+
Number of layers, each with 2 sublayers.
|
| 22 |
+
num_attention_heads (`int`, *optional*, defaults to 64):
|
| 23 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 24 |
+
num_key_value_heads (`int`, *optional*):
|
| 25 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 26 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 27 |
+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
| 28 |
+
converting from a multi-head checkpoint to a GQA checkpoint, each group key and value head should be
|
| 29 |
+
constructed by meanpooling all the original heads within that group. For more details checkout [this
|
| 30 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
| 31 |
+
`num_attention_heads`.
|
| 32 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 33 |
+
The non-linear activation function (function or string) in the decoder.
|
| 34 |
+
max_position_embeddings (`int`, *optional*, defaults to 131072):
|
| 35 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
| 36 |
+
just in case (e.g., 512 or 1024 or 2048).
|
| 37 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 38 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 39 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 40 |
+
The epsilon value used by the RMS normalization layers.
|
| 41 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 42 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 43 |
+
relevant if `config.is_decoder=True`.
|
| 44 |
+
pad_token_id (`int`, *optional*):
|
| 45 |
+
Padding token id.
|
| 46 |
+
bos_token_id (`int`, *optional*, defaults to 1):
|
| 47 |
+
Beginning of stream token id.
|
| 48 |
+
eos_token_id (`int`, *optional*, defaults to 2):
|
| 49 |
+
End of stream token id.
|
| 50 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 51 |
+
Whether to tie input and output embeddings.
|
| 52 |
+
rope_theta (`float`, *optional*, defaults to 10000000.0):
|
| 53 |
+
The base period of the RoPE embeddings.
|
| 54 |
+
rope_scaling (`Dict`, *optional*):
|
| 55 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
| 56 |
+
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
|
| 57 |
+
`{"type": strategy name, "factor": scaling factor}`.
|
| 58 |
+
attention_bias (`bool`, *optional*, defaults to `False`):
|
| 59 |
+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
| 60 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 61 |
+
The dropout ratio for the attention probabilities.
|
| 62 |
+
ffn_hidden_size (`int`, *optional*, defaults to 12288):
|
| 63 |
+
Dimension of the MLP representations.
|
| 64 |
+
q_lora_rank (`int`, *optional*, defaults to 1536):
|
| 65 |
+
The rank of the query LoRA projection in MLA (Multi-head Latent Attention).
|
| 66 |
+
kv_lora_rank (`int`, *optional*, defaults to 512):
|
| 67 |
+
The rank of the key-value LoRA projection in MLA.
|
| 68 |
+
qk_nope_head_dim (`int`, *optional*, defaults to 128):
|
| 69 |
+
The dimension of the non-position encoding part of query/key heads.
|
| 70 |
+
qk_rope_head_dim (`int`, *optional*, defaults to 64):
|
| 71 |
+
The dimension of the RoPE part of query/key heads.
|
| 72 |
+
head_dim (`int`, *optional*, defaults to 64):
|
| 73 |
+
Standard dimension of qk heads, unused except for CI.
|
| 74 |
+
v_head_dim (`int`, *optional*, defaults to 128):
|
| 75 |
+
The dimension of value heads.
|
| 76 |
+
qk_head_dim (`int`, *optional*):
|
| 77 |
+
The total dimension of query/key heads. If not specified, set to `qk_nope_head_dim + qk_rope_head_dim`.
|
| 78 |
+
moe_topk (`int`, *optional*, defaults to 12):
|
| 79 |
+
Number of experts to route to for each token in the MoE layer.
|
| 80 |
+
n_routed_experts (`int`, *optional*, defaults to 512):
|
| 81 |
+
Number of routed experts in the MoE layer.
|
| 82 |
+
zero_expert_num (`int`, *optional*, defaults to 256):
|
| 83 |
+
Number of zero experts (identity function) to add to the expert pool.
|
| 84 |
+
expert_ffn_hidden_size (`int`, *optional*, defaults to 2048):
|
| 85 |
+
Hidden size of individual expert FFN layers.
|
| 86 |
+
routed_scaling_factor (`float`, *optional*, defaults to 6.0):
|
| 87 |
+
Scaling factor applied to the routing weights.
|
| 88 |
+
emb_neighbor_num (`int`, *optional*):
|
| 89 |
+
Maximum N-gram length for N-gram embeddings. This parameter determines the context window size for N-gram computation. Higher values capture
|
| 90 |
+
longer-range lexical patterns but increase memory usage.
|
| 91 |
+
emb_split_num (`int`, *optional*):
|
| 92 |
+
Number of hash functions (or splits) to use for N-gram embeddings. Multiple hash functions help improve the quality of N-gram representations.
|
| 93 |
+
ngram_vocab_size_ratio (`float`, *optional*):
|
| 94 |
+
Ratio multiplier for N-gram vocabulary size relative to the base vocabulary size. The N-gram vocabulary
|
| 95 |
+
size is calculated as `vocab_size * ngram_vocab_size_ratio`.
|
| 96 |
+
|
| 97 |
+
Example:
|
| 98 |
+
```python
|
| 99 |
+
>>> from transformers import LongcatFlashNgramModel, LongcatFlashNgramConfig
|
| 100 |
+
|
| 101 |
+
>>> # Initializing a LongCat Flash N-gram style configuration
|
| 102 |
+
>>> configuration = LongcatFlashNgramConfig(
|
| 103 |
+
... emb_neighbor_num=3,
|
| 104 |
+
... emb_split_num=4,
|
| 105 |
+
... ngram_vocab_size_ratio=1.5
|
| 106 |
+
... )
|
| 107 |
+
|
| 108 |
+
>>> # Initializing a model from the configuration
|
| 109 |
+
>>> model = LongcatFlashNgramModel(configuration)
|
| 110 |
+
|
| 111 |
+
>>> # Accessing the model configuration
|
| 112 |
+
>>> configuration = model.config
|
| 113 |
+
```"""
|
| 114 |
+
|
| 115 |
+
model_type = "longcat_flash_ngram"
|
| 116 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 117 |
+
base_model_tp_plan = {
|
| 118 |
+
"layers.*.self_attn.*.q_b_proj": "colwise",
|
| 119 |
+
"layers.*.self_attn.*.kv_b_proj": "colwise",
|
| 120 |
+
"layers.*.self_attn.*.o_proj": "rowwise",
|
| 121 |
+
"layers.*.mlps.*.gate_proj": "colwise",
|
| 122 |
+
"layers.*.mlps.*.up_proj": "colwise",
|
| 123 |
+
"layers.*.mlps.*.down_proj": "rowwise",
|
| 124 |
+
"layers.*.mlp.experts.*.gate_proj": "colwise",
|
| 125 |
+
"layers.*.mlp.experts.*.up_proj": "colwise",
|
| 126 |
+
"layers.*.mlp.experts.*.down_proj": "rowwise",
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
base_model_pp_plan = {
|
| 130 |
+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
| 131 |
+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
| 132 |
+
"norm": (["hidden_states"], ["hidden_states"]),
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
def __init__(
|
| 136 |
+
self,
|
| 137 |
+
vocab_size=131072,
|
| 138 |
+
hidden_size=6144,
|
| 139 |
+
num_hidden_layers=56,
|
| 140 |
+
num_layers=28,
|
| 141 |
+
num_attention_heads=64,
|
| 142 |
+
num_key_value_heads=None,
|
| 143 |
+
hidden_act="silu",
|
| 144 |
+
max_position_embeddings=131072,
|
| 145 |
+
initializer_range=0.02,
|
| 146 |
+
rms_norm_eps=1e-5,
|
| 147 |
+
use_cache=True,
|
| 148 |
+
pad_token_id=None,
|
| 149 |
+
bos_token_id=1,
|
| 150 |
+
eos_token_id=2,
|
| 151 |
+
tie_word_embeddings=False,
|
| 152 |
+
rope_theta=10000000.0,
|
| 153 |
+
rope_scaling=None,
|
| 154 |
+
attention_bias=False,
|
| 155 |
+
attention_dropout=0.0,
|
| 156 |
+
ffn_hidden_size=12288,
|
| 157 |
+
q_lora_rank=1536,
|
| 158 |
+
kv_lora_rank=512,
|
| 159 |
+
qk_nope_head_dim=128,
|
| 160 |
+
qk_rope_head_dim=64,
|
| 161 |
+
head_dim=64,
|
| 162 |
+
v_head_dim=128,
|
| 163 |
+
qk_head_dim=None,
|
| 164 |
+
moe_topk=12,
|
| 165 |
+
n_routed_experts=512,
|
| 166 |
+
zero_expert_num=256,
|
| 167 |
+
expert_ffn_hidden_size=2048,
|
| 168 |
+
routed_scaling_factor=6.0,
|
| 169 |
+
emb_neighbor_num=None,
|
| 170 |
+
emb_split_num=None,
|
| 171 |
+
ngram_vocab_size_ratio=None,
|
| 172 |
+
oe_ignored_token_ids=[],
|
| 173 |
+
**kwargs,
|
| 174 |
+
):
|
| 175 |
+
# N-gram embedding specific parameters
|
| 176 |
+
self.emb_neighbor_num = emb_neighbor_num
|
| 177 |
+
self.emb_split_num = emb_split_num
|
| 178 |
+
self.ngram_vocab_size_ratio = ngram_vocab_size_ratio
|
| 179 |
+
self.oe_ignored_token_ids = oe_ignored_token_ids
|
| 180 |
+
|
| 181 |
+
super().__init__(
|
| 182 |
+
vocab_size=vocab_size,
|
| 183 |
+
hidden_size=hidden_size,
|
| 184 |
+
num_hidden_layers=num_hidden_layers,
|
| 185 |
+
num_layers=num_layers,
|
| 186 |
+
num_attention_heads=num_attention_heads,
|
| 187 |
+
num_key_value_heads=num_key_value_heads,
|
| 188 |
+
hidden_act=hidden_act,
|
| 189 |
+
max_position_embeddings=max_position_embeddings,
|
| 190 |
+
initializer_range=initializer_range,
|
| 191 |
+
rms_norm_eps=rms_norm_eps,
|
| 192 |
+
use_cache=use_cache,
|
| 193 |
+
pad_token_id=pad_token_id,
|
| 194 |
+
bos_token_id=bos_token_id,
|
| 195 |
+
eos_token_id=eos_token_id,
|
| 196 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 197 |
+
rope_theta=rope_theta,
|
| 198 |
+
rope_scaling=rope_scaling,
|
| 199 |
+
attention_bias=attention_bias,
|
| 200 |
+
attention_dropout=attention_dropout,
|
| 201 |
+
ffn_hidden_size=ffn_hidden_size,
|
| 202 |
+
q_lora_rank=q_lora_rank,
|
| 203 |
+
kv_lora_rank=kv_lora_rank,
|
| 204 |
+
qk_nope_head_dim=qk_nope_head_dim,
|
| 205 |
+
qk_rope_head_dim=qk_rope_head_dim,
|
| 206 |
+
head_dim=head_dim,
|
| 207 |
+
v_head_dim=v_head_dim,
|
| 208 |
+
qk_head_dim=qk_head_dim,
|
| 209 |
+
moe_topk=moe_topk,
|
| 210 |
+
n_routed_experts=n_routed_experts,
|
| 211 |
+
zero_expert_num=zero_expert_num,
|
| 212 |
+
expert_ffn_hidden_size=expert_ffn_hidden_size,
|
| 213 |
+
routed_scaling_factor=routed_scaling_factor,
|
| 214 |
+
**kwargs,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
__all__ = ["LongcatFlashNgramConfig"]
|
cosy24k_vocoder.py
ADDED
|
@@ -0,0 +1,552 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""HIFI-GAN"""
|
| 16 |
+
|
| 17 |
+
from typing import Dict, Optional, List
|
| 18 |
+
import numpy as np
|
| 19 |
+
from scipy.signal import get_window
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
from torch.nn import Conv1d
|
| 24 |
+
from torch.nn import ConvTranspose1d
|
| 25 |
+
from torch.nn.utils import remove_weight_norm
|
| 26 |
+
from torch.nn.utils import weight_norm
|
| 27 |
+
from torch.distributions.uniform import Uniform
|
| 28 |
+
from torch.nn import Parameter
|
| 29 |
+
from torch import nn, sin, pow
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class Snake(nn.Module):
|
| 33 |
+
'''
|
| 34 |
+
Implementation of a sine-based periodic activation function
|
| 35 |
+
Shape:
|
| 36 |
+
- Input: (B, C, T)
|
| 37 |
+
- Output: (B, C, T), same shape as the input
|
| 38 |
+
Parameters:
|
| 39 |
+
- alpha - trainable parameter
|
| 40 |
+
References:
|
| 41 |
+
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
| 42 |
+
https://arxiv.org/abs/2006.08195
|
| 43 |
+
Examples:
|
| 44 |
+
>>> a1 = snake(256)
|
| 45 |
+
>>> x = torch.randn(256)
|
| 46 |
+
>>> x = a1(x)
|
| 47 |
+
'''
|
| 48 |
+
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
| 49 |
+
'''
|
| 50 |
+
Initialization.
|
| 51 |
+
INPUT:
|
| 52 |
+
- in_features: shape of the input
|
| 53 |
+
- alpha: trainable parameter
|
| 54 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
| 55 |
+
alpha will be trained along with the rest of your model.
|
| 56 |
+
'''
|
| 57 |
+
super(Snake, self).__init__()
|
| 58 |
+
self.in_features = in_features
|
| 59 |
+
|
| 60 |
+
# initialize alpha
|
| 61 |
+
self.alpha_logscale = alpha_logscale
|
| 62 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
| 63 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
| 64 |
+
else: # linear scale alphas initialized to ones
|
| 65 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
| 66 |
+
|
| 67 |
+
self.alpha.requires_grad = alpha_trainable
|
| 68 |
+
|
| 69 |
+
self.no_div_by_zero = 0.000000001
|
| 70 |
+
|
| 71 |
+
def forward(self, x):
|
| 72 |
+
'''
|
| 73 |
+
Forward pass of the function.
|
| 74 |
+
Applies the function to the input elementwise.
|
| 75 |
+
Snake ∶= x + 1/a * sin^2 (xa)
|
| 76 |
+
'''
|
| 77 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
| 78 |
+
if self.alpha_logscale:
|
| 79 |
+
alpha = torch.exp(alpha)
|
| 80 |
+
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
| 81 |
+
|
| 82 |
+
return x
|
| 83 |
+
|
| 84 |
+
def get_padding(kernel_size, dilation=1):
|
| 85 |
+
return int((kernel_size * dilation - dilation) / 2)
|
| 86 |
+
|
| 87 |
+
def init_weights(m, mean=0.0, std=0.01):
|
| 88 |
+
classname = m.__class__.__name__
|
| 89 |
+
if classname.find("Conv") != -1:
|
| 90 |
+
m.weight.data.normal_(mean, std)
|
| 91 |
+
|
| 92 |
+
"""hifigan based generator implementation.
|
| 93 |
+
|
| 94 |
+
This code is modified from https://github.com/jik876/hifi-gan
|
| 95 |
+
,https://github.com/kan-bayashi/ParallelWaveGAN and
|
| 96 |
+
https://github.com/NVIDIA/BigVGAN
|
| 97 |
+
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class ResBlock(torch.nn.Module):
|
| 102 |
+
"""Residual block module in HiFiGAN/BigVGAN."""
|
| 103 |
+
def __init__(
|
| 104 |
+
self,
|
| 105 |
+
channels: int = 512,
|
| 106 |
+
kernel_size: int = 3,
|
| 107 |
+
dilations: List[int] = [1, 3, 5],
|
| 108 |
+
):
|
| 109 |
+
super(ResBlock, self).__init__()
|
| 110 |
+
self.convs1 = nn.ModuleList()
|
| 111 |
+
self.convs2 = nn.ModuleList()
|
| 112 |
+
|
| 113 |
+
for dilation in dilations:
|
| 114 |
+
self.convs1.append(
|
| 115 |
+
weight_norm(
|
| 116 |
+
Conv1d(
|
| 117 |
+
channels,
|
| 118 |
+
channels,
|
| 119 |
+
kernel_size,
|
| 120 |
+
1,
|
| 121 |
+
dilation=dilation,
|
| 122 |
+
padding=get_padding(kernel_size, dilation)
|
| 123 |
+
)
|
| 124 |
+
)
|
| 125 |
+
)
|
| 126 |
+
self.convs2.append(
|
| 127 |
+
weight_norm(
|
| 128 |
+
Conv1d(
|
| 129 |
+
channels,
|
| 130 |
+
channels,
|
| 131 |
+
kernel_size,
|
| 132 |
+
1,
|
| 133 |
+
dilation=1,
|
| 134 |
+
padding=get_padding(kernel_size, 1)
|
| 135 |
+
)
|
| 136 |
+
)
|
| 137 |
+
)
|
| 138 |
+
self.convs1.apply(init_weights)
|
| 139 |
+
self.convs2.apply(init_weights)
|
| 140 |
+
self.activations1 = nn.ModuleList([
|
| 141 |
+
Snake(channels, alpha_logscale=False)
|
| 142 |
+
for _ in range(len(self.convs1))
|
| 143 |
+
])
|
| 144 |
+
self.activations2 = nn.ModuleList([
|
| 145 |
+
Snake(channels, alpha_logscale=False)
|
| 146 |
+
for _ in range(len(self.convs2))
|
| 147 |
+
])
|
| 148 |
+
|
| 149 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 150 |
+
for idx in range(len(self.convs1)):
|
| 151 |
+
xt = self.activations1[idx](x)
|
| 152 |
+
xt = self.convs1[idx](xt)
|
| 153 |
+
xt = self.activations2[idx](xt)
|
| 154 |
+
xt = self.convs2[idx](xt)
|
| 155 |
+
x = xt + x
|
| 156 |
+
return x
|
| 157 |
+
|
| 158 |
+
def remove_weight_norm(self):
|
| 159 |
+
for idx in range(len(self.convs1)):
|
| 160 |
+
remove_weight_norm(self.convs1[idx])
|
| 161 |
+
remove_weight_norm(self.convs2[idx])
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class SineGen(torch.nn.Module):
|
| 165 |
+
""" Definition of sine generator
|
| 166 |
+
SineGen(samp_rate, harmonic_num = 0,
|
| 167 |
+
sine_amp = 0.1, noise_std = 0.003,
|
| 168 |
+
voiced_threshold = 0,
|
| 169 |
+
flag_for_pulse=False)
|
| 170 |
+
samp_rate: sampling rate in Hz
|
| 171 |
+
harmonic_num: number of harmonic overtones (default 0)
|
| 172 |
+
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
| 173 |
+
noise_std: std of Gaussian noise (default 0.003)
|
| 174 |
+
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
| 175 |
+
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
| 176 |
+
Note: when flag_for_pulse is True, the first time step of a voiced
|
| 177 |
+
segment is always sin(np.pi) or cos(0)
|
| 178 |
+
"""
|
| 179 |
+
|
| 180 |
+
def __init__(self, samp_rate, harmonic_num=0,
|
| 181 |
+
sine_amp=0.1, noise_std=0.003,
|
| 182 |
+
voiced_threshold=0):
|
| 183 |
+
super(SineGen, self).__init__()
|
| 184 |
+
self.sine_amp = sine_amp
|
| 185 |
+
self.noise_std = noise_std
|
| 186 |
+
self.harmonic_num = harmonic_num
|
| 187 |
+
self.sampling_rate = samp_rate
|
| 188 |
+
self.voiced_threshold = voiced_threshold
|
| 189 |
+
|
| 190 |
+
def _f02uv(self, f0):
|
| 191 |
+
# generate uv signal
|
| 192 |
+
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
| 193 |
+
return uv
|
| 194 |
+
|
| 195 |
+
@torch.no_grad()
|
| 196 |
+
def forward(self, f0):
|
| 197 |
+
"""
|
| 198 |
+
:param f0: [B, 1, sample_len], Hz
|
| 199 |
+
:return: [B, 1, sample_len]
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
|
| 203 |
+
for i in range(self.harmonic_num + 1):
|
| 204 |
+
F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
|
| 205 |
+
|
| 206 |
+
theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
|
| 207 |
+
u_dist = Uniform(low=-np.pi, high=np.pi)
|
| 208 |
+
phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
|
| 209 |
+
phase_vec[:, 0, :] = 0
|
| 210 |
+
|
| 211 |
+
# generate sine waveforms
|
| 212 |
+
sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
|
| 213 |
+
|
| 214 |
+
# generate uv signal
|
| 215 |
+
uv = self._f02uv(f0)
|
| 216 |
+
|
| 217 |
+
# noise: for unvoiced should be similar to sine_amp
|
| 218 |
+
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
| 219 |
+
# . for voiced regions is self.noise_std
|
| 220 |
+
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
| 221 |
+
noise = noise_amp * torch.randn_like(sine_waves)
|
| 222 |
+
|
| 223 |
+
# first: set the unvoiced part to 0 by uv
|
| 224 |
+
# then: additive noise
|
| 225 |
+
sine_waves = sine_waves * uv + noise
|
| 226 |
+
return sine_waves, uv, noise
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class SourceModuleHnNSF(torch.nn.Module):
|
| 230 |
+
""" SourceModule for hn-nsf
|
| 231 |
+
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
| 232 |
+
add_noise_std=0.003, voiced_threshod=0)
|
| 233 |
+
sampling_rate: sampling_rate in Hz
|
| 234 |
+
harmonic_num: number of harmonic above F0 (default: 0)
|
| 235 |
+
sine_amp: amplitude of sine source signal (default: 0.1)
|
| 236 |
+
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
| 237 |
+
note that amplitude of noise in unvoiced is decided
|
| 238 |
+
by sine_amp
|
| 239 |
+
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
| 240 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
| 241 |
+
F0_sampled (batchsize, length, 1)
|
| 242 |
+
Sine_source (batchsize, length, 1)
|
| 243 |
+
noise_source (batchsize, length 1)
|
| 244 |
+
uv (batchsize, length, 1)
|
| 245 |
+
"""
|
| 246 |
+
|
| 247 |
+
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
|
| 248 |
+
add_noise_std=0.003, voiced_threshod=0):
|
| 249 |
+
super(SourceModuleHnNSF, self).__init__()
|
| 250 |
+
|
| 251 |
+
self.sine_amp = sine_amp
|
| 252 |
+
self.noise_std = add_noise_std
|
| 253 |
+
|
| 254 |
+
# to produce sine waveforms
|
| 255 |
+
self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
|
| 256 |
+
sine_amp, add_noise_std, voiced_threshod)
|
| 257 |
+
|
| 258 |
+
# to merge source harmonics into a single excitation
|
| 259 |
+
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
| 260 |
+
self.l_tanh = torch.nn.Tanh()
|
| 261 |
+
|
| 262 |
+
def forward(self, x):
|
| 263 |
+
"""
|
| 264 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
| 265 |
+
F0_sampled (batchsize, length, 1)
|
| 266 |
+
Sine_source (batchsize, length, 1)
|
| 267 |
+
noise_source (batchsize, length 1)
|
| 268 |
+
"""
|
| 269 |
+
# source for harmonic branch
|
| 270 |
+
with torch.no_grad():
|
| 271 |
+
sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
|
| 272 |
+
sine_wavs = sine_wavs.transpose(1, 2)
|
| 273 |
+
uv = uv.transpose(1, 2)
|
| 274 |
+
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
| 275 |
+
|
| 276 |
+
# source for noise branch, in the same shape as uv
|
| 277 |
+
noise = torch.randn_like(uv) * self.sine_amp / 3
|
| 278 |
+
return sine_merge, noise, uv
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class HiFTGenerator(nn.Module):
|
| 282 |
+
"""
|
| 283 |
+
HiFTNet Generator: Neural Source Filter + ISTFTNet
|
| 284 |
+
https://arxiv.org/abs/2309.09493
|
| 285 |
+
"""
|
| 286 |
+
def __init__(
|
| 287 |
+
self,
|
| 288 |
+
in_channels: int = 80,
|
| 289 |
+
base_channels: int = 512,
|
| 290 |
+
nb_harmonics: int = 8,
|
| 291 |
+
sampling_rate: int = 22050,
|
| 292 |
+
nsf_alpha: float = 0.1,
|
| 293 |
+
nsf_sigma: float = 0.003,
|
| 294 |
+
nsf_voiced_threshold: float = 10,
|
| 295 |
+
upsample_rates: List[int] = [8, 8],
|
| 296 |
+
upsample_kernel_sizes: List[int] = [16, 16],
|
| 297 |
+
istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
|
| 298 |
+
resblock_kernel_sizes: List[int] = [3, 7, 11],
|
| 299 |
+
resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
| 300 |
+
source_resblock_kernel_sizes: List[int] = [7, 11],
|
| 301 |
+
source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
|
| 302 |
+
lrelu_slope: float = 0.1,
|
| 303 |
+
audio_limit: float = 0.99,
|
| 304 |
+
f0_predictor: torch.nn.Module = None,
|
| 305 |
+
):
|
| 306 |
+
super(HiFTGenerator, self).__init__()
|
| 307 |
+
|
| 308 |
+
self.out_channels = 1
|
| 309 |
+
self.nb_harmonics = nb_harmonics
|
| 310 |
+
self.sampling_rate = sampling_rate
|
| 311 |
+
self.istft_params = istft_params
|
| 312 |
+
self.lrelu_slope = lrelu_slope
|
| 313 |
+
self.audio_limit = audio_limit
|
| 314 |
+
|
| 315 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
| 316 |
+
self.num_upsamples = len(upsample_rates)
|
| 317 |
+
self.m_source = SourceModuleHnNSF(
|
| 318 |
+
sampling_rate=sampling_rate,
|
| 319 |
+
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
|
| 320 |
+
harmonic_num=nb_harmonics,
|
| 321 |
+
sine_amp=nsf_alpha,
|
| 322 |
+
add_noise_std=nsf_sigma,
|
| 323 |
+
voiced_threshod=nsf_voiced_threshold)
|
| 324 |
+
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
|
| 325 |
+
|
| 326 |
+
self.conv_pre = weight_norm(
|
| 327 |
+
Conv1d(in_channels, base_channels, 7, 1, padding=3)
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
# Up
|
| 331 |
+
self.ups = nn.ModuleList()
|
| 332 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
| 333 |
+
self.ups.append(
|
| 334 |
+
weight_norm(
|
| 335 |
+
ConvTranspose1d(
|
| 336 |
+
base_channels // (2**i),
|
| 337 |
+
base_channels // (2**(i + 1)),
|
| 338 |
+
k,
|
| 339 |
+
u,
|
| 340 |
+
padding=(k - u) // 2,
|
| 341 |
+
)
|
| 342 |
+
)
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
# Down
|
| 346 |
+
self.source_downs = nn.ModuleList()
|
| 347 |
+
self.source_resblocks = nn.ModuleList()
|
| 348 |
+
downsample_rates = [1] + upsample_rates[::-1][:-1]
|
| 349 |
+
downsample_cum_rates = np.cumprod(downsample_rates)
|
| 350 |
+
for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
|
| 351 |
+
if u == 1:
|
| 352 |
+
self.source_downs.append(
|
| 353 |
+
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
|
| 354 |
+
)
|
| 355 |
+
else:
|
| 356 |
+
self.source_downs.append(
|
| 357 |
+
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
self.source_resblocks.append(
|
| 361 |
+
ResBlock(base_channels // (2 ** (i + 1)), k, d)
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
self.resblocks = nn.ModuleList()
|
| 365 |
+
for i in range(len(self.ups)):
|
| 366 |
+
ch = base_channels // (2**(i + 1))
|
| 367 |
+
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
| 368 |
+
self.resblocks.append(ResBlock(ch, k, d))
|
| 369 |
+
|
| 370 |
+
self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
|
| 371 |
+
self.ups.apply(init_weights)
|
| 372 |
+
self.conv_post.apply(init_weights)
|
| 373 |
+
self.reflection_pad = nn.ReflectionPad1d((1, 0))
|
| 374 |
+
self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
|
| 375 |
+
self.f0_predictor = f0_predictor
|
| 376 |
+
|
| 377 |
+
def remove_weight_norm(self):
|
| 378 |
+
print('Removing weight norm...')
|
| 379 |
+
for l in self.ups:
|
| 380 |
+
remove_weight_norm(l)
|
| 381 |
+
for l in self.resblocks:
|
| 382 |
+
l.remove_weight_norm()
|
| 383 |
+
remove_weight_norm(self.conv_pre)
|
| 384 |
+
remove_weight_norm(self.conv_post)
|
| 385 |
+
self.m_source.remove_weight_norm()
|
| 386 |
+
for l in self.source_downs:
|
| 387 |
+
remove_weight_norm(l)
|
| 388 |
+
for l in self.source_resblocks:
|
| 389 |
+
l.remove_weight_norm()
|
| 390 |
+
|
| 391 |
+
def _stft(self, x):
|
| 392 |
+
spec = torch.stft(
|
| 393 |
+
x,
|
| 394 |
+
self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
|
| 395 |
+
return_complex=True)
|
| 396 |
+
spec = torch.view_as_real(spec) # [B, F, TT, 2]
|
| 397 |
+
return spec[..., 0], spec[..., 1]
|
| 398 |
+
|
| 399 |
+
def _istft(self, magnitude, phase):
|
| 400 |
+
magnitude = torch.clip(magnitude, max=1e2)
|
| 401 |
+
real = magnitude * torch.cos(phase)
|
| 402 |
+
img = magnitude * torch.sin(phase)
|
| 403 |
+
inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
|
| 404 |
+
self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
|
| 405 |
+
return inverse_transform
|
| 406 |
+
|
| 407 |
+
def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
| 408 |
+
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
|
| 409 |
+
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
|
| 410 |
+
|
| 411 |
+
x = self.conv_pre(x)
|
| 412 |
+
for i in range(self.num_upsamples):
|
| 413 |
+
x = F.leaky_relu(x, self.lrelu_slope)
|
| 414 |
+
x = self.ups[i](x)
|
| 415 |
+
|
| 416 |
+
if i == self.num_upsamples - 1:
|
| 417 |
+
x = self.reflection_pad(x)
|
| 418 |
+
|
| 419 |
+
# fusion
|
| 420 |
+
si = self.source_downs[i](s_stft)
|
| 421 |
+
si = self.source_resblocks[i](si)
|
| 422 |
+
x = x + si
|
| 423 |
+
|
| 424 |
+
xs = None
|
| 425 |
+
for j in range(self.num_kernels):
|
| 426 |
+
if xs is None:
|
| 427 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
| 428 |
+
else:
|
| 429 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
| 430 |
+
x = xs / self.num_kernels
|
| 431 |
+
|
| 432 |
+
x = F.leaky_relu(x)
|
| 433 |
+
x = self.conv_post(x)
|
| 434 |
+
magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
|
| 435 |
+
phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
|
| 436 |
+
|
| 437 |
+
x = self._istft(magnitude, phase)
|
| 438 |
+
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
|
| 439 |
+
return x
|
| 440 |
+
|
| 441 |
+
def forward(
|
| 442 |
+
self,
|
| 443 |
+
batch: dict,
|
| 444 |
+
# device: torch.device,
|
| 445 |
+
) -> Dict[str, Optional[torch.Tensor]]:
|
| 446 |
+
speech_feat = batch['speech_feat'].transpose(1, 2) # .to(device)
|
| 447 |
+
# mel->f0
|
| 448 |
+
f0 = self.f0_predictor(speech_feat)
|
| 449 |
+
# f0->source
|
| 450 |
+
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
| 451 |
+
s, _, _ = self.m_source(s)
|
| 452 |
+
s = s.transpose(1, 2)
|
| 453 |
+
# mel+source->speech
|
| 454 |
+
generated_speech = self.decode(x=speech_feat, s=s)
|
| 455 |
+
return generated_speech, f0
|
| 456 |
+
|
| 457 |
+
@torch.inference_mode()
|
| 458 |
+
def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
| 459 |
+
# mel->f0
|
| 460 |
+
f0 = self.f0_predictor(speech_feat)
|
| 461 |
+
# f0->source
|
| 462 |
+
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
| 463 |
+
s, _, _ = self.m_source(s)
|
| 464 |
+
s = s.transpose(1, 2)
|
| 465 |
+
# use cache_source to avoid glitch
|
| 466 |
+
if cache_source.shape[2] != 0:
|
| 467 |
+
s[:, :, :cache_source.shape[2]] = cache_source
|
| 468 |
+
generated_speech = self.decode(x=speech_feat, s=s)
|
| 469 |
+
return generated_speech, s
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
class ConvRNNF0Predictor(nn.Module):
|
| 473 |
+
def __init__(self,
|
| 474 |
+
num_class: int = 1,
|
| 475 |
+
in_channels: int = 80,
|
| 476 |
+
cond_channels: int = 512
|
| 477 |
+
):
|
| 478 |
+
super().__init__()
|
| 479 |
+
|
| 480 |
+
self.num_class = num_class
|
| 481 |
+
self.condnet = nn.Sequential(
|
| 482 |
+
weight_norm(
|
| 483 |
+
nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
|
| 484 |
+
),
|
| 485 |
+
nn.ELU(),
|
| 486 |
+
weight_norm(
|
| 487 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
| 488 |
+
),
|
| 489 |
+
nn.ELU(),
|
| 490 |
+
weight_norm(
|
| 491 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
| 492 |
+
),
|
| 493 |
+
nn.ELU(),
|
| 494 |
+
weight_norm(
|
| 495 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
| 496 |
+
),
|
| 497 |
+
nn.ELU(),
|
| 498 |
+
weight_norm(
|
| 499 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
| 500 |
+
),
|
| 501 |
+
nn.ELU(),
|
| 502 |
+
)
|
| 503 |
+
self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
|
| 504 |
+
|
| 505 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 506 |
+
x = self.condnet(x)
|
| 507 |
+
x = x.transpose(1, 2)
|
| 508 |
+
return torch.abs(self.classifier(x).squeeze(-1))
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
class Cosy24kVocoder(nn.Module):
|
| 512 |
+
def __init__(self):
|
| 513 |
+
super().__init__()
|
| 514 |
+
self.hifigan_generator = HiFTGenerator(
|
| 515 |
+
in_channels=80,
|
| 516 |
+
base_channels=512,
|
| 517 |
+
nb_harmonics=8,
|
| 518 |
+
sampling_rate=24000,
|
| 519 |
+
nsf_alpha=0.1,
|
| 520 |
+
nsf_sigma=0.003,
|
| 521 |
+
nsf_voiced_threshold=10,
|
| 522 |
+
upsample_rates=[8, 5, 3],
|
| 523 |
+
upsample_kernel_sizes=[16, 11, 7],
|
| 524 |
+
resblock_kernel_sizes=[3, 7, 11],
|
| 525 |
+
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
| 526 |
+
source_resblock_kernel_sizes=[7, 7, 11],
|
| 527 |
+
source_resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
| 528 |
+
lrelu_slope=0.1,
|
| 529 |
+
audio_limit=0.99,
|
| 530 |
+
f0_predictor=ConvRNNF0Predictor(
|
| 531 |
+
num_class=1,
|
| 532 |
+
in_channels=80,
|
| 533 |
+
cond_channels=512,
|
| 534 |
+
),
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
def decode(self, mel, device="cuda"):
|
| 538 |
+
"""
|
| 539 |
+
Args: mel: (batch_size, n_frames, n_mel)
|
| 540 |
+
"""
|
| 541 |
+
generated_speech, f0 = self.hifigan_generator.forward(
|
| 542 |
+
{"speech_feat": mel.transpose(1, 2)}, # device=device
|
| 543 |
+
)
|
| 544 |
+
return generated_speech
|
| 545 |
+
|
| 546 |
+
@classmethod
|
| 547 |
+
def from_pretrained(cls, model_path: str):
|
| 548 |
+
"""Load a pretrained model from a checkpoint."""
|
| 549 |
+
model = cls()
|
| 550 |
+
model.hifigan_generator.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=True)
|
| 551 |
+
model.eval()
|
| 552 |
+
return model
|
environment.yml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: longcat_next
|
| 2 |
+
|
| 3 |
+
dependencies:
|
| 4 |
+
- python=3.10
|
| 5 |
+
- ffmpeg<7
|
| 6 |
+
- pip
|
| 7 |
+
- pip:
|
| 8 |
+
- soundfile==0.13.1
|
generation_config.json
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token_id": 1,
|
| 3 |
+
"eos_token_id": 2,
|
| 4 |
+
"pad_token_id": 3,
|
| 5 |
+
|
| 6 |
+
"max_new_tokens": 2048,
|
| 7 |
+
"do_sample": true,
|
| 8 |
+
"temperature": 0.4,
|
| 9 |
+
"top_k": 20,
|
| 10 |
+
"top_p": 0.85,
|
| 11 |
+
"repetition_penalty": 1.1,
|
| 12 |
+
|
| 13 |
+
"visual_generation_config": {
|
| 14 |
+
"do_sample": true,
|
| 15 |
+
"temperature": 0.5,
|
| 16 |
+
"top_p": 0.75,
|
| 17 |
+
"top_k": 1024,
|
| 18 |
+
"custom_params": {
|
| 19 |
+
"cfg_scale": 3.0,
|
| 20 |
+
"token_h": 37,
|
| 21 |
+
"token_w": 37,
|
| 22 |
+
"anyres_prefix": "<longcat_img_token_size>{h} {w}</longcat_img_token_size>"
|
| 23 |
+
}
|
| 24 |
+
},
|
| 25 |
+
|
| 26 |
+
"audio_generation_config": {
|
| 27 |
+
"audio_parallel_decoding": false,
|
| 28 |
+
"do_sample": true,
|
| 29 |
+
"temperature": 0.5,
|
| 30 |
+
"top_k": 5,
|
| 31 |
+
"top_p": 0.85,
|
| 32 |
+
"repetition_penalty": 1.3,
|
| 33 |
+
"custom_params": {
|
| 34 |
+
"sampling_rate": 24000,
|
| 35 |
+
"wave_concat_overlap": 1200
|
| 36 |
+
}
|
| 37 |
+
},
|
| 38 |
+
|
| 39 |
+
"transformers_version": "4.57.6"
|
| 40 |
+
}
|
image_refiner.py
ADDED
|
@@ -0,0 +1,748 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Image refiner: refiner pipeline, refiner container, and utilities.
|
| 2 |
+
|
| 3 |
+
Contains:
|
| 4 |
+
- RefinerImageProcessor: Image pre/post-processing for the diffusion pipeline
|
| 5 |
+
- RefinerPipeline: DiffusionPipeline for image refinement
|
| 6 |
+
- ImageRefinerContainer: nn.Module container for refiner sub-modules
|
| 7 |
+
- IdentityWithArgs: Placeholder module for cond_proj
|
| 8 |
+
- de_transform / tensor2pil: Tensor-to-PIL conversion utilities
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import inspect
|
| 12 |
+
import math
|
| 13 |
+
import warnings
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
from safetensors.torch import load_file
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
from PIL import Image
|
| 23 |
+
|
| 24 |
+
from diffusers import DiffusionPipeline
|
| 25 |
+
from diffusers.configuration_utils import register_to_config
|
| 26 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor, is_valid_image_imagelist
|
| 27 |
+
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
| 28 |
+
from .refiner_modules import FlowMatchEulerDiscreteScheduler
|
| 29 |
+
|
| 30 |
+
from .refiner_modules import Transformer2DModel, RotaryPosEmbed
|
| 31 |
+
|
| 32 |
+
# ---------------------------------------------------------------------------
|
| 33 |
+
# Helpers
|
| 34 |
+
# ---------------------------------------------------------------------------
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _clean_config_dict(cfg, cls=None) -> dict:
|
| 38 |
+
"""Convert a PretrainedConfig to a clean dict for model construction.
|
| 39 |
+
|
| 40 |
+
If ``cls`` is provided, only keeps keys that match the cls.__init__ params
|
| 41 |
+
(allowlist approach). Otherwise falls back to blocklist filtering.
|
| 42 |
+
"""
|
| 43 |
+
if hasattr(cfg, "to_dict"):
|
| 44 |
+
d = cfg.to_dict()
|
| 45 |
+
elif isinstance(cfg, dict):
|
| 46 |
+
d = dict(cfg)
|
| 47 |
+
else:
|
| 48 |
+
d = {k: v for k, v in vars(cfg).items()}
|
| 49 |
+
|
| 50 |
+
if cls is not None:
|
| 51 |
+
import inspect
|
| 52 |
+
sig = inspect.signature(cls.__init__)
|
| 53 |
+
valid_keys = set(sig.parameters.keys()) - {"self"}
|
| 54 |
+
if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()):
|
| 55 |
+
# Has **kwargs — can't filter by allowlist, fall through to blocklist
|
| 56 |
+
pass
|
| 57 |
+
else:
|
| 58 |
+
return {k: v for k, v in d.items() if k in valid_keys}
|
| 59 |
+
|
| 60 |
+
# Blocklist: remove HuggingFace PretrainedConfig metadata
|
| 61 |
+
_PRETRAINED_CONFIG_KEYS = {
|
| 62 |
+
"_name_or_path", "transformers_version", "model_type", "_commit_hash",
|
| 63 |
+
"_attn_implementation", "_attn_implementation_autoset", "return_dict",
|
| 64 |
+
"output_hidden_states", "output_attentions", "use_bfloat16",
|
| 65 |
+
"torchscript", "torch_dtype", "is_encoder_decoder", "is_decoder",
|
| 66 |
+
"add_cross_attention", "tie_encoder_decoder", "tie_word_embeddings",
|
| 67 |
+
"cross_attention_hidden_size", "chunk_size_feed_forward", "decoder_start_token_id",
|
| 68 |
+
"architectures", "finetuning_task", "id2label", "label2id", "prefix",
|
| 69 |
+
"problem_type", "tokenizer_class", "task_specific_params", "pruned_heads",
|
| 70 |
+
"bos_token_id", "eos_token_id", "pad_token_id", "sep_token_id",
|
| 71 |
+
"max_length", "min_length", "do_sample", "early_stopping",
|
| 72 |
+
"num_beams", "num_beam_groups", "diversity_penalty", "temperature",
|
| 73 |
+
"top_k", "top_p", "typical_p", "repetition_penalty", "length_penalty",
|
| 74 |
+
"no_repeat_ngram_size", "encoder_no_repeat_ngram_size", "bad_words_ids",
|
| 75 |
+
"num_return_sequences", "output_scores", "return_dict_in_generate",
|
| 76 |
+
"forced_bos_token_id", "forced_eos_token_id", "remove_invalid_values",
|
| 77 |
+
"exponential_decay_length_penalty", "suppress_tokens", "begin_suppress_tokens",
|
| 78 |
+
"tf_legacy_loss", "dtype",
|
| 79 |
+
}
|
| 80 |
+
return {k: v for k, v in d.items() if not k.startswith("_") and k not in _PRETRAINED_CONFIG_KEYS}
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# ---------------------------------------------------------------------------
|
| 84 |
+
# Image Refiner Container (nn.Module for state_dict loading)
|
| 85 |
+
# ---------------------------------------------------------------------------
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class ImageRefinerContainer(nn.Module):
|
| 89 |
+
"""Container for refiner components.
|
| 90 |
+
|
| 91 |
+
Holds base_transformer, vae, cond_proj as nn.Module children so their
|
| 92 |
+
parameters appear in the parent model's state_dict and are loaded
|
| 93 |
+
automatically via from_pretrained.
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
def __init__(self, visual_decoder_config):
|
| 97 |
+
super().__init__()
|
| 98 |
+
|
| 99 |
+
tc = visual_decoder_config.transformer_config
|
| 100 |
+
vc = visual_decoder_config.vae_config
|
| 101 |
+
|
| 102 |
+
self.base_transformer = Transformer2DModel(**_clean_config_dict(tc))
|
| 103 |
+
|
| 104 |
+
self.vae = AutoencoderKL(**_clean_config_dict(vc))
|
| 105 |
+
self.vae.requires_grad_(False)
|
| 106 |
+
|
| 107 |
+
text_feat_dim = getattr(tc, "text_feat_dim", 3584)
|
| 108 |
+
codebook_dim = getattr(visual_decoder_config, "codebook_dim", text_feat_dim)
|
| 109 |
+
if codebook_dim != text_feat_dim:
|
| 110 |
+
self.cond_proj = nn.Linear(codebook_dim, text_feat_dim)
|
| 111 |
+
else:
|
| 112 |
+
self.cond_proj = IdentityWithArgs()
|
| 113 |
+
|
| 114 |
+
@classmethod
|
| 115 |
+
def from_pretrained(cls, config, model_path: str):
|
| 116 |
+
model = cls(config)
|
| 117 |
+
weight_dict = load_file(model_path, device="cpu")
|
| 118 |
+
model.load_state_dict({k.removeprefix("image_refiner."): v for k, v in weight_dict.items() if k.startswith("image_refiner.")}, strict=True)
|
| 119 |
+
model.eval()
|
| 120 |
+
return model
|
| 121 |
+
|
| 122 |
+
@property
|
| 123 |
+
def device(self):
|
| 124 |
+
return next(self.parameters()).device
|
| 125 |
+
|
| 126 |
+
@property
|
| 127 |
+
def dtype(self):
|
| 128 |
+
return next(self.parameters()).dtype
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class RefinerImageProcessor(VaeImageProcessor):
|
| 132 |
+
"""Image processor for refiner - extends diffusers' VaeImageProcessor."""
|
| 133 |
+
|
| 134 |
+
@register_to_config
|
| 135 |
+
def __init__(
|
| 136 |
+
self,
|
| 137 |
+
do_resize: bool = True,
|
| 138 |
+
vae_scale_factor: int = 16,
|
| 139 |
+
resample: str = "lanczos",
|
| 140 |
+
max_pixels: Optional[int] = None,
|
| 141 |
+
max_side_length: Optional[int] = None,
|
| 142 |
+
do_normalize: bool = True,
|
| 143 |
+
do_binarize: bool = False,
|
| 144 |
+
do_convert_grayscale: bool = False,
|
| 145 |
+
):
|
| 146 |
+
super().__init__(
|
| 147 |
+
do_resize=do_resize,
|
| 148 |
+
vae_scale_factor=vae_scale_factor,
|
| 149 |
+
resample=resample,
|
| 150 |
+
do_normalize=do_normalize,
|
| 151 |
+
do_binarize=do_binarize,
|
| 152 |
+
do_convert_grayscale=do_convert_grayscale,
|
| 153 |
+
)
|
| 154 |
+
self.max_pixels = max_pixels
|
| 155 |
+
self.max_side_length = max_side_length
|
| 156 |
+
|
| 157 |
+
def get_new_height_width(
|
| 158 |
+
self,
|
| 159 |
+
image: Union["PIL.Image.Image", np.ndarray, torch.Tensor],
|
| 160 |
+
height: Optional[int] = None,
|
| 161 |
+
width: Optional[int] = None,
|
| 162 |
+
max_pixels: Optional[int] = None,
|
| 163 |
+
max_side_length: Optional[int] = None,
|
| 164 |
+
) -> Tuple[int, int]:
|
| 165 |
+
import PIL.Image
|
| 166 |
+
|
| 167 |
+
if height is None:
|
| 168 |
+
if isinstance(image, PIL.Image.Image):
|
| 169 |
+
height = image.height
|
| 170 |
+
elif isinstance(image, torch.Tensor):
|
| 171 |
+
height = image.shape[2]
|
| 172 |
+
else:
|
| 173 |
+
height = image.shape[1]
|
| 174 |
+
|
| 175 |
+
if width is None:
|
| 176 |
+
if isinstance(image, PIL.Image.Image):
|
| 177 |
+
width = image.width
|
| 178 |
+
elif isinstance(image, torch.Tensor):
|
| 179 |
+
width = image.shape[3]
|
| 180 |
+
else:
|
| 181 |
+
width = image.shape[2]
|
| 182 |
+
|
| 183 |
+
if max_side_length is None:
|
| 184 |
+
max_side_length = self.max_side_length
|
| 185 |
+
if max_pixels is None:
|
| 186 |
+
max_pixels = self.max_pixels
|
| 187 |
+
|
| 188 |
+
ratio = 1.0
|
| 189 |
+
if max_side_length is not None:
|
| 190 |
+
max_side_length_ratio = max_side_length / max(height, width)
|
| 191 |
+
else:
|
| 192 |
+
max_side_length_ratio = 1.0
|
| 193 |
+
|
| 194 |
+
cur_pixels = height * width
|
| 195 |
+
max_pixels_ratio = (max_pixels / cur_pixels) ** 0.5 if max_pixels is not None else 1.0
|
| 196 |
+
ratio = min(max_pixels_ratio, max_side_length_ratio, 1.0)
|
| 197 |
+
|
| 198 |
+
sf = self.config.vae_scale_factor
|
| 199 |
+
new_height = int(height * ratio) // sf * sf
|
| 200 |
+
new_width = int(width * ratio) // sf * sf
|
| 201 |
+
return new_height, new_width
|
| 202 |
+
|
| 203 |
+
def preprocess(
|
| 204 |
+
self,
|
| 205 |
+
image: PipelineImageInput,
|
| 206 |
+
height: Optional[int] = None,
|
| 207 |
+
width: Optional[int] = None,
|
| 208 |
+
max_pixels: Optional[int] = None,
|
| 209 |
+
max_side_length: Optional[int] = None,
|
| 210 |
+
resize_mode: str = "default",
|
| 211 |
+
crops_coords: Optional[Tuple[int, int, int, int]] = None,
|
| 212 |
+
) -> torch.Tensor:
|
| 213 |
+
import PIL.Image
|
| 214 |
+
|
| 215 |
+
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
|
| 216 |
+
|
| 217 |
+
if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3:
|
| 218 |
+
if isinstance(image, torch.Tensor):
|
| 219 |
+
image = image.unsqueeze(1)
|
| 220 |
+
else:
|
| 221 |
+
if image.shape[-1] == 1:
|
| 222 |
+
image = np.expand_dims(image, axis=0)
|
| 223 |
+
else:
|
| 224 |
+
image = np.expand_dims(image, axis=-1)
|
| 225 |
+
|
| 226 |
+
if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4:
|
| 227 |
+
warnings.warn(
|
| 228 |
+
"Passing `image` as a list of 4d np.ndarray is deprecated. "
|
| 229 |
+
"Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray",
|
| 230 |
+
FutureWarning,
|
| 231 |
+
)
|
| 232 |
+
image = np.concatenate(image, axis=0)
|
| 233 |
+
if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
|
| 234 |
+
warnings.warn(
|
| 235 |
+
"Passing `image` as a list of 4d torch.Tensor is deprecated. "
|
| 236 |
+
"Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor",
|
| 237 |
+
FutureWarning,
|
| 238 |
+
)
|
| 239 |
+
image = torch.cat(image, axis=0)
|
| 240 |
+
|
| 241 |
+
if not is_valid_image_imagelist(image):
|
| 242 |
+
raise ValueError(
|
| 243 |
+
f"Input is in incorrect format. Currently, we only support "
|
| 244 |
+
f"{', '.join(str(x) for x in supported_formats)}"
|
| 245 |
+
)
|
| 246 |
+
if not isinstance(image, list):
|
| 247 |
+
image = [image]
|
| 248 |
+
|
| 249 |
+
if isinstance(image[0], PIL.Image.Image):
|
| 250 |
+
if crops_coords is not None:
|
| 251 |
+
image = [i.crop(crops_coords) for i in image]
|
| 252 |
+
if self.config.do_resize:
|
| 253 |
+
height, width = self.get_new_height_width(image[0], height, width, max_pixels, max_side_length)
|
| 254 |
+
image = [self.resize(i, height, width, resize_mode=resize_mode) for i in image]
|
| 255 |
+
if self.config.do_convert_grayscale:
|
| 256 |
+
image = [self.convert_to_grayscale(i) for i in image]
|
| 257 |
+
image = self.pil_to_numpy(image)
|
| 258 |
+
image = self.numpy_to_pt(image)
|
| 259 |
+
elif isinstance(image[0], np.ndarray):
|
| 260 |
+
image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
|
| 261 |
+
image = self.numpy_to_pt(image)
|
| 262 |
+
height, width = self.get_new_height_width(image, height, width, max_pixels, max_side_length)
|
| 263 |
+
if self.config.do_resize:
|
| 264 |
+
image = self.resize(image, height, width)
|
| 265 |
+
elif isinstance(image[0], torch.Tensor):
|
| 266 |
+
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
|
| 267 |
+
if self.config.do_convert_grayscale and image.ndim == 3:
|
| 268 |
+
image = image.unsqueeze(1)
|
| 269 |
+
channel = image.shape[1]
|
| 270 |
+
if channel == self.config.vae_latent_channels:
|
| 271 |
+
return image
|
| 272 |
+
height, width = self.get_new_height_width(image, height, width, max_pixels, max_side_length)
|
| 273 |
+
if self.config.do_resize:
|
| 274 |
+
image = self.resize(image, height, width)
|
| 275 |
+
|
| 276 |
+
do_normalize = self.config.do_normalize
|
| 277 |
+
if do_normalize and image.min() < 0:
|
| 278 |
+
warnings.warn(
|
| 279 |
+
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. "
|
| 280 |
+
f"The expected value range for image tensor is [0,1] when passing as pytorch tensor or numpy Array. "
|
| 281 |
+
f"You passed `image` with value range [{image.min()},{image.max()}]",
|
| 282 |
+
FutureWarning,
|
| 283 |
+
)
|
| 284 |
+
do_normalize = False
|
| 285 |
+
if do_normalize:
|
| 286 |
+
image = self.normalize(image)
|
| 287 |
+
|
| 288 |
+
if self.config.do_binarize:
|
| 289 |
+
image = self.binarize(image)
|
| 290 |
+
|
| 291 |
+
return image
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
@dataclass
|
| 295 |
+
class RefinerOutput:
|
| 296 |
+
images: Union[List[Image.Image], torch.Tensor]
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
class IdentityWithArgs(nn.Module):
|
| 300 |
+
"""Placeholder Identity module for cond_proj."""
|
| 301 |
+
|
| 302 |
+
def __init__(self, dtype=torch.float32, device=None):
|
| 303 |
+
super().__init__()
|
| 304 |
+
self.register_buffer("_dummy", torch.zeros((), dtype=dtype, device=device))
|
| 305 |
+
|
| 306 |
+
@property
|
| 307 |
+
def dtype(self):
|
| 308 |
+
return self._dummy.dtype
|
| 309 |
+
|
| 310 |
+
@property
|
| 311 |
+
def device(self):
|
| 312 |
+
return self._dummy.device
|
| 313 |
+
|
| 314 |
+
def forward(self, x, *args, **kwargs):
|
| 315 |
+
return x
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def _retrieve_timesteps(
|
| 319 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 320 |
+
num_inference_steps: Optional[int] = None,
|
| 321 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 322 |
+
timesteps: Optional[List[int]] = None,
|
| 323 |
+
**kwargs,
|
| 324 |
+
):
|
| 325 |
+
# If scheduler uses dynamic shifting and caller passed num_tokens, compute mu
|
| 326 |
+
# (same as training code refiner pipeline)
|
| 327 |
+
num_tokens = kwargs.pop("num_tokens", None)
|
| 328 |
+
if num_tokens is not None and getattr(scheduler.config, "use_dynamic_shifting", False):
|
| 329 |
+
# Compute mu from num_tokens using scheduler's linear interpolation
|
| 330 |
+
base_shift = getattr(scheduler.config, "base_shift", 0.5)
|
| 331 |
+
max_shift = getattr(scheduler.config, "max_shift", 1.15)
|
| 332 |
+
base_seq_len = getattr(scheduler.config, "base_image_seq_len", 256)
|
| 333 |
+
max_seq_len = getattr(scheduler.config, "max_image_seq_len", 4096)
|
| 334 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
| 335 |
+
b = base_shift - m * base_seq_len
|
| 336 |
+
mu = num_tokens * m + b
|
| 337 |
+
kwargs["mu"] = mu
|
| 338 |
+
|
| 339 |
+
accepted = set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 340 |
+
filtered_kwargs = {k: v for k, v in kwargs.items() if k in accepted}
|
| 341 |
+
|
| 342 |
+
if timesteps is not None:
|
| 343 |
+
if "timesteps" not in accepted:
|
| 344 |
+
raise ValueError(
|
| 345 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 346 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 347 |
+
)
|
| 348 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **filtered_kwargs)
|
| 349 |
+
timesteps = scheduler.timesteps
|
| 350 |
+
num_inference_steps = len(timesteps)
|
| 351 |
+
else:
|
| 352 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **filtered_kwargs)
|
| 353 |
+
timesteps = scheduler.timesteps
|
| 354 |
+
return timesteps, num_inference_steps
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
class RefinerPipeline(DiffusionPipeline):
|
| 358 |
+
"""
|
| 359 |
+
Image refiner evaluation pipeline.
|
| 360 |
+
|
| 361 |
+
- cond comes from upstream model: encoder_hidden_states (quants / last_latent)
|
| 362 |
+
- grid_thw_list is used to split cond (consistent with training)
|
| 363 |
+
- image as ref image
|
| 364 |
+
- Supports FlowMatchEulerDiscreteScheduler + velocity model
|
| 365 |
+
"""
|
| 366 |
+
|
| 367 |
+
def __init__(
|
| 368 |
+
self,
|
| 369 |
+
vae: AutoencoderKL,
|
| 370 |
+
transformer: Transformer2DModel,
|
| 371 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 372 |
+
cond_proj: Optional[nn.Module] = None,
|
| 373 |
+
):
|
| 374 |
+
super().__init__()
|
| 375 |
+
|
| 376 |
+
self.register_modules(
|
| 377 |
+
vae=vae,
|
| 378 |
+
transformer=transformer,
|
| 379 |
+
scheduler=scheduler,
|
| 380 |
+
cond_proj=cond_proj if cond_proj is not None else IdentityWithArgs(),
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
self.vae_scale_factor = (
|
| 384 |
+
2 ** (len(self.vae.config.block_out_channels) - 1)
|
| 385 |
+
if hasattr(self.vae.config, "block_out_channels")
|
| 386 |
+
else 8
|
| 387 |
+
)
|
| 388 |
+
self.image_processor = RefinerImageProcessor(
|
| 389 |
+
vae_scale_factor=self.vae_scale_factor * 2, do_resize=True
|
| 390 |
+
)
|
| 391 |
+
self.patch_size = int(getattr(self.transformer.config, "patch_size", 16))
|
| 392 |
+
|
| 393 |
+
self._num_timesteps: int = 0
|
| 394 |
+
self._current_timestep: Optional[torch.Tensor] = None
|
| 395 |
+
self._interrupt: bool = False
|
| 396 |
+
self._freqs_cis: Optional[torch.Tensor] = None
|
| 397 |
+
self._text_guidance_scale: float = 1.0
|
| 398 |
+
self._image_guidance_scale: float = 1.0
|
| 399 |
+
self._cfg_range: Tuple[float, float] = (0.0, 1.0)
|
| 400 |
+
|
| 401 |
+
@torch.no_grad()
|
| 402 |
+
def _get_freqs_cis(self, device, dtype):
|
| 403 |
+
if self._freqs_cis is None:
|
| 404 |
+
self._freqs_cis = RotaryPosEmbed.get_freqs_cis(
|
| 405 |
+
self.transformer.config.axes_dim_rope,
|
| 406 |
+
self.transformer.config.axes_lens,
|
| 407 |
+
theta=10000,
|
| 408 |
+
)
|
| 409 |
+
return self._freqs_cis
|
| 410 |
+
|
| 411 |
+
@staticmethod
|
| 412 |
+
def _split_tokens(
|
| 413 |
+
encoder_hidden_states: torch.Tensor,
|
| 414 |
+
grid_thw_list: List[Tuple[int, int, int]],
|
| 415 |
+
) -> List[torch.Tensor]:
|
| 416 |
+
splits = [int(h) * int(w) // 4 for (_, h, w) in grid_thw_list]
|
| 417 |
+
return list(torch.split(encoder_hidden_states, splits, dim=1))
|
| 418 |
+
|
| 419 |
+
@staticmethod
|
| 420 |
+
def _looks_like_latents(x: Union[torch.Tensor, Image.Image], latent_ch_hint: int = 16) -> bool:
|
| 421 |
+
if not isinstance(x, torch.Tensor):
|
| 422 |
+
return False
|
| 423 |
+
if x.ndim not in (3, 4):
|
| 424 |
+
return False
|
| 425 |
+
c = int(x.shape[-3])
|
| 426 |
+
if c == 3:
|
| 427 |
+
return False
|
| 428 |
+
if c == latent_ch_hint:
|
| 429 |
+
return True
|
| 430 |
+
if c > 3 and c <= 32:
|
| 431 |
+
return True
|
| 432 |
+
return False
|
| 433 |
+
|
| 434 |
+
@torch.no_grad()
|
| 435 |
+
def _preprocess_to_vae_range(self, img: torch.Tensor) -> torch.Tensor:
|
| 436 |
+
if img.dtype not in (torch.float32, torch.float16, torch.bfloat16):
|
| 437 |
+
img = img.float()
|
| 438 |
+
if img.max() > 1.5:
|
| 439 |
+
img = img / 255.0
|
| 440 |
+
if img.min() >= 0.0 and img.max() <= 1.0:
|
| 441 |
+
img = img * 2.0 - 1.0
|
| 442 |
+
return img.clamp(-1, 1)
|
| 443 |
+
|
| 444 |
+
@torch.no_grad()
|
| 445 |
+
def _encode_image_to_latents(
|
| 446 |
+
self,
|
| 447 |
+
img_any: Union[Image.Image, torch.Tensor],
|
| 448 |
+
device,
|
| 449 |
+
dtype,
|
| 450 |
+
) -> Tuple[torch.Tensor, int, int]:
|
| 451 |
+
latent_ch_hint = int(getattr(getattr(self.vae, "config", None), "latent_channels", 16))
|
| 452 |
+
|
| 453 |
+
if self._looks_like_latents(img_any, latent_ch_hint=latent_ch_hint):
|
| 454 |
+
z = img_any
|
| 455 |
+
if z.ndim == 3:
|
| 456 |
+
z = z.unsqueeze(0)
|
| 457 |
+
z = z.to(device=device, dtype=dtype)
|
| 458 |
+
H_lat, W_lat = z.shape[-2], z.shape[-1]
|
| 459 |
+
return z, H_lat, W_lat
|
| 460 |
+
|
| 461 |
+
if isinstance(img_any, Image.Image):
|
| 462 |
+
img = torch.from_numpy(
|
| 463 |
+
np.array(img_any).astype("float32") / 255.0
|
| 464 |
+
).permute(2, 0, 1).unsqueeze(0)
|
| 465 |
+
elif isinstance(img_any, torch.Tensor):
|
| 466 |
+
img = img_any
|
| 467 |
+
if img.ndim == 3:
|
| 468 |
+
img = img.unsqueeze(0)
|
| 469 |
+
else:
|
| 470 |
+
raise TypeError("Unsupported image type. Use PIL.Image or torch.Tensor or latent Tensor.")
|
| 471 |
+
|
| 472 |
+
img = self._preprocess_to_vae_range(img)
|
| 473 |
+
|
| 474 |
+
H, W = img.shape[-2:]
|
| 475 |
+
base = self.patch_size * self.vae_scale_factor
|
| 476 |
+
target_H = max(base, math.ceil(H / base) * base)
|
| 477 |
+
target_W = max(base, math.ceil(W / base) * base)
|
| 478 |
+
if (H != target_H) or (W != target_W):
|
| 479 |
+
img = F.interpolate(img, size=(target_H, target_W), mode="bilinear", align_corners=False)
|
| 480 |
+
|
| 481 |
+
img = img.to(device=device, dtype=self.vae.dtype)
|
| 482 |
+
|
| 483 |
+
posterior = self.vae.encode(img).latent_dist
|
| 484 |
+
z0 = posterior.sample()
|
| 485 |
+
if getattr(self.vae.config, "shift_factor", None) is not None:
|
| 486 |
+
z0 = z0 - self.vae.config.shift_factor
|
| 487 |
+
if getattr(self.vae.config, "scaling_factor", None) is not None:
|
| 488 |
+
z0 = z0 * self.vae.config.scaling_factor
|
| 489 |
+
|
| 490 |
+
z0 = z0.to(device=device, dtype=dtype)
|
| 491 |
+
H_lat, W_lat = z0.shape[-2], z0.shape[-1]
|
| 492 |
+
return z0, H_lat, W_lat
|
| 493 |
+
|
| 494 |
+
@staticmethod
|
| 495 |
+
def _expand_to_list(x, n):
|
| 496 |
+
if x is None:
|
| 497 |
+
return [None] * n
|
| 498 |
+
if isinstance(x, (Image.Image, torch.Tensor)):
|
| 499 |
+
return [x] * n
|
| 500 |
+
assert isinstance(x, list), "`image` must be PIL / Tensor or list of them."
|
| 501 |
+
assert len(x) == n, "`len(image)` must equal number of image chunks"
|
| 502 |
+
return x
|
| 503 |
+
|
| 504 |
+
@torch.no_grad()
|
| 505 |
+
def _denoise_once(
|
| 506 |
+
self,
|
| 507 |
+
cond_tokens: torch.Tensor,
|
| 508 |
+
ref_img: Optional[Union[Image.Image, torch.Tensor]],
|
| 509 |
+
num_inference_steps: int = 28,
|
| 510 |
+
timesteps: Optional[List[int]] = None,
|
| 511 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 512 |
+
output_type: str = "pil",
|
| 513 |
+
text_guidance_scale: float = 1.0,
|
| 514 |
+
image_guidance_scale: float = 1.0,
|
| 515 |
+
cfg_range: Tuple[float, float] = (0.0, 1.0),
|
| 516 |
+
enable_processor_bar: bool = True,
|
| 517 |
+
):
|
| 518 |
+
device = cond_tokens.device
|
| 519 |
+
weight_dtype = self.transformer.dtype
|
| 520 |
+
|
| 521 |
+
self._text_guidance_scale = text_guidance_scale
|
| 522 |
+
self._image_guidance_scale = image_guidance_scale
|
| 523 |
+
self._cfg_range = cfg_range
|
| 524 |
+
|
| 525 |
+
cond_tokens = cond_tokens.to(device=device, dtype=weight_dtype)
|
| 526 |
+
text_feats = self.cond_proj(cond_tokens)
|
| 527 |
+
B, L, _ = text_feats.shape
|
| 528 |
+
text_mask = torch.ones(B, L, device=device, dtype=torch.bool)
|
| 529 |
+
|
| 530 |
+
ref_image_hidden_states = None
|
| 531 |
+
H_lat: int
|
| 532 |
+
W_lat: int
|
| 533 |
+
|
| 534 |
+
if ref_img is not None:
|
| 535 |
+
if isinstance(ref_img, torch.Tensor) and ref_img.ndim == 4 and ref_img.shape[0] == B:
|
| 536 |
+
z_ref, H_lat, W_lat = self._encode_image_to_latents(ref_img, device=device, dtype=weight_dtype)
|
| 537 |
+
ref_image_hidden_states = [[z_ref[b]] for b in range(B)]
|
| 538 |
+
else:
|
| 539 |
+
z_ref, H_lat, W_lat = self._encode_image_to_latents(ref_img, device=device, dtype=weight_dtype)
|
| 540 |
+
z_single = z_ref[0]
|
| 541 |
+
ref_image_hidden_states = [[z_single] for _ in range(B)]
|
| 542 |
+
else:
|
| 543 |
+
H_lat = W_lat = 128 // self.vae_scale_factor
|
| 544 |
+
|
| 545 |
+
C_lat = getattr(self.transformer.config, "in_channels", None)
|
| 546 |
+
if C_lat is None:
|
| 547 |
+
if ref_image_hidden_states is not None:
|
| 548 |
+
C_lat = ref_image_hidden_states[0][0].shape[0]
|
| 549 |
+
else:
|
| 550 |
+
raise ValueError("transformer.config.in_channels is None and no ref_img was provided.")
|
| 551 |
+
latents_shape = (B, C_lat, H_lat, W_lat)
|
| 552 |
+
|
| 553 |
+
if isinstance(generator, list):
|
| 554 |
+
if len(generator) != B:
|
| 555 |
+
raise ValueError(
|
| 556 |
+
f"len(generator)={len(generator)} must equal B={B} when passing list of generators."
|
| 557 |
+
)
|
| 558 |
+
latents = torch.stack(
|
| 559 |
+
[
|
| 560 |
+
torch.randn(
|
| 561 |
+
(1, C_lat, H_lat, W_lat),
|
| 562 |
+
generator=generator[i],
|
| 563 |
+
device=device,
|
| 564 |
+
dtype=weight_dtype,
|
| 565 |
+
).squeeze(0)
|
| 566 |
+
for i in range(B)
|
| 567 |
+
],
|
| 568 |
+
dim=0,
|
| 569 |
+
)
|
| 570 |
+
else:
|
| 571 |
+
latents = torch.randn(latents_shape, generator=generator, device=device, dtype=weight_dtype)
|
| 572 |
+
|
| 573 |
+
num_tokens = H_lat * W_lat
|
| 574 |
+
timesteps_sched, num_inference_steps = _retrieve_timesteps(
|
| 575 |
+
self.scheduler,
|
| 576 |
+
num_inference_steps=num_inference_steps,
|
| 577 |
+
device=device,
|
| 578 |
+
timesteps=timesteps,
|
| 579 |
+
num_tokens=num_tokens,
|
| 580 |
+
)
|
| 581 |
+
num_warmup_steps = max(len(timesteps_sched) - num_inference_steps * self.scheduler.order, 0)
|
| 582 |
+
self._num_timesteps = len(timesteps_sched)
|
| 583 |
+
|
| 584 |
+
freqs_cis = self._get_freqs_cis(device=device, dtype=weight_dtype)
|
| 585 |
+
|
| 586 |
+
progress_bar = self.progress_bar(total=num_inference_steps) if enable_processor_bar else None
|
| 587 |
+
for i, t in enumerate(timesteps_sched):
|
| 588 |
+
if self._interrupt:
|
| 589 |
+
continue
|
| 590 |
+
self._current_timestep = t
|
| 591 |
+
|
| 592 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
| 593 |
+
|
| 594 |
+
step_frac = i / max(len(timesteps_sched) - 1, 1)
|
| 595 |
+
use_cfg = (cfg_range[0] <= step_frac <= cfg_range[1]) and (
|
| 596 |
+
text_guidance_scale > 1.0 or image_guidance_scale > 1.0
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
if not use_cfg:
|
| 600 |
+
optional_kwargs: Dict[str, Any] = {}
|
| 601 |
+
if "ref_image_hidden_states" in inspect.signature(self.transformer.forward).parameters:
|
| 602 |
+
optional_kwargs["ref_image_hidden_states"] = ref_image_hidden_states
|
| 603 |
+
model_pred = self.transformer(
|
| 604 |
+
latents, timestep, text_feats, freqs_cis, text_mask, **optional_kwargs
|
| 605 |
+
)
|
| 606 |
+
else:
|
| 607 |
+
text_uncond = torch.zeros_like(text_feats)
|
| 608 |
+
|
| 609 |
+
opt_kwargs_text: Dict[str, Any] = {}
|
| 610 |
+
if "ref_image_hidden_states" in inspect.signature(self.transformer.forward).parameters:
|
| 611 |
+
opt_kwargs_text["ref_image_hidden_states"] = ref_image_hidden_states
|
| 612 |
+
|
| 613 |
+
model_pred_text = self.transformer(
|
| 614 |
+
latents, timestep, text_feats, freqs_cis, text_mask, **opt_kwargs_text
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
opt_kwargs_ref: Dict[str, Any] = {}
|
| 618 |
+
if "ref_image_hidden_states" in inspect.signature(self.transformer.forward).parameters:
|
| 619 |
+
opt_kwargs_ref["ref_image_hidden_states"] = ref_image_hidden_states
|
| 620 |
+
|
| 621 |
+
model_pred_ref = self.transformer(
|
| 622 |
+
latents, timestep, text_uncond, freqs_cis, text_mask, **opt_kwargs_ref
|
| 623 |
+
)
|
| 624 |
+
|
| 625 |
+
opt_kwargs_uncond: Dict[str, Any] = {}
|
| 626 |
+
if "ref_image_hidden_states" in inspect.signature(self.transformer.forward).parameters:
|
| 627 |
+
opt_kwargs_uncond["ref_image_hidden_states"] = None
|
| 628 |
+
|
| 629 |
+
model_pred_uncond = self.transformer(
|
| 630 |
+
latents, timestep, text_uncond, freqs_cis, text_mask, **opt_kwargs_uncond
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
if text_guidance_scale > 1.0 and image_guidance_scale > 1.0:
|
| 634 |
+
model_pred = (
|
| 635 |
+
model_pred_uncond
|
| 636 |
+
+ image_guidance_scale * (model_pred_ref - model_pred_uncond)
|
| 637 |
+
+ text_guidance_scale * (model_pred_text - model_pred_ref)
|
| 638 |
+
)
|
| 639 |
+
elif text_guidance_scale > 1.0:
|
| 640 |
+
model_pred = model_pred_uncond + text_guidance_scale * (model_pred_text - model_pred_uncond)
|
| 641 |
+
elif image_guidance_scale > 1.0:
|
| 642 |
+
model_pred = model_pred_uncond + image_guidance_scale * (model_pred_ref - model_pred_uncond)
|
| 643 |
+
else:
|
| 644 |
+
model_pred = model_pred_text
|
| 645 |
+
|
| 646 |
+
latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0]
|
| 647 |
+
latents = latents.to(dtype=weight_dtype)
|
| 648 |
+
|
| 649 |
+
if progress_bar is not None:
|
| 650 |
+
if i == len(timesteps_sched) - 1 or (
|
| 651 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
| 652 |
+
):
|
| 653 |
+
progress_bar.update()
|
| 654 |
+
|
| 655 |
+
if progress_bar is not None:
|
| 656 |
+
progress_bar.close()
|
| 657 |
+
|
| 658 |
+
self._current_timestep = None
|
| 659 |
+
|
| 660 |
+
latents = latents.to(dtype=self.vae.dtype)
|
| 661 |
+
if getattr(self.vae.config, "scaling_factor", None) is not None:
|
| 662 |
+
latents = latents / self.vae.config.scaling_factor
|
| 663 |
+
if getattr(self.vae.config, "shift_factor", None) is not None:
|
| 664 |
+
latents = latents + self.vae.config.shift_factor
|
| 665 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 666 |
+
|
| 667 |
+
images = self.image_processor.postprocess(image, output_type=output_type)
|
| 668 |
+
return images
|
| 669 |
+
|
| 670 |
+
@torch.no_grad()
|
| 671 |
+
def __call__(
|
| 672 |
+
self,
|
| 673 |
+
*,
|
| 674 |
+
encoder_hidden_states: torch.Tensor,
|
| 675 |
+
grid_thw_list: List[Tuple[int, int, int]],
|
| 676 |
+
image: Union[Image.Image, torch.Tensor, List[Union[Image.Image, torch.Tensor]], None] = None,
|
| 677 |
+
num_inference_steps: int = 28,
|
| 678 |
+
timesteps: Optional[List[int]] = None,
|
| 679 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 680 |
+
output_type: str = "pil",
|
| 681 |
+
return_dict: bool = True,
|
| 682 |
+
text_guidance_scale: float = 1.5,
|
| 683 |
+
image_guidance_scale: float = 1.5,
|
| 684 |
+
cfg_range: Tuple[float, float] = (0.0, 1.0),
|
| 685 |
+
enable_processor_bar: bool = True,
|
| 686 |
+
**kwargs,
|
| 687 |
+
) -> Union[RefinerOutput, List[Image.Image], torch.Tensor]:
|
| 688 |
+
self._interrupt = False
|
| 689 |
+
|
| 690 |
+
token_chunks = self._split_tokens(encoder_hidden_states, grid_thw_list)
|
| 691 |
+
ref_list = self._expand_to_list(image, len(token_chunks))
|
| 692 |
+
|
| 693 |
+
results_pil: List[Image.Image] = []
|
| 694 |
+
results_pt: Optional[torch.Tensor] = None
|
| 695 |
+
|
| 696 |
+
for tok, _, img_any in zip(token_chunks, grid_thw_list, ref_list):
|
| 697 |
+
imgs = self._denoise_once(
|
| 698 |
+
cond_tokens=tok,
|
| 699 |
+
ref_img=img_any,
|
| 700 |
+
num_inference_steps=num_inference_steps,
|
| 701 |
+
timesteps=timesteps,
|
| 702 |
+
generator=generator,
|
| 703 |
+
output_type=output_type,
|
| 704 |
+
text_guidance_scale=text_guidance_scale,
|
| 705 |
+
image_guidance_scale=image_guidance_scale,
|
| 706 |
+
cfg_range=cfg_range,
|
| 707 |
+
enable_processor_bar=enable_processor_bar,
|
| 708 |
+
)
|
| 709 |
+
|
| 710 |
+
if output_type == "pil":
|
| 711 |
+
results_pil += imgs
|
| 712 |
+
else:
|
| 713 |
+
results_pt = imgs if results_pt is None else torch.cat([results_pt, imgs], dim=0)
|
| 714 |
+
|
| 715 |
+
if not return_dict:
|
| 716 |
+
return results_pil if output_type == "pil" else results_pt
|
| 717 |
+
return RefinerOutput(images=results_pil if output_type == "pil" else results_pt)
|
| 718 |
+
|
| 719 |
+
|
| 720 |
+
def de_transform(
|
| 721 |
+
tensor: torch.Tensor,
|
| 722 |
+
mean=(0.48145466, 0.4578275, 0.40821073),
|
| 723 |
+
std=(0.26862954, 0.26130258, 0.27577711),
|
| 724 |
+
rescale_factor: float = 1 / 255,
|
| 725 |
+
) -> torch.Tensor:
|
| 726 |
+
"""De-normalize and de-rescale, suitable for images processed by Qwen2VLImageProcessor."""
|
| 727 |
+
if tensor.ndim == 3:
|
| 728 |
+
tensor = tensor.unsqueeze(0)
|
| 729 |
+
mean_t = torch.tensor(mean).view(1, -1, 1, 1).to(tensor.device)
|
| 730 |
+
std_t = torch.tensor(std).view(1, -1, 1, 1).to(tensor.device)
|
| 731 |
+
tensor = tensor * std_t + mean_t
|
| 732 |
+
tensor = tensor / rescale_factor
|
| 733 |
+
tensor = torch.clamp(tensor / 255.0, 0, 1)
|
| 734 |
+
return tensor
|
| 735 |
+
|
| 736 |
+
|
| 737 |
+
def tensor2pil(image_t: torch.Tensor, image_mean, image_std) -> Image.Image:
|
| 738 |
+
"""Convert a tensor to a PIL Image."""
|
| 739 |
+
image_t = image_t.detach().cpu()
|
| 740 |
+
rescale_factor = 1 / 255
|
| 741 |
+
sample = de_transform(
|
| 742 |
+
image_t,
|
| 743 |
+
mean=image_mean,
|
| 744 |
+
std=image_std,
|
| 745 |
+
rescale_factor=rescale_factor,
|
| 746 |
+
)[0]
|
| 747 |
+
ndarr = sample.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
|
| 748 |
+
return Image.fromarray(ndarr)
|
modeling_longcat_next.py
ADDED
|
@@ -0,0 +1,824 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2026 Meituan
|
| 3 |
+
# This code is licensed under the MIT License, for details, see the ./LICENSE file.
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from typing import Optional, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from torch import nn
|
| 13 |
+
|
| 14 |
+
from transformers.cache_utils import Cache
|
| 15 |
+
from transformers.generation.configuration_utils import GenerationConfig
|
| 16 |
+
from transformers.generation.logits_process import LogitsProcessorList
|
| 17 |
+
from transformers.generation.stopping_criteria import StoppingCriteriaList
|
| 18 |
+
from transformers.generation.utils import GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput, GenerateNonBeamOutput
|
| 19 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 20 |
+
from transformers.models.longcat_flash.modeling_longcat_flash import LongcatFlashForCausalLM
|
| 21 |
+
from transformers.processing_utils import Unpack
|
| 22 |
+
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
| 23 |
+
|
| 24 |
+
from .configuration_longcat_next import LongcatNextConfig
|
| 25 |
+
from .modeling_longcat_ngram import LongcatFlashNgramModel, NgramCache
|
| 26 |
+
from .modular_longcat_next import CasualDepthTransformerHead
|
| 27 |
+
from .modular_longcat_next_audio import LongcatNextAudioTokenizer
|
| 28 |
+
from .modular_longcat_next_visual import LongcatNextVisualTokenizer
|
| 29 |
+
|
| 30 |
+
from .cosy24k_vocoder import Cosy24kVocoder
|
| 31 |
+
from .image_refiner import ImageRefinerContainer
|
| 32 |
+
from .refiner_modules import FlowMatchEulerDiscreteScheduler
|
| 33 |
+
|
| 34 |
+
logger = logging.get_logger(__name__)
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class LongcatNextForCausalLMOutputWithPast(CausalLMOutputWithPast):
|
| 38 |
+
visual_loss: Optional[torch.FloatTensor] = None
|
| 39 |
+
visual_logits: Optional[torch.FloatTensor] = None
|
| 40 |
+
visual_ids: Optional[torch.LongTensor] = None
|
| 41 |
+
audio_loss: Optional[torch.FloatTensor] = None
|
| 42 |
+
audio_logits: Optional[torch.FloatTensor] = None
|
| 43 |
+
audio_ids: Optional[torch.LongTensor] = None
|
| 44 |
+
|
| 45 |
+
@dataclass
|
| 46 |
+
class LongcatNextForCausalLMGenerateDecoderOnlyOutput(GenerateDecoderOnlyOutput):
|
| 47 |
+
visual_ids: Optional[torch.LongTensor] = None
|
| 48 |
+
audio_ids: Optional[torch.LongTensor] = None
|
| 49 |
+
audio_text_ids: Optional[torch.LongTensor] = None
|
| 50 |
+
|
| 51 |
+
@dataclass
|
| 52 |
+
class LongcatNextForCausalLMGenerateEncoderDecoderOutput(GenerateEncoderDecoderOutput):
|
| 53 |
+
visual_ids: Optional[torch.LongTensor] = None
|
| 54 |
+
audio_ids: Optional[torch.LongTensor] = None
|
| 55 |
+
audio_text_ids: Optional[torch.LongTensor] = None
|
| 56 |
+
|
| 57 |
+
@dataclass
|
| 58 |
+
class LongcatNextForCausalLMGenerationStatus:
|
| 59 |
+
mode: str = "text"
|
| 60 |
+
current_image_token_num: int = -1
|
| 61 |
+
audio_parallel_decoding: bool = False
|
| 62 |
+
is_audio_text_end: bool = False
|
| 63 |
+
is_audio_start: bool = False
|
| 64 |
+
last_step_mode: str = None
|
| 65 |
+
|
| 66 |
+
def __init__(self, visual_generation_config, audio_generation_config):
|
| 67 |
+
self.visual_generation_config = visual_generation_config
|
| 68 |
+
self.h = self.visual_generation_config.custom_params["token_h"]
|
| 69 |
+
self.w = self.visual_generation_config.custom_params["token_w"]
|
| 70 |
+
self.anyres_prefix = self.visual_generation_config.custom_params["anyres_prefix"].format(h=self.h, w=self.w)
|
| 71 |
+
self.audio_generation_config = audio_generation_config
|
| 72 |
+
self.audio_parallel_decoding = audio_generation_config.audio_parallel_decoding
|
| 73 |
+
|
| 74 |
+
def switch_to(self, modal):
|
| 75 |
+
assert modal in ["text", "visual", "audio"]
|
| 76 |
+
self.mode = modal
|
| 77 |
+
self.current_image_token_num = 0 if modal == "visual" else -1
|
| 78 |
+
self.is_audio_text_end = False
|
| 79 |
+
self.is_audio_start = False
|
| 80 |
+
|
| 81 |
+
@property
|
| 82 |
+
def is_img_newline(self):
|
| 83 |
+
return ((self.current_image_token_num + 1) % (self.w + 1)) == 0 and not self.is_img_end
|
| 84 |
+
|
| 85 |
+
@property
|
| 86 |
+
def is_img_end(self):
|
| 87 |
+
return (self.current_image_token_num + 1) / (self.w + 1) == self.h
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class LongcatNextModel(LongcatFlashNgramModel):
|
| 91 |
+
_keys_to_ignore_on_load_unexpected = [r"model\.mtp.*"]
|
| 92 |
+
config_class = LongcatNextConfig
|
| 93 |
+
|
| 94 |
+
def __init__(self, config):
|
| 95 |
+
super().__init__(config)
|
| 96 |
+
self.visual_tokenizer = LongcatNextVisualTokenizer(config)
|
| 97 |
+
self.audio_tokenizer = LongcatNextAudioTokenizer(config)
|
| 98 |
+
|
| 99 |
+
self._init_multimodal_constants(config)
|
| 100 |
+
self.post_init()
|
| 101 |
+
|
| 102 |
+
def _init_multimodal_constants(self, config):
|
| 103 |
+
name2id_dict = {
|
| 104 |
+
"image_newline_token_id": self.config.visual_config.image_newline_token_id,
|
| 105 |
+
"image_end_token_id": self.config.visual_config.image_end_token_id,
|
| 106 |
+
"image_pad_token_id": self.config.visual_config.image_pad_token_id,
|
| 107 |
+
"audiotext_start_token_id": config.audio_config.audiotext_start_token_id,
|
| 108 |
+
"audiotext_pad_token_id": self.config.audio_config.audiotext_pad_token_id,
|
| 109 |
+
"audiogen_end_token_id": config.audio_config.audiogen_end_token_id,
|
| 110 |
+
"audio_pad_token_id": self.config.audio_config.audio_pad_token_id,
|
| 111 |
+
}
|
| 112 |
+
for k, v in name2id_dict.items():
|
| 113 |
+
self.register_buffer(k, torch.tensor([v], dtype=torch.long), persistent=False)
|
| 114 |
+
visual_offset_list = [config.visual_offset] + config.visual_config.vq_config.codebook_sizes[:-1]
|
| 115 |
+
visual_offset_vals = torch.cumsum(torch.tensor(visual_offset_list, dtype=torch.long), dim=0)
|
| 116 |
+
self.register_buffer("visual_offset_vals", visual_offset_vals, persistent=False)
|
| 117 |
+
audio_offset_list = [config.audio_offset] + config.audio_config.vq_config.codebook_sizes[:-1]
|
| 118 |
+
audio_offset_vals = torch.cumsum(torch.tensor(audio_offset_list, dtype=torch.long), dim=0)
|
| 119 |
+
self.register_buffer("audio_offset_vals", audio_offset_vals, persistent=False)
|
| 120 |
+
print(f"{self.visual_offset_vals=}")
|
| 121 |
+
print(f"{self.audio_offset_vals=}")
|
| 122 |
+
|
| 123 |
+
def forward(
|
| 124 |
+
self,
|
| 125 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 126 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 127 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 128 |
+
past_key_values: Optional[Cache] = None,
|
| 129 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 130 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 131 |
+
use_cache: Optional[bool] = None,
|
| 132 |
+
visual_inputs=None,
|
| 133 |
+
visual_ids=None,
|
| 134 |
+
audio_inputs=None,
|
| 135 |
+
audio_ids=None,
|
| 136 |
+
audio_text_ids=None,
|
| 137 |
+
multimodal_generation_status=None,
|
| 138 |
+
**kwargs
|
| 139 |
+
) -> BaseModelOutputWithPast:
|
| 140 |
+
|
| 141 |
+
if input_ids is None:
|
| 142 |
+
raise ValueError("You must specify input_ids")
|
| 143 |
+
|
| 144 |
+
# Extract N-gram context if available
|
| 145 |
+
ngram_context = None
|
| 146 |
+
if isinstance(past_key_values, NgramCache) and past_key_values.ngram_context is not None:
|
| 147 |
+
ngram_context = past_key_values.ngram_context
|
| 148 |
+
|
| 149 |
+
# assert input_ids.size(0) == 1, "only support bs=1 for now" # but when bs=2, idx=1 is for uncond_image_generation
|
| 150 |
+
special_visual_mask, special_audio_mask, special_audio_text_start_mask, special_audio_text_pad_mask = self.get_placeholder_mask(input_ids[:1]) # seq-dim
|
| 151 |
+
|
| 152 |
+
if inputs_embeds is None:
|
| 153 |
+
input_ids[:, special_visual_mask | special_audio_mask | special_audio_text_pad_mask | special_audio_text_start_mask] = 0
|
| 154 |
+
filled_text_pad_mask = torch.ones_like(special_audio_mask)
|
| 155 |
+
audio_text_position_mask = (special_audio_text_pad_mask | special_audio_text_start_mask | special_audio_mask)
|
| 156 |
+
|
| 157 |
+
if audio_text_ids is not None and audio_text_ids.size(1) > 0 and audio_text_position_mask.sum() > 0:
|
| 158 |
+
filled_text = audio_text_ids[:, -audio_text_position_mask.sum():]
|
| 159 |
+
filled_text_pad_mask = (filled_text==self.config.audio_config.audiotext_pad_token_id)[0]
|
| 160 |
+
input_ids[:, audio_text_position_mask] = filled_text
|
| 161 |
+
input_ids[input_ids == self.config.audio_config.audiotext_pad_token_id] = 0
|
| 162 |
+
|
| 163 |
+
inputs_embeds = self.ngram_embeddings(input_ids, ngram_context=ngram_context)
|
| 164 |
+
inputs_embeds[:, (special_visual_mask | (special_audio_mask & filled_text_pad_mask))] = 0
|
| 165 |
+
|
| 166 |
+
if special_audio_text_start_mask.sum() > 0:
|
| 167 |
+
audio_text_start_embedding = self.embed_tokens(self.audiotext_start_token_id)
|
| 168 |
+
if multimodal_generation_status.last_step_mode is None: # prefill
|
| 169 |
+
inputs_embeds[:1, special_audio_text_start_mask] += audio_text_start_embedding
|
| 170 |
+
else:
|
| 171 |
+
inputs_embeds[:, special_audio_text_start_mask] += audio_text_start_embedding
|
| 172 |
+
|
| 173 |
+
if visual_inputs is not None:
|
| 174 |
+
visual_ids = self.get_visual_ids(**visual_inputs) # [<bs=1>*seq, lev]
|
| 175 |
+
|
| 176 |
+
if visual_ids is not None and special_visual_mask.sum() > 0:
|
| 177 |
+
visual_embeddings = self.get_visual_embeddings(visual_ids[-special_visual_mask.sum():]) # -> [seq, dim]
|
| 178 |
+
if multimodal_generation_status.last_step_mode is None: # prefill
|
| 179 |
+
inputs_embeds[:1, special_visual_mask] = visual_embeddings.to(inputs_embeds.device)
|
| 180 |
+
else:
|
| 181 |
+
inputs_embeds[:, special_visual_mask] = visual_embeddings.to(inputs_embeds.device)
|
| 182 |
+
|
| 183 |
+
if audio_inputs is not None:
|
| 184 |
+
audio_ids = self.get_audio_ids(**audio_inputs) # -> [<bs=1>*seq, lev]
|
| 185 |
+
|
| 186 |
+
if audio_ids is not None and special_audio_mask.sum() > 0:
|
| 187 |
+
audio_embeddings = self.get_audio_embeddings(audio_ids[-special_audio_mask.sum():]) # -> [seq, dim]
|
| 188 |
+
if multimodal_generation_status.last_step_mode is None: # prefill
|
| 189 |
+
inputs_embeds[:1, special_audio_mask] += audio_embeddings.to(inputs_embeds.device)
|
| 190 |
+
else:
|
| 191 |
+
inputs_embeds[:, special_audio_mask] += audio_embeddings.to(inputs_embeds.device)
|
| 192 |
+
|
| 193 |
+
# Initialize NgramCache if needed
|
| 194 |
+
if use_cache and past_key_values is None:
|
| 195 |
+
past_key_values = NgramCache(config=self.config)
|
| 196 |
+
|
| 197 |
+
# Update N-gram context
|
| 198 |
+
if use_cache and isinstance(past_key_values, NgramCache):
|
| 199 |
+
past_key_values.update_ngram_context(input_ids)
|
| 200 |
+
|
| 201 |
+
return super().forward(
|
| 202 |
+
input_ids=None,
|
| 203 |
+
attention_mask=attention_mask,
|
| 204 |
+
position_ids=position_ids,
|
| 205 |
+
past_key_values=past_key_values,
|
| 206 |
+
inputs_embeds=inputs_embeds,
|
| 207 |
+
cache_position=cache_position,
|
| 208 |
+
use_cache=use_cache,
|
| 209 |
+
**kwargs
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
def get_visual_ids(self, pixel_values, visual_grid_thw, offset=True):
|
| 213 |
+
visual_ids = self.visual_tokenizer.encode(pixel_values, visual_grid_thw)
|
| 214 |
+
if offset:
|
| 215 |
+
visual_ids += self.visual_offset_vals.to(visual_ids.device)
|
| 216 |
+
return visual_ids
|
| 217 |
+
|
| 218 |
+
def get_audio_ids(self, audio, encoder_length, bridge_length, offset=True):
|
| 219 |
+
audio_ids = self.audio_tokenizer.encode(audio, encoder_length, bridge_length)
|
| 220 |
+
if offset:
|
| 221 |
+
audio_ids += self.audio_offset_vals.to(audio_ids.device)
|
| 222 |
+
return audio_ids
|
| 223 |
+
|
| 224 |
+
@torch.no_grad()
|
| 225 |
+
def decode_visual_ids_and_save(
|
| 226 |
+
self,
|
| 227 |
+
visual_ids,
|
| 228 |
+
save_prefix,
|
| 229 |
+
token_h,
|
| 230 |
+
token_w,
|
| 231 |
+
**kwargs,
|
| 232 |
+
):
|
| 233 |
+
visual_ids -= self.visual_offset_vals.to(visual_ids.device)
|
| 234 |
+
|
| 235 |
+
if not (save_prefix.startswith("./") or save_prefix.startswith("/")):
|
| 236 |
+
save_prefix = f"./{save_prefix}"
|
| 237 |
+
os.makedirs(os.path.dirname(save_prefix), exist_ok=True)
|
| 238 |
+
return self.visual_tokenizer.lazy_decode_and_save(visual_ids, token_h, token_w, f"{save_prefix}_{0}.png")
|
| 239 |
+
|
| 240 |
+
@torch.no_grad()
|
| 241 |
+
def decode_audio_ids_and_save(
|
| 242 |
+
self,
|
| 243 |
+
audio_ids,
|
| 244 |
+
save_prefix,
|
| 245 |
+
sampling_rate,
|
| 246 |
+
wave_concat_overlap,
|
| 247 |
+
**kwargs,
|
| 248 |
+
):
|
| 249 |
+
audio_ids -= self.audio_offset_vals.to(audio_ids.device)
|
| 250 |
+
|
| 251 |
+
if not (save_prefix.startswith("./") or save_prefix.startswith("/")):
|
| 252 |
+
save_prefix = f"./{save_prefix}"
|
| 253 |
+
os.makedirs(os.path.dirname(save_prefix), exist_ok=True)
|
| 254 |
+
save_path = f"{save_prefix}_{0}.wav"
|
| 255 |
+
self.audio_tokenizer.lazy_decode_and_save(audio_ids, sampling_rate, wave_concat_overlap, save_path)
|
| 256 |
+
return [save_path]
|
| 257 |
+
|
| 258 |
+
def get_visual_embeddings(self, visual_ids):
|
| 259 |
+
visual_embeddings = self.embed_tokens(visual_ids).sum(dim=1) # [seq, lev] -> [seq, lev, dim] -> [seq, dim]
|
| 260 |
+
visual_embeddings = self.visual_tokenizer.visual_embedding_layer(visual_embeddings)
|
| 261 |
+
return visual_embeddings
|
| 262 |
+
|
| 263 |
+
def get_audio_embeddings(self, audio_ids):
|
| 264 |
+
audio_embeddings = self.embed_tokens(audio_ids).sum(dim=1)
|
| 265 |
+
return audio_embeddings
|
| 266 |
+
|
| 267 |
+
def get_placeholder_mask(self, input_ids: torch.LongTensor):
|
| 268 |
+
special_image_mask = (input_ids == self.config.visual_config.image_pad_token_id).squeeze(0)
|
| 269 |
+
special_audio_mask = (input_ids == self.config.audio_config.audio_pad_token_id).squeeze(0)
|
| 270 |
+
special_audio_text_start_mask = (input_ids == self.config.audio_config.audiotext_start_token_id).squeeze(0)
|
| 271 |
+
special_audio_text_pad_mask = (input_ids == self.config.audio_config.audiotext_pad_token_id).squeeze(0)
|
| 272 |
+
return special_image_mask, special_audio_mask, special_audio_text_start_mask, special_audio_text_pad_mask
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class LongcatNextForCausalLM(LongcatFlashForCausalLM):
|
| 276 |
+
_keys_to_ignore_on_load_unexpected = [r"model\.mtp.*"]
|
| 277 |
+
_no_split_modules = [
|
| 278 |
+
"LongcatFlashDecoderLayer",
|
| 279 |
+
"CasualDepthTransformerHead",
|
| 280 |
+
]
|
| 281 |
+
config_class = LongcatNextConfig
|
| 282 |
+
|
| 283 |
+
def __init__(self, config):
|
| 284 |
+
super().__init__(config)
|
| 285 |
+
self.config = config
|
| 286 |
+
self.model = LongcatNextModel(config)
|
| 287 |
+
self.lm_head = nn.Linear(config.hidden_size, config.text_vocab_plus_multimodal_special_token_size, bias=False)
|
| 288 |
+
|
| 289 |
+
self.visual_head = CasualDepthTransformerHead(
|
| 290 |
+
hidden_size=config.hidden_size,
|
| 291 |
+
codebook_sizes=config.visual_config.vq_config.codebook_sizes,
|
| 292 |
+
transformer_layer_num=config.visual_config.image_head_transformer_layers,
|
| 293 |
+
transformer_dim=config.visual_config.image_head_transformer_dims,
|
| 294 |
+
transformer_ffn_scale=config.visual_config.image_head_transformer_ffn_scale,
|
| 295 |
+
)
|
| 296 |
+
self.audio_head = CasualDepthTransformerHead(
|
| 297 |
+
hidden_size=config.hidden_size,
|
| 298 |
+
codebook_sizes=config.audio_config.vq_config.codebook_sizes,
|
| 299 |
+
transformer_layer_num=config.audio_config.audio_head_transformer_layers,
|
| 300 |
+
transformer_dim=config.audio_config.audio_head_transformer_dims,
|
| 301 |
+
transformer_ffn_scale=config.audio_config.audio_head_transformer_ffn_scale,
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
self.post_init()
|
| 305 |
+
|
| 306 |
+
@can_return_tuple
|
| 307 |
+
@auto_docstring
|
| 308 |
+
def forward(
|
| 309 |
+
self,
|
| 310 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 311 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 312 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 313 |
+
past_key_values: Optional[Cache] = None,
|
| 314 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 315 |
+
labels: Optional[torch.LongTensor] = None,
|
| 316 |
+
use_cache: Optional[bool] = None,
|
| 317 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 318 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 319 |
+
visual_inputs=None,
|
| 320 |
+
visual_ids=None,
|
| 321 |
+
audio_inputs=None,
|
| 322 |
+
audio_ids=None,
|
| 323 |
+
audio_text_ids=None,
|
| 324 |
+
multimodal_generation_status: LongcatNextForCausalLMGenerationStatus = None,
|
| 325 |
+
visual_generation_config: GenerationConfig = None,
|
| 326 |
+
audio_generation_config: GenerationConfig = None,
|
| 327 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 328 |
+
) -> CausalLMOutputWithPast:
|
| 329 |
+
r"""
|
| 330 |
+
visual_inputs (`BatchFeature`, *optional*):
|
| 331 |
+
Visual inputs returned by the processor, containing pixel values and grid metadata for image encoding.
|
| 332 |
+
visual_ids (`torch.LongTensor` of shape `(num_visual_tokens, num_codebooks)`, *optional*):
|
| 333 |
+
Quantized visual token ids from the visual tokenizer, used to build visual embeddings during generation.
|
| 334 |
+
audio_inputs (`BatchFeature`, *optional*):
|
| 335 |
+
Audio inputs returned by the processor, containing mel-spectrogram features and length metadata.
|
| 336 |
+
audio_ids (`torch.LongTensor` of shape `(num_audio_tokens, num_codebooks)`, *optional*):
|
| 337 |
+
Quantized audio token ids from the audio tokenizer, used to build audio embeddings during generation.
|
| 338 |
+
audio_text_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 339 |
+
Token ids for the audio text transcript generated alongside audio tokens.
|
| 340 |
+
multimodal_generation_status (`LongcatNextForCausalLMGenerationStatus`, *optional*):
|
| 341 |
+
Stateful object tracking the current multimodal generation mode (text / visual / audio) and
|
| 342 |
+
associated counters used to route logits to the correct head during auto-regressive decoding.
|
| 343 |
+
visual_generation_config (`GenerationConfig`, *optional*):
|
| 344 |
+
Generation configuration for the visual head, controlling sampling parameters such as
|
| 345 |
+
`temperature`, `top_k`, `top_p`, and custom parameters like `cfg_scale` and `anyres_config`.
|
| 346 |
+
audio_generation_config (`GenerationConfig`, *optional*):
|
| 347 |
+
Generation configuration for the audio head, controlling sampling parameters such as
|
| 348 |
+
`temperature`, `top_k`, `top_p`, `repetition_penalty`, and `audio_parallel_decoding`.
|
| 349 |
+
"""
|
| 350 |
+
|
| 351 |
+
if multimodal_generation_status.mode == "visual" and visual_generation_config.custom_params["cfg_scale"] != 1.0 and input_ids.size(0) == 1:
|
| 352 |
+
input_ids = input_ids.repeat((2, 1))
|
| 353 |
+
|
| 354 |
+
outputs: BaseModelOutputWithPast = self.model(
|
| 355 |
+
input_ids=input_ids,
|
| 356 |
+
attention_mask=attention_mask,
|
| 357 |
+
position_ids=position_ids,
|
| 358 |
+
past_key_values=past_key_values,
|
| 359 |
+
inputs_embeds=inputs_embeds,
|
| 360 |
+
use_cache=use_cache,
|
| 361 |
+
cache_position=cache_position,
|
| 362 |
+
visual_inputs=visual_inputs,
|
| 363 |
+
visual_ids=visual_ids,
|
| 364 |
+
audio_inputs=audio_inputs,
|
| 365 |
+
audio_ids=audio_ids,
|
| 366 |
+
audio_text_ids=audio_text_ids,
|
| 367 |
+
multimodal_generation_status=multimodal_generation_status,
|
| 368 |
+
**kwargs,
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
hidden_states = outputs.last_hidden_state
|
| 372 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 373 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 374 |
+
slice_hidden_states = hidden_states[:, slice_indices, :]
|
| 375 |
+
|
| 376 |
+
loss, logits = None, None
|
| 377 |
+
if multimodal_generation_status.mode == "visual" and \
|
| 378 |
+
(not multimodal_generation_status.is_img_newline) and (not multimodal_generation_status.is_img_end):
|
| 379 |
+
visual_ids = self.get_multimodal_logits_and_ids(
|
| 380 |
+
self.visual_head,
|
| 381 |
+
visual_ids,
|
| 382 |
+
slice_hidden_states,
|
| 383 |
+
self.model.embed_tokens,
|
| 384 |
+
self.config.visual_config.vq_config.codebook_sizes,
|
| 385 |
+
self.model.visual_offset_vals,
|
| 386 |
+
visual_generation_config,
|
| 387 |
+
)
|
| 388 |
+
else:
|
| 389 |
+
logits = self.lm_head(slice_hidden_states)
|
| 390 |
+
|
| 391 |
+
if multimodal_generation_status.mode == "audio" and multimodal_generation_status.is_audio_start:
|
| 392 |
+
audio_ids = self.get_multimodal_logits_and_ids(
|
| 393 |
+
self.audio_head,
|
| 394 |
+
audio_ids,
|
| 395 |
+
slice_hidden_states,
|
| 396 |
+
self.model.embed_tokens,
|
| 397 |
+
self.config.audio_config.vq_config.codebook_sizes,
|
| 398 |
+
self.model.audio_offset_vals,
|
| 399 |
+
audio_generation_config,
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
return LongcatNextForCausalLMOutputWithPast(
|
| 403 |
+
loss=loss,
|
| 404 |
+
logits=logits,
|
| 405 |
+
past_key_values=outputs.past_key_values,
|
| 406 |
+
hidden_states=outputs.hidden_states,
|
| 407 |
+
attentions=outputs.attentions,
|
| 408 |
+
visual_ids=visual_ids,
|
| 409 |
+
audio_ids=audio_ids,
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
def get_multimodal_logits_and_ids(
|
| 413 |
+
self,
|
| 414 |
+
head_model,
|
| 415 |
+
multimodal_ids,
|
| 416 |
+
hidden_states,
|
| 417 |
+
multimodal_embedding_layer,
|
| 418 |
+
codebook_sizes,
|
| 419 |
+
offset_vals,
|
| 420 |
+
multimodal_generation_config,
|
| 421 |
+
):
|
| 422 |
+
next_token_ids = torch.zeros(hidden_states.size(0), len(codebook_sizes), dtype=torch.long, device=hidden_states.device)
|
| 423 |
+
multimodal_embedding_layer = multimodal_embedding_layer.to(hidden_states.device)
|
| 424 |
+
|
| 425 |
+
for level, _ in enumerate(codebook_sizes):
|
| 426 |
+
logits = head_model(hidden_states, next_token_ids, multimodal_embedding_layer, level) # -> (bs, 1, dim)
|
| 427 |
+
next_token_id = self.inner_sample(logits, multimodal_ids[None, :, level]-offset_vals[level], multimodal_generation_config) # (bs, 1)
|
| 428 |
+
next_token_id += offset_vals[level]
|
| 429 |
+
next_token_ids[:, level] = next_token_id
|
| 430 |
+
|
| 431 |
+
return next_token_ids[:1]
|
| 432 |
+
|
| 433 |
+
def inner_sample(
|
| 434 |
+
self,
|
| 435 |
+
next_token_logits: torch.Tensor,
|
| 436 |
+
multimodal_ids: torch.LongTensor,
|
| 437 |
+
generation_config: GenerationConfig,
|
| 438 |
+
) -> torch.Tensor:
|
| 439 |
+
logits_processor = self._get_logits_processor(generation_config)
|
| 440 |
+
|
| 441 |
+
if "cfg_scale" in generation_config.custom_params and generation_config.custom_params["cfg_scale"] != 1.0:
|
| 442 |
+
cond_logits, uncond_logits = next_token_logits.chunk(2, dim=0)
|
| 443 |
+
next_token_logits = generation_config.custom_params["cfg_scale"] * (cond_logits - uncond_logits) + uncond_logits
|
| 444 |
+
|
| 445 |
+
next_token_scores = logits_processor(multimodal_ids, next_token_logits.to(multimodal_ids.device))
|
| 446 |
+
if generation_config.do_sample:
|
| 447 |
+
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
| 448 |
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
| 449 |
+
else:
|
| 450 |
+
next_tokens = torch.argmax(next_token_scores, dim=-1)
|
| 451 |
+
return next_tokens
|
| 452 |
+
|
| 453 |
+
@torch.no_grad()
|
| 454 |
+
def generate(self, inputs=None, **kwargs):
|
| 455 |
+
"""Override to ensure NgramCache is used."""
|
| 456 |
+
|
| 457 |
+
if "past_key_values" not in kwargs or kwargs["past_key_values"] is None:
|
| 458 |
+
kwargs["past_key_values"] = NgramCache(config=self.config)
|
| 459 |
+
|
| 460 |
+
return super().generate(
|
| 461 |
+
inputs=inputs,
|
| 462 |
+
**kwargs,
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
def prepare_inputs_for_generation(
|
| 466 |
+
self,
|
| 467 |
+
input_ids,
|
| 468 |
+
visual_ids,
|
| 469 |
+
audio_ids,
|
| 470 |
+
audio_text_ids,
|
| 471 |
+
multimodal_generation_status,
|
| 472 |
+
generation_config,
|
| 473 |
+
attention_mask,
|
| 474 |
+
cache_position,
|
| 475 |
+
**kwargs,
|
| 476 |
+
):
|
| 477 |
+
extra_new_tokens = torch.empty(input_ids.size(0), 0, dtype=torch.long, device=input_ids.device)
|
| 478 |
+
if visual_ids is None:
|
| 479 |
+
visual_ids = torch.empty(0, len(self.config.visual_config.vq_config.codebook_sizes), dtype=torch.long, device=input_ids.device)
|
| 480 |
+
if audio_ids is None:
|
| 481 |
+
audio_ids = torch.empty(0, len(self.config.audio_config.vq_config.codebook_sizes), dtype=torch.long, device=input_ids.device)
|
| 482 |
+
if audio_text_ids is None:
|
| 483 |
+
audio_text_ids = torch.empty(input_ids.size(0), 0, dtype=torch.long, device=input_ids.device)
|
| 484 |
+
|
| 485 |
+
def insert_ids(new_ids, _input_ids, _attention_mask, _cache_position, position=0):
|
| 486 |
+
if position < 0:
|
| 487 |
+
parts = [_input_ids[:, :position], new_ids, _input_ids[:, position:]]
|
| 488 |
+
else:
|
| 489 |
+
parts = [_input_ids, new_ids]
|
| 490 |
+
_input_ids = torch.cat(parts, dim=1)
|
| 491 |
+
insert_len = new_ids.size(1)
|
| 492 |
+
_attention_mask = F.pad(_attention_mask, (0, insert_len), value=1)
|
| 493 |
+
insert_position = _cache_position[-1] + 1 + torch.arange(insert_len, device=_cache_position.device)
|
| 494 |
+
_cache_position = torch.cat([_cache_position, insert_position])
|
| 495 |
+
return _input_ids, _attention_mask, _cache_position
|
| 496 |
+
|
| 497 |
+
# multimodal generation status change
|
| 498 |
+
if cache_position[0] != 0:
|
| 499 |
+
multimodal_generation_status.last_step_mode = multimodal_generation_status.mode
|
| 500 |
+
|
| 501 |
+
if multimodal_generation_status.mode == "visual":
|
| 502 |
+
multimodal_generation_status.current_image_token_num += 1
|
| 503 |
+
|
| 504 |
+
if (input_ids[:, -1] == self.config.visual_config.image_start_token_id).all():
|
| 505 |
+
multimodal_generation_status.switch_to("visual")
|
| 506 |
+
anyres_prefix_ids = self.text_tokenizer.encode(multimodal_generation_status.anyres_prefix, return_tensors="pt")
|
| 507 |
+
anyres_prefix_ids = anyres_prefix_ids.to(input_ids.device)
|
| 508 |
+
extra_new_tokens = torch.cat([extra_new_tokens, anyres_prefix_ids], dim=1)
|
| 509 |
+
input_ids, attention_mask, cache_position = insert_ids(anyres_prefix_ids, input_ids, attention_mask, cache_position, position=-1)
|
| 510 |
+
if input_ids.size(0) == 1: # cfg, change bs=1 -> 2
|
| 511 |
+
input_ids = input_ids.repeat((2, input_ids.size(1)))
|
| 512 |
+
input_ids[1, :-(anyres_prefix_ids.size(-1)+1)] = 0
|
| 513 |
+
print(f"change to cfg, input_ids: {input_ids}")
|
| 514 |
+
attention_mask = attention_mask.repeat((2, attention_mask.size(1)))
|
| 515 |
+
|
| 516 |
+
elif (input_ids[:, -1] == self.config.audio_config.audiogen_start_token_id).all():
|
| 517 |
+
multimodal_generation_status.switch_to("audio")
|
| 518 |
+
|
| 519 |
+
elif (input_ids[:, -1] == self.config.audio_config.audiotext_start_token_id).all():
|
| 520 |
+
multimodal_generation_status.is_audio_start = True
|
| 521 |
+
|
| 522 |
+
elif ((input_ids[:, -1] == self.config.visual_config.image_end_token_id) | (input_ids[:, -1] == self.config.audio_config.audiogen_end_token_id)).all():
|
| 523 |
+
multimodal_generation_status.switch_to("text")
|
| 524 |
+
|
| 525 |
+
model_inputs = super().prepare_inputs_for_generation(
|
| 526 |
+
input_ids=input_ids,
|
| 527 |
+
visual_ids=visual_ids,
|
| 528 |
+
audio_ids=audio_ids,
|
| 529 |
+
audio_text_ids=audio_text_ids,
|
| 530 |
+
attention_mask=attention_mask,
|
| 531 |
+
cache_position=cache_position,
|
| 532 |
+
**kwargs,
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
if model_inputs["cache_position"][0] != 0:
|
| 536 |
+
model_inputs["visual_inputs"] = None
|
| 537 |
+
model_inputs["audio_inputs"] = None
|
| 538 |
+
|
| 539 |
+
return model_inputs, multimodal_generation_status, extra_new_tokens
|
| 540 |
+
|
| 541 |
+
def _sample(
|
| 542 |
+
self,
|
| 543 |
+
input_ids: torch.LongTensor,
|
| 544 |
+
logits_processor: LogitsProcessorList,
|
| 545 |
+
stopping_criteria: StoppingCriteriaList,
|
| 546 |
+
generation_config: GenerationConfig,
|
| 547 |
+
synced_gpus: bool = False,
|
| 548 |
+
streamer: Optional["BaseStreamer"] = None,
|
| 549 |
+
visual_ids=None,
|
| 550 |
+
audio_ids=None,
|
| 551 |
+
audio_text_ids=None,
|
| 552 |
+
**model_kwargs,
|
| 553 |
+
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
|
| 554 |
+
r"""
|
| 555 |
+
Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
|
| 556 |
+
can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
| 557 |
+
|
| 558 |
+
Parameters:
|
| 559 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 560 |
+
The sequence used as a prompt for the generation.
|
| 561 |
+
logits_processor (`LogitsProcessorList`):
|
| 562 |
+
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
| 563 |
+
used to modify the prediction scores of the language modeling head applied at each generation step.
|
| 564 |
+
stopping_criteria (`StoppingCriteriaList`):
|
| 565 |
+
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
|
| 566 |
+
used to tell if the generation loop should stop.
|
| 567 |
+
generation_config ([`~generation.GenerationConfig`]):
|
| 568 |
+
The generation configuration to be used as parametrization of the decoding method.
|
| 569 |
+
synced_gpus (`bool`):
|
| 570 |
+
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
|
| 571 |
+
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
|
| 572 |
+
streamer (`BaseStreamer`, *optional*):
|
| 573 |
+
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
| 574 |
+
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
| 575 |
+
model_kwargs:
|
| 576 |
+
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
|
| 577 |
+
an encoder-decoder model the kwargs should include `encoder_outputs`.
|
| 578 |
+
|
| 579 |
+
Return:
|
| 580 |
+
[`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`:
|
| 581 |
+
A `torch.LongTensor` containing the generated tokens (default behaviour) or a
|
| 582 |
+
[`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
|
| 583 |
+
`return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
|
| 584 |
+
`model.config.is_encoder_decoder=True`.
|
| 585 |
+
"""
|
| 586 |
+
# init values
|
| 587 |
+
pad_token_id = generation_config._pad_token_tensor
|
| 588 |
+
output_attentions = generation_config.output_attentions
|
| 589 |
+
output_hidden_states = generation_config.output_hidden_states
|
| 590 |
+
output_scores = generation_config.output_scores
|
| 591 |
+
output_logits = generation_config.output_logits
|
| 592 |
+
return_dict_in_generate = generation_config.return_dict_in_generate
|
| 593 |
+
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
|
| 594 |
+
do_sample = generation_config.do_sample
|
| 595 |
+
|
| 596 |
+
# init attention / hidden states / scores tuples
|
| 597 |
+
scores = () if (return_dict_in_generate and output_scores) else None
|
| 598 |
+
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
| 599 |
+
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
| 600 |
+
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
| 601 |
+
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
| 602 |
+
|
| 603 |
+
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
| 604 |
+
if return_dict_in_generate and self.config.is_encoder_decoder:
|
| 605 |
+
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
| 606 |
+
encoder_hidden_states = (
|
| 607 |
+
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
# keep track of which sequences are already finished
|
| 611 |
+
batch_size, cur_len = input_ids.shape[:2]
|
| 612 |
+
this_peer_finished = False
|
| 613 |
+
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
| 614 |
+
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
|
| 615 |
+
|
| 616 |
+
model_forward = self.__call__
|
| 617 |
+
compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
|
| 618 |
+
if compile_forward:
|
| 619 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "0"
|
| 620 |
+
# If we use FA2 and a static cache, we cannot compile with fullgraph
|
| 621 |
+
if self.config._attn_implementation == "flash_attention_2":
|
| 622 |
+
# only raise warning if the user passed an explicit compile-config
|
| 623 |
+
if generation_config.compile_config is not None and generation_config.compile_config.fullgraph:
|
| 624 |
+
logger.warning_once(
|
| 625 |
+
"When using Flash Attention 2 and a static cache, you cannot use the option `CompileConfig(fullgraph=True)` as "
|
| 626 |
+
"FA2 introduces graph breaks. We overrode the option with `fullgraph=False`."
|
| 627 |
+
)
|
| 628 |
+
generation_config.compile_config.fullgraph = False
|
| 629 |
+
model_forward = self.get_compiled_call(generation_config.compile_config)
|
| 630 |
+
|
| 631 |
+
if generation_config.prefill_chunk_size is not None:
|
| 632 |
+
model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
|
| 633 |
+
is_prefill = False
|
| 634 |
+
else:
|
| 635 |
+
is_prefill = True
|
| 636 |
+
|
| 637 |
+
visual_generation_config = GenerationConfig(**generation_config.visual_generation_config)
|
| 638 |
+
audio_generation_config = GenerationConfig(**generation_config.audio_generation_config)
|
| 639 |
+
multimodal_generation_status = LongcatNextForCausalLMGenerationStatus(visual_generation_config, audio_generation_config)
|
| 640 |
+
|
| 641 |
+
pbar = tqdm(iter(int, 1), desc="Generating", unit="tok")
|
| 642 |
+
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
| 643 |
+
# prepare model inputs
|
| 644 |
+
model_inputs, multimodal_generation_status, extra_new_tokens = self.prepare_inputs_for_generation(
|
| 645 |
+
input_ids,
|
| 646 |
+
visual_ids,
|
| 647 |
+
audio_ids,
|
| 648 |
+
audio_text_ids,
|
| 649 |
+
multimodal_generation_status,
|
| 650 |
+
generation_config,
|
| 651 |
+
**model_kwargs,
|
| 652 |
+
)
|
| 653 |
+
if extra_new_tokens.size(1) > 0:
|
| 654 |
+
input_ids = torch.cat([input_ids[:, :-1], extra_new_tokens, input_ids[:, -1:]], dim=1)
|
| 655 |
+
model_kwargs["attention_mask"] = model_inputs["attention_mask"]
|
| 656 |
+
model_kwargs["cache_position"] = model_inputs["cache_position"]
|
| 657 |
+
|
| 658 |
+
if multimodal_generation_status.mode == "text" and multimodal_generation_status.last_step_mode == "visual":
|
| 659 |
+
next_tokens = generation_config._eos_token_tensor
|
| 660 |
+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
| 661 |
+
if streamer is not None:
|
| 662 |
+
streamer.put(next_tokens.cpu())
|
| 663 |
+
break
|
| 664 |
+
|
| 665 |
+
visual_ids = model_inputs["visual_ids"]
|
| 666 |
+
audio_ids = model_inputs["audio_ids"]
|
| 667 |
+
audio_text_ids = model_inputs["audio_text_ids"]
|
| 668 |
+
|
| 669 |
+
if is_prefill:
|
| 670 |
+
outputs = self(**model_inputs, return_dict=True, multimodal_generation_status=multimodal_generation_status, visual_generation_config=visual_generation_config, audio_generation_config=audio_generation_config)
|
| 671 |
+
is_prefill = False
|
| 672 |
+
else:
|
| 673 |
+
outputs = model_forward(**model_inputs, return_dict=True, multimodal_generation_status=multimodal_generation_status, visual_generation_config=visual_generation_config, audio_generation_config=audio_generation_config)
|
| 674 |
+
|
| 675 |
+
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
|
| 676 |
+
model_kwargs = self._update_model_kwargs_for_generation(
|
| 677 |
+
outputs,
|
| 678 |
+
model_kwargs,
|
| 679 |
+
is_encoder_decoder=self.config.is_encoder_decoder,
|
| 680 |
+
num_new_tokens=1,
|
| 681 |
+
)
|
| 682 |
+
if synced_gpus and this_peer_finished:
|
| 683 |
+
continue
|
| 684 |
+
|
| 685 |
+
|
| 686 |
+
# multimodal generation
|
| 687 |
+
if multimodal_generation_status.mode == "text" or \
|
| 688 |
+
(multimodal_generation_status.mode == "audio" and not multimodal_generation_status.is_audio_text_end):
|
| 689 |
+
# Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
|
| 690 |
+
# (the clone itself is always small)
|
| 691 |
+
next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
|
| 692 |
+
|
| 693 |
+
# pre-process distribution
|
| 694 |
+
next_token_scores = logits_processor(input_ids, next_token_logits)
|
| 695 |
+
|
| 696 |
+
# Store scores, attentions and hidden_states when required
|
| 697 |
+
if return_dict_in_generate:
|
| 698 |
+
if output_scores:
|
| 699 |
+
scores += (next_token_scores,)
|
| 700 |
+
if output_logits:
|
| 701 |
+
raw_logits += (next_token_logits,)
|
| 702 |
+
if output_attentions:
|
| 703 |
+
decoder_attentions += (
|
| 704 |
+
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
| 705 |
+
)
|
| 706 |
+
if self.config.is_encoder_decoder:
|
| 707 |
+
cross_attentions += (outputs.cross_attentions,)
|
| 708 |
+
|
| 709 |
+
if output_hidden_states:
|
| 710 |
+
decoder_hidden_states += (
|
| 711 |
+
(outputs.decoder_hidden_states,)
|
| 712 |
+
if self.config.is_encoder_decoder
|
| 713 |
+
else (outputs.hidden_states,)
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
# token selection
|
| 717 |
+
if do_sample:
|
| 718 |
+
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
| 719 |
+
# TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
|
| 720 |
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
| 721 |
+
else:
|
| 722 |
+
next_tokens = torch.argmax(next_token_scores, dim=-1)
|
| 723 |
+
|
| 724 |
+
# audio_text_ids done
|
| 725 |
+
if multimodal_generation_status.mode == "audio" and (next_tokens == self.config.audio_config.audiotext_pad_token_id).all():
|
| 726 |
+
multimodal_generation_status.is_audio_text_end = True
|
| 727 |
+
|
| 728 |
+
elif multimodal_generation_status.mode == "visual":
|
| 729 |
+
if multimodal_generation_status.is_img_end:
|
| 730 |
+
next_tokens = self.model.image_end_token_id.to(input_ids.device)
|
| 731 |
+
|
| 732 |
+
elif multimodal_generation_status.is_img_newline:
|
| 733 |
+
next_tokens = self.model.image_newline_token_id.to(input_ids.device)
|
| 734 |
+
|
| 735 |
+
else:
|
| 736 |
+
visual_ids = torch.cat([visual_ids, outputs.visual_ids], dim=0) # [seq, lev]
|
| 737 |
+
next_tokens = self.model.image_pad_token_id.to(input_ids.device)
|
| 738 |
+
|
| 739 |
+
else: # mode == "audio" and multimodal_generation_status.is_audio_text_end
|
| 740 |
+
next_tokens = self.model.audio_pad_token_id.to(input_ids.device)
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
if multimodal_generation_status.mode == "audio":
|
| 744 |
+
# audio_text_ids update
|
| 745 |
+
audio_text_next_tokens = self.model.audiotext_pad_token_id.to(input_ids.device)
|
| 746 |
+
if not multimodal_generation_status.is_audio_text_end:
|
| 747 |
+
audio_text_next_tokens, next_tokens = next_tokens, audio_text_next_tokens
|
| 748 |
+
audio_text_ids = torch.cat((audio_text_ids, audio_text_next_tokens[:, None]), dim=1)
|
| 749 |
+
|
| 750 |
+
# audio_ids update
|
| 751 |
+
if multimodal_generation_status.is_audio_start:
|
| 752 |
+
if outputs.audio_ids[-1, 0] == (self.model.audio_offset_vals[1]): # offset + (level_1_len)
|
| 753 |
+
next_tokens = self.model.audiogen_end_token_id.to(input_ids.device)
|
| 754 |
+
else:
|
| 755 |
+
next_tokens = self.model.audio_pad_token_id.to(input_ids.device)
|
| 756 |
+
audio_ids = torch.cat([audio_ids, outputs.audio_ids], dim=0)
|
| 757 |
+
|
| 758 |
+
elif (multimodal_generation_status.audio_parallel_decoding) or \
|
| 759 |
+
(not multimodal_generation_status.audio_parallel_decoding and multimodal_generation_status.is_audio_text_end):
|
| 760 |
+
next_tokens = self.model.audiotext_start_token_id.to(input_ids.device)
|
| 761 |
+
|
| 762 |
+
|
| 763 |
+
# finished sentences should have their next token be a padding token
|
| 764 |
+
if has_eos_stopping_criteria:
|
| 765 |
+
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
| 766 |
+
|
| 767 |
+
# update generated ids, model inputs, and length for next step
|
| 768 |
+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
| 769 |
+
|
| 770 |
+
# TODO: streaming mm ids
|
| 771 |
+
if streamer is not None:
|
| 772 |
+
streamer.put(next_tokens.cpu())
|
| 773 |
+
|
| 774 |
+
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
| 775 |
+
this_peer_finished = unfinished_sequences.max() == 0
|
| 776 |
+
cur_len += 1
|
| 777 |
+
|
| 778 |
+
# This is needed to properly delete outputs.logits which may be very large for first iteration
|
| 779 |
+
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
|
| 780 |
+
del outputs
|
| 781 |
+
|
| 782 |
+
pbar.update(1)
|
| 783 |
+
pbar.set_postfix({
|
| 784 |
+
"recent_5toks": f"{input_ids[:, -5:].tolist()}",
|
| 785 |
+
})
|
| 786 |
+
|
| 787 |
+
pbar.close()
|
| 788 |
+
|
| 789 |
+
if streamer is not None:
|
| 790 |
+
streamer.end()
|
| 791 |
+
|
| 792 |
+
if return_dict_in_generate:
|
| 793 |
+
if self.config.is_encoder_decoder:
|
| 794 |
+
return LongcatNextForCausalLMGenerateEncoderDecoderOutput(
|
| 795 |
+
sequences=input_ids,
|
| 796 |
+
scores=scores,
|
| 797 |
+
logits=raw_logits,
|
| 798 |
+
encoder_attentions=encoder_attentions,
|
| 799 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 800 |
+
decoder_attentions=decoder_attentions,
|
| 801 |
+
cross_attentions=cross_attentions,
|
| 802 |
+
decoder_hidden_states=decoder_hidden_states,
|
| 803 |
+
past_key_values=model_kwargs.get("past_key_values"),
|
| 804 |
+
visual_ids=visual_ids,
|
| 805 |
+
audio_ids=audio_ids,
|
| 806 |
+
audio_text_ids=audio_text_ids,
|
| 807 |
+
)
|
| 808 |
+
else:
|
| 809 |
+
return LongcatNextForCausalLMGenerateDecoderOnlyOutput(
|
| 810 |
+
sequences=input_ids,
|
| 811 |
+
scores=scores,
|
| 812 |
+
logits=raw_logits,
|
| 813 |
+
attentions=decoder_attentions,
|
| 814 |
+
hidden_states=decoder_hidden_states,
|
| 815 |
+
past_key_values=model_kwargs.get("past_key_values"),
|
| 816 |
+
visual_ids=visual_ids,
|
| 817 |
+
audio_ids=audio_ids,
|
| 818 |
+
audio_text_ids=audio_text_ids,
|
| 819 |
+
)
|
| 820 |
+
else:
|
| 821 |
+
return input_ids, visual_ids, audio_ids, audio_text_ids
|
| 822 |
+
|
| 823 |
+
|
| 824 |
+
__all__ = ["LongcatNextModel", "LongcatNextForCausalLM"]
|
modeling_longcat_ngram.py
ADDED
|
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) 2026 Meituan
|
| 3 |
+
# This code is licensed under the MIT License, for details, see the ./LICENSE file.
|
| 4 |
+
|
| 5 |
+
from typing import Optional, Tuple, Dict, List
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 12 |
+
from transformers.masking_utils import create_causal_mask
|
| 13 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast
|
| 14 |
+
from transformers.processing_utils import Unpack
|
| 15 |
+
from transformers.utils import auto_docstring, logging
|
| 16 |
+
from transformers.models.longcat_flash.modeling_longcat_flash import (
|
| 17 |
+
LongcatFlashForCausalLM,
|
| 18 |
+
LongcatFlashModel,
|
| 19 |
+
LongcatFlashRMSNorm,
|
| 20 |
+
LongcatFlashRotaryEmbedding,
|
| 21 |
+
LongcatFlashDecoderLayer,
|
| 22 |
+
LongcatFlashPreTrainedModel,
|
| 23 |
+
)
|
| 24 |
+
from .configuration_longcat_ngram import LongcatFlashNgramConfig
|
| 25 |
+
|
| 26 |
+
logger = logging.get_logger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@auto_docstring
|
| 30 |
+
class LongcatFlashNgramPreTrainedModel(LongcatFlashPreTrainedModel):
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class NgramCache(DynamicCache):
|
| 35 |
+
"""
|
| 36 |
+
Extended DynamicCache for storing N-gram context alongside KV cache.
|
| 37 |
+
"""
|
| 38 |
+
def __init__(self, config=None):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.ngram_context = None
|
| 41 |
+
# Keep only n-1 tokens (minimum needed for N-gram computation)
|
| 42 |
+
self.max_context_len = config.emb_neighbor_num - 1
|
| 43 |
+
self.oe_ignored_token_ids = torch.tensor(config.oe_ignored_token_ids, dtype=torch.long)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def update_ngram_context(self, new_tokens: torch.Tensor) -> None:
|
| 47 |
+
"""
|
| 48 |
+
Update N-gram context with window management.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
new_tokens: New tokens to append, shape (batch_size, seq_len)
|
| 52 |
+
"""
|
| 53 |
+
new_tokens = new_tokens.clone()
|
| 54 |
+
new_tokens[torch.isin(new_tokens, self.oe_ignored_token_ids.to(new_tokens.device))] = 0
|
| 55 |
+
|
| 56 |
+
if self.ngram_context is None:
|
| 57 |
+
self.ngram_context = new_tokens
|
| 58 |
+
else:
|
| 59 |
+
self.ngram_context = torch.cat([self.ngram_context, new_tokens], dim=-1)
|
| 60 |
+
|
| 61 |
+
# Truncate to maintain constant memory footprint
|
| 62 |
+
if self.ngram_context.size(-1) > self.max_context_len:
|
| 63 |
+
self.ngram_context = self.ngram_context[..., -self.max_context_len:]
|
| 64 |
+
|
| 65 |
+
def reorder_cache(self, beam_idx: torch.LongTensor) -> "Cache":
|
| 66 |
+
"""Reorder cache for beam search."""
|
| 67 |
+
# Reorder parent's KV cache
|
| 68 |
+
super().reorder_cache(beam_idx)
|
| 69 |
+
|
| 70 |
+
# Reorder N-gram context
|
| 71 |
+
if self.ngram_context is not None:
|
| 72 |
+
self.ngram_context = self.ngram_context.index_select(0, beam_idx.to(self.ngram_context.device))
|
| 73 |
+
|
| 74 |
+
return self
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class EmbeddingWithMask(nn.Embedding):
|
| 78 |
+
def forward(self, input: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
| 79 |
+
"""
|
| 80 |
+
Args:
|
| 81 |
+
x (torch.Tensor): Input indices of shape (batch_size, seq_len)
|
| 82 |
+
mask (torch.Tensor): Boolean mask of shape (batch_size, seq_len).
|
| 83 |
+
True means compute, False means skip and return 0.
|
| 84 |
+
Returns:
|
| 85 |
+
torch.Tensor: Embeddings of shape (batch_size, seq_len, embedding_dim)
|
| 86 |
+
"""
|
| 87 |
+
if mask is not None:
|
| 88 |
+
# Ensure mask is boolean
|
| 89 |
+
mask = mask.bool()
|
| 90 |
+
else:
|
| 91 |
+
mask = torch.ones_like(input, dtype=torch.bool)
|
| 92 |
+
|
| 93 |
+
batch_size, seq_len = input.shape
|
| 94 |
+
embedding_dim = self.embedding_dim
|
| 95 |
+
|
| 96 |
+
# 1. Initialize the output tensor with zeros on the correct device
|
| 97 |
+
output = torch.zeros(
|
| 98 |
+
(batch_size, seq_len, embedding_dim),
|
| 99 |
+
device=input.device,
|
| 100 |
+
dtype=self.weight.dtype
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# 2. Filter out the valid indices using the mask
|
| 104 |
+
# valid_indices is a 1D tensor containing only the elements where mask is True
|
| 105 |
+
valid_indices = input[mask]
|
| 106 |
+
|
| 107 |
+
# 3. Only perform the embedding lookup if there is at least one valid index
|
| 108 |
+
if valid_indices.numel() > 0:
|
| 109 |
+
# Look up only the necessary embeddings (saves compute/memory bandwidth)
|
| 110 |
+
valid_embeddings = F.embedding(
|
| 111 |
+
valid_indices, self.weight, self.padding_idx, self.max_norm,
|
| 112 |
+
self.norm_type, self.scale_grad_by_freq, self.sparse)
|
| 113 |
+
|
| 114 |
+
# 4. Scatter the valid embeddings back to their original positions in the output tensor
|
| 115 |
+
output[mask] = valid_embeddings
|
| 116 |
+
|
| 117 |
+
return output
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class NgramEmbedding(nn.Module):
|
| 121 |
+
"""
|
| 122 |
+
Computes embeddings enriched with N-gram features without maintaining internal state.
|
| 123 |
+
"""
|
| 124 |
+
def __init__(self, config, base_embeddings):
|
| 125 |
+
super().__init__()
|
| 126 |
+
self.config = config
|
| 127 |
+
self.word_embeddings = base_embeddings
|
| 128 |
+
|
| 129 |
+
# self.m = config.ngram_vocab_size_ratio * config.vocab_size
|
| 130 |
+
self.m = config.ngram_vocab_size_ratio * config.text_vocab_size
|
| 131 |
+
self.k = config.emb_split_num
|
| 132 |
+
self.n = config.emb_neighbor_num
|
| 133 |
+
self.oe_ignored_token_ids = torch.tensor(config.oe_ignored_token_ids)
|
| 134 |
+
|
| 135 |
+
self._init_ngram_embeddings()
|
| 136 |
+
self._vocab_mods_cache = None
|
| 137 |
+
|
| 138 |
+
def _init_ngram_embeddings(self) -> None:
|
| 139 |
+
"""Initialize N-gram embedding and projection layers."""
|
| 140 |
+
num_embedders = self.k * (self.n - 1)
|
| 141 |
+
emb_dim = self.config.hidden_size // num_embedders
|
| 142 |
+
|
| 143 |
+
embedders = []
|
| 144 |
+
post_projs = []
|
| 145 |
+
|
| 146 |
+
for i in range(num_embedders):
|
| 147 |
+
vocab_size = int(self.m + i * 2 + 1)
|
| 148 |
+
emb = EmbeddingWithMask(vocab_size, emb_dim, padding_idx=self.config.pad_token_id)
|
| 149 |
+
proj = nn.Linear(emb_dim, self.config.hidden_size, bias=False)
|
| 150 |
+
embedders.append(emb)
|
| 151 |
+
post_projs.append(proj)
|
| 152 |
+
|
| 153 |
+
self.embedders = nn.ModuleList(embedders)
|
| 154 |
+
self.post_projs = nn.ModuleList(post_projs)
|
| 155 |
+
|
| 156 |
+
def _shift_right_ignore_eos(self, tensor: torch.Tensor, n: int, eos_token_id: int = 2) -> torch.Tensor:
|
| 157 |
+
p, q = tensor.shape
|
| 158 |
+
# special_token / modal set 0
|
| 159 |
+
special_tokens = 0
|
| 160 |
+
|
| 161 |
+
if n == 0:
|
| 162 |
+
return tensor.clone()
|
| 163 |
+
|
| 164 |
+
if n >= q:
|
| 165 |
+
return torch.zeros_like(tensor)
|
| 166 |
+
|
| 167 |
+
result = torch.zeros_like(tensor)
|
| 168 |
+
|
| 169 |
+
# Find all special_token/modal/EOS locations
|
| 170 |
+
special_mask = (tensor == special_tokens)
|
| 171 |
+
total_mask = (tensor == eos_token_id | special_mask)
|
| 172 |
+
|
| 173 |
+
# Calculate the segment ID to which each position belongs
|
| 174 |
+
eos_cumsum = total_mask.long().cumsum(dim=1)
|
| 175 |
+
# Shift right by 1, so that the first EOS position still belongs to segment 0, and the second EOS position belongs to segment 1
|
| 176 |
+
segment_ids = torch.cat([
|
| 177 |
+
torch.zeros(p, 1, dtype=torch.long, device=tensor.device),
|
| 178 |
+
eos_cumsum[:, :-1]
|
| 179 |
+
], dim=1)
|
| 180 |
+
|
| 181 |
+
col_indices = torch.arange(q, device=tensor.device).unsqueeze(0).expand(p, q)
|
| 182 |
+
# Number of segments
|
| 183 |
+
max_segments = segment_ids.max().item() + 1
|
| 184 |
+
segment_starts = torch.full((p, max_segments), q, dtype=torch.long, device=tensor.device)
|
| 185 |
+
# Calculate the starting position of each segment
|
| 186 |
+
segment_starts.scatter_reduce_(1, segment_ids, col_indices, reduce='amin', include_self=False)
|
| 187 |
+
|
| 188 |
+
# Get the start position of the segment to which each position belongs
|
| 189 |
+
segment_start_per_pos = torch.gather(segment_starts, 1, segment_ids)
|
| 190 |
+
|
| 191 |
+
# Calculate the offset of each position within the segment
|
| 192 |
+
offset_in_segment = col_indices - segment_start_per_pos
|
| 193 |
+
|
| 194 |
+
# Data for each position should be taken from the position offset -n within the segment
|
| 195 |
+
source_offset = offset_in_segment - n
|
| 196 |
+
valid_mask = source_offset >= 0
|
| 197 |
+
|
| 198 |
+
# Calculate the actual source index
|
| 199 |
+
source_indices = segment_start_per_pos + torch.clamp(source_offset, min=0)
|
| 200 |
+
|
| 201 |
+
# Data is collected by source_indices
|
| 202 |
+
result = torch.gather(tensor, 1, source_indices)
|
| 203 |
+
|
| 204 |
+
# Set invalid position to zero
|
| 205 |
+
result = result * valid_mask * (~special_mask)
|
| 206 |
+
|
| 207 |
+
return result
|
| 208 |
+
|
| 209 |
+
def _precompute_vocab_mods(self) -> Dict[Tuple[int, int], List[int]]:
|
| 210 |
+
"""Precompute modular arithmetic values for vocabulary."""
|
| 211 |
+
if self._vocab_mods_cache is not None:
|
| 212 |
+
return self._vocab_mods_cache
|
| 213 |
+
|
| 214 |
+
vocab_mods = {}
|
| 215 |
+
vocab_size = self.config.text_vocab_size
|
| 216 |
+
|
| 217 |
+
for i in range(2, self.n + 1):
|
| 218 |
+
for j in range(self.k):
|
| 219 |
+
index = (i - 2) * self.k + j
|
| 220 |
+
emb_vocab_dim = int(self.m + index * 2 + 1)
|
| 221 |
+
|
| 222 |
+
mods = []
|
| 223 |
+
power_mod = 1
|
| 224 |
+
for _ in range(i - 1):
|
| 225 |
+
power_mod = (power_mod * vocab_size) % emb_vocab_dim
|
| 226 |
+
mods.append(power_mod)
|
| 227 |
+
|
| 228 |
+
vocab_mods[(i, j)] = mods
|
| 229 |
+
|
| 230 |
+
self._vocab_mods_cache = vocab_mods
|
| 231 |
+
return vocab_mods
|
| 232 |
+
|
| 233 |
+
def _get_ngram_ids(
|
| 234 |
+
self,
|
| 235 |
+
input_ids: torch.Tensor,
|
| 236 |
+
shifted_ids: Dict[int, torch.Tensor],
|
| 237 |
+
vocab_mods: List[int],
|
| 238 |
+
ngram: int
|
| 239 |
+
) -> torch.Tensor:
|
| 240 |
+
"""Compute N-gram hash IDs using polynomial rolling hash."""
|
| 241 |
+
ngram_ids = input_ids.clone()
|
| 242 |
+
for k in range(2, ngram + 1):
|
| 243 |
+
ngram_ids = ngram_ids + shifted_ids[k] * vocab_mods[k - 2]
|
| 244 |
+
return ngram_ids
|
| 245 |
+
|
| 246 |
+
def forward(
|
| 247 |
+
self,
|
| 248 |
+
input_ids: torch.Tensor,
|
| 249 |
+
ngram_context: Optional[torch.Tensor] = None
|
| 250 |
+
) -> torch.Tensor:
|
| 251 |
+
"""
|
| 252 |
+
Stateless forward pass.
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
input_ids: Current input token IDs of shape (batch_size, seq_len)
|
| 256 |
+
ngram_context: Optional historical context of shape (batch_size, context_len)
|
| 257 |
+
|
| 258 |
+
Returns:
|
| 259 |
+
Embedding tensor of shape (batch_size, seq_len, hidden_size)
|
| 260 |
+
"""
|
| 261 |
+
seq_len = input_ids.size(-1)
|
| 262 |
+
|
| 263 |
+
# Determine complete context
|
| 264 |
+
if ngram_context is not None:
|
| 265 |
+
context = torch.cat([ngram_context[..., -(self.n-1):], input_ids], dim=-1)
|
| 266 |
+
else:
|
| 267 |
+
context = input_ids.clone()
|
| 268 |
+
|
| 269 |
+
# Skip N-gram look-up for oe_ignored_token_ids
|
| 270 |
+
oe_ignored_mask = torch.isin(input_ids, self.oe_ignored_token_ids.to(device=input_ids.device))
|
| 271 |
+
context[torch.isin(context, self.oe_ignored_token_ids.to(device=context.device))] = 0
|
| 272 |
+
|
| 273 |
+
# Base word embeddings
|
| 274 |
+
device = self.word_embeddings.weight.device
|
| 275 |
+
x = self.word_embeddings(input_ids.to(device)).clone()
|
| 276 |
+
|
| 277 |
+
# Precompute modular values
|
| 278 |
+
vocab_mods = self._precompute_vocab_mods()
|
| 279 |
+
|
| 280 |
+
# Compute shifted IDs
|
| 281 |
+
shifted_ids = {}
|
| 282 |
+
for i in range(2, self.n + 1):
|
| 283 |
+
shifted_ids[i] = self._shift_right_ignore_eos(
|
| 284 |
+
context, i - 1, eos_token_id=self.config.eos_token_id
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
# Add N-gram embeddings
|
| 288 |
+
for i in range(2, self.n + 1):
|
| 289 |
+
for j in range(self.k):
|
| 290 |
+
index = (i - 2) * self.k + j
|
| 291 |
+
emb_vocab_dim = int(self.m + index * 2 + 1)
|
| 292 |
+
|
| 293 |
+
ngram_ids = self._get_ngram_ids(context, shifted_ids, vocab_mods[(i, j)], ngram=i)
|
| 294 |
+
new_ids = (ngram_ids % emb_vocab_dim)[..., -seq_len:]
|
| 295 |
+
text_mask = new_ids > 0
|
| 296 |
+
|
| 297 |
+
embedder_device = self.embedders[index].weight.device
|
| 298 |
+
x_ngram = self.embedders[index](new_ids.to(embedder_device), text_mask)
|
| 299 |
+
|
| 300 |
+
proj_device = self.post_projs[index].weight.device
|
| 301 |
+
x_proj = self.post_projs[index](x_ngram.to(proj_device))
|
| 302 |
+
x = x + x_proj.to(x.device)
|
| 303 |
+
|
| 304 |
+
# Normalize
|
| 305 |
+
x[~oe_ignored_mask] /= (1 + self.k * (self.n - 1))
|
| 306 |
+
|
| 307 |
+
return x
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
class LongcatFlashNgramModel(LongcatFlashModel):
|
| 311 |
+
"""LongcatFlash model with N-gram enhanced embeddings."""
|
| 312 |
+
_keys_to_ignore_on_load_unexpected = [r"model\.mtp.*"]
|
| 313 |
+
config_class = LongcatFlashNgramConfig
|
| 314 |
+
|
| 315 |
+
def __init__(self, config):
|
| 316 |
+
super().__init__(config)
|
| 317 |
+
|
| 318 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 319 |
+
self.ngram_embeddings = NgramEmbedding(config, self.embed_tokens)
|
| 320 |
+
|
| 321 |
+
self.layers = nn.ModuleList(
|
| 322 |
+
[LongcatFlashDecoderLayer(config, layer_idx) for layer_idx in range(config.num_layers)]
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
self.head_dim = config.head_dim
|
| 326 |
+
self.config.num_hidden_layers = 2 * config.num_layers
|
| 327 |
+
self.norm = LongcatFlashRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 328 |
+
self.rotary_emb = LongcatFlashRotaryEmbedding(config=config)
|
| 329 |
+
self.gradient_checkpointing = False
|
| 330 |
+
|
| 331 |
+
self.post_init()
|
| 332 |
+
|
| 333 |
+
def forward(
|
| 334 |
+
self,
|
| 335 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 336 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 337 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 338 |
+
past_key_values: Optional[Cache] = None,
|
| 339 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 340 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 341 |
+
use_cache: Optional[bool] = None,
|
| 342 |
+
**kwargs
|
| 343 |
+
) -> BaseModelOutputWithPast:
|
| 344 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 345 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 346 |
+
|
| 347 |
+
# Extract N-gram context if available
|
| 348 |
+
ngram_context = None
|
| 349 |
+
if isinstance(past_key_values, NgramCache) and past_key_values.ngram_context is not None:
|
| 350 |
+
ngram_context = past_key_values.ngram_context
|
| 351 |
+
|
| 352 |
+
if inputs_embeds is None:
|
| 353 |
+
inputs_embeds = self.ngram_embeddings(input_ids, ngram_context=ngram_context)
|
| 354 |
+
|
| 355 |
+
# Initialize NgramCache if needed
|
| 356 |
+
if use_cache and past_key_values is None:
|
| 357 |
+
past_key_values = NgramCache(config=self.config)
|
| 358 |
+
|
| 359 |
+
# Update N-gram context
|
| 360 |
+
if use_cache and isinstance(past_key_values, NgramCache) and input_ids is not None:
|
| 361 |
+
past_key_values.update_ngram_context(input_ids)
|
| 362 |
+
|
| 363 |
+
# Prepare cache position
|
| 364 |
+
if cache_position is None:
|
| 365 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 366 |
+
cache_position = torch.arange(
|
| 367 |
+
inputs_embeds.shape[1], device=inputs_embeds.device
|
| 368 |
+
) + past_seen_tokens
|
| 369 |
+
|
| 370 |
+
if position_ids is None:
|
| 371 |
+
position_ids = cache_position.unsqueeze(0)
|
| 372 |
+
|
| 373 |
+
# Create causal mask
|
| 374 |
+
causal_mask = create_causal_mask(
|
| 375 |
+
config=self.config,
|
| 376 |
+
input_embeds=inputs_embeds,
|
| 377 |
+
attention_mask=attention_mask,
|
| 378 |
+
cache_position=cache_position,
|
| 379 |
+
past_key_values=past_key_values,
|
| 380 |
+
position_ids=position_ids,
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
# Forward through decoder layers
|
| 384 |
+
hidden_states = inputs_embeds
|
| 385 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 386 |
+
|
| 387 |
+
for decoder_layer in self.layers[: self.config.num_layers]:
|
| 388 |
+
hidden_states = decoder_layer(
|
| 389 |
+
hidden_states,
|
| 390 |
+
attention_mask=causal_mask,
|
| 391 |
+
position_ids=position_ids,
|
| 392 |
+
past_key_values=past_key_values,
|
| 393 |
+
cache_position=cache_position,
|
| 394 |
+
position_embeddings=position_embeddings,
|
| 395 |
+
**kwargs,
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
hidden_states = self.norm(hidden_states)
|
| 399 |
+
|
| 400 |
+
return BaseModelOutputWithPast(
|
| 401 |
+
last_hidden_state=hidden_states,
|
| 402 |
+
past_key_values=past_key_values,
|
| 403 |
+
hidden_states=None,
|
| 404 |
+
attentions=None,
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
class LongcatFlashNgramForCausalLM(LongcatFlashForCausalLM):
|
| 409 |
+
"""LongcatFlash model for causal language modeling with N-gram embeddings."""
|
| 410 |
+
_keys_to_ignore_on_load_unexpected = [r"model\.mtp.*"]
|
| 411 |
+
config_class = LongcatFlashNgramConfig
|
| 412 |
+
|
| 413 |
+
def __init__(self, config):
|
| 414 |
+
super().__init__(config)
|
| 415 |
+
self.model = LongcatFlashNgramModel(config)
|
| 416 |
+
|
| 417 |
+
@torch.no_grad()
|
| 418 |
+
def generate(self, inputs=None, generation_config=None, **kwargs):
|
| 419 |
+
"""Override to ensure NgramCache is used."""
|
| 420 |
+
|
| 421 |
+
if "past_key_values" not in kwargs or kwargs["past_key_values"] is None:
|
| 422 |
+
kwargs["past_key_values"] = NgramCache(config=self.config)
|
| 423 |
+
|
| 424 |
+
return super().generate(inputs=inputs, generation_config=generation_config, **kwargs)
|
| 425 |
+
|
| 426 |
+
__all__ = ["LongcatFlashNgramPreTrainedModel", "LongcatFlashNgramModel", "LongcatFlashNgramForCausalLM"]
|
modular_longcat_next.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
from flash_attn import flash_attn_varlen_func
|
| 6 |
+
|
| 7 |
+
from transformers.models.t5.modeling_t5 import T5LayerNorm as RMSNorm
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class FlashVarLenAttention(nn.Module):
|
| 11 |
+
def __init__(self, embed_dim, num_heads, causal=False, window_size=(-1,-1)):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.embed_dim = embed_dim
|
| 14 |
+
self.num_heads = num_heads
|
| 15 |
+
self.head_dim = embed_dim // num_heads
|
| 16 |
+
|
| 17 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
|
| 18 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
| 19 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
| 20 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
| 21 |
+
|
| 22 |
+
self.causal = causal
|
| 23 |
+
self.window_size = window_size
|
| 24 |
+
|
| 25 |
+
def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor):
|
| 26 |
+
bsz, _ = hidden_states.size()
|
| 27 |
+
|
| 28 |
+
query_states = self.q_proj(hidden_states)
|
| 29 |
+
query_states = query_states.view(bsz, self.num_heads, self.head_dim).contiguous()
|
| 30 |
+
key_states = self.k_proj(hidden_states)
|
| 31 |
+
key_states = key_states.view(bsz, self.num_heads, self.head_dim).contiguous()
|
| 32 |
+
value_states = self.v_proj(hidden_states)
|
| 33 |
+
value_states = value_states.view(bsz, self.num_heads, self.head_dim).contiguous()
|
| 34 |
+
|
| 35 |
+
cu_len = F.pad(torch.cumsum(seq_len, dim=0), (1, 0), "constant", 0).to(torch.int32)
|
| 36 |
+
max_seqlen = torch.max(seq_len).to(torch.int32).detach()
|
| 37 |
+
|
| 38 |
+
attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_len, cu_len, max_seqlen,
|
| 39 |
+
max_seqlen, causal=self.causal, window_size=self.window_size) # (bsz * qlen, nheads, headdim)
|
| 40 |
+
attn_output = attn_output.reshape(bsz, self.embed_dim)
|
| 41 |
+
attn_output = self.out_proj(attn_output)
|
| 42 |
+
return attn_output
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class CasualDepthTransformerLayer(nn.Module):
|
| 47 |
+
def __init__(self, depth, transformer_dim, transformer_ffn_scale):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.depth = depth
|
| 50 |
+
self.transformer_dim = transformer_dim
|
| 51 |
+
self.transformer_ffn_scale = transformer_ffn_scale
|
| 52 |
+
self.num_heads = self.transformer_dim // 128
|
| 53 |
+
|
| 54 |
+
assert self.transformer_dim % 128 == 0
|
| 55 |
+
assert self.transformer_dim % depth == 0
|
| 56 |
+
|
| 57 |
+
self.self_attention = FlashVarLenAttention(embed_dim=self.transformer_dim, num_heads=self.num_heads, causal=True)
|
| 58 |
+
|
| 59 |
+
self.layernorm1 = RMSNorm(self.transformer_dim)
|
| 60 |
+
self.layernorm2 = RMSNorm(self.transformer_dim)
|
| 61 |
+
|
| 62 |
+
self.linear1 = nn.Linear(self.transformer_dim, self.transformer_ffn_scale * self.transformer_dim)
|
| 63 |
+
self.linear2 = nn.Linear(self.transformer_ffn_scale * self.transformer_dim, self.transformer_dim)
|
| 64 |
+
|
| 65 |
+
def forward(self, x):
|
| 66 |
+
bsz = x.shape[0]
|
| 67 |
+
res = x
|
| 68 |
+
x = self.layernorm1(x)
|
| 69 |
+
seqlens = torch.tensor([self.depth] * bsz, dtype=torch.int32, device=x.device)
|
| 70 |
+
_x = self.self_attention(x.view(-1, self.transformer_dim), seqlens)
|
| 71 |
+
_x = _x.view(bsz, self.depth, self.transformer_dim).contiguous()
|
| 72 |
+
|
| 73 |
+
_res = _x + res # (bs, sl, d)
|
| 74 |
+
res = self.layernorm2(_res)
|
| 75 |
+
x = torch.einsum('bld,tld->blt', res, torch.reshape(self.linear1.weight, (self.transformer_ffn_scale * self.transformer_dim // self.depth, self.depth, self.transformer_dim)))
|
| 76 |
+
x = torch.nn.functional.gelu(x)
|
| 77 |
+
x = torch.einsum('blt,dlt->bld',x, torch.reshape(self.linear2.weight, (self.transformer_dim, self.depth, self.transformer_ffn_scale * self.transformer_dim // self.depth)))
|
| 78 |
+
return _res + x
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class CasualDepthTransformerHead(nn.Module):
|
| 82 |
+
"""
|
| 83 |
+
Depth-wise causal transformer head shared by image/audio heads.
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
def __init__(
|
| 87 |
+
self,
|
| 88 |
+
hidden_size,
|
| 89 |
+
codebook_sizes,
|
| 90 |
+
transformer_layer_num,
|
| 91 |
+
transformer_dim,
|
| 92 |
+
transformer_ffn_scale,
|
| 93 |
+
gradient_checkpointing=False,
|
| 94 |
+
):
|
| 95 |
+
super().__init__()
|
| 96 |
+
self.hidden_size = hidden_size
|
| 97 |
+
self.codebook_sizes = codebook_sizes
|
| 98 |
+
self.transformer_ffn_scale = transformer_ffn_scale
|
| 99 |
+
self.gradient_checkpointing = gradient_checkpointing
|
| 100 |
+
|
| 101 |
+
if self.transformer_ffn_scale > 0:
|
| 102 |
+
self.hidden_norm = RMSNorm(self.hidden_size)
|
| 103 |
+
self.hidden_proj = nn.Linear(self.hidden_size, transformer_dim, bias=False)
|
| 104 |
+
|
| 105 |
+
self.transformer_layers = nn.ModuleList(
|
| 106 |
+
[
|
| 107 |
+
CasualDepthTransformerLayer(len(codebook_sizes), transformer_dim, transformer_ffn_scale)
|
| 108 |
+
for _ in range(transformer_layer_num)
|
| 109 |
+
]
|
| 110 |
+
)
|
| 111 |
+
self.headnorm = RMSNorm(transformer_dim)
|
| 112 |
+
self.heads = nn.ModuleList(
|
| 113 |
+
[nn.Linear(transformer_dim, vq_size + 1) for vq_size in codebook_sizes]
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
for param in self.parameters():
|
| 117 |
+
param.requires_grad = False
|
| 118 |
+
|
| 119 |
+
def forward(self, x, visual_tokens, visual_emb_layers, level):
|
| 120 |
+
main_device = "cuda:0"
|
| 121 |
+
visual_tokens = visual_tokens.to(main_device)
|
| 122 |
+
visual_emb_layers = visual_emb_layers.to(main_device)
|
| 123 |
+
|
| 124 |
+
cumsum_visual_embed = torch.stack([
|
| 125 |
+
visual_emb_layers(visual_tokens[..., i])
|
| 126 |
+
for i, vq_size in enumerate(self.codebook_sizes[:-1])
|
| 127 |
+
], dim=1).to(x.device)
|
| 128 |
+
|
| 129 |
+
cumsum_visual_embed = torch.cumsum(cumsum_visual_embed, dim=1) # (bs, depth-1, d)
|
| 130 |
+
|
| 131 |
+
hidden_states = torch.concat([x.reshape(-1, 1, self.hidden_size), cumsum_visual_embed], dim=1) # (bs, depth, d)
|
| 132 |
+
assert hidden_states.size(1) == len(self.codebook_sizes)
|
| 133 |
+
|
| 134 |
+
if self.transformer_ffn_scale > 0:
|
| 135 |
+
hidden_states = self.hidden_norm(hidden_states)
|
| 136 |
+
hidden_states = self.hidden_proj(hidden_states)
|
| 137 |
+
|
| 138 |
+
for i, tlayer in enumerate(self.transformer_layers):
|
| 139 |
+
if self.gradient_checkpointing and self.training:
|
| 140 |
+
|
| 141 |
+
def create_custom_forward(module):
|
| 142 |
+
def custom_forward(*inputs):
|
| 143 |
+
# None for past_key_value
|
| 144 |
+
return module(*inputs)
|
| 145 |
+
|
| 146 |
+
return custom_forward
|
| 147 |
+
|
| 148 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 149 |
+
create_custom_forward(tlayer), hidden_states,
|
| 150 |
+
)
|
| 151 |
+
else:
|
| 152 |
+
hidden_states = tlayer(
|
| 153 |
+
hidden_states,
|
| 154 |
+
)
|
| 155 |
+
hidden_states = self.headnorm(hidden_states)
|
| 156 |
+
logits = self.heads[level](hidden_states[:, level])
|
| 157 |
+
return logits
|
modular_longcat_next_audio.py
ADDED
|
@@ -0,0 +1,2039 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import copy
|
| 3 |
+
from abc import ABC
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Any, Dict, Optional
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torchaudio
|
| 10 |
+
from einops import pack, rearrange, repeat
|
| 11 |
+
from flash_attn import flash_attn_varlen_func
|
| 12 |
+
from torch import nn
|
| 13 |
+
from torch.cuda.amp import autocast
|
| 14 |
+
from torch.nn import functional as F
|
| 15 |
+
|
| 16 |
+
from diffusers.models.activations import get_activation
|
| 17 |
+
from diffusers.models.attention import (
|
| 18 |
+
GEGLU,
|
| 19 |
+
GELU,
|
| 20 |
+
AdaLayerNorm,
|
| 21 |
+
AdaLayerNormZero,
|
| 22 |
+
ApproximateGELU,
|
| 23 |
+
)
|
| 24 |
+
from diffusers.models.attention_processor import Attention
|
| 25 |
+
from diffusers.models.lora import LoRACompatibleLinear
|
| 26 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
| 27 |
+
|
| 28 |
+
from transformers.activations import ACT2FN
|
| 29 |
+
from transformers.modeling_outputs import ModelOutput
|
| 30 |
+
from transformers.utils import logging
|
| 31 |
+
|
| 32 |
+
from .cosy24k_vocoder import Cosy24kVocoder
|
| 33 |
+
|
| 34 |
+
logger = logging.get_logger(__name__)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def sinusoids(length, channels, max_timescale=10000):
|
| 38 |
+
"""Returns sinusoids for positional embedding"""
|
| 39 |
+
assert channels % 2 == 0
|
| 40 |
+
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
| 41 |
+
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
|
| 42 |
+
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
| 43 |
+
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_sequence_mask(inputs, inputs_length):
|
| 47 |
+
if inputs.dim() == 3:
|
| 48 |
+
bsz, tgt_len, _ = inputs.size()
|
| 49 |
+
else:
|
| 50 |
+
bsz, tgt_len = inputs_length.shape[0], torch.max(inputs_length)
|
| 51 |
+
sequence_mask = torch.arange(0, tgt_len).to(inputs.device)
|
| 52 |
+
sequence_mask = torch.lt(sequence_mask, inputs_length.reshape(bsz, 1)).view(bsz, tgt_len, 1)
|
| 53 |
+
unpacking_index = torch.cumsum(sequence_mask.to(torch.int64).view(-1), dim=0) - 1 # 转成下标
|
| 54 |
+
return sequence_mask, unpacking_index
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def unpack_hidden_states(hidden_states, lengths):
|
| 58 |
+
bsz = lengths.shape[0]
|
| 59 |
+
sequence_mask, unpacking_index = get_sequence_mask(hidden_states, lengths)
|
| 60 |
+
hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(
|
| 61 |
+
bsz, torch.max(lengths), hidden_states.shape[-1]
|
| 62 |
+
)
|
| 63 |
+
hidden_states = torch.where(
|
| 64 |
+
sequence_mask, hidden_states, 0
|
| 65 |
+
) # 3d (bsz, max_input_len, d)
|
| 66 |
+
return hidden_states
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def uniform_init(*shape):
|
| 70 |
+
t = torch.zeros(shape)
|
| 71 |
+
nn.init.kaiming_uniform_(t)
|
| 72 |
+
return t
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def cdist(x, y):
|
| 76 |
+
x2 = torch.sum(x ** 2, dim=-1, keepdims=True) # (b, 1)
|
| 77 |
+
y2 = torch.sum(y ** 2, dim=-1).reshape(1, -1) # (1, c)
|
| 78 |
+
xy = torch.einsum('bd,cd->bc', x, y) * -2
|
| 79 |
+
return (x2 + y2 + xy).clamp(min=0).sqrt() # (b, c)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
| 83 |
+
assert mask.dtype == torch.bool
|
| 84 |
+
assert dtype in [torch.float32, torch.bfloat16, torch.float16]
|
| 85 |
+
mask = mask.to(dtype)
|
| 86 |
+
# attention mask bias
|
| 87 |
+
# NOTE(Mddct): torch.finfo jit issues
|
| 88 |
+
# chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
|
| 89 |
+
mask = (1.0 - mask) * torch.finfo(dtype).min
|
| 90 |
+
return mask
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def subsequent_chunk_mask(
|
| 94 |
+
size: int,
|
| 95 |
+
chunk_size: int,
|
| 96 |
+
num_left_chunks: int = -1,
|
| 97 |
+
device: torch.device = torch.device("cpu"),
|
| 98 |
+
) -> torch.Tensor:
|
| 99 |
+
"""Create mask for subsequent steps (size, size) with chunk size,
|
| 100 |
+
this is for streaming encoder
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
size (int): size of mask
|
| 104 |
+
chunk_size (int): size of chunk
|
| 105 |
+
num_left_chunks (int): number of left chunks
|
| 106 |
+
<0: use full chunk
|
| 107 |
+
>=0: use num_left_chunks
|
| 108 |
+
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
torch.Tensor: mask
|
| 112 |
+
|
| 113 |
+
Examples:
|
| 114 |
+
>>> subsequent_chunk_mask(4, 2)
|
| 115 |
+
[[1, 1, 0, 0],
|
| 116 |
+
[1, 1, 0, 0],
|
| 117 |
+
[1, 1, 1, 1],
|
| 118 |
+
[1, 1, 1, 1]]
|
| 119 |
+
"""
|
| 120 |
+
# NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
|
| 121 |
+
# actually this is not needed after we have inference cache implemented, will remove it later
|
| 122 |
+
pos_idx = torch.arange(size, device=device)
|
| 123 |
+
block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
|
| 124 |
+
ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
|
| 125 |
+
return ret
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def add_optional_chunk_mask(xs: torch.Tensor,
|
| 129 |
+
masks: torch.Tensor,
|
| 130 |
+
use_dynamic_chunk: bool,
|
| 131 |
+
use_dynamic_left_chunk: bool,
|
| 132 |
+
decoding_chunk_size: int,
|
| 133 |
+
static_chunk_size: int,
|
| 134 |
+
num_decoding_left_chunks: int,
|
| 135 |
+
enable_full_context: bool = True):
|
| 136 |
+
""" Apply optional mask for encoder.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
xs (torch.Tensor): padded input, (B, L, D), L for max length
|
| 140 |
+
mask (torch.Tensor): mask for xs, (B, 1, L)
|
| 141 |
+
use_dynamic_chunk (bool): whether to use dynamic chunk or not
|
| 142 |
+
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
|
| 143 |
+
training.
|
| 144 |
+
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
|
| 145 |
+
0: default for training, use random dynamic chunk.
|
| 146 |
+
<0: for decoding, use full chunk.
|
| 147 |
+
>0: for decoding, use fixed chunk size as set.
|
| 148 |
+
static_chunk_size (int): chunk size for static chunk training/decoding
|
| 149 |
+
if it's greater than 0, if use_dynamic_chunk is true,
|
| 150 |
+
this parameter will be ignored
|
| 151 |
+
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
| 152 |
+
the chunk size is decoding_chunk_size.
|
| 153 |
+
>=0: use num_decoding_left_chunks
|
| 154 |
+
<0: use all left chunks
|
| 155 |
+
enable_full_context (bool):
|
| 156 |
+
True: chunk size is either [1, 25] or full context(max_len)
|
| 157 |
+
False: chunk size ~ U[1, 25]
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
torch.Tensor: chunk mask of the input xs.
|
| 161 |
+
"""
|
| 162 |
+
# Whether to use chunk mask or not
|
| 163 |
+
if use_dynamic_chunk:
|
| 164 |
+
max_len = xs.size(1)
|
| 165 |
+
if decoding_chunk_size < 0:
|
| 166 |
+
chunk_size = max_len
|
| 167 |
+
num_left_chunks = -1
|
| 168 |
+
elif decoding_chunk_size > 0:
|
| 169 |
+
chunk_size = decoding_chunk_size
|
| 170 |
+
num_left_chunks = num_decoding_left_chunks
|
| 171 |
+
else:
|
| 172 |
+
# chunk size is either [1, 25] or full context(max_len).
|
| 173 |
+
# Since we use 4 times subsampling and allow up to 1s(100 frames)
|
| 174 |
+
# delay, the maximum frame is 100 / 4 = 25.
|
| 175 |
+
chunk_size = torch.randint(1, max_len, (1, )).item()
|
| 176 |
+
num_left_chunks = -1
|
| 177 |
+
if chunk_size > max_len // 2 and enable_full_context:
|
| 178 |
+
chunk_size = max_len
|
| 179 |
+
else:
|
| 180 |
+
chunk_size = chunk_size % 25 + 1
|
| 181 |
+
if use_dynamic_left_chunk:
|
| 182 |
+
max_left_chunks = (max_len - 1) // chunk_size
|
| 183 |
+
num_left_chunks = torch.randint(0, max_left_chunks,
|
| 184 |
+
(1, )).item()
|
| 185 |
+
chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
|
| 186 |
+
num_left_chunks,
|
| 187 |
+
xs.device) # (L, L)
|
| 188 |
+
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
| 189 |
+
chunk_masks = masks & chunk_masks # (B, L, L)
|
| 190 |
+
elif static_chunk_size > 0:
|
| 191 |
+
num_left_chunks = num_decoding_left_chunks
|
| 192 |
+
chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
|
| 193 |
+
num_left_chunks,
|
| 194 |
+
xs.device) # (L, L)
|
| 195 |
+
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
| 196 |
+
chunk_masks = masks & chunk_masks # (B, L, L)
|
| 197 |
+
else:
|
| 198 |
+
chunk_masks = masks
|
| 199 |
+
return chunk_masks
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class EuclideanCodebook(nn.Module):
|
| 203 |
+
def __init__(
|
| 204 |
+
self,
|
| 205 |
+
dim,
|
| 206 |
+
codebook_size,
|
| 207 |
+
init_std=0.02,
|
| 208 |
+
):
|
| 209 |
+
super().__init__()
|
| 210 |
+
self.init_std = init_std
|
| 211 |
+
self.dim = dim
|
| 212 |
+
self.codebook_size = codebook_size
|
| 213 |
+
|
| 214 |
+
embed = uniform_init(codebook_size, dim).to(torch.float32)
|
| 215 |
+
self.cluster_size = nn.Parameter(torch.ones(codebook_size))
|
| 216 |
+
self.embed_avg = nn.Parameter(embed.clone())
|
| 217 |
+
self.embed = nn.Parameter(embed)
|
| 218 |
+
del embed
|
| 219 |
+
|
| 220 |
+
@autocast(enabled=True, dtype=torch.float32)
|
| 221 |
+
@torch.no_grad()
|
| 222 |
+
def forward(self, x):
|
| 223 |
+
assert(len(x.shape) == 2)
|
| 224 |
+
assert(x.dtype == torch.float32)
|
| 225 |
+
embed = self.embed.detach().to(x.device)
|
| 226 |
+
dist = -cdist(x, embed) # dist((bs*sl, d), (c, d)) --> (bs*sl, c)
|
| 227 |
+
embed_ind = dist.argmax(dim=-1)
|
| 228 |
+
quantize = embed[embed_ind] # (bs*sl, d)
|
| 229 |
+
return quantize, embed_ind, dist
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class VectorQuantize(nn.Module):
|
| 233 |
+
def __init__(self, config, *args, **kwargs):
|
| 234 |
+
super().__init__(*args, **kwargs)
|
| 235 |
+
self.config = config
|
| 236 |
+
self.codebook = EuclideanCodebook(dim=config.dim, codebook_size=config.codebook_size)
|
| 237 |
+
|
| 238 |
+
def forward(self, x, input_length):
|
| 239 |
+
batch_size, seq_len, _ = x.shape
|
| 240 |
+
mask, unpacking_index = get_sequence_mask(x, input_length)
|
| 241 |
+
if x.dtype != torch.float32:
|
| 242 |
+
x = x.to(torch.float32)
|
| 243 |
+
x = torch.masked_select(x, mask).reshape(-1, self.config.dim) # (bs*sl?, d)
|
| 244 |
+
quantize, embed_ind, _ = self.codebook(x)
|
| 245 |
+
quantize = torch.index_select(quantize, 0, unpacking_index).view(batch_size, seq_len, self.config.dim)
|
| 246 |
+
quantize = torch.where(mask, quantize, 0)
|
| 247 |
+
embed_ind = torch.index_select(embed_ind.reshape(-1, 1), 0, unpacking_index).view(batch_size, seq_len, 1)
|
| 248 |
+
embed_ind = torch.where(mask, embed_ind, -1).squeeze()
|
| 249 |
+
return quantize, embed_ind
|
| 250 |
+
|
| 251 |
+
def get_output_from_indices(self, indices):
|
| 252 |
+
indices = indices.to(self.codebook.embed.device)
|
| 253 |
+
return self.codebook.embed[indices]
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
class SnakeBeta(nn.Module):
|
| 257 |
+
"""
|
| 258 |
+
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
| 259 |
+
Shape:
|
| 260 |
+
- Input: (B, C, T)
|
| 261 |
+
- Output: (B, C, T), same shape as the input
|
| 262 |
+
Parameters:
|
| 263 |
+
- alpha - trainable parameter that controls frequency
|
| 264 |
+
- beta - trainable parameter that controls magnitude
|
| 265 |
+
References:
|
| 266 |
+
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
| 267 |
+
https://arxiv.org/abs/2006.08195
|
| 268 |
+
Examples:
|
| 269 |
+
>>> a1 = snakebeta(256)
|
| 270 |
+
>>> x = torch.randn(256)
|
| 271 |
+
>>> x = a1(x)
|
| 272 |
+
"""
|
| 273 |
+
|
| 274 |
+
def __init__(
|
| 275 |
+
self,
|
| 276 |
+
in_features,
|
| 277 |
+
out_features,
|
| 278 |
+
alpha=1.0,
|
| 279 |
+
alpha_trainable=True,
|
| 280 |
+
alpha_logscale=True,
|
| 281 |
+
):
|
| 282 |
+
"""
|
| 283 |
+
Initialization.
|
| 284 |
+
INPUT:
|
| 285 |
+
- in_features: shape of the input
|
| 286 |
+
- alpha - trainable parameter that controls frequency
|
| 287 |
+
- beta - trainable parameter that controls magnitude
|
| 288 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
| 289 |
+
beta is initialized to 1 by default, higher values = higher-magnitude.
|
| 290 |
+
alpha will be trained along with the rest of your model.
|
| 291 |
+
"""
|
| 292 |
+
super().__init__()
|
| 293 |
+
self.in_features = (
|
| 294 |
+
out_features if isinstance(out_features, list) else [out_features]
|
| 295 |
+
)
|
| 296 |
+
self.proj = LoRACompatibleLinear(in_features, out_features)
|
| 297 |
+
|
| 298 |
+
# initialize alpha
|
| 299 |
+
self.alpha_logscale = alpha_logscale
|
| 300 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
| 301 |
+
self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha)
|
| 302 |
+
self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha)
|
| 303 |
+
else: # linear scale alphas initialized to ones
|
| 304 |
+
self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha)
|
| 305 |
+
self.beta = nn.Parameter(torch.ones(self.in_features) * alpha)
|
| 306 |
+
|
| 307 |
+
self.alpha.requires_grad = alpha_trainable
|
| 308 |
+
self.beta.requires_grad = alpha_trainable
|
| 309 |
+
|
| 310 |
+
self.no_div_by_zero = 0.000000001
|
| 311 |
+
|
| 312 |
+
def forward(self, x):
|
| 313 |
+
"""
|
| 314 |
+
Forward pass of the function.
|
| 315 |
+
Applies the function to the input elementwise.
|
| 316 |
+
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
| 317 |
+
"""
|
| 318 |
+
x = self.proj(x)
|
| 319 |
+
if self.alpha_logscale:
|
| 320 |
+
alpha = torch.exp(self.alpha)
|
| 321 |
+
beta = torch.exp(self.beta)
|
| 322 |
+
else:
|
| 323 |
+
alpha = self.alpha
|
| 324 |
+
beta = self.beta
|
| 325 |
+
|
| 326 |
+
x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(
|
| 327 |
+
torch.sin(x * alpha), 2
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
return x
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
class FeedForward(nn.Module):
|
| 334 |
+
r"""
|
| 335 |
+
A feed-forward layer.
|
| 336 |
+
|
| 337 |
+
Parameters:
|
| 338 |
+
dim (`int`): The number of channels in the input.
|
| 339 |
+
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
| 340 |
+
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
| 341 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
| 342 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
| 343 |
+
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
| 344 |
+
"""
|
| 345 |
+
|
| 346 |
+
def __init__(
|
| 347 |
+
self,
|
| 348 |
+
dim: int,
|
| 349 |
+
dim_out: Optional[int] = None,
|
| 350 |
+
mult: int = 4,
|
| 351 |
+
dropout: float = 0.0,
|
| 352 |
+
activation_fn: str = "geglu",
|
| 353 |
+
final_dropout: bool = False,
|
| 354 |
+
):
|
| 355 |
+
super().__init__()
|
| 356 |
+
inner_dim = int(dim * mult)
|
| 357 |
+
dim_out = dim_out if dim_out is not None else dim
|
| 358 |
+
|
| 359 |
+
if activation_fn == "gelu":
|
| 360 |
+
act_fn = GELU(dim, inner_dim)
|
| 361 |
+
if activation_fn == "gelu-approximate":
|
| 362 |
+
act_fn = GELU(dim, inner_dim, approximate="tanh")
|
| 363 |
+
elif activation_fn == "geglu":
|
| 364 |
+
act_fn = GEGLU(dim, inner_dim)
|
| 365 |
+
elif activation_fn == "geglu-approximate":
|
| 366 |
+
act_fn = ApproximateGELU(dim, inner_dim)
|
| 367 |
+
elif activation_fn == "snakebeta":
|
| 368 |
+
act_fn = SnakeBeta(dim, inner_dim)
|
| 369 |
+
|
| 370 |
+
self.net = nn.ModuleList([])
|
| 371 |
+
# project in
|
| 372 |
+
self.net.append(act_fn)
|
| 373 |
+
# project dropout
|
| 374 |
+
self.net.append(nn.Dropout(dropout))
|
| 375 |
+
# project out
|
| 376 |
+
self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
|
| 377 |
+
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
| 378 |
+
if final_dropout:
|
| 379 |
+
self.net.append(nn.Dropout(dropout))
|
| 380 |
+
|
| 381 |
+
def forward(self, hidden_states):
|
| 382 |
+
for module in self.net:
|
| 383 |
+
hidden_states = module(hidden_states)
|
| 384 |
+
return hidden_states
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
@maybe_allow_in_graph
|
| 388 |
+
class BasicTransformerBlock(nn.Module):
|
| 389 |
+
r"""
|
| 390 |
+
A basic Transformer block.
|
| 391 |
+
|
| 392 |
+
Parameters:
|
| 393 |
+
dim (`int`): The number of channels in the input and output.
|
| 394 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
| 395 |
+
attention_head_dim (`int`): The number of channels in each head.
|
| 396 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
| 397 |
+
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
| 398 |
+
only_cross_attention (`bool`, *optional*):
|
| 399 |
+
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
| 400 |
+
double_self_attention (`bool`, *optional*):
|
| 401 |
+
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
| 402 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
| 403 |
+
num_embeds_ada_norm (:
|
| 404 |
+
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
| 405 |
+
attention_bias (:
|
| 406 |
+
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
| 407 |
+
"""
|
| 408 |
+
|
| 409 |
+
def __init__(
|
| 410 |
+
self,
|
| 411 |
+
dim: int,
|
| 412 |
+
num_attention_heads: int,
|
| 413 |
+
attention_head_dim: int,
|
| 414 |
+
dropout=0.0,
|
| 415 |
+
cross_attention_dim: Optional[int] = None,
|
| 416 |
+
activation_fn: str = "geglu",
|
| 417 |
+
num_embeds_ada_norm: Optional[int] = None,
|
| 418 |
+
attention_bias: bool = False,
|
| 419 |
+
only_cross_attention: bool = False,
|
| 420 |
+
double_self_attention: bool = False,
|
| 421 |
+
upcast_attention: bool = False,
|
| 422 |
+
norm_elementwise_affine: bool = True,
|
| 423 |
+
norm_type: str = "layer_norm",
|
| 424 |
+
final_dropout: bool = False,
|
| 425 |
+
use_omni_attn: bool = False,
|
| 426 |
+
):
|
| 427 |
+
super().__init__()
|
| 428 |
+
|
| 429 |
+
self.use_omni_attn = use_omni_attn
|
| 430 |
+
self.dim = dim
|
| 431 |
+
|
| 432 |
+
self.only_cross_attention = only_cross_attention
|
| 433 |
+
|
| 434 |
+
self.use_ada_layer_norm_zero = (
|
| 435 |
+
num_embeds_ada_norm is not None
|
| 436 |
+
) and norm_type == "ada_norm_zero"
|
| 437 |
+
self.use_ada_layer_norm = (
|
| 438 |
+
num_embeds_ada_norm is not None
|
| 439 |
+
) and norm_type == "ada_norm"
|
| 440 |
+
|
| 441 |
+
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
| 442 |
+
raise ValueError(
|
| 443 |
+
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
| 444 |
+
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
| 448 |
+
# 1. Self-Attn
|
| 449 |
+
if self.use_ada_layer_norm:
|
| 450 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
| 451 |
+
elif self.use_ada_layer_norm_zero:
|
| 452 |
+
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
| 453 |
+
else:
|
| 454 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
| 455 |
+
|
| 456 |
+
if self.use_omni_attn:
|
| 457 |
+
if only_cross_attention:
|
| 458 |
+
raise NotImplementedError
|
| 459 |
+
print(
|
| 460 |
+
"Use OmniWhisperAttention with flash attention. Dropout is ignored."
|
| 461 |
+
)
|
| 462 |
+
self.attn1 = OmniWhisperAttention(
|
| 463 |
+
embed_dim=dim, num_heads=num_attention_heads, causal=False
|
| 464 |
+
)
|
| 465 |
+
else:
|
| 466 |
+
self.attn1 = Attention(
|
| 467 |
+
query_dim=dim,
|
| 468 |
+
heads=num_attention_heads,
|
| 469 |
+
dim_head=attention_head_dim,
|
| 470 |
+
dropout=dropout,
|
| 471 |
+
bias=attention_bias,
|
| 472 |
+
cross_attention_dim=(
|
| 473 |
+
cross_attention_dim if only_cross_attention else None
|
| 474 |
+
),
|
| 475 |
+
upcast_attention=upcast_attention,
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
# 2. Cross-Attn
|
| 479 |
+
if cross_attention_dim is not None or double_self_attention:
|
| 480 |
+
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
| 481 |
+
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
| 482 |
+
# the second cross attention block.
|
| 483 |
+
self.norm2 = (
|
| 484 |
+
AdaLayerNorm(dim, num_embeds_ada_norm)
|
| 485 |
+
if self.use_ada_layer_norm
|
| 486 |
+
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
| 487 |
+
)
|
| 488 |
+
self.attn2 = Attention(
|
| 489 |
+
query_dim=dim,
|
| 490 |
+
cross_attention_dim=(
|
| 491 |
+
cross_attention_dim if not double_self_attention else None
|
| 492 |
+
),
|
| 493 |
+
heads=num_attention_heads,
|
| 494 |
+
dim_head=attention_head_dim,
|
| 495 |
+
dropout=dropout,
|
| 496 |
+
bias=attention_bias,
|
| 497 |
+
upcast_attention=upcast_attention,
|
| 498 |
+
# scale_qk=False, # uncomment this to not to use flash attention
|
| 499 |
+
) # is self-attn if encoder_hidden_states is none
|
| 500 |
+
else:
|
| 501 |
+
self.norm2 = None
|
| 502 |
+
self.attn2 = None
|
| 503 |
+
|
| 504 |
+
# 3. Feed-forward
|
| 505 |
+
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
| 506 |
+
self.ff = FeedForward(
|
| 507 |
+
dim,
|
| 508 |
+
dropout=dropout,
|
| 509 |
+
activation_fn=activation_fn,
|
| 510 |
+
final_dropout=final_dropout,
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
# let chunk size default to None
|
| 514 |
+
self._chunk_size = None
|
| 515 |
+
self._chunk_dim = 0
|
| 516 |
+
|
| 517 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
|
| 518 |
+
# Sets chunk feed-forward
|
| 519 |
+
self._chunk_size = chunk_size
|
| 520 |
+
self._chunk_dim = dim
|
| 521 |
+
|
| 522 |
+
def forward(
|
| 523 |
+
self,
|
| 524 |
+
hidden_states: torch.FloatTensor,
|
| 525 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 526 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 527 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 528 |
+
timestep: Optional[torch.LongTensor] = None,
|
| 529 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
| 530 |
+
class_labels: Optional[torch.LongTensor] = None,
|
| 531 |
+
):
|
| 532 |
+
|
| 533 |
+
bsz, tgt_len, d_model = hidden_states.shape
|
| 534 |
+
|
| 535 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
| 536 |
+
# 1. Self-Attention
|
| 537 |
+
if self.use_ada_layer_norm:
|
| 538 |
+
norm_hidden_states = self.norm1(hidden_states, timestep)
|
| 539 |
+
elif self.use_ada_layer_norm_zero:
|
| 540 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
| 541 |
+
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
| 542 |
+
)
|
| 543 |
+
else:
|
| 544 |
+
norm_hidden_states = self.norm1(hidden_states)
|
| 545 |
+
|
| 546 |
+
cross_attention_kwargs = (
|
| 547 |
+
cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
if self.use_omni_attn:
|
| 551 |
+
seq_len = attention_mask[:, 0, :].float().long().sum(dim=1)
|
| 552 |
+
var_len_attention_mask, unpacking_index = get_sequence_mask(
|
| 553 |
+
norm_hidden_states, seq_len
|
| 554 |
+
)
|
| 555 |
+
norm_hidden_states = torch.masked_select(
|
| 556 |
+
norm_hidden_states, var_len_attention_mask
|
| 557 |
+
)
|
| 558 |
+
norm_hidden_states = norm_hidden_states.view(torch.sum(seq_len), self.dim)
|
| 559 |
+
attn_output = self.attn1(norm_hidden_states, seq_len)
|
| 560 |
+
# unpacking
|
| 561 |
+
attn_output = torch.index_select(attn_output, 0, unpacking_index).view(
|
| 562 |
+
bsz, tgt_len, d_model
|
| 563 |
+
)
|
| 564 |
+
attn_output = torch.where(var_len_attention_mask, attn_output, 0)
|
| 565 |
+
else:
|
| 566 |
+
attn_output = self.attn1(
|
| 567 |
+
norm_hidden_states,
|
| 568 |
+
encoder_hidden_states=(
|
| 569 |
+
encoder_hidden_states if self.only_cross_attention else None
|
| 570 |
+
),
|
| 571 |
+
attention_mask=(
|
| 572 |
+
encoder_attention_mask
|
| 573 |
+
if self.only_cross_attention
|
| 574 |
+
else attention_mask
|
| 575 |
+
),
|
| 576 |
+
**cross_attention_kwargs,
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
if self.use_ada_layer_norm_zero:
|
| 580 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
| 581 |
+
hidden_states = attn_output + hidden_states
|
| 582 |
+
|
| 583 |
+
# 2. Cross-Attention
|
| 584 |
+
if self.attn2 is not None:
|
| 585 |
+
norm_hidden_states = (
|
| 586 |
+
self.norm2(hidden_states, timestep)
|
| 587 |
+
if self.use_ada_layer_norm
|
| 588 |
+
else self.norm2(hidden_states)
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
attn_output = self.attn2(
|
| 592 |
+
norm_hidden_states,
|
| 593 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 594 |
+
attention_mask=encoder_attention_mask,
|
| 595 |
+
**cross_attention_kwargs,
|
| 596 |
+
)
|
| 597 |
+
hidden_states = attn_output + hidden_states
|
| 598 |
+
|
| 599 |
+
# 3. Feed-forward
|
| 600 |
+
norm_hidden_states = self.norm3(hidden_states)
|
| 601 |
+
|
| 602 |
+
if self.use_ada_layer_norm_zero:
|
| 603 |
+
norm_hidden_states = (
|
| 604 |
+
norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
if self._chunk_size is not None:
|
| 608 |
+
# "feed_forward_chunk_size" can be used to save memory
|
| 609 |
+
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
| 610 |
+
raise ValueError(
|
| 611 |
+
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
| 615 |
+
ff_output = torch.cat(
|
| 616 |
+
[
|
| 617 |
+
self.ff(hid_slice)
|
| 618 |
+
for hid_slice in norm_hidden_states.chunk(
|
| 619 |
+
num_chunks, dim=self._chunk_dim
|
| 620 |
+
)
|
| 621 |
+
],
|
| 622 |
+
dim=self._chunk_dim,
|
| 623 |
+
)
|
| 624 |
+
else:
|
| 625 |
+
ff_output = self.ff(norm_hidden_states)
|
| 626 |
+
|
| 627 |
+
if self.use_ada_layer_norm_zero:
|
| 628 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
| 629 |
+
|
| 630 |
+
hidden_states = ff_output + hidden_states
|
| 631 |
+
|
| 632 |
+
return hidden_states
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
class Transpose(torch.nn.Module):
|
| 636 |
+
def __init__(self, dim0: int, dim1: int):
|
| 637 |
+
super().__init__()
|
| 638 |
+
self.dim0 = dim0
|
| 639 |
+
self.dim1 = dim1
|
| 640 |
+
|
| 641 |
+
def forward(self, x: torch.Tensor):
|
| 642 |
+
x = torch.transpose(x, self.dim0, self.dim1)
|
| 643 |
+
return x
|
| 644 |
+
|
| 645 |
+
|
| 646 |
+
class Block1D(torch.nn.Module):
|
| 647 |
+
def __init__(self, dim, dim_out, groups=8):
|
| 648 |
+
super().__init__()
|
| 649 |
+
self.block = torch.nn.Sequential(
|
| 650 |
+
torch.nn.Conv1d(dim, dim_out, 3, padding=1),
|
| 651 |
+
torch.nn.GroupNorm(groups, dim_out),
|
| 652 |
+
nn.Mish(),
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
def forward(self, x, mask):
|
| 656 |
+
output = self.block(x * mask)
|
| 657 |
+
return output * mask
|
| 658 |
+
|
| 659 |
+
|
| 660 |
+
class ResnetBlock1D(torch.nn.Module):
|
| 661 |
+
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
|
| 662 |
+
super().__init__()
|
| 663 |
+
self.mlp = torch.nn.Sequential(
|
| 664 |
+
nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out)
|
| 665 |
+
)
|
| 666 |
+
|
| 667 |
+
self.block1 = Block1D(dim, dim_out, groups=groups)
|
| 668 |
+
self.block2 = Block1D(dim_out, dim_out, groups=groups)
|
| 669 |
+
|
| 670 |
+
self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
|
| 671 |
+
|
| 672 |
+
def forward(self, x, mask, time_emb):
|
| 673 |
+
h = self.block1(x, mask)
|
| 674 |
+
h += self.mlp(time_emb).unsqueeze(-1)
|
| 675 |
+
h = self.block2(h, mask)
|
| 676 |
+
output = h + self.res_conv(x * mask)
|
| 677 |
+
return output
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
class CausalBlock1D(Block1D):
|
| 681 |
+
def __init__(self, dim: int, dim_out: int):
|
| 682 |
+
super(CausalBlock1D, self).__init__(dim, dim_out)
|
| 683 |
+
self.block = torch.nn.Sequential(
|
| 684 |
+
CausalConv1d(dim, dim_out, 3),
|
| 685 |
+
Transpose(1, 2),
|
| 686 |
+
nn.LayerNorm(dim_out),
|
| 687 |
+
Transpose(1, 2),
|
| 688 |
+
nn.Mish(),
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
def forward(self, x: torch.Tensor, mask: torch.Tensor):
|
| 692 |
+
output = self.block(x * mask)
|
| 693 |
+
return output * mask
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
class CausalResnetBlock1D(ResnetBlock1D):
|
| 697 |
+
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
|
| 698 |
+
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
|
| 699 |
+
self.block1 = CausalBlock1D(dim, dim_out)
|
| 700 |
+
self.block2 = CausalBlock1D(dim_out, dim_out)
|
| 701 |
+
|
| 702 |
+
|
| 703 |
+
class CausalConv1d(torch.nn.Conv1d):
|
| 704 |
+
def __init__(
|
| 705 |
+
self,
|
| 706 |
+
in_channels: int,
|
| 707 |
+
out_channels: int,
|
| 708 |
+
kernel_size: int,
|
| 709 |
+
stride: int = 1,
|
| 710 |
+
dilation: int = 1,
|
| 711 |
+
groups: int = 1,
|
| 712 |
+
bias: bool = True,
|
| 713 |
+
padding_mode: str = 'zeros',
|
| 714 |
+
device=None,
|
| 715 |
+
dtype=None
|
| 716 |
+
) -> None:
|
| 717 |
+
super(CausalConv1d, self).__init__(in_channels, out_channels,
|
| 718 |
+
kernel_size, stride,
|
| 719 |
+
padding=0, dilation=dilation,
|
| 720 |
+
groups=groups, bias=bias,
|
| 721 |
+
padding_mode=padding_mode,
|
| 722 |
+
device=device, dtype=dtype)
|
| 723 |
+
assert stride == 1
|
| 724 |
+
self.causal_padding = (kernel_size - 1, 0)
|
| 725 |
+
|
| 726 |
+
def forward(self, x: torch.Tensor):
|
| 727 |
+
x = F.pad(x, self.causal_padding)
|
| 728 |
+
x = super(CausalConv1d, self).forward(x)
|
| 729 |
+
return x
|
| 730 |
+
|
| 731 |
+
|
| 732 |
+
class BASECFM(torch.nn.Module, ABC):
|
| 733 |
+
def __init__(
|
| 734 |
+
self,
|
| 735 |
+
n_feats,
|
| 736 |
+
cfm_params,
|
| 737 |
+
n_spks=1,
|
| 738 |
+
spk_emb_dim=128,
|
| 739 |
+
):
|
| 740 |
+
super().__init__()
|
| 741 |
+
self.n_feats = n_feats
|
| 742 |
+
self.n_spks = n_spks
|
| 743 |
+
self.spk_emb_dim = spk_emb_dim
|
| 744 |
+
self.solver = cfm_params.solver
|
| 745 |
+
if hasattr(cfm_params, "sigma_min"):
|
| 746 |
+
self.sigma_min = cfm_params.sigma_min
|
| 747 |
+
else:
|
| 748 |
+
self.sigma_min = 1e-4
|
| 749 |
+
|
| 750 |
+
self.estimator = None
|
| 751 |
+
|
| 752 |
+
@torch.inference_mode()
|
| 753 |
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
| 754 |
+
"""Forward diffusion
|
| 755 |
+
|
| 756 |
+
Args:
|
| 757 |
+
mu (torch.Tensor): output of encoder
|
| 758 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 759 |
+
mask (torch.Tensor): output_mask
|
| 760 |
+
shape: (batch_size, 1, mel_timesteps)
|
| 761 |
+
n_timesteps (int): number of diffusion steps
|
| 762 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
| 763 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
| 764 |
+
shape: (batch_size, spk_emb_dim)
|
| 765 |
+
cond: Not used but kept for future purposes
|
| 766 |
+
|
| 767 |
+
Returns:
|
| 768 |
+
sample: generated mel-spectrogram
|
| 769 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 770 |
+
"""
|
| 771 |
+
z = torch.randn_like(mu) * temperature
|
| 772 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
| 773 |
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
|
| 774 |
+
|
| 775 |
+
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
| 776 |
+
"""
|
| 777 |
+
Fixed euler solver for ODEs.
|
| 778 |
+
Args:
|
| 779 |
+
x (torch.Tensor): random noise
|
| 780 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
| 781 |
+
shape: (n_timesteps + 1,)
|
| 782 |
+
mu (torch.Tensor): output of encoder
|
| 783 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 784 |
+
mask (torch.Tensor): output_mask
|
| 785 |
+
shape: (batch_size, 1, mel_timesteps)
|
| 786 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
| 787 |
+
shape: (batch_size, spk_emb_dim)
|
| 788 |
+
cond: Not used but kept for future purposes
|
| 789 |
+
"""
|
| 790 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
| 791 |
+
|
| 792 |
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
| 793 |
+
# Or in future might add like a return_all_steps flag
|
| 794 |
+
sol = []
|
| 795 |
+
|
| 796 |
+
for step in range(1, len(t_span)):
|
| 797 |
+
dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
|
| 798 |
+
|
| 799 |
+
x = x + dt * dphi_dt
|
| 800 |
+
t = t + dt
|
| 801 |
+
sol.append(x)
|
| 802 |
+
if step < len(t_span) - 1:
|
| 803 |
+
dt = t_span[step + 1] - t
|
| 804 |
+
|
| 805 |
+
return sol[-1]
|
| 806 |
+
|
| 807 |
+
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
| 808 |
+
"""Computes diffusion loss
|
| 809 |
+
|
| 810 |
+
Args:
|
| 811 |
+
x1 (torch.Tensor): Target
|
| 812 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 813 |
+
mask (torch.Tensor): target mask
|
| 814 |
+
shape: (batch_size, 1, mel_timesteps)
|
| 815 |
+
mu (torch.Tensor): output of encoder
|
| 816 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 817 |
+
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
| 818 |
+
shape: (batch_size, spk_emb_dim)
|
| 819 |
+
|
| 820 |
+
Returns:
|
| 821 |
+
loss: conditional flow matching loss
|
| 822 |
+
y: conditional flow
|
| 823 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 824 |
+
"""
|
| 825 |
+
b, _, t = mu.shape
|
| 826 |
+
|
| 827 |
+
# random timestep
|
| 828 |
+
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
| 829 |
+
# sample noise p(x_0)
|
| 830 |
+
z = torch.randn_like(x1)
|
| 831 |
+
|
| 832 |
+
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
| 833 |
+
u = x1 - (1 - self.sigma_min) * z
|
| 834 |
+
|
| 835 |
+
loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / (
|
| 836 |
+
torch.sum(mask) * u.shape[1]
|
| 837 |
+
)
|
| 838 |
+
return loss, y
|
| 839 |
+
|
| 840 |
+
|
| 841 |
+
class ConditionalDecoder(nn.Module):
|
| 842 |
+
def __init__(
|
| 843 |
+
self,
|
| 844 |
+
in_channels,
|
| 845 |
+
out_channels,
|
| 846 |
+
causal=False,
|
| 847 |
+
channels=(256, 256),
|
| 848 |
+
dropout=0.05,
|
| 849 |
+
attention_head_dim=64,
|
| 850 |
+
n_blocks=1,
|
| 851 |
+
num_mid_blocks=2,
|
| 852 |
+
num_heads=4,
|
| 853 |
+
act_fn="snake",
|
| 854 |
+
gradient_checkpointing=False,
|
| 855 |
+
):
|
| 856 |
+
"""
|
| 857 |
+
This decoder requires an input with the same shape of the target. So, if your text content
|
| 858 |
+
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
|
| 859 |
+
"""
|
| 860 |
+
super().__init__()
|
| 861 |
+
channels = tuple(channels)
|
| 862 |
+
self.in_channels = in_channels
|
| 863 |
+
self.out_channels = out_channels
|
| 864 |
+
self.causal = causal
|
| 865 |
+
self.static_chunk_size = 2 * 25 * 2 # 2*input_frame_rate*token_mel_ratio
|
| 866 |
+
self.gradient_checkpointing = gradient_checkpointing
|
| 867 |
+
|
| 868 |
+
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
| 869 |
+
time_embed_dim = channels[0] * 4
|
| 870 |
+
self.time_mlp = TimestepEmbedding(
|
| 871 |
+
in_channels=in_channels,
|
| 872 |
+
time_embed_dim=time_embed_dim,
|
| 873 |
+
act_fn="silu",
|
| 874 |
+
)
|
| 875 |
+
self.down_blocks = nn.ModuleList([])
|
| 876 |
+
self.mid_blocks = nn.ModuleList([])
|
| 877 |
+
self.up_blocks = nn.ModuleList([])
|
| 878 |
+
|
| 879 |
+
output_channel = in_channels
|
| 880 |
+
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
|
| 881 |
+
input_channel = output_channel
|
| 882 |
+
output_channel = channels[i]
|
| 883 |
+
is_last = i == len(channels) - 1
|
| 884 |
+
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
|
| 885 |
+
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
| 886 |
+
transformer_blocks = nn.ModuleList(
|
| 887 |
+
[
|
| 888 |
+
BasicTransformerBlock(
|
| 889 |
+
dim=output_channel,
|
| 890 |
+
num_attention_heads=num_heads,
|
| 891 |
+
attention_head_dim=attention_head_dim,
|
| 892 |
+
dropout=dropout,
|
| 893 |
+
activation_fn=act_fn,
|
| 894 |
+
)
|
| 895 |
+
for _ in range(n_blocks)
|
| 896 |
+
]
|
| 897 |
+
)
|
| 898 |
+
downsample = (
|
| 899 |
+
Downsample1D(output_channel) if not is_last else
|
| 900 |
+
CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
| 901 |
+
)
|
| 902 |
+
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
| 903 |
+
|
| 904 |
+
for _ in range(num_mid_blocks):
|
| 905 |
+
input_channel = channels[-1]
|
| 906 |
+
out_channels = channels[-1]
|
| 907 |
+
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
|
| 908 |
+
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
| 909 |
+
|
| 910 |
+
transformer_blocks = nn.ModuleList(
|
| 911 |
+
[
|
| 912 |
+
BasicTransformerBlock(
|
| 913 |
+
dim=output_channel,
|
| 914 |
+
num_attention_heads=num_heads,
|
| 915 |
+
attention_head_dim=attention_head_dim,
|
| 916 |
+
dropout=dropout,
|
| 917 |
+
activation_fn=act_fn,
|
| 918 |
+
)
|
| 919 |
+
for _ in range(n_blocks)
|
| 920 |
+
]
|
| 921 |
+
)
|
| 922 |
+
|
| 923 |
+
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
|
| 924 |
+
|
| 925 |
+
channels = channels[::-1] + (channels[0],)
|
| 926 |
+
for i in range(len(channels) - 1):
|
| 927 |
+
input_channel = channels[i] * 2
|
| 928 |
+
output_channel = channels[i + 1]
|
| 929 |
+
is_last = i == len(channels) - 2
|
| 930 |
+
resnet = CausalResnetBlock1D(
|
| 931 |
+
dim=input_channel,
|
| 932 |
+
dim_out=output_channel,
|
| 933 |
+
time_emb_dim=time_embed_dim,
|
| 934 |
+
) if self.causal else ResnetBlock1D(
|
| 935 |
+
dim=input_channel,
|
| 936 |
+
dim_out=output_channel,
|
| 937 |
+
time_emb_dim=time_embed_dim,
|
| 938 |
+
)
|
| 939 |
+
transformer_blocks = nn.ModuleList(
|
| 940 |
+
[
|
| 941 |
+
BasicTransformerBlock(
|
| 942 |
+
dim=output_channel,
|
| 943 |
+
num_attention_heads=num_heads,
|
| 944 |
+
attention_head_dim=attention_head_dim,
|
| 945 |
+
dropout=dropout,
|
| 946 |
+
activation_fn=act_fn,
|
| 947 |
+
)
|
| 948 |
+
for _ in range(n_blocks)
|
| 949 |
+
]
|
| 950 |
+
)
|
| 951 |
+
upsample = (
|
| 952 |
+
Upsample1D(output_channel, use_conv_transpose=True)
|
| 953 |
+
if not is_last
|
| 954 |
+
else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
| 955 |
+
)
|
| 956 |
+
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
| 957 |
+
self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
|
| 958 |
+
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
| 959 |
+
self.initialize_weights()
|
| 960 |
+
|
| 961 |
+
def initialize_weights(self):
|
| 962 |
+
for m in self.modules():
|
| 963 |
+
if isinstance(m, nn.Conv1d):
|
| 964 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
| 965 |
+
if m.bias is not None:
|
| 966 |
+
nn.init.constant_(m.bias, 0)
|
| 967 |
+
elif isinstance(m, nn.GroupNorm):
|
| 968 |
+
nn.init.constant_(m.weight, 1)
|
| 969 |
+
nn.init.constant_(m.bias, 0)
|
| 970 |
+
elif isinstance(m, nn.Linear):
|
| 971 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
| 972 |
+
if m.bias is not None:
|
| 973 |
+
nn.init.constant_(m.bias, 0)
|
| 974 |
+
|
| 975 |
+
def forward(self, x, mask, mu, t, spks=None, cond=None):
|
| 976 |
+
"""Forward pass of the UNet1DConditional model.
|
| 977 |
+
|
| 978 |
+
Args:
|
| 979 |
+
x (torch.Tensor): shape (batch_size, in_channels, time)
|
| 980 |
+
mask (_type_): shape (batch_size, 1, time)
|
| 981 |
+
t (_type_): shape (batch_size)
|
| 982 |
+
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
| 983 |
+
cond (_type_, optional): placeholder for future use. Defaults to None.
|
| 984 |
+
|
| 985 |
+
Raises:
|
| 986 |
+
ValueError: _description_
|
| 987 |
+
ValueError: _description_
|
| 988 |
+
|
| 989 |
+
Returns:
|
| 990 |
+
_type_: _description_
|
| 991 |
+
"""
|
| 992 |
+
t = self.time_embeddings(t)
|
| 993 |
+
t = t.to(x.dtype)
|
| 994 |
+
t = self.time_mlp(t)
|
| 995 |
+
x = pack([x, mu], "b * t")[0]
|
| 996 |
+
mask = mask.to(x.dtype)
|
| 997 |
+
if spks is not None:
|
| 998 |
+
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
| 999 |
+
x = pack([x, spks], "b * t")[0]
|
| 1000 |
+
if cond is not None:
|
| 1001 |
+
x = pack([x, cond], "b * t")[0]
|
| 1002 |
+
|
| 1003 |
+
hiddens = []
|
| 1004 |
+
masks = [mask]
|
| 1005 |
+
for resnet, transformer_blocks, downsample in self.down_blocks:
|
| 1006 |
+
mask_down = masks[-1]
|
| 1007 |
+
x = resnet(x, mask_down, t)
|
| 1008 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
| 1009 |
+
# attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
|
| 1010 |
+
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
|
| 1011 |
+
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
| 1012 |
+
for transformer_block in transformer_blocks:
|
| 1013 |
+
if self.gradient_checkpointing and self.training:
|
| 1014 |
+
def create_custom_forward(module):
|
| 1015 |
+
def custom_forward(*inputs):
|
| 1016 |
+
return module(*inputs)
|
| 1017 |
+
return custom_forward
|
| 1018 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 1019 |
+
create_custom_forward(transformer_block),
|
| 1020 |
+
x,
|
| 1021 |
+
attn_mask,
|
| 1022 |
+
t,
|
| 1023 |
+
)
|
| 1024 |
+
else:
|
| 1025 |
+
x = transformer_block(
|
| 1026 |
+
hidden_states=x,
|
| 1027 |
+
attention_mask=attn_mask,
|
| 1028 |
+
timestep=t,
|
| 1029 |
+
)
|
| 1030 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
| 1031 |
+
hiddens.append(x) # Save hidden states for skip connections
|
| 1032 |
+
x = downsample(x * mask_down)
|
| 1033 |
+
masks.append(mask_down[:, :, ::2])
|
| 1034 |
+
masks = masks[:-1]
|
| 1035 |
+
mask_mid = masks[-1]
|
| 1036 |
+
|
| 1037 |
+
for resnet, transformer_blocks in self.mid_blocks:
|
| 1038 |
+
x = resnet(x, mask_mid, t)
|
| 1039 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
| 1040 |
+
# attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
|
| 1041 |
+
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
|
| 1042 |
+
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
| 1043 |
+
for transformer_block in transformer_blocks:
|
| 1044 |
+
if self.gradient_checkpointing and self.training:
|
| 1045 |
+
def create_custom_forward(module):
|
| 1046 |
+
def custom_forward(*inputs):
|
| 1047 |
+
return module(*inputs)
|
| 1048 |
+
return custom_forward
|
| 1049 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 1050 |
+
create_custom_forward(transformer_block),
|
| 1051 |
+
x,
|
| 1052 |
+
attn_mask,
|
| 1053 |
+
t,
|
| 1054 |
+
)
|
| 1055 |
+
else:
|
| 1056 |
+
x = transformer_block(
|
| 1057 |
+
hidden_states=x,
|
| 1058 |
+
attention_mask=attn_mask,
|
| 1059 |
+
timestep=t,
|
| 1060 |
+
)
|
| 1061 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
| 1062 |
+
|
| 1063 |
+
for resnet, transformer_blocks, upsample in self.up_blocks:
|
| 1064 |
+
mask_up = masks.pop()
|
| 1065 |
+
skip = hiddens.pop()
|
| 1066 |
+
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
| 1067 |
+
x = resnet(x, mask_up, t)
|
| 1068 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
| 1069 |
+
# attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
|
| 1070 |
+
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
|
| 1071 |
+
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
|
| 1072 |
+
for transformer_block in transformer_blocks:
|
| 1073 |
+
if self.gradient_checkpointing and self.training:
|
| 1074 |
+
def create_custom_forward(module):
|
| 1075 |
+
def custom_forward(*inputs):
|
| 1076 |
+
return module(*inputs)
|
| 1077 |
+
return custom_forward
|
| 1078 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 1079 |
+
create_custom_forward(transformer_block),
|
| 1080 |
+
x,
|
| 1081 |
+
attn_mask,
|
| 1082 |
+
t,
|
| 1083 |
+
)
|
| 1084 |
+
else:
|
| 1085 |
+
x = transformer_block(
|
| 1086 |
+
hidden_states=x,
|
| 1087 |
+
attention_mask=attn_mask,
|
| 1088 |
+
timestep=t,
|
| 1089 |
+
)
|
| 1090 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
| 1091 |
+
x = upsample(x * mask_up)
|
| 1092 |
+
x = self.final_block(x, mask_up)
|
| 1093 |
+
output = self.final_proj(x * mask_up)
|
| 1094 |
+
return output * mask
|
| 1095 |
+
|
| 1096 |
+
|
| 1097 |
+
class ConditionalCFM(BASECFM):
|
| 1098 |
+
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64):
|
| 1099 |
+
super().__init__(
|
| 1100 |
+
n_feats=in_channels,
|
| 1101 |
+
cfm_params=cfm_params,
|
| 1102 |
+
n_spks=n_spks,
|
| 1103 |
+
spk_emb_dim=spk_emb_dim,
|
| 1104 |
+
)
|
| 1105 |
+
self.t_scheduler = cfm_params.t_scheduler
|
| 1106 |
+
self.training_cfg_rate = cfm_params.training_cfg_rate
|
| 1107 |
+
self.inference_cfg_rate = cfm_params.inference_cfg_rate
|
| 1108 |
+
|
| 1109 |
+
@torch.inference_mode()
|
| 1110 |
+
def forward(self, estimator, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
| 1111 |
+
"""Forward diffusion
|
| 1112 |
+
|
| 1113 |
+
Args:
|
| 1114 |
+
mu (torch.Tensor): output of encoder
|
| 1115 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 1116 |
+
mask (torch.Tensor): output_mask
|
| 1117 |
+
shape: (batch_size, 1, mel_timesteps)
|
| 1118 |
+
n_timesteps (int): number of diffusion steps
|
| 1119 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
| 1120 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
| 1121 |
+
shape: (batch_size, spk_emb_dim)
|
| 1122 |
+
cond: Not used but kept for future purposes
|
| 1123 |
+
|
| 1124 |
+
Returns:
|
| 1125 |
+
sample: generated mel-spectrogram
|
| 1126 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 1127 |
+
"""
|
| 1128 |
+
z = torch.randn_like(mu) * temperature
|
| 1129 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
| 1130 |
+
if self.t_scheduler == 'cosine':
|
| 1131 |
+
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
| 1132 |
+
return self.solve_euler(estimator, z, t_span=t_span.to(mu.dtype), mu=mu, mask=mask, spks=spks, cond=cond)
|
| 1133 |
+
|
| 1134 |
+
def solve_euler(self, estimator, x, t_span, mu, mask, spks, cond):
|
| 1135 |
+
"""
|
| 1136 |
+
Fixed euler solver for ODEs.
|
| 1137 |
+
Args:
|
| 1138 |
+
x (torch.Tensor): random noise
|
| 1139 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
| 1140 |
+
shape: (n_timesteps + 1,)
|
| 1141 |
+
mu (torch.Tensor): output of encoder
|
| 1142 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 1143 |
+
mask (torch.Tensor): output_mask
|
| 1144 |
+
shape: (batch_size, 1, mel_timesteps)
|
| 1145 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
| 1146 |
+
shape: (batch_size, spk_emb_dim)
|
| 1147 |
+
cond: Not used but kept for future purposes
|
| 1148 |
+
"""
|
| 1149 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
| 1150 |
+
|
| 1151 |
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
| 1152 |
+
# Or in future might add like a return_all_steps flag
|
| 1153 |
+
sol = []
|
| 1154 |
+
|
| 1155 |
+
for step in range(1, len(t_span)):
|
| 1156 |
+
dphi_dt = estimator(x, mask, mu, t, spks, cond)
|
| 1157 |
+
# Classifier-Free Guidance inference introduced in VoiceBox
|
| 1158 |
+
if self.inference_cfg_rate > 0:
|
| 1159 |
+
cfg_dphi_dt = estimator(
|
| 1160 |
+
x, mask,
|
| 1161 |
+
torch.zeros_like(mu), t,
|
| 1162 |
+
torch.zeros_like(spks) if spks is not None else None,
|
| 1163 |
+
cond=cond
|
| 1164 |
+
)
|
| 1165 |
+
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt -
|
| 1166 |
+
self.inference_cfg_rate * cfg_dphi_dt)
|
| 1167 |
+
x = x + dt * dphi_dt
|
| 1168 |
+
t = t + dt
|
| 1169 |
+
sol.append(x)
|
| 1170 |
+
if step < len(t_span) - 1:
|
| 1171 |
+
dt = t_span[step + 1] - t
|
| 1172 |
+
|
| 1173 |
+
return sol[-1]
|
| 1174 |
+
|
| 1175 |
+
def compute_loss(self, estimator, x1, mask, mu, spks=None, cond=None):
|
| 1176 |
+
"""Computes diffusion loss
|
| 1177 |
+
|
| 1178 |
+
Args:
|
| 1179 |
+
x1 (torch.Tensor): Target
|
| 1180 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 1181 |
+
mask (torch.Tensor): target mask
|
| 1182 |
+
shape: (batch_size, 1, mel_timesteps)
|
| 1183 |
+
mu (torch.Tensor): output of encoder
|
| 1184 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 1185 |
+
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
| 1186 |
+
shape: (batch_size, spk_emb_dim)
|
| 1187 |
+
|
| 1188 |
+
Returns:
|
| 1189 |
+
loss: conditional flow matching loss
|
| 1190 |
+
y: conditional flow
|
| 1191 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 1192 |
+
"""
|
| 1193 |
+
org_dtype = x1.dtype
|
| 1194 |
+
|
| 1195 |
+
b, _, t = mu.shape
|
| 1196 |
+
# random timestep
|
| 1197 |
+
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
| 1198 |
+
if self.t_scheduler == 'cosine':
|
| 1199 |
+
t = 1 - torch.cos(t * 0.5 * torch.pi)
|
| 1200 |
+
# sample noise p(x_0)
|
| 1201 |
+
z = torch.randn_like(x1)
|
| 1202 |
+
|
| 1203 |
+
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
| 1204 |
+
u = x1 - (1 - self.sigma_min) * z
|
| 1205 |
+
|
| 1206 |
+
# during training, we randomly drop condition to trade off mode coverage and sample fidelity
|
| 1207 |
+
if self.training_cfg_rate > 0:
|
| 1208 |
+
cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
|
| 1209 |
+
mu = mu * cfg_mask.view(-1, 1, 1)
|
| 1210 |
+
if spks is not None:
|
| 1211 |
+
spks = spks * cfg_mask.view(-1, 1)
|
| 1212 |
+
if cond is not None:
|
| 1213 |
+
cond = cond * cfg_mask.view(-1, 1, 1)
|
| 1214 |
+
|
| 1215 |
+
pred = estimator(y, mask, mu, t.squeeze(), spks, cond)
|
| 1216 |
+
pred = pred.float()
|
| 1217 |
+
u = u.float()
|
| 1218 |
+
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
|
| 1219 |
+
loss = loss.to(org_dtype)
|
| 1220 |
+
return loss, y
|
| 1221 |
+
|
| 1222 |
+
|
| 1223 |
+
class SinusoidalPosEmb(torch.nn.Module):
|
| 1224 |
+
def __init__(self, dim):
|
| 1225 |
+
super().__init__()
|
| 1226 |
+
self.dim = dim
|
| 1227 |
+
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
|
| 1228 |
+
|
| 1229 |
+
def forward(self, x, scale=1000):
|
| 1230 |
+
if x.ndim < 1:
|
| 1231 |
+
x = x.unsqueeze(0)
|
| 1232 |
+
device = x.device
|
| 1233 |
+
half_dim = self.dim // 2
|
| 1234 |
+
emb = math.log(10000) / (half_dim - 1)
|
| 1235 |
+
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
| 1236 |
+
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
| 1237 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
| 1238 |
+
return emb
|
| 1239 |
+
|
| 1240 |
+
|
| 1241 |
+
class Downsample1D(nn.Module):
|
| 1242 |
+
def __init__(self, dim):
|
| 1243 |
+
super().__init__()
|
| 1244 |
+
self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1)
|
| 1245 |
+
|
| 1246 |
+
def forward(self, x):
|
| 1247 |
+
return self.conv(x)
|
| 1248 |
+
|
| 1249 |
+
|
| 1250 |
+
class TimestepEmbedding(nn.Module):
|
| 1251 |
+
def __init__(
|
| 1252 |
+
self,
|
| 1253 |
+
in_channels: int,
|
| 1254 |
+
time_embed_dim: int,
|
| 1255 |
+
act_fn: str = "silu",
|
| 1256 |
+
out_dim: int = None,
|
| 1257 |
+
post_act_fn: Optional[str] = None,
|
| 1258 |
+
cond_proj_dim=None,
|
| 1259 |
+
):
|
| 1260 |
+
super().__init__()
|
| 1261 |
+
|
| 1262 |
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
| 1263 |
+
|
| 1264 |
+
if cond_proj_dim is not None:
|
| 1265 |
+
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
| 1266 |
+
else:
|
| 1267 |
+
self.cond_proj = None
|
| 1268 |
+
|
| 1269 |
+
self.act = get_activation(act_fn)
|
| 1270 |
+
|
| 1271 |
+
if out_dim is not None:
|
| 1272 |
+
time_embed_dim_out = out_dim
|
| 1273 |
+
else:
|
| 1274 |
+
time_embed_dim_out = time_embed_dim
|
| 1275 |
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
|
| 1276 |
+
|
| 1277 |
+
if post_act_fn is None:
|
| 1278 |
+
self.post_act = None
|
| 1279 |
+
else:
|
| 1280 |
+
self.post_act = get_activation(post_act_fn)
|
| 1281 |
+
|
| 1282 |
+
def forward(self, sample, condition=None):
|
| 1283 |
+
if condition is not None:
|
| 1284 |
+
sample = sample + self.cond_proj(condition)
|
| 1285 |
+
sample = self.linear_1(sample)
|
| 1286 |
+
|
| 1287 |
+
if self.act is not None:
|
| 1288 |
+
sample = self.act(sample)
|
| 1289 |
+
|
| 1290 |
+
sample = self.linear_2(sample)
|
| 1291 |
+
|
| 1292 |
+
if self.post_act is not None:
|
| 1293 |
+
sample = self.post_act(sample)
|
| 1294 |
+
return sample
|
| 1295 |
+
|
| 1296 |
+
|
| 1297 |
+
class Upsample1D(nn.Module):
|
| 1298 |
+
"""A 1D upsampling layer with an optional convolution.
|
| 1299 |
+
|
| 1300 |
+
Parameters:
|
| 1301 |
+
channels (`int`):
|
| 1302 |
+
number of channels in the inputs and outputs.
|
| 1303 |
+
use_conv (`bool`, default `False`):
|
| 1304 |
+
option to use a convolution.
|
| 1305 |
+
use_conv_transpose (`bool`, default `False`):
|
| 1306 |
+
option to use a convolution transpose.
|
| 1307 |
+
out_channels (`int`, optional):
|
| 1308 |
+
number of output channels. Defaults to `channels`.
|
| 1309 |
+
"""
|
| 1310 |
+
|
| 1311 |
+
def __init__(
|
| 1312 |
+
self,
|
| 1313 |
+
channels,
|
| 1314 |
+
use_conv=False,
|
| 1315 |
+
use_conv_transpose=True,
|
| 1316 |
+
out_channels=None,
|
| 1317 |
+
name="conv",
|
| 1318 |
+
):
|
| 1319 |
+
super().__init__()
|
| 1320 |
+
self.channels = channels
|
| 1321 |
+
self.out_channels = out_channels or channels
|
| 1322 |
+
self.use_conv = use_conv
|
| 1323 |
+
self.use_conv_transpose = use_conv_transpose
|
| 1324 |
+
self.name = name
|
| 1325 |
+
|
| 1326 |
+
self.conv = None
|
| 1327 |
+
if use_conv_transpose:
|
| 1328 |
+
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
|
| 1329 |
+
elif use_conv:
|
| 1330 |
+
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
|
| 1331 |
+
|
| 1332 |
+
def forward(self, inputs):
|
| 1333 |
+
assert inputs.shape[1] == self.channels
|
| 1334 |
+
if self.use_conv_transpose:
|
| 1335 |
+
return self.conv(inputs)
|
| 1336 |
+
|
| 1337 |
+
outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
|
| 1338 |
+
|
| 1339 |
+
if self.use_conv:
|
| 1340 |
+
outputs = self.conv(outputs)
|
| 1341 |
+
|
| 1342 |
+
return outputs
|
| 1343 |
+
|
| 1344 |
+
|
| 1345 |
+
class RMSNorm(nn.Module):
|
| 1346 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 1347 |
+
"""
|
| 1348 |
+
RMSNorm is equivalent to T5LayerNorm
|
| 1349 |
+
"""
|
| 1350 |
+
super().__init__()
|
| 1351 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 1352 |
+
self.variance_epsilon = eps
|
| 1353 |
+
|
| 1354 |
+
def forward(self, hidden_states):
|
| 1355 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
| 1356 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 1357 |
+
|
| 1358 |
+
# convert into half-precision if necessary
|
| 1359 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
| 1360 |
+
hidden_states = hidden_states.to(self.weight.dtype)
|
| 1361 |
+
|
| 1362 |
+
return self.weight * hidden_states
|
| 1363 |
+
|
| 1364 |
+
|
| 1365 |
+
class OmniWhisperAttention(nn.Module):
|
| 1366 |
+
def __init__(self, embed_dim, num_heads, causal=False):
|
| 1367 |
+
super().__init__()
|
| 1368 |
+
self.embed_dim = embed_dim
|
| 1369 |
+
self.num_heads = num_heads
|
| 1370 |
+
self.head_dim = embed_dim // num_heads
|
| 1371 |
+
|
| 1372 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
|
| 1373 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
| 1374 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
| 1375 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
| 1376 |
+
|
| 1377 |
+
self.causal = causal
|
| 1378 |
+
|
| 1379 |
+
def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor):
|
| 1380 |
+
bsz, _ = hidden_states.size()
|
| 1381 |
+
|
| 1382 |
+
query_states = self.q_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
|
| 1383 |
+
key_states = self.k_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
|
| 1384 |
+
value_states = self.v_proj(hidden_states).view(bsz, self.num_heads, self.head_dim)
|
| 1385 |
+
|
| 1386 |
+
cu_len = F.pad(torch.cumsum(seq_len, dim=0), (1, 0), "constant", 0).to(torch.int32)
|
| 1387 |
+
max_seqlen = torch.max(seq_len).to(torch.int32).detach()
|
| 1388 |
+
attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_len, cu_len, max_seqlen, max_seqlen, causal=self.causal) # (bsz * qlen, nheads, headdim)
|
| 1389 |
+
attn_output = attn_output.reshape(bsz, self.embed_dim)
|
| 1390 |
+
attn_output = self.out_proj(attn_output)
|
| 1391 |
+
return attn_output
|
| 1392 |
+
|
| 1393 |
+
|
| 1394 |
+
class OmniWhisperTransformerLayer(nn.Module):
|
| 1395 |
+
def __init__(
|
| 1396 |
+
self,
|
| 1397 |
+
act,
|
| 1398 |
+
d_model,
|
| 1399 |
+
encoder_attention_heads,
|
| 1400 |
+
encoder_ffn_dim,
|
| 1401 |
+
causal,
|
| 1402 |
+
ln_type="LayerNorm",
|
| 1403 |
+
):
|
| 1404 |
+
super().__init__()
|
| 1405 |
+
self.embed_dim = d_model
|
| 1406 |
+
self.self_attn = OmniWhisperAttention(
|
| 1407 |
+
self.embed_dim, encoder_attention_heads, causal
|
| 1408 |
+
)
|
| 1409 |
+
|
| 1410 |
+
if ln_type == "LayerNorm":
|
| 1411 |
+
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 1412 |
+
elif ln_type == "RMSNorm":
|
| 1413 |
+
self.self_attn_layer_norm = RMSNorm(self.embed_dim)
|
| 1414 |
+
else:
|
| 1415 |
+
raise ValueError(f"Unknown ln_type: {ln_type}")
|
| 1416 |
+
|
| 1417 |
+
self.activation_fn = act
|
| 1418 |
+
self.fc1 = nn.Linear(self.embed_dim, encoder_ffn_dim)
|
| 1419 |
+
self.fc2 = nn.Linear(encoder_ffn_dim, self.embed_dim)
|
| 1420 |
+
|
| 1421 |
+
if ln_type == "LayerNorm":
|
| 1422 |
+
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 1423 |
+
elif ln_type == "RMSNorm":
|
| 1424 |
+
self.final_layer_norm = RMSNorm(self.embed_dim)
|
| 1425 |
+
else:
|
| 1426 |
+
raise ValueError(f"Unknown ln_type: {ln_type}")
|
| 1427 |
+
|
| 1428 |
+
def forward(
|
| 1429 |
+
self, hidden_states: torch.Tensor, seq_len: torch.Tensor
|
| 1430 |
+
) -> torch.Tensor:
|
| 1431 |
+
residual = hidden_states
|
| 1432 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 1433 |
+
hidden_states = self.self_attn(hidden_states, seq_len)
|
| 1434 |
+
hidden_states = residual + hidden_states
|
| 1435 |
+
residual = hidden_states
|
| 1436 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 1437 |
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
| 1438 |
+
hidden_states = self.fc2(hidden_states)
|
| 1439 |
+
hidden_states = residual + hidden_states
|
| 1440 |
+
|
| 1441 |
+
if (
|
| 1442 |
+
hidden_states.dtype == torch.float16
|
| 1443 |
+
or hidden_states.dtype == torch.bfloat16
|
| 1444 |
+
) and (torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()):
|
| 1445 |
+
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
| 1446 |
+
hidden_states = torch.clamp(
|
| 1447 |
+
hidden_states, min=-clamp_value, max=clamp_value
|
| 1448 |
+
)
|
| 1449 |
+
return hidden_states
|
| 1450 |
+
|
| 1451 |
+
|
| 1452 |
+
|
| 1453 |
+
class LongcatNextAudioEncoder(nn.Module):
|
| 1454 |
+
def __init__(self, config):
|
| 1455 |
+
super().__init__()
|
| 1456 |
+
self.config = config
|
| 1457 |
+
self.max_source_positions = (config.max_audio_seconds * config.sampling_rate // config.hop_length) // config.stride_size
|
| 1458 |
+
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
| 1459 |
+
|
| 1460 |
+
self.conv1 = nn.Conv1d(config.num_mel_bins, config.d_model, kernel_size=config.kernel_size, padding=1)
|
| 1461 |
+
self.conv2 = nn.Conv1d(config.d_model, config.d_model, kernel_size=config.kernel_size,
|
| 1462 |
+
stride=config.stride_size, padding=1)
|
| 1463 |
+
self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, config.d_model)) # 1500 * d
|
| 1464 |
+
|
| 1465 |
+
self.layers = nn.ModuleList([OmniWhisperTransformerLayer(
|
| 1466 |
+
ACT2FN[config.activation_function],
|
| 1467 |
+
config.d_model,
|
| 1468 |
+
config.encoder_attention_heads,
|
| 1469 |
+
config.encoder_ffn_dim,
|
| 1470 |
+
False) for _ in range(config.encoder_layers)])
|
| 1471 |
+
self.layer_norm = nn.LayerNorm(config.d_model)
|
| 1472 |
+
|
| 1473 |
+
def forward(
|
| 1474 |
+
self,
|
| 1475 |
+
input_features,
|
| 1476 |
+
output_length,
|
| 1477 |
+
):
|
| 1478 |
+
input_features = input_features.to(self.conv1.weight.dtype)
|
| 1479 |
+
inputs_embeds = nn.functional.gelu(self.conv1(input_features)) # (bs, channels, frames)
|
| 1480 |
+
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) # (bs, channels, frames // 2)
|
| 1481 |
+
inputs_embeds = inputs_embeds.permute(0, 2, 1) # (bs, frams, channels)
|
| 1482 |
+
bsz, tgt_len, _ = inputs_embeds.size()
|
| 1483 |
+
if tgt_len < self.positional_embedding.shape[0]:
|
| 1484 |
+
current_positional_embedding = self.positional_embedding[:tgt_len]
|
| 1485 |
+
else:
|
| 1486 |
+
current_positional_embedding = self.positional_embedding
|
| 1487 |
+
hidden_states = (inputs_embeds.to(torch.float32) + current_positional_embedding).to(inputs_embeds.dtype)
|
| 1488 |
+
|
| 1489 |
+
# packing hidden states
|
| 1490 |
+
attention_mask, unpacking_index = get_sequence_mask(hidden_states, output_length)
|
| 1491 |
+
hidden_states = torch.masked_select(hidden_states, attention_mask).view(torch.sum(output_length),
|
| 1492 |
+
self.config.d_model)
|
| 1493 |
+
|
| 1494 |
+
for idx, encoder_layer in enumerate(self.layers):
|
| 1495 |
+
hidden_states = encoder_layer(hidden_states, output_length)
|
| 1496 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 1497 |
+
# unpacking
|
| 1498 |
+
hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(bsz, tgt_len, self.config.d_model)
|
| 1499 |
+
hidden_states = torch.where(attention_mask, hidden_states, 0)
|
| 1500 |
+
return hidden_states
|
| 1501 |
+
|
| 1502 |
+
|
| 1503 |
+
class CasualConvTranspose1d(nn.Module):
|
| 1504 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride):
|
| 1505 |
+
super().__init__()
|
| 1506 |
+
self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride)
|
| 1507 |
+
self.norm = nn.GroupNorm(1, out_channels)
|
| 1508 |
+
self.in_channels = in_channels
|
| 1509 |
+
self.out_channels = out_channels
|
| 1510 |
+
|
| 1511 |
+
def forward(self, hidden_states, input_length, output_dim=None):
|
| 1512 |
+
kernel_size = self.conv.kernel_size[0]
|
| 1513 |
+
stride = self.conv.stride[0]
|
| 1514 |
+
bsz = input_length.shape[0]
|
| 1515 |
+
|
| 1516 |
+
if output_dim is None:
|
| 1517 |
+
output_dim = hidden_states.dim()
|
| 1518 |
+
if hidden_states.dim() <= 2: # unpack sequence to 3d
|
| 1519 |
+
sequence_mask, unpacking_index = get_sequence_mask(hidden_states, input_length)
|
| 1520 |
+
hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(bsz, torch.max(input_length),
|
| 1521 |
+
self.in_channels)
|
| 1522 |
+
hidden_states = torch.where(sequence_mask, hidden_states, 0) # 3d (bsz, max_input_len, d)
|
| 1523 |
+
|
| 1524 |
+
hidden_states = hidden_states.transpose(2, 1) # (N, L, C) -> (N, C, L)
|
| 1525 |
+
hidden_states = self.conv(hidden_states)
|
| 1526 |
+
hidden_states = self.norm(hidden_states)
|
| 1527 |
+
hidden_states = hidden_states.transpose(2, 1) # (N, C, L) -> (N, L, C)
|
| 1528 |
+
|
| 1529 |
+
casual_padding_right = max(0, kernel_size - stride)
|
| 1530 |
+
hidden_states = hidden_states[:, :hidden_states.shape[1] - casual_padding_right,
|
| 1531 |
+
:]
|
| 1532 |
+
output_length = (input_length - 1) * stride + kernel_size - casual_padding_right
|
| 1533 |
+
sequence_mask, _ = get_sequence_mask(hidden_states, output_length)
|
| 1534 |
+
if output_dim <= 2:
|
| 1535 |
+
hidden_states = torch.masked_select(hidden_states, sequence_mask).view(-1, self.out_channels)
|
| 1536 |
+
else:
|
| 1537 |
+
hidden_states = torch.where(sequence_mask, hidden_states, 0)
|
| 1538 |
+
hidden_states = hidden_states[:, :torch.max(output_length), :]
|
| 1539 |
+
return hidden_states, output_length
|
| 1540 |
+
|
| 1541 |
+
|
| 1542 |
+
class MelSpecRefineNet(nn.Module):
|
| 1543 |
+
"""
|
| 1544 |
+
# post net, coarse to refined mel-spectrogram frames
|
| 1545 |
+
# ref1: Autoregressive Speech Synthesis without Vector Quantization
|
| 1546 |
+
# ref2: CosyVoice length_regulator.py
|
| 1547 |
+
# ref3: Neural Speech Synthesis with Transformer Network https://github.com/soobinseo/Transformer-TTS/blob/master/network.py
|
| 1548 |
+
"""
|
| 1549 |
+
|
| 1550 |
+
def __init__(self, encoder_config, vocoder_config):
|
| 1551 |
+
super().__init__()
|
| 1552 |
+
self.encoder_config = encoder_config
|
| 1553 |
+
self.vocoder_config = vocoder_config
|
| 1554 |
+
|
| 1555 |
+
layers = nn.ModuleList([])
|
| 1556 |
+
in_channels = self.vocoder_config.num_mel_bins
|
| 1557 |
+
for i, out_channels in enumerate(self.vocoder_config.channels[:-1]):
|
| 1558 |
+
module = nn.Conv1d(in_channels, out_channels, 5, 1, 2) # cosyvoice kernel=3, stride=1, pad=1
|
| 1559 |
+
in_channels = out_channels
|
| 1560 |
+
norm = nn.GroupNorm(1, out_channels)
|
| 1561 |
+
act = nn.Mish()
|
| 1562 |
+
layers.extend([module, norm, act])
|
| 1563 |
+
layers.append(nn.Conv1d(in_channels, self.vocoder_config.num_mel_bins, 1, 1)) # projector
|
| 1564 |
+
self.layers = nn.Sequential(*layers)
|
| 1565 |
+
|
| 1566 |
+
def compute_output_length(self, input_length):
|
| 1567 |
+
output_length = input_length.to(
|
| 1568 |
+
torch.float32) * self.encoder_config.hop_length / self.encoder_config.sampling_rate
|
| 1569 |
+
output_length = output_length * self.vocoder_config.sampling_rate / self.vocoder_config.hop_length
|
| 1570 |
+
return output_length.to(torch.int64)
|
| 1571 |
+
|
| 1572 |
+
def forward(self, coarse_mel, input_length, output_length=None):
|
| 1573 |
+
bsz, _, d = coarse_mel.shape
|
| 1574 |
+
assert (d == self.vocoder_config.num_mel_bins)
|
| 1575 |
+
if output_length is None or not self.training:
|
| 1576 |
+
output_length = self.compute_output_length(input_length)
|
| 1577 |
+
coarse_mel, default_dtype = coarse_mel[:, :torch.max(input_length), :], coarse_mel.dtype
|
| 1578 |
+
coarse_mel = F.interpolate(coarse_mel.to(torch.float32).transpose(1, 2).contiguous(), size=output_length.max(),
|
| 1579 |
+
mode='nearest').to(default_dtype)
|
| 1580 |
+
refined_mel = self.layers(coarse_mel).transpose(1, 2).contiguous() # (bs, t, d)
|
| 1581 |
+
coarse_mel = coarse_mel.transpose(1, 2) # (bs, max(output_length), d)
|
| 1582 |
+
refined_mel += coarse_mel # residual conntection
|
| 1583 |
+
sequence_mask, _ = get_sequence_mask(refined_mel, output_length)
|
| 1584 |
+
coarse_mel = torch.where(sequence_mask, coarse_mel, 0)
|
| 1585 |
+
refined_mel = torch.where(sequence_mask, refined_mel, 0)
|
| 1586 |
+
return refined_mel, coarse_mel, output_length
|
| 1587 |
+
|
| 1588 |
+
|
| 1589 |
+
@dataclass
|
| 1590 |
+
class OmniAudioDecoderOutput(ModelOutput):
|
| 1591 |
+
refined_mel: Optional[torch.FloatTensor] = None
|
| 1592 |
+
coarse_mel: Optional[torch.FloatTensor] = None
|
| 1593 |
+
mel_length: Optional[torch.Tensor] = None
|
| 1594 |
+
hidden_states_before_dconv2: Optional[torch.FloatTensor] = None
|
| 1595 |
+
output_length_before_dconv2: Optional[torch.Tensor] = None
|
| 1596 |
+
|
| 1597 |
+
|
| 1598 |
+
class LongcatNextAudioDecoder(nn.Module):
|
| 1599 |
+
def __init__(self, config):
|
| 1600 |
+
super().__init__()
|
| 1601 |
+
self.config = config
|
| 1602 |
+
self.vocoder_config = config.vocoder_config
|
| 1603 |
+
self.max_source_positions = self.config.max_audio_seconds * self.config.sampling_rate // self.config.hop_length
|
| 1604 |
+
|
| 1605 |
+
self.dconv1 = CasualConvTranspose1d(
|
| 1606 |
+
self.config.d_model,
|
| 1607 |
+
self.config.d_model,
|
| 1608 |
+
self.config.decoder_kernel_size,
|
| 1609 |
+
self.config.avg_pooler,
|
| 1610 |
+
)
|
| 1611 |
+
self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, self.config.d_model))
|
| 1612 |
+
# causal transformer layers
|
| 1613 |
+
self.layers = nn.ModuleList(
|
| 1614 |
+
[OmniWhisperTransformerLayer(
|
| 1615 |
+
ACT2FN[self.config.activation_function],
|
| 1616 |
+
self.config.d_model,
|
| 1617 |
+
self.config.decoder_attention_heads,
|
| 1618 |
+
self.config.decoder_ffn_dim,
|
| 1619 |
+
True # causal
|
| 1620 |
+
) for _ in range(self.config.decoder_layers)
|
| 1621 |
+
])
|
| 1622 |
+
self.layer_norm = nn.LayerNorm(self.config.d_model)
|
| 1623 |
+
self.dconv2 = CasualConvTranspose1d(
|
| 1624 |
+
self.config.d_model,
|
| 1625 |
+
self.vocoder_config.num_mel_bins,
|
| 1626 |
+
self.config.decoder_kernel_size,
|
| 1627 |
+
self.config.decoder_stride_size
|
| 1628 |
+
)
|
| 1629 |
+
self.post_net = MelSpecRefineNet(self.config, self.vocoder_config)
|
| 1630 |
+
self.gradient_checkpointing = False
|
| 1631 |
+
|
| 1632 |
+
def forward(self,
|
| 1633 |
+
audio_embed,
|
| 1634 |
+
input_length,
|
| 1635 |
+
mel_labels=None,
|
| 1636 |
+
mel_labels_length=None,
|
| 1637 |
+
):
|
| 1638 |
+
assert (audio_embed.shape[-1] == self.config.d_model)
|
| 1639 |
+
audio_embed = audio_embed.to(self.layer_norm.weight) # device and type
|
| 1640 |
+
audio_embed, output_length = self.dconv1(audio_embed, input_length, output_dim=3) # (b, l*2, d_model)
|
| 1641 |
+
_, tgt_len, _ = audio_embed.size()
|
| 1642 |
+
if tgt_len < self.positional_embedding.shape[0]:
|
| 1643 |
+
current_positional_embedding = self.positional_embedding[:tgt_len]
|
| 1644 |
+
else:
|
| 1645 |
+
current_positional_embedding = self.positional_embedding
|
| 1646 |
+
hidden_states = (audio_embed.to(torch.float32) + current_positional_embedding).to(audio_embed.dtype)
|
| 1647 |
+
|
| 1648 |
+
# packing hidden states
|
| 1649 |
+
attention_mask, _ = get_sequence_mask(hidden_states, output_length)
|
| 1650 |
+
hidden_states = torch.masked_select(hidden_states, attention_mask).view(torch.sum(output_length), self.config.d_model)
|
| 1651 |
+
|
| 1652 |
+
for idx, encoder_layer in enumerate(self.layers):
|
| 1653 |
+
hidden_states = encoder_layer(hidden_states, output_length)
|
| 1654 |
+
|
| 1655 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 1656 |
+
hidden_states_before_dconv2 = hidden_states
|
| 1657 |
+
output_length_before_dconv2 = output_length
|
| 1658 |
+
|
| 1659 |
+
coarse_mel, output_length = self.dconv2(hidden_states, output_length, output_dim=3)
|
| 1660 |
+
refined_mel, coarse_mel, mel_labels_length = self.post_net(coarse_mel, output_length, mel_labels_length)
|
| 1661 |
+
|
| 1662 |
+
return OmniAudioDecoderOutput(
|
| 1663 |
+
refined_mel=refined_mel,
|
| 1664 |
+
coarse_mel=coarse_mel,
|
| 1665 |
+
mel_length=mel_labels_length,
|
| 1666 |
+
hidden_states_before_dconv2=hidden_states_before_dconv2,
|
| 1667 |
+
output_length_before_dconv2=output_length_before_dconv2,
|
| 1668 |
+
)
|
| 1669 |
+
|
| 1670 |
+
|
| 1671 |
+
class LongcatNextAudioVQBridger(nn.Module):
|
| 1672 |
+
def __init__(self, config):
|
| 1673 |
+
super().__init__()
|
| 1674 |
+
self.config = config
|
| 1675 |
+
self.gradient_checkpointing = False
|
| 1676 |
+
self.intermediate_dim = self.config.d_model * self.config.avg_pooler
|
| 1677 |
+
self.gate_proj = nn.Conv1d(self.config.d_model, self.intermediate_dim, self.config.avg_pooler, self.config.avg_pooler, bias=False)
|
| 1678 |
+
self.up_proj = nn.Conv1d(self.config.d_model, self.intermediate_dim, self.config.avg_pooler, self.config.avg_pooler, bias=False)
|
| 1679 |
+
|
| 1680 |
+
self.down_proj = nn.Linear(self.intermediate_dim, self.intermediate_dim, bias=False)
|
| 1681 |
+
self.act_fn = ACT2FN['silu']
|
| 1682 |
+
self.layer_norm = nn.LayerNorm(self.intermediate_dim)
|
| 1683 |
+
self.proj_decoder = nn.Linear(self.intermediate_dim, self.config.d_model)
|
| 1684 |
+
|
| 1685 |
+
self.vq_list = nn.ModuleList([])
|
| 1686 |
+
for idx, codebook_size in enumerate(self.config.vq_config.codebook_sizes):
|
| 1687 |
+
vq_config = copy.deepcopy(self.config.vq_config)
|
| 1688 |
+
vq_config.dim = self.intermediate_dim
|
| 1689 |
+
vq_config.codebook_size = codebook_size
|
| 1690 |
+
self.vq_list.append(VectorQuantize(vq_config))
|
| 1691 |
+
|
| 1692 |
+
def rvq_op(self, inputs, output_length):
|
| 1693 |
+
def rvq_layer_op(vq_layer, residual_encoding, output_length):
|
| 1694 |
+
q_v_i, code_ids_i = vq_layer(residual_encoding, output_length)
|
| 1695 |
+
residual_encoding = residual_encoding.float() - q_v_i.float()
|
| 1696 |
+
residual_encoding = residual_encoding.to(inputs.dtype)
|
| 1697 |
+
return residual_encoding, code_ids_i
|
| 1698 |
+
|
| 1699 |
+
cmt_loss, residual_encoding = 0, inputs
|
| 1700 |
+
code_ids_list = []
|
| 1701 |
+
for i, vq_layer in enumerate(self.vq_list):
|
| 1702 |
+
residual_encoding, code_ids_i = rvq_layer_op(vq_layer, residual_encoding, output_length)
|
| 1703 |
+
code_ids_list.append(code_ids_i)
|
| 1704 |
+
return torch.stack(code_ids_list, -1)
|
| 1705 |
+
|
| 1706 |
+
def forward(self, x, output_length):
|
| 1707 |
+
batch_size, _, _ = x.shape
|
| 1708 |
+
output_length = output_length.to(x.device)
|
| 1709 |
+
|
| 1710 |
+
if x.shape[1] % self.config.avg_pooler != 0:
|
| 1711 |
+
x = F.pad(x, (0, 0, 0, self.config.avg_pooler - x.shape[1] % self.config.avg_pooler), "constant", 0)
|
| 1712 |
+
xt = x.permute(0, 2, 1)
|
| 1713 |
+
g = self.gate_proj(xt).permute(0, 2, 1) # (bs, sl//poolersizre+1, d*2)
|
| 1714 |
+
u = self.up_proj(xt).permute(0, 2, 1)
|
| 1715 |
+
x = x.reshape(batch_size, -1, self.intermediate_dim) # (bs, sl//poolersizre+1, d*2)
|
| 1716 |
+
|
| 1717 |
+
c = self.down_proj(self.act_fn(g) * u)
|
| 1718 |
+
res = self.layer_norm(c + x)
|
| 1719 |
+
valid_mask, _ = get_sequence_mask(res, output_length)
|
| 1720 |
+
code_ids = self.rvq_op(res, output_length)
|
| 1721 |
+
code_ids = torch.masked_select(code_ids, valid_mask).reshape(-1, len(self.vq_list)) # (sum(valid_sequence_length), vq_num)
|
| 1722 |
+
return code_ids
|
| 1723 |
+
|
| 1724 |
+
@torch.no_grad()
|
| 1725 |
+
def decode(self, code_ids):
|
| 1726 |
+
vq_num = code_ids.shape[-1]
|
| 1727 |
+
res = sum(self.vq_list[i].get_output_from_indices(code_ids[:, i]).float() for i in range(vq_num-1,-1,-1)).to(self.proj_decoder.weight)
|
| 1728 |
+
decoder_emb = self.proj_decoder(res.to(self.proj_decoder.weight))
|
| 1729 |
+
return decoder_emb
|
| 1730 |
+
|
| 1731 |
+
@torch.no_grad()
|
| 1732 |
+
def recover(self, code_ids):
|
| 1733 |
+
vq_num = code_ids.shape[-1]
|
| 1734 |
+
res = sum(self.vq_list[i].get_output_from_indices(code_ids[:, i]).float() for i in range(vq_num-1,-1,-1)).to(self.proj_decoder.weight)
|
| 1735 |
+
return res
|
| 1736 |
+
|
| 1737 |
+
|
| 1738 |
+
class FlowmatchingPrenet(nn.Module):
|
| 1739 |
+
def __init__(
|
| 1740 |
+
self,
|
| 1741 |
+
input_feat_dim,
|
| 1742 |
+
out_feat_dim,
|
| 1743 |
+
d_model,
|
| 1744 |
+
attention_heads,
|
| 1745 |
+
ffn_dim,
|
| 1746 |
+
nlayers,
|
| 1747 |
+
activation_function,
|
| 1748 |
+
max_source_positions,
|
| 1749 |
+
target_mel_length_scale_ratio,
|
| 1750 |
+
):
|
| 1751 |
+
super().__init__()
|
| 1752 |
+
|
| 1753 |
+
self.d_model = d_model
|
| 1754 |
+
self.target_mel_length_scale_ratio = target_mel_length_scale_ratio
|
| 1755 |
+
self.gradient_checkpointing = False
|
| 1756 |
+
|
| 1757 |
+
self.register_buffer(
|
| 1758 |
+
"positional_embedding", sinusoids(max_source_positions, d_model)
|
| 1759 |
+
)
|
| 1760 |
+
|
| 1761 |
+
self.in_mlp = nn.Sequential(
|
| 1762 |
+
nn.Linear(input_feat_dim, d_model * 4),
|
| 1763 |
+
nn.SiLU(),
|
| 1764 |
+
nn.Linear(d_model * 4, d_model),
|
| 1765 |
+
)
|
| 1766 |
+
|
| 1767 |
+
self.transformer_layers = nn.ModuleList(
|
| 1768 |
+
[
|
| 1769 |
+
OmniWhisperTransformerLayer(
|
| 1770 |
+
act=ACT2FN[activation_function],
|
| 1771 |
+
d_model=d_model,
|
| 1772 |
+
encoder_attention_heads=attention_heads,
|
| 1773 |
+
encoder_ffn_dim=ffn_dim,
|
| 1774 |
+
causal=True, # causal
|
| 1775 |
+
ln_type="RMSNorm",
|
| 1776 |
+
)
|
| 1777 |
+
for _ in range(nlayers)
|
| 1778 |
+
]
|
| 1779 |
+
)
|
| 1780 |
+
|
| 1781 |
+
self.final_norm = RMSNorm(self.d_model)
|
| 1782 |
+
self.out_proj = nn.Linear(d_model, out_feat_dim, bias=False)
|
| 1783 |
+
|
| 1784 |
+
def compute_output_length(self, input_length):
|
| 1785 |
+
output_length = input_length.float() * self.target_mel_length_scale_ratio
|
| 1786 |
+
return output_length.to(torch.int64)
|
| 1787 |
+
|
| 1788 |
+
def forward(self, input_feat, input_length, output_length=None):
|
| 1789 |
+
"""
|
| 1790 |
+
Args:
|
| 1791 |
+
input_feat: [B, T, input_feat_dim]
|
| 1792 |
+
input_length: [B]
|
| 1793 |
+
output_length: [B]
|
| 1794 |
+
|
| 1795 |
+
"""
|
| 1796 |
+
if output_length is None or not self.training:
|
| 1797 |
+
output_length = self.compute_output_length(input_length)
|
| 1798 |
+
|
| 1799 |
+
input_feat = input_feat[:, : input_length.max(), :] # [B, T, D]
|
| 1800 |
+
orig_dtype = input_feat.dtype
|
| 1801 |
+
|
| 1802 |
+
input_feat = F.interpolate(
|
| 1803 |
+
input=input_feat.to(torch.float32).transpose(1, 2).contiguous(),
|
| 1804 |
+
size=output_length.max(),
|
| 1805 |
+
mode="nearest",
|
| 1806 |
+
).to(orig_dtype)
|
| 1807 |
+
input_feat = input_feat.transpose(1, 2).contiguous() # [B, T, D]
|
| 1808 |
+
hidden_states = self.in_mlp(input_feat)
|
| 1809 |
+
|
| 1810 |
+
# packing hidden states
|
| 1811 |
+
bsz, tgt_len, d_model = hidden_states.shape
|
| 1812 |
+
attention_mask, unpacking_index = get_sequence_mask(
|
| 1813 |
+
hidden_states, output_length
|
| 1814 |
+
)
|
| 1815 |
+
hidden_states = torch.masked_select(hidden_states, attention_mask).view(
|
| 1816 |
+
torch.sum(output_length), self.d_model
|
| 1817 |
+
)
|
| 1818 |
+
|
| 1819 |
+
for idx, encoder_layer in enumerate(self.transformer_layers):
|
| 1820 |
+
hidden_states = encoder_layer(hidden_states, output_length)
|
| 1821 |
+
|
| 1822 |
+
# unpacking
|
| 1823 |
+
hidden_states = torch.index_select(hidden_states, 0, unpacking_index).view(
|
| 1824 |
+
bsz, tgt_len, d_model
|
| 1825 |
+
)
|
| 1826 |
+
hidden_states = torch.where(attention_mask, hidden_states, 0)
|
| 1827 |
+
|
| 1828 |
+
hidden_states = self.final_norm(hidden_states)
|
| 1829 |
+
output = self.out_proj(hidden_states)
|
| 1830 |
+
return output, output_length
|
| 1831 |
+
|
| 1832 |
+
|
| 1833 |
+
@dataclass
|
| 1834 |
+
class OmniAudioFlowMatchingDecoderOutput(ModelOutput):
|
| 1835 |
+
flow_matching_mel: Optional[torch.FloatTensor] = None
|
| 1836 |
+
flow_matching_mel_lengths: Optional[torch.FloatTensor] = None
|
| 1837 |
+
|
| 1838 |
+
|
| 1839 |
+
class LongcatNextAudioFlowMatchingDecoder(nn.Module):
|
| 1840 |
+
def __init__(self, config):
|
| 1841 |
+
super().__init__()
|
| 1842 |
+
self.config = config.flow_matching_config
|
| 1843 |
+
self.in_channels = self.config.in_channels
|
| 1844 |
+
self.spk_emb_dim = self.config.spk_emb_dim
|
| 1845 |
+
self.diffusion_steps = self.config.diffusion_steps
|
| 1846 |
+
self.cal_mel_mae = self.config.cal_mel_mae
|
| 1847 |
+
self.forward_step = -1
|
| 1848 |
+
|
| 1849 |
+
self.prenet = FlowmatchingPrenet(
|
| 1850 |
+
input_feat_dim=self.config.prenet_in_dim,
|
| 1851 |
+
out_feat_dim=self.config.prenet_out_dim,
|
| 1852 |
+
d_model=self.config.prenet_d_model,
|
| 1853 |
+
attention_heads=self.config.prenet_attention_heads,
|
| 1854 |
+
ffn_dim=self.config.prenet_ffn_dim,
|
| 1855 |
+
nlayers=self.config.prenet_nlayers,
|
| 1856 |
+
activation_function=self.config.prenet_activation_function,
|
| 1857 |
+
max_source_positions=self.config.prenet_max_source_positions,
|
| 1858 |
+
target_mel_length_scale_ratio=self.config.prenet_target_mel_length_scale_ratio,
|
| 1859 |
+
)
|
| 1860 |
+
|
| 1861 |
+
self.conditional_decoder = ConditionalDecoder(
|
| 1862 |
+
in_channels=self.in_channels * 2 + self.spk_emb_dim,
|
| 1863 |
+
out_channels=self.in_channels,
|
| 1864 |
+
causal=True,
|
| 1865 |
+
channels=self.config.channels,
|
| 1866 |
+
dropout=self.config.dropout,
|
| 1867 |
+
attention_head_dim=self.config.attention_head_dim,
|
| 1868 |
+
n_blocks=self.config.n_blocks,
|
| 1869 |
+
num_mid_blocks=self.config.num_mid_blocks,
|
| 1870 |
+
num_heads=self.config.num_heads,
|
| 1871 |
+
act_fn=self.config.act_fn,
|
| 1872 |
+
)
|
| 1873 |
+
|
| 1874 |
+
self.cfm = ConditionalCFM(
|
| 1875 |
+
in_channels=self.in_channels,
|
| 1876 |
+
cfm_params=self.config.cfm_params,
|
| 1877 |
+
n_spks=0,
|
| 1878 |
+
spk_emb_dim=self.spk_emb_dim,
|
| 1879 |
+
)
|
| 1880 |
+
|
| 1881 |
+
|
| 1882 |
+
def unpack_hidden_states(self, hidden_states, output_length):
|
| 1883 |
+
unpacked = unpack_hidden_states(hidden_states, output_length)
|
| 1884 |
+
return unpacked, output_length
|
| 1885 |
+
|
| 1886 |
+
def forward(
|
| 1887 |
+
self, refined_mel, input_length, mel_labels=None, mel_labels_length=None
|
| 1888 |
+
):
|
| 1889 |
+
"""
|
| 1890 |
+
:param refined_mel: [bs, max_input_len, mel_bin]
|
| 1891 |
+
:param input_length: [batch_size]
|
| 1892 |
+
:param refined_mel: [bs, mel_bin, max_input_len]
|
| 1893 |
+
:return:
|
| 1894 |
+
"""
|
| 1895 |
+
self.forward_step += 1
|
| 1896 |
+
|
| 1897 |
+
orig_dtype = refined_mel.dtype
|
| 1898 |
+
prenet_mae_metric = torch.tensor(0.0).to(refined_mel.device)
|
| 1899 |
+
prenet_regression_loss = torch.tensor(0.0).to(refined_mel.device)
|
| 1900 |
+
|
| 1901 |
+
if self.prenet is not None:
|
| 1902 |
+
refined_mel = refined_mel[:, : torch.max(input_length), :]
|
| 1903 |
+
if mel_labels_length is None:
|
| 1904 |
+
mel_labels_length = self.prenet.compute_output_length(input_length)
|
| 1905 |
+
refined_mel, input_length = self.prenet(
|
| 1906 |
+
refined_mel, input_length, mel_labels_length
|
| 1907 |
+
)
|
| 1908 |
+
|
| 1909 |
+
float_dtype = refined_mel.dtype
|
| 1910 |
+
refined_mel = refined_mel.float()
|
| 1911 |
+
input_length = input_length.long()
|
| 1912 |
+
|
| 1913 |
+
refined_mel = refined_mel[:, : torch.max(input_length), :]
|
| 1914 |
+
sequence_mask, unpacking_index = get_sequence_mask(refined_mel, input_length)
|
| 1915 |
+
refined_mel = refined_mel.transpose(1, 2) # (bs, mel_bin, max_input_len)
|
| 1916 |
+
sequence_mask = sequence_mask.transpose(2, 1) # (bs, 1, sl)
|
| 1917 |
+
|
| 1918 |
+
fm_mel = self.cfm.forward(
|
| 1919 |
+
estimator=self.conditional_decoder,
|
| 1920 |
+
mu=refined_mel.to(float_dtype),
|
| 1921 |
+
mask=sequence_mask.float(),
|
| 1922 |
+
n_timesteps=self.diffusion_steps,
|
| 1923 |
+
)
|
| 1924 |
+
return OmniAudioFlowMatchingDecoderOutput(
|
| 1925 |
+
flow_matching_mel=fm_mel.transpose(1, 2),
|
| 1926 |
+
flow_matching_mel_lengths=mel_labels_length,
|
| 1927 |
+
)
|
| 1928 |
+
|
| 1929 |
+
|
| 1930 |
+
@torch.no_grad()
|
| 1931 |
+
def decode_wave_vocoder2(response, vocoder, audio_tokenizer):
|
| 1932 |
+
response_len = (response[:,:,0] == audio_tokenizer.config.audio_config.vq_config.codebook_sizes[0]).long().argmax(dim=1)
|
| 1933 |
+
valid_response_list = [response[i, :response_len[i], :] for i in range(response.shape[0]) if int(response_len[i])>0]
|
| 1934 |
+
|
| 1935 |
+
if len(valid_response_list)==0:
|
| 1936 |
+
return []
|
| 1937 |
+
flatten_response = torch.cat(valid_response_list, dim=0) if len(valid_response_list)>1 else valid_response_list[0]
|
| 1938 |
+
valid_response_len = response_len[response_len>0]
|
| 1939 |
+
ret = audio_tokenizer.decode(flatten_response.view(-1,response.shape[-1]),
|
| 1940 |
+
bridge_length=valid_response_len)
|
| 1941 |
+
batch_size = response.shape[0]
|
| 1942 |
+
valid_start = 0
|
| 1943 |
+
r = []
|
| 1944 |
+
for i in range(batch_size):
|
| 1945 |
+
if response_len[i]==0:
|
| 1946 |
+
r.append(None)
|
| 1947 |
+
continue
|
| 1948 |
+
if isinstance(ret, torch.Tensor):
|
| 1949 |
+
r.append(ret[valid_start:valid_start+1])
|
| 1950 |
+
valid_start+=1
|
| 1951 |
+
continue
|
| 1952 |
+
decode_wave = vocoder.decode(ret.flow_matching_mel[valid_start ][:ret.flow_matching_mel_lengths[valid_start ], :].transpose(0, 1).to(torch.float32).unsqueeze(0))
|
| 1953 |
+
r.append(decode_wave.cpu())
|
| 1954 |
+
valid_start+=1
|
| 1955 |
+
return r
|
| 1956 |
+
|
| 1957 |
+
|
| 1958 |
+
@torch.no_grad()
|
| 1959 |
+
def decode_save_concat2(response_list, vocoder, model, path, sampling_rate=16000, wave_concat_overlap=800):
|
| 1960 |
+
wave_list = []
|
| 1961 |
+
for response in response_list:
|
| 1962 |
+
wave_list.extend([wave_i for wave_i in decode_wave_vocoder2(response, vocoder, model) if wave_i is not None])
|
| 1963 |
+
new_wave_list = [wave_list[0]]
|
| 1964 |
+
for w in wave_list[1:]:
|
| 1965 |
+
if new_wave_list[-1].shape[1] > wave_concat_overlap and w.shape[1] > wave_concat_overlap:
|
| 1966 |
+
new_wave_list.append((new_wave_list[-1][:, -wave_concat_overlap:] * torch.linspace(1.0, 0.0, wave_concat_overlap, device=new_wave_list[-1].device)[None, :]
|
| 1967 |
+
+ w[:, :wave_concat_overlap] * torch.linspace(0.0, 1.0, wave_concat_overlap, device=new_wave_list[-1].device)[None, :]))
|
| 1968 |
+
new_wave_list.append(w)
|
| 1969 |
+
full_wave = torch.cat(new_wave_list, dim=1) if len(new_wave_list) > 1 else new_wave_list[0]
|
| 1970 |
+
torchaudio.save(path, full_wave, sampling_rate)
|
| 1971 |
+
|
| 1972 |
+
|
| 1973 |
+
class LongcatNextAudioTokenizer(nn.Module):
|
| 1974 |
+
|
| 1975 |
+
def __init__(self, config):
|
| 1976 |
+
super().__init__()
|
| 1977 |
+
self.config = config
|
| 1978 |
+
self.audio_model = LongcatNextAudioEncoder(config.audio_config)
|
| 1979 |
+
self.audio_bridge_model = LongcatNextAudioVQBridger(config.audio_config)
|
| 1980 |
+
self.audio_decoder = LongcatNextAudioDecoder(config.audio_config)
|
| 1981 |
+
self.audio_flow_matching_decoder = LongcatNextAudioFlowMatchingDecoder(config.audio_config)
|
| 1982 |
+
self.cosy24kvocoder = None
|
| 1983 |
+
|
| 1984 |
+
@torch.no_grad()
|
| 1985 |
+
def encode(self, x, encoder_length: Optional[torch.Tensor] = None, bridge_length: Optional[torch.Tensor] = None):
|
| 1986 |
+
audio_emb = self.audio_model(x, encoder_length)
|
| 1987 |
+
audio_tokens = self.audio_bridge_model(audio_emb, bridge_length)
|
| 1988 |
+
return audio_tokens
|
| 1989 |
+
|
| 1990 |
+
@torch.no_grad()
|
| 1991 |
+
def decode(self, audio_ids, bridge_length: Optional[torch.Tensor] = None):
|
| 1992 |
+
audio_emb = self.audio_bridge_model.decode(audio_ids)
|
| 1993 |
+
audio_dec = self.audio_decoder(
|
| 1994 |
+
audio_emb.to(next(self.audio_decoder.parameters())), bridge_length
|
| 1995 |
+
)
|
| 1996 |
+
if self.config.audio_config.flow_matching_config.use_hidden_states_before_dconv2:
|
| 1997 |
+
hidden_states, hidden_states_length = (
|
| 1998 |
+
self.audio_flow_matching_decoder.unpack_hidden_states(
|
| 1999 |
+
audio_dec.hidden_states_before_dconv2,
|
| 2000 |
+
audio_dec.output_length_before_dconv2,
|
| 2001 |
+
)
|
| 2002 |
+
)
|
| 2003 |
+
audio_flow_matching_decoder_ret = self.audio_flow_matching_decoder(
|
| 2004 |
+
hidden_states, hidden_states_length
|
| 2005 |
+
)
|
| 2006 |
+
else:
|
| 2007 |
+
audio_flow_matching_decoder_ret = self.audio_flow_matching_decoder(
|
| 2008 |
+
audio_dec.refined_mel, audio_dec.mel_length
|
| 2009 |
+
)
|
| 2010 |
+
return audio_flow_matching_decoder_ret
|
| 2011 |
+
|
| 2012 |
+
@torch.no_grad()
|
| 2013 |
+
def lazy_decode_and_save(self, audio_ids, sampling_rate, wave_concat_overlap, save_path):
|
| 2014 |
+
if self.cosy24kvocoder is None:
|
| 2015 |
+
print("lazy load cosy24kvocoder ...")
|
| 2016 |
+
device = next(self.parameters()).device
|
| 2017 |
+
self.cosy24kvocoder = Cosy24kVocoder.from_pretrained(self.config.audio_config.cosy24kvocoder_config.weight_path).to(device)
|
| 2018 |
+
|
| 2019 |
+
if audio_ids[-1, 0] != self.config.audio_config.vq_config.codebook_sizes[0]: # exceed max_new_tokens
|
| 2020 |
+
audio_ids = F.pad(audio_ids, (0, 0, 0, 1), value=self.config.audio_config.vq_config.codebook_sizes[0])
|
| 2021 |
+
|
| 2022 |
+
audio_end_pos = [-1] + (audio_ids[:, 0] == self.config.audio_config.vq_config.codebook_sizes[0]).nonzero().view(-1).tolist()
|
| 2023 |
+
|
| 2024 |
+
audio_ids_chunk = []
|
| 2025 |
+
for i in range(len(audio_end_pos) - 1):
|
| 2026 |
+
start = audio_end_pos[i] + 1
|
| 2027 |
+
end = audio_end_pos[i+1] + 1
|
| 2028 |
+
audio_ids_chunk.append(audio_ids[start:end].unsqueeze(0))
|
| 2029 |
+
|
| 2030 |
+
audio_ids = audio_ids_chunk
|
| 2031 |
+
|
| 2032 |
+
decode_save_concat2(
|
| 2033 |
+
response_list=audio_ids,
|
| 2034 |
+
vocoder=self.cosy24kvocoder,
|
| 2035 |
+
model=self,
|
| 2036 |
+
path=save_path,
|
| 2037 |
+
sampling_rate=sampling_rate,
|
| 2038 |
+
wave_concat_overlap=wave_concat_overlap,
|
| 2039 |
+
)
|
modular_longcat_next_visual.py
ADDED
|
@@ -0,0 +1,1077 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Iterable, Optional, Tuple
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
from safetensors.torch import load_file
|
| 5 |
+
import torch
|
| 6 |
+
import torch.utils.checkpoint
|
| 7 |
+
from torch import nn
|
| 8 |
+
from torch.amp import autocast
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
+
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
from flash_attn import flash_attn_varlen_func
|
| 13 |
+
|
| 14 |
+
from transformers.activations import ACT2FN
|
| 15 |
+
from transformers.modeling_outputs import BaseModelOutput
|
| 16 |
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
| 17 |
+
Qwen2RMSNorm,
|
| 18 |
+
Qwen2_5_VisionTransformerPretrainedModel,
|
| 19 |
+
)
|
| 20 |
+
from transformers.utils import logging
|
| 21 |
+
|
| 22 |
+
from .image_refiner import (
|
| 23 |
+
ImageRefinerContainer,
|
| 24 |
+
RefinerImageProcessor,
|
| 25 |
+
RefinerPipeline,
|
| 26 |
+
de_transform,
|
| 27 |
+
tensor2pil,
|
| 28 |
+
)
|
| 29 |
+
from .refiner_modules import FlowMatchEulerDiscreteScheduler
|
| 30 |
+
|
| 31 |
+
logger = logging.get_logger(__name__)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def uniform_init(*shape):
|
| 35 |
+
t = torch.zeros(shape)
|
| 36 |
+
nn.init.kaiming_uniform_(t)
|
| 37 |
+
return t
|
| 38 |
+
|
| 39 |
+
class VQEmbedding(nn.Module):
|
| 40 |
+
"""VQ embedding module with ema update."""
|
| 41 |
+
|
| 42 |
+
def __init__(self, n_embed, embed_dim, ema=True, decay=0.99, restart_unused_codes=True, eps=1e-5, init_std=0.02):
|
| 43 |
+
super().__init__()
|
| 44 |
+
|
| 45 |
+
self.ema = ema
|
| 46 |
+
self.decay = decay
|
| 47 |
+
self.eps = eps
|
| 48 |
+
self.restart_unused_codes = restart_unused_codes
|
| 49 |
+
self.n_embed = n_embed
|
| 50 |
+
self.init_std = init_std
|
| 51 |
+
|
| 52 |
+
assert self.ema
|
| 53 |
+
embed = uniform_init(n_embed + 1, embed_dim).to(torch.float32)
|
| 54 |
+
self.embed = nn.Parameter(embed)
|
| 55 |
+
self.embed_ema = nn.Parameter(embed[:-1, :].clone())
|
| 56 |
+
self.cluster_size_ema = nn.Parameter(torch.ones(n_embed))
|
| 57 |
+
del embed
|
| 58 |
+
_ = [p.requires_grad_(False) for p in self.parameters()]
|
| 59 |
+
|
| 60 |
+
@torch.no_grad()
|
| 61 |
+
def compute_distances(self, inputs):
|
| 62 |
+
codebook_t = self.embed[:-1, :].t()
|
| 63 |
+
|
| 64 |
+
(embed_dim, _) = codebook_t.shape
|
| 65 |
+
inputs_shape = inputs.shape
|
| 66 |
+
assert inputs_shape[-1] == embed_dim
|
| 67 |
+
|
| 68 |
+
inputs_flat = inputs.reshape(-1, embed_dim)
|
| 69 |
+
|
| 70 |
+
inputs_norm_sq = inputs_flat.pow(2.).sum(dim=1, keepdim=True)
|
| 71 |
+
codebook_t_norm_sq = codebook_t.pow(2.).sum(dim=0, keepdim=True)
|
| 72 |
+
distances = torch.addmm(
|
| 73 |
+
inputs_norm_sq + codebook_t_norm_sq,
|
| 74 |
+
inputs_flat,
|
| 75 |
+
codebook_t,
|
| 76 |
+
alpha=-2.0,
|
| 77 |
+
)
|
| 78 |
+
distances = distances.reshape(*inputs_shape[:-1], -1) # [B, h, w, n_embed or n_embed+1]
|
| 79 |
+
return distances
|
| 80 |
+
|
| 81 |
+
@torch.no_grad()
|
| 82 |
+
def find_nearest_embedding(self, inputs):
|
| 83 |
+
distances = self.compute_distances(inputs) # [B, h, w, n_embed or n_embed+1]
|
| 84 |
+
embed_idxs = distances.argmin(dim=-1) # use padding index or not
|
| 85 |
+
|
| 86 |
+
return embed_idxs
|
| 87 |
+
|
| 88 |
+
@autocast('cuda', enabled=True, dtype=torch.float32)
|
| 89 |
+
@torch.no_grad()
|
| 90 |
+
def forward(self, inputs):
|
| 91 |
+
if inputs.dtype != torch.float32:
|
| 92 |
+
inputs = inputs.to(torch.float32)
|
| 93 |
+
embed_idxs = self.find_nearest_embedding(inputs)
|
| 94 |
+
embeds = self.embed[embed_idxs]
|
| 95 |
+
return embeds, embed_idxs
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class RQBottleneck(nn.Module):
|
| 99 |
+
"""
|
| 100 |
+
Quantization bottleneck via Residual Quantization.
|
| 101 |
+
|
| 102 |
+
Arguments:
|
| 103 |
+
latent_shape (Tuple[int, int, int]): the shape of latents, denoted (H, W, D)
|
| 104 |
+
code_shape (Tuple[int, int, int]): the shape of codes, denoted (h, w, d)
|
| 105 |
+
n_embed (int, List, or Tuple): the number of embeddings (i.e., the size of codebook)
|
| 106 |
+
If isinstance(n_embed, int), the sizes of all codebooks are same.
|
| 107 |
+
shared_codebook (bool): If True, codebooks are shared in all location. If False,
|
| 108 |
+
uses separate codebooks along the ``depth'' dimension. (default: False)
|
| 109 |
+
restart_unused_codes (bool): If True, it randomly assigns a feature vector in the curruent batch
|
| 110 |
+
as the new embedding of unused codes in training. (default: True)
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
def __init__(self,
|
| 114 |
+
latent_shape,
|
| 115 |
+
code_shape,
|
| 116 |
+
n_embed,
|
| 117 |
+
decay=0.99,
|
| 118 |
+
shared_codebook=False,
|
| 119 |
+
restart_unused_codes=True,
|
| 120 |
+
commitment_loss='cumsum'
|
| 121 |
+
):
|
| 122 |
+
super().__init__()
|
| 123 |
+
|
| 124 |
+
if not len(code_shape) == len(latent_shape) == 3:
|
| 125 |
+
raise ValueError("incompatible code shape or latent shape")
|
| 126 |
+
if any([y % x != 0 for x, y in zip(code_shape[:2], latent_shape[:2])]):
|
| 127 |
+
raise ValueError("incompatible code shape or latent shape")
|
| 128 |
+
|
| 129 |
+
#residual quantization does not divide feature dims for quantization.
|
| 130 |
+
embed_dim = np.prod(latent_shape[:2]) // np.prod(code_shape[:2]) * latent_shape[2]
|
| 131 |
+
|
| 132 |
+
self.latent_shape = torch.Size(latent_shape)
|
| 133 |
+
self.code_shape = torch.Size(code_shape)
|
| 134 |
+
self.shape_divisor = torch.Size([latent_shape[i] // code_shape[i] for i in range(len(latent_shape))])
|
| 135 |
+
|
| 136 |
+
self.shared_codebook = shared_codebook
|
| 137 |
+
if self.shared_codebook:
|
| 138 |
+
if isinstance(n_embed, Iterable) or isinstance(decay, Iterable):
|
| 139 |
+
raise ValueError("Shared codebooks are incompatible \
|
| 140 |
+
with list types of momentums or sizes: Change it into int")
|
| 141 |
+
|
| 142 |
+
self.restart_unused_codes = restart_unused_codes
|
| 143 |
+
self.n_embed = n_embed if isinstance(n_embed, Iterable) else [n_embed for _ in range(self.code_shape[-1])]
|
| 144 |
+
self.decay = decay if isinstance(decay, Iterable) else [decay for _ in range(self.code_shape[-1])]
|
| 145 |
+
assert len(self.n_embed) == self.code_shape[-1]
|
| 146 |
+
assert len(self.decay) == self.code_shape[-1]
|
| 147 |
+
|
| 148 |
+
if self.shared_codebook:
|
| 149 |
+
codebook0 = VQEmbedding(self.n_embed[0],
|
| 150 |
+
embed_dim,
|
| 151 |
+
decay=self.decay[0],
|
| 152 |
+
restart_unused_codes=restart_unused_codes,
|
| 153 |
+
).to(torch.float32)
|
| 154 |
+
self.codebooks = nn.ModuleList([codebook0 for _ in range(self.code_shape[-1])])
|
| 155 |
+
else:
|
| 156 |
+
codebooks = [VQEmbedding(self.n_embed[idx],
|
| 157 |
+
embed_dim,
|
| 158 |
+
decay=self.decay[idx],
|
| 159 |
+
restart_unused_codes=restart_unused_codes,
|
| 160 |
+
).to(torch.float32) for idx in range(self.code_shape[-1])]
|
| 161 |
+
self.codebooks = nn.ModuleList(codebooks)
|
| 162 |
+
|
| 163 |
+
self.commitment_loss = commitment_loss
|
| 164 |
+
|
| 165 |
+
def to_code_shape(self, x):
|
| 166 |
+
(B, H, W, D) = x.shape
|
| 167 |
+
(rH, rW, _) = self.shape_divisor
|
| 168 |
+
|
| 169 |
+
x = x.reshape(B, H//rH, rH, W//rW, rW, D)
|
| 170 |
+
x = x.permute(0, 1, 3, 2, 4, 5)
|
| 171 |
+
x = x.reshape(B, H//rH, W//rW, -1)
|
| 172 |
+
|
| 173 |
+
return x
|
| 174 |
+
|
| 175 |
+
def to_latent_shape(self, x):
|
| 176 |
+
(B, h, w, _) = x.shape
|
| 177 |
+
(_, _, D) = self.latent_shape
|
| 178 |
+
(rH, rW, _) = self.shape_divisor
|
| 179 |
+
|
| 180 |
+
x = x.reshape(B, h, w, rH, rW, D)
|
| 181 |
+
x = x.permute(0, 1, 3, 2, 4, 5)
|
| 182 |
+
x = x.reshape(B, h*rH, w*rW, D)
|
| 183 |
+
|
| 184 |
+
return x
|
| 185 |
+
|
| 186 |
+
def quantize(self, x):
|
| 187 |
+
r"""
|
| 188 |
+
Return list of quantized features and the selected codewords by the residual quantization.
|
| 189 |
+
The code is selected by the residuals between x and quantized features by the previous codebooks.
|
| 190 |
+
|
| 191 |
+
Arguments:
|
| 192 |
+
x (Tensor): bottleneck feature maps to quantize.
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
quant_list (list): list of sequentially aggregated and quantized feature maps by codebooks.
|
| 196 |
+
codes (LongTensor): codewords index, corresponding to quants.
|
| 197 |
+
|
| 198 |
+
Shape:
|
| 199 |
+
- x: (B, h, w, embed_dim)
|
| 200 |
+
- quant_list[i]: (B, h, w, embed_dim)
|
| 201 |
+
- codes: (B, h, w, d)
|
| 202 |
+
"""
|
| 203 |
+
B, h, w, embed_dim = x.shape
|
| 204 |
+
ori_dtype = x.dtype
|
| 205 |
+
x = x.to(torch.float32)
|
| 206 |
+
self.codebooks = self.codebooks.to(torch.float32)
|
| 207 |
+
|
| 208 |
+
residual_feature = x.detach().clone()
|
| 209 |
+
|
| 210 |
+
quant_list = []
|
| 211 |
+
code_list = []
|
| 212 |
+
aggregated_quants = torch.zeros_like(x)
|
| 213 |
+
for i in range(self.code_shape[-1]):
|
| 214 |
+
quant, code = self.codebooks[i](residual_feature)
|
| 215 |
+
residual_feature.sub_(quant)
|
| 216 |
+
aggregated_quants.add_(quant)
|
| 217 |
+
quant_list.append(aggregated_quants.clone().to(dtype=ori_dtype))
|
| 218 |
+
code_list.append(code.unsqueeze(-1))
|
| 219 |
+
|
| 220 |
+
codes = torch.cat(code_list, dim=-1)
|
| 221 |
+
return quant_list, codes
|
| 222 |
+
|
| 223 |
+
def forward(self, x):
|
| 224 |
+
x_reshaped = self.to_code_shape(x)
|
| 225 |
+
# 强制使用float32精度来执行
|
| 226 |
+
quant_list, codes = self.quantize(x_reshaped)
|
| 227 |
+
# quant_list, codes = self.quantize(x_reshaped)
|
| 228 |
+
|
| 229 |
+
commitment_loss = self.compute_commitment_loss(x_reshaped, quant_list)
|
| 230 |
+
quants_trunc = self.to_latent_shape(quant_list[-1])
|
| 231 |
+
quants_trunc = x + (quants_trunc - x).detach()
|
| 232 |
+
|
| 233 |
+
'''
|
| 234 |
+
if self.shared_codebook:
|
| 235 |
+
cur_len = codes.view(-1).shape[0]
|
| 236 |
+
self.codebook_used[:-cur_len] = self.codebook_used[cur_len:].clone()
|
| 237 |
+
self.codebook_used[-cur_len:] = codes.view(-1)
|
| 238 |
+
codebook_usage = len(torch.unique(self.codebook_used)) / self.n_embed[0]
|
| 239 |
+
else:
|
| 240 |
+
# info|code: torch.Size([10, 16, 16, 4])
|
| 241 |
+
codebook_usage = 0
|
| 242 |
+
for idx in range(self.code_shape[-1]):
|
| 243 |
+
cur_len = codes[..., idx].view(-1).shape[0]
|
| 244 |
+
self.codebook_used[idx, :-cur_len] = self.codebook_used[idx, cur_len:].clone()
|
| 245 |
+
self.codebook_used[idx, -cur_len:] = codes[..., idx].view(-1)
|
| 246 |
+
codebook_usage += len(torch.unique(self.codebook_used[idx]))
|
| 247 |
+
codebook_usage /= (self.n_embed[0] * self.code_shape[-1])
|
| 248 |
+
'''
|
| 249 |
+
codebook_usage = 0
|
| 250 |
+
# (vq_loss, commit_loss, entropy_loss, codebook_usage) # 格式对齐
|
| 251 |
+
codebook_loss = [0, commitment_loss, 0, codebook_usage]
|
| 252 |
+
|
| 253 |
+
return quants_trunc, codebook_loss, codes
|
| 254 |
+
|
| 255 |
+
def compute_commitment_loss(self, x, quant_list):
|
| 256 |
+
r"""
|
| 257 |
+
Compute the commitment loss for the residual quantization.
|
| 258 |
+
The loss is iteratively computed by aggregating quantized features.
|
| 259 |
+
"""
|
| 260 |
+
loss_list = []
|
| 261 |
+
|
| 262 |
+
for idx, quant in enumerate(quant_list):
|
| 263 |
+
partial_loss = (x-quant.detach()).pow(2.0).mean()
|
| 264 |
+
loss_list.append(partial_loss)
|
| 265 |
+
|
| 266 |
+
commitment_loss = torch.mean(torch.stack(loss_list))
|
| 267 |
+
return commitment_loss
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class Qwen2_5_VisionRotaryEmbedding_Modified(nn.Module):
|
| 272 |
+
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
| 273 |
+
super().__init__()
|
| 274 |
+
self.inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
| 275 |
+
# self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 276 |
+
|
| 277 |
+
def forward(self, seqlen: int, device: torch.device) -> torch.Tensor:
|
| 278 |
+
self.inv_freq = self.inv_freq.to(device)
|
| 279 |
+
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
| 280 |
+
freqs = torch.outer(seq, self.inv_freq)
|
| 281 |
+
return freqs
|
| 282 |
+
|
| 283 |
+
class VisualEncoder(Qwen2_5_VisionTransformerPretrainedModel):
|
| 284 |
+
|
| 285 |
+
def __init__(self, config):
|
| 286 |
+
config._attn_implementation = 'flash_attention_2'
|
| 287 |
+
super().__init__(config)
|
| 288 |
+
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding_Modified(config.hidden_size // config.num_heads // 2)
|
| 289 |
+
self.gradient_checkpointing = False
|
| 290 |
+
self._gradient_checkpointing_func = torch.utils.checkpoint.checkpoint
|
| 291 |
+
self.merge_size = config.merge_size if hasattr(config, 'merge_size') else 2
|
| 292 |
+
del self.merger # register visual.merger in visual_bridge_model
|
| 293 |
+
|
| 294 |
+
def get_dtype(self) -> torch.dtype:
|
| 295 |
+
return self.blocks[0].mlp.down_proj.weight.dtype
|
| 296 |
+
|
| 297 |
+
def get_device(self) -> torch.device:
|
| 298 |
+
return self.blocks[0].mlp.down_proj.weight.device
|
| 299 |
+
|
| 300 |
+
def rot_pos_emb(self, grid_thw):
|
| 301 |
+
pos_ids = []
|
| 302 |
+
for t, h, w in grid_thw:
|
| 303 |
+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
| 304 |
+
hpos_ids = hpos_ids.reshape(
|
| 305 |
+
h // self.spatial_merge_size,
|
| 306 |
+
self.spatial_merge_size,
|
| 307 |
+
w // self.spatial_merge_size,
|
| 308 |
+
self.spatial_merge_size,
|
| 309 |
+
)
|
| 310 |
+
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
| 311 |
+
hpos_ids = hpos_ids.flatten()
|
| 312 |
+
|
| 313 |
+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
| 314 |
+
wpos_ids = wpos_ids.reshape(
|
| 315 |
+
h // self.spatial_merge_size,
|
| 316 |
+
self.spatial_merge_size,
|
| 317 |
+
w // self.spatial_merge_size,
|
| 318 |
+
self.spatial_merge_size,
|
| 319 |
+
)
|
| 320 |
+
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
| 321 |
+
wpos_ids = wpos_ids.flatten()
|
| 322 |
+
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
| 323 |
+
pos_ids = torch.cat(pos_ids, dim=0)
|
| 324 |
+
max_grid_size = grid_thw[:, 1:].max()
|
| 325 |
+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size, device=grid_thw.device)
|
| 326 |
+
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
| 327 |
+
return rotary_pos_emb
|
| 328 |
+
|
| 329 |
+
def forward(
|
| 330 |
+
self,
|
| 331 |
+
pixel_values: torch.Tensor,
|
| 332 |
+
grid_thw: torch.Tensor,
|
| 333 |
+
require_window_index: bool = False,
|
| 334 |
+
):
|
| 335 |
+
'''
|
| 336 |
+
pixel_values.shape=[NumOfPatches, 1176]
|
| 337 |
+
grid_thw.shape=[NumOfSamples, 3]. [grid_t,grid_h,grid_w]
|
| 338 |
+
'''
|
| 339 |
+
hidden_states = pixel_values.to(torch.bfloat16)
|
| 340 |
+
grid_thw = grid_thw.to(pixel_values.device)
|
| 341 |
+
|
| 342 |
+
hidden_states = self.patch_embed(hidden_states)
|
| 343 |
+
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
| 344 |
+
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
|
| 345 |
+
cu_window_seqlens = torch.tensor(
|
| 346 |
+
cu_window_seqlens,
|
| 347 |
+
device=hidden_states.device,
|
| 348 |
+
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
|
| 349 |
+
)
|
| 350 |
+
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
| 351 |
+
|
| 352 |
+
seq_len, _ = hidden_states.size()
|
| 353 |
+
hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
| 354 |
+
hidden_states = hidden_states[window_index, :, :]
|
| 355 |
+
hidden_states = hidden_states.reshape(seq_len, -1)
|
| 356 |
+
rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
| 357 |
+
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
|
| 358 |
+
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
| 359 |
+
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
| 360 |
+
position_embeddings = (emb.cos(), emb.sin())
|
| 361 |
+
|
| 362 |
+
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
|
| 363 |
+
dim=0,
|
| 364 |
+
# Select dtype based on the following factors:
|
| 365 |
+
# - FA2 requires that cu_seqlens_q must have dtype int32
|
| 366 |
+
# - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
|
| 367 |
+
# See https://github.com/huggingface/transformers/pull/34852 for more information
|
| 368 |
+
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
|
| 369 |
+
)
|
| 370 |
+
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
| 371 |
+
|
| 372 |
+
for layer_num, blk in enumerate(self.blocks):
|
| 373 |
+
if layer_num in self.fullatt_block_indexes:
|
| 374 |
+
cu_seqlens_now = cu_seqlens
|
| 375 |
+
else:
|
| 376 |
+
cu_seqlens_now = cu_window_seqlens
|
| 377 |
+
if self.gradient_checkpointing and self.training:
|
| 378 |
+
hidden_states = self._gradient_checkpointing_func(blk.__call__, hidden_states, cu_seqlens_now, None, position_embeddings)
|
| 379 |
+
else:
|
| 380 |
+
hidden_states = blk(
|
| 381 |
+
hidden_states,
|
| 382 |
+
cu_seqlens=cu_seqlens_now,
|
| 383 |
+
position_embeddings=position_embeddings,
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
if require_window_index:
|
| 387 |
+
return hidden_states, window_index
|
| 388 |
+
return hidden_states
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
class OmniVisualBridge(nn.Module):
|
| 392 |
+
def __init__(self, config):
|
| 393 |
+
super().__init__()
|
| 394 |
+
self.config = config
|
| 395 |
+
self.merge_size = self.config.merge_size if hasattr(self.config, 'merge_size') else 2
|
| 396 |
+
self.hidden_size = self.config.hidden_size * (self.merge_size**2)
|
| 397 |
+
self.window_index = self.config.window_size
|
| 398 |
+
self.ln_q = Qwen2RMSNorm(self.config.hidden_size, eps=1e-6)
|
| 399 |
+
self.mlp = nn.Sequential(
|
| 400 |
+
nn.Linear(self.hidden_size, self.hidden_size),
|
| 401 |
+
nn.GELU(),
|
| 402 |
+
nn.Linear(self.hidden_size, self.config.out_hidden_size),
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
def forward(self, x: torch.Tensor, window_index) -> torch.Tensor:
|
| 406 |
+
x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
|
| 407 |
+
reverse_indices = torch.argsort(window_index)
|
| 408 |
+
x = x[reverse_indices, :]
|
| 409 |
+
|
| 410 |
+
return x
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
class VisualQuantizer(nn.Module):
|
| 414 |
+
def __init__(self, quantizer_config):
|
| 415 |
+
super().__init__()
|
| 416 |
+
|
| 417 |
+
self.config = quantizer_config
|
| 418 |
+
self.depth = self.config.depth
|
| 419 |
+
self.decay = self.config.decay
|
| 420 |
+
self.codebook_size = self.config.codebook_size
|
| 421 |
+
self.codebook_dim = self.config.codebook_dim
|
| 422 |
+
self.shared_codebook = self.config.shared_codebook
|
| 423 |
+
self.restart_unused_codes = self.config.restart_unused_codes
|
| 424 |
+
self.in_channels = self.config.in_channels
|
| 425 |
+
|
| 426 |
+
self.vq_loss_ratio = self.config.vq_loss_ratio
|
| 427 |
+
self.entropy_loss_ratio = self.config.entropy_loss_ratio
|
| 428 |
+
self.commit_loss_ratio = self.config.commit_loss_ratio
|
| 429 |
+
|
| 430 |
+
code_h_w = int(448 / 14)
|
| 431 |
+
latent_shape = [code_h_w, code_h_w, self.codebook_dim]
|
| 432 |
+
code_shape = [code_h_w, code_h_w, self.depth]
|
| 433 |
+
|
| 434 |
+
self.quantize = RQBottleneck(
|
| 435 |
+
latent_shape=latent_shape,
|
| 436 |
+
code_shape=code_shape,
|
| 437 |
+
n_embed=self.codebook_size,
|
| 438 |
+
decay=self.decay,
|
| 439 |
+
shared_codebook=self.shared_codebook,
|
| 440 |
+
restart_unused_codes=self.restart_unused_codes,
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
if self.config.quant_conv:
|
| 444 |
+
self.quant_conv = nn.Sequential(
|
| 445 |
+
nn.LayerNorm(self.in_channels),
|
| 446 |
+
nn.Linear(self.in_channels, self.in_channels),
|
| 447 |
+
nn.GELU(),
|
| 448 |
+
nn.Linear(self.in_channels, self.codebook_dim)
|
| 449 |
+
)
|
| 450 |
+
else:
|
| 451 |
+
self.quant_conv = None
|
| 452 |
+
|
| 453 |
+
def encode(self, x):
|
| 454 |
+
L, D = x.shape
|
| 455 |
+
to_qnt_feat = x.clone()
|
| 456 |
+
to_qnt_feat = to_qnt_feat.unsqueeze(0) # [L, D] -> [1, L, D]
|
| 457 |
+
N = 1
|
| 458 |
+
|
| 459 |
+
if self.quant_conv is not None:
|
| 460 |
+
to_qnt_feat = self.quant_conv(to_qnt_feat)
|
| 461 |
+
|
| 462 |
+
# quantizer needs nchw format. N,L,d -> N,1,L,d -> N,d,1,L
|
| 463 |
+
to_qnt_feat = to_qnt_feat.reshape(N, 1, L, self.codebook_dim).permute(0,3,1,2)
|
| 464 |
+
if self.config.quantizer_type == "rq":
|
| 465 |
+
to_qnt_feat = to_qnt_feat.permute(0, 2, 3, 1).contiguous() # N,d,1,L -> N,1,L,d
|
| 466 |
+
quant, emb_loss, info = self.quantize(to_qnt_feat)
|
| 467 |
+
info = info.reshape(-1, info.shape[-1]) # n,h,w,lv -> n*h*w,lv
|
| 468 |
+
info = [None, None, info]
|
| 469 |
+
quant = quant.permute(0, 3, 1, 2).contiguous() # N,1,L,d -> N,d,1,L
|
| 470 |
+
else:
|
| 471 |
+
quant, emb_loss, info = self.quantize(to_qnt_feat)
|
| 472 |
+
return quant, emb_loss, info, x.detach()
|
| 473 |
+
|
| 474 |
+
def forward(self, x):
|
| 475 |
+
quant, (vq_loss, commit_loss, entropy_loss, codebook_usage), (perplexity, min_encodings, min_encoding_indices), align_feature = \
|
| 476 |
+
self.encode(x)
|
| 477 |
+
return min_encoding_indices
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
class MLP(nn.Module):
|
| 481 |
+
def __init__(
|
| 482 |
+
self,
|
| 483 |
+
hidden_size: int,
|
| 484 |
+
intermediate_size: int,
|
| 485 |
+
hidden_act: str,
|
| 486 |
+
):
|
| 487 |
+
super().__init__()
|
| 488 |
+
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
| 489 |
+
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
| 490 |
+
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
| 491 |
+
self.act_fn = ACT2FN[hidden_act]
|
| 492 |
+
|
| 493 |
+
def forward(self, x):
|
| 494 |
+
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 495 |
+
|
| 496 |
+
class DecoderLayer(nn.Module):
|
| 497 |
+
def __init__(self, config):
|
| 498 |
+
super().__init__()
|
| 499 |
+
self.hidden_size = config.hidden_size
|
| 500 |
+
self.mlp = MLP(
|
| 501 |
+
hidden_size=self.hidden_size,
|
| 502 |
+
intermediate_size=config.visual_embedding_layer_intermediate_size,
|
| 503 |
+
hidden_act=config.visual_embedding_layer_hidden_act,
|
| 504 |
+
)
|
| 505 |
+
self.pre_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 506 |
+
|
| 507 |
+
def forward(
|
| 508 |
+
self,
|
| 509 |
+
hidden_states: torch.Tensor,
|
| 510 |
+
):
|
| 511 |
+
residual = hidden_states
|
| 512 |
+
hidden_states = self.pre_layernorm(hidden_states)
|
| 513 |
+
hidden_states = self.mlp(hidden_states)
|
| 514 |
+
hidden_states = residual + hidden_states
|
| 515 |
+
|
| 516 |
+
return hidden_states
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
class VisualEmbeddingBridge(nn.Module):
|
| 520 |
+
def __init__(self, config):
|
| 521 |
+
super().__init__()
|
| 522 |
+
self.pre_buffer = DecoderLayer(config)
|
| 523 |
+
|
| 524 |
+
def forward(self, embeding):
|
| 525 |
+
return self.pre_buffer(embeding)
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
class VisualVQBridge(nn.Module):
|
| 529 |
+
def __init__(self, visual_config):
|
| 530 |
+
super().__init__()
|
| 531 |
+
self.bridge = OmniVisualBridge(visual_config)
|
| 532 |
+
self.quantizer = VisualQuantizer(visual_config.vq_config)
|
| 533 |
+
|
| 534 |
+
def forward(
|
| 535 |
+
self,
|
| 536 |
+
visual_embed: torch.Tensor,
|
| 537 |
+
window_index: torch.Tensor,
|
| 538 |
+
):
|
| 539 |
+
visual_embed = self.bridge(visual_embed, window_index)
|
| 540 |
+
indices = self.quantizer(visual_embed)
|
| 541 |
+
return indices
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
class LongcatNextVisualTokenizer(nn.Module):
|
| 545 |
+
|
| 546 |
+
def __init__(self, config):
|
| 547 |
+
super().__init__()
|
| 548 |
+
self.config = config
|
| 549 |
+
self.visual_model = VisualEncoder(config.visual_config)
|
| 550 |
+
self.visual_bridge_model = VisualVQBridge(config.visual_config)
|
| 551 |
+
self.visual_embedding_layer = VisualEmbeddingBridge(config)
|
| 552 |
+
self.image_decoder = None
|
| 553 |
+
self._refiner_pipeline = None
|
| 554 |
+
|
| 555 |
+
@torch.no_grad()
|
| 556 |
+
def encode(self, pixel_values: torch.Tensor, visual_grid_thw: torch.Tensor):
|
| 557 |
+
visual_embed, window_index = self.visual_model(pixel_values, grid_thw=visual_grid_thw, require_window_index=True)
|
| 558 |
+
indices = self.visual_bridge_model(visual_embed, window_index)
|
| 559 |
+
return indices
|
| 560 |
+
|
| 561 |
+
@torch.no_grad()
|
| 562 |
+
def lazy_decode_and_save(self, visual_ids, tokens_h, tokens_w, save_path):
|
| 563 |
+
device = next(self.parameters()).device
|
| 564 |
+
if self.image_decoder is None:
|
| 565 |
+
print("lazy load image_decoder / image_refiner / _refiner_pipeline ...")
|
| 566 |
+
vdc = self.config.visual_config.visual_decoder_config
|
| 567 |
+
self.image_decoder = VisionTransformerDecoder.from_pretrained(
|
| 568 |
+
vdc.image_decoder_config,
|
| 569 |
+
vdc.weight_path,
|
| 570 |
+
).to(device=device, dtype=torch.bfloat16)
|
| 571 |
+
image_refiner = ImageRefinerContainer.from_pretrained(vdc, vdc.weight_path).to(device=device, dtype=torch.bfloat16)
|
| 572 |
+
|
| 573 |
+
sc = vdc.scheduler_config
|
| 574 |
+
scheduler = FlowMatchEulerDiscreteScheduler(
|
| 575 |
+
num_train_timesteps=sc.num_train_timesteps,
|
| 576 |
+
dynamic_time_shift=sc.dynamic_time_shift)
|
| 577 |
+
self._refiner_pipeline = RefinerPipeline(
|
| 578 |
+
vae=image_refiner.vae,
|
| 579 |
+
transformer=image_refiner.base_transformer,
|
| 580 |
+
scheduler=scheduler,
|
| 581 |
+
cond_proj=image_refiner.cond_proj,
|
| 582 |
+
)
|
| 583 |
+
self._refiner_pipeline.set_progress_bar_config(disable=False)
|
| 584 |
+
|
| 585 |
+
data = torch.as_tensor(visual_ids, dtype=torch.long)
|
| 586 |
+
if data.ndim == 1:
|
| 587 |
+
data = data.view(-1, len(self.config.visual_config.vq_config.codebook_sizes))
|
| 588 |
+
if data.ndim == 2:
|
| 589 |
+
data = data.unsqueeze(0)
|
| 590 |
+
batch_size = data.shape[0]
|
| 591 |
+
|
| 592 |
+
quant_features = None
|
| 593 |
+
for idx in range(len(self.config.visual_config.vq_config.codebook_sizes)):
|
| 594 |
+
embed = self.visual_bridge_model.quantizer.quantize.codebooks[idx].embed
|
| 595 |
+
feat = embed[data[..., idx].to(embed.device)]
|
| 596 |
+
quant_features = feat if quant_features is None else quant_features + feat
|
| 597 |
+
quant_features = quant_features.to(device)
|
| 598 |
+
|
| 599 |
+
# tokens_h/tokens_w are the merged grid; expand to the full (unmerged) grid
|
| 600 |
+
s = self.image_decoder.spatial_merge_size
|
| 601 |
+
grid_thw_list = [(1, tokens_h * s, tokens_w * s)]
|
| 602 |
+
grid_thw_batch = list(grid_thw_list) * batch_size
|
| 603 |
+
|
| 604 |
+
image_mean = [0.48145466, 0.4578275, 0.40821073]
|
| 605 |
+
image_std = [0.26862954, 0.26130258, 0.27577711]
|
| 606 |
+
|
| 607 |
+
emb_2d = quant_features.reshape(-1, quant_features.shape[-1]).contiguous()
|
| 608 |
+
device_type = "cuda" if str(device).startswith("cuda") else str(device)
|
| 609 |
+
with torch.amp.autocast(device_type=device_type, enabled=True, dtype=torch.float32):
|
| 610 |
+
decoder_out = self.image_decoder(emb_2d, grid_thw_batch, return_pixel_features=False)
|
| 611 |
+
|
| 612 |
+
decoded_tensors = decoder_out.get("images") or []
|
| 613 |
+
decoded_images = [tensor2pil(t, image_mean, image_std) for t in decoded_tensors]
|
| 614 |
+
decoded_path = save_path.replace(".png", "_decoded.png")
|
| 615 |
+
# decoded_images[0].save(decoded_path)
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
ref_input = []
|
| 619 |
+
for t in decoded_tensors:
|
| 620 |
+
img_01 = de_transform(t, mean=image_mean, std=image_std, rescale_factor=1 / 255)
|
| 621 |
+
img_norm = RefinerImageProcessor.normalize(img_01)
|
| 622 |
+
ref_input.append(img_norm.squeeze(0).to(device))
|
| 623 |
+
|
| 624 |
+
generators = [torch.Generator(device=device).manual_seed(42 + b) for b in range(batch_size)]
|
| 625 |
+
out = self._refiner_pipeline(
|
| 626 |
+
encoder_hidden_states=quant_features,
|
| 627 |
+
grid_thw_list=grid_thw_list,
|
| 628 |
+
image=ref_input,
|
| 629 |
+
generator=generators[0] if batch_size == 1 else generators,
|
| 630 |
+
output_type="pil",
|
| 631 |
+
return_dict=True,
|
| 632 |
+
)
|
| 633 |
+
refined_images = out.images
|
| 634 |
+
refined_path = save_path.replace(".png", "_refined.png")
|
| 635 |
+
refined_images[0].save(refined_path)
|
| 636 |
+
|
| 637 |
+
return [refined_path]
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
# ---------------------------------------------------------------------------
|
| 641 |
+
# Vision Transformer Decoder
|
| 642 |
+
# ---------------------------------------------------------------------------
|
| 643 |
+
|
| 644 |
+
def _rotate_half(x):
|
| 645 |
+
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
| 646 |
+
x1, x2 = x.unbind(dim=-1)
|
| 647 |
+
x = torch.stack((-x2, x1), dim=-1)
|
| 648 |
+
return rearrange(x, "... d r -> ... (d r)")
|
| 649 |
+
|
| 650 |
+
|
| 651 |
+
class VisionRoPE2D(nn.Module):
|
| 652 |
+
"""2D Rotary Position Embedding for Q/K in vision decoder attention."""
|
| 653 |
+
|
| 654 |
+
def __init__(self, theta: float = 10000.0):
|
| 655 |
+
super().__init__()
|
| 656 |
+
self.theta = theta
|
| 657 |
+
|
| 658 |
+
def _rope_half(self, x_half, pos_1d, theta):
|
| 659 |
+
BH, T, d_half = x_half.shape
|
| 660 |
+
idx = torch.arange(0, d_half, 2, device=x_half.device, dtype=torch.float32)
|
| 661 |
+
inv_freq = (1.0 / (theta ** (idx / d_half))).to(x_half.dtype)
|
| 662 |
+
angles = pos_1d.to(x_half.dtype)[:, None] * inv_freq[None, :]
|
| 663 |
+
cos = torch.repeat_interleave(torch.cos(angles), 2, dim=-1).unsqueeze(0)
|
| 664 |
+
sin = torch.repeat_interleave(torch.sin(angles), 2, dim=-1).unsqueeze(0)
|
| 665 |
+
return x_half * cos + _rotate_half(x_half) * sin
|
| 666 |
+
|
| 667 |
+
def forward(self, x, positions_2d):
|
| 668 |
+
d_half = x.shape[-1] // 2
|
| 669 |
+
x_y = self._rope_half(x[:, :, :d_half], positions_2d[:, 0], self.theta)
|
| 670 |
+
x_x = self._rope_half(x[:, :, d_half:], positions_2d[:, 1], self.theta)
|
| 671 |
+
return torch.cat([x_y, x_x], dim=-1)
|
| 672 |
+
|
| 673 |
+
|
| 674 |
+
class VisionAttention(nn.Module):
|
| 675 |
+
"""Multi-headed attention with 2D RoPE + FlashAttention varlen."""
|
| 676 |
+
|
| 677 |
+
def __init__(self, config, rope=None, rope_shift=0):
|
| 678 |
+
super().__init__()
|
| 679 |
+
self.config = config
|
| 680 |
+
self.embed_dim = config.hidden_size
|
| 681 |
+
self.num_heads = config.num_attention_heads
|
| 682 |
+
self.head_dim = self.embed_dim // self.num_heads
|
| 683 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
| 684 |
+
raise ValueError(
|
| 685 |
+
f"embed_dim must be divisible by num_heads (got embed_dim={self.embed_dim}, num_heads={self.num_heads})"
|
| 686 |
+
)
|
| 687 |
+
self.scale = self.head_dim ** -0.5
|
| 688 |
+
self.dropout = config.attention_dropout
|
| 689 |
+
self.subln = config.subln
|
| 690 |
+
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=getattr(config, "k_bias", True))
|
| 691 |
+
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=getattr(config, "v_bias", True))
|
| 692 |
+
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=getattr(config, "q_bias", True))
|
| 693 |
+
self.inner_attn_ln = Qwen2RMSNorm(self.embed_dim, eps=config.layer_norm_eps) if config.subln else nn.Identity()
|
| 694 |
+
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
|
| 695 |
+
self.rope = rope
|
| 696 |
+
self.rope_shift = int(rope_shift)
|
| 697 |
+
|
| 698 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
| 699 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
| 700 |
+
|
| 701 |
+
def _maybe_flash_attention(self, query_states, key_states, value_states, seq_lens, training):
|
| 702 |
+
if not (query_states.is_cuda and (query_states.dtype in (torch.float16, torch.bfloat16, torch.float32))):
|
| 703 |
+
return None
|
| 704 |
+
if seq_lens is None:
|
| 705 |
+
return None
|
| 706 |
+
try:
|
| 707 |
+
BxH, T, hd = query_states.shape
|
| 708 |
+
H = self.num_heads
|
| 709 |
+
assert BxH % H == 0
|
| 710 |
+
B = BxH // H
|
| 711 |
+
if int(seq_lens.sum().item()) != T:
|
| 712 |
+
return None
|
| 713 |
+
q = query_states.view(B, H, T, hd).transpose(1, 2).reshape(-1, H, hd).contiguous()
|
| 714 |
+
k = key_states.view(B, H, T, hd).transpose(1, 2).reshape(-1, H, hd).contiguous()
|
| 715 |
+
v = value_states.view(B, H, T, hd).transpose(1, 2).reshape(-1, H, hd).contiguous()
|
| 716 |
+
cu_q = torch.zeros(seq_lens.numel() + 1, dtype=torch.int32, device=seq_lens.device)
|
| 717 |
+
cu_q[1:] = torch.cumsum(seq_lens.to(torch.int32), dim=0)
|
| 718 |
+
cu_k = cu_q
|
| 719 |
+
max_seqlen = int(seq_lens.max().item())
|
| 720 |
+
orig_dtype = q.dtype
|
| 721 |
+
use_dtype = q.dtype if q.dtype in (torch.float16, torch.bfloat16) else torch.float16
|
| 722 |
+
if q.dtype != use_dtype:
|
| 723 |
+
q = q.to(use_dtype)
|
| 724 |
+
k = k.to(use_dtype)
|
| 725 |
+
v = v.to(use_dtype)
|
| 726 |
+
out = flash_attn_varlen_func(
|
| 727 |
+
q, k, v, cu_q, cu_k, max_seqlen, max_seqlen,
|
| 728 |
+
dropout_p=self.dropout if training else 0.0,
|
| 729 |
+
softmax_scale=None, causal=False, return_attn_probs=False
|
| 730 |
+
)
|
| 731 |
+
if out.dtype != orig_dtype:
|
| 732 |
+
out = out.to(orig_dtype)
|
| 733 |
+
return out.view(B, -1, H, hd).transpose(1, 2).contiguous().view(B * H, T, hd)
|
| 734 |
+
except Exception:
|
| 735 |
+
return None
|
| 736 |
+
|
| 737 |
+
def forward(
|
| 738 |
+
self,
|
| 739 |
+
hidden_states: torch.Tensor,
|
| 740 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 741 |
+
causal_attention_mask: Optional[torch.Tensor] = None,
|
| 742 |
+
output_attentions: Optional[bool] = False,
|
| 743 |
+
positions_2d: Optional[torch.Tensor] = None,
|
| 744 |
+
seq_lens: Optional[torch.Tensor] = None,
|
| 745 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 746 |
+
bsz, tgt_len, embed_dim = hidden_states.size()
|
| 747 |
+
query_states = self.q_proj(hidden_states) * self.scale
|
| 748 |
+
key_states = self.k_proj(hidden_states)
|
| 749 |
+
value_states = self.v_proj(hidden_states)
|
| 750 |
+
query_states = self._shape(query_states, tgt_len, bsz).view(bsz * self.num_heads, tgt_len, self.head_dim)
|
| 751 |
+
key_states = self._shape(key_states, tgt_len, bsz).view(bsz * self.num_heads, tgt_len, self.head_dim)
|
| 752 |
+
value_states = self._shape(value_states, tgt_len, bsz).view(bsz * self.num_heads, tgt_len, self.head_dim)
|
| 753 |
+
if self.rope is not None and positions_2d is not None:
|
| 754 |
+
if self.rope_shift > 0:
|
| 755 |
+
q_pref = query_states[:, :self.rope_shift, :]
|
| 756 |
+
k_pref = key_states[:, :self.rope_shift, :]
|
| 757 |
+
q_rot = self.rope(query_states[:, self.rope_shift:, :], positions_2d[self.rope_shift:])
|
| 758 |
+
k_rot = self.rope(key_states[:, self.rope_shift:, :], positions_2d[self.rope_shift:])
|
| 759 |
+
query_states = torch.cat([q_pref, q_rot], dim=1).type_as(value_states)
|
| 760 |
+
key_states = torch.cat([k_pref, k_rot], dim=1).type_as(value_states)
|
| 761 |
+
else:
|
| 762 |
+
query_states = self.rope(query_states, positions_2d).type_as(value_states)
|
| 763 |
+
key_states = self.rope(key_states, positions_2d).type_as(value_states)
|
| 764 |
+
attn_output = self._maybe_flash_attention(
|
| 765 |
+
query_states, key_states, value_states, seq_lens=seq_lens, training=self.training
|
| 766 |
+
)
|
| 767 |
+
if attn_output is not None:
|
| 768 |
+
attn_weights_reshaped = None
|
| 769 |
+
else:
|
| 770 |
+
src_len = key_states.size(1)
|
| 771 |
+
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
| 772 |
+
if causal_attention_mask is not None:
|
| 773 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
|
| 774 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
| 775 |
+
if attention_mask is not None:
|
| 776 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
| 777 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
| 778 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
| 779 |
+
if output_attentions:
|
| 780 |
+
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
| 781 |
+
else:
|
| 782 |
+
attn_weights_reshaped = None
|
| 783 |
+
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
| 784 |
+
attn_output = torch.bmm(attn_probs, value_states)
|
| 785 |
+
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
| 786 |
+
attn_output = attn_output.transpose(1, 2).reshape(bsz, tgt_len, embed_dim)
|
| 787 |
+
attn_output = self.inner_attn_ln(attn_output)
|
| 788 |
+
attn_output = self.out_proj(attn_output)
|
| 789 |
+
return attn_output, attn_weights_reshaped
|
| 790 |
+
|
| 791 |
+
|
| 792 |
+
class VisionSwiGLU(nn.Module):
|
| 793 |
+
def __init__(self, config):
|
| 794 |
+
super().__init__()
|
| 795 |
+
self.config = config
|
| 796 |
+
self.hidden_size = config.hidden_size
|
| 797 |
+
self.intermediate_size = config.intermediate_size
|
| 798 |
+
self.w1 = nn.Linear(self.hidden_size, self.intermediate_size)
|
| 799 |
+
self.w2 = nn.Linear(self.hidden_size, self.intermediate_size)
|
| 800 |
+
self.w3 = nn.Linear(self.intermediate_size, self.hidden_size)
|
| 801 |
+
self.act_fn = nn.SiLU()
|
| 802 |
+
self.ffn_ln = Qwen2RMSNorm(self.intermediate_size, eps=config.layer_norm_eps) if config.subln else nn.Identity()
|
| 803 |
+
|
| 804 |
+
def forward(self, x):
|
| 805 |
+
x1 = self.w1(x)
|
| 806 |
+
x2 = self.w2(x)
|
| 807 |
+
hidden = self.act_fn(x1) * x2
|
| 808 |
+
x = self.ffn_ln(hidden)
|
| 809 |
+
x = self.w3(x)
|
| 810 |
+
return x
|
| 811 |
+
|
| 812 |
+
|
| 813 |
+
class VisionMLP(nn.Module):
|
| 814 |
+
def __init__(self, config):
|
| 815 |
+
super().__init__()
|
| 816 |
+
self.config = config
|
| 817 |
+
self.activation_fn = ACT2FN[config.hidden_act]
|
| 818 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 819 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 820 |
+
self.ffn_ln = Qwen2RMSNorm(config.intermediate_size, eps=config.layer_norm_eps) if config.subln else nn.Identity()
|
| 821 |
+
|
| 822 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 823 |
+
hidden_states = self.fc1(hidden_states)
|
| 824 |
+
hidden_states = self.activation_fn(hidden_states)
|
| 825 |
+
hidden_states = self.ffn_ln(hidden_states)
|
| 826 |
+
hidden_states = self.fc2(hidden_states)
|
| 827 |
+
return hidden_states
|
| 828 |
+
|
| 829 |
+
|
| 830 |
+
class VisionEncoderLayer(nn.Module):
|
| 831 |
+
def __init__(self, config, rope=None, rope_shift=0):
|
| 832 |
+
super().__init__()
|
| 833 |
+
self.embed_dim = config.hidden_size
|
| 834 |
+
self.self_attn = VisionAttention(config, rope=rope, rope_shift=rope_shift)
|
| 835 |
+
self.layer_norm1 = Qwen2RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 836 |
+
self.mlp = VisionSwiGLU(config) if config.swiglu else VisionMLP(config)
|
| 837 |
+
self.layer_norm2 = Qwen2RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 838 |
+
|
| 839 |
+
def forward(
|
| 840 |
+
self,
|
| 841 |
+
hidden_states: torch.Tensor,
|
| 842 |
+
attention_mask: Optional[torch.Tensor],
|
| 843 |
+
causal_attention_mask: Optional[torch.Tensor],
|
| 844 |
+
output_attentions: Optional[bool] = False,
|
| 845 |
+
positions_2d: Optional[torch.Tensor] = None,
|
| 846 |
+
seq_lens: Optional[torch.Tensor] = None,
|
| 847 |
+
) -> Tuple[torch.FloatTensor, Optional[torch.Tensor]]:
|
| 848 |
+
residual = hidden_states
|
| 849 |
+
hidden_states = self.layer_norm1(hidden_states)
|
| 850 |
+
hidden_states, attn_weights = self.self_attn(
|
| 851 |
+
hidden_states=hidden_states,
|
| 852 |
+
attention_mask=attention_mask,
|
| 853 |
+
causal_attention_mask=causal_attention_mask,
|
| 854 |
+
output_attentions=output_attentions,
|
| 855 |
+
positions_2d=positions_2d,
|
| 856 |
+
seq_lens=seq_lens,
|
| 857 |
+
)
|
| 858 |
+
hidden_states = residual + hidden_states
|
| 859 |
+
residual = hidden_states
|
| 860 |
+
hidden_states = self.layer_norm2(hidden_states)
|
| 861 |
+
hidden_states = self.mlp(hidden_states)
|
| 862 |
+
hidden_states = residual + hidden_states
|
| 863 |
+
outputs = (hidden_states,)
|
| 864 |
+
if output_attentions:
|
| 865 |
+
outputs += (attn_weights,)
|
| 866 |
+
return outputs
|
| 867 |
+
|
| 868 |
+
|
| 869 |
+
class VisionEncoder(nn.Module):
|
| 870 |
+
def __init__(self, config, rope=None, rope_shift=0):
|
| 871 |
+
super().__init__()
|
| 872 |
+
self.config = config
|
| 873 |
+
self.layers = nn.ModuleList(
|
| 874 |
+
[VisionEncoderLayer(config, rope=rope, rope_shift=rope_shift) for _ in range(config.num_hidden_layers)]
|
| 875 |
+
)
|
| 876 |
+
self.gradient_checkpointing = False
|
| 877 |
+
self._gradient_checkpointing_func = torch.utils.checkpoint.checkpoint
|
| 878 |
+
|
| 879 |
+
def forward(
|
| 880 |
+
self,
|
| 881 |
+
inputs_embeds: torch.Tensor,
|
| 882 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 883 |
+
causal_attention_mask: Optional[torch.Tensor] = None,
|
| 884 |
+
output_attentions: Optional[bool] = None,
|
| 885 |
+
output_hidden_states: Optional[bool] = None,
|
| 886 |
+
return_dict: Optional[bool] = None,
|
| 887 |
+
positions_2d: Optional[torch.Tensor] = None,
|
| 888 |
+
seq_lens: Optional[torch.Tensor] = None,
|
| 889 |
+
):
|
| 890 |
+
output_attentions = output_attentions if output_attentions is not None else False
|
| 891 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else False
|
| 892 |
+
return_dict = True if return_dict is None else return_dict
|
| 893 |
+
|
| 894 |
+
encoder_states = () if output_hidden_states else None
|
| 895 |
+
all_attentions = () if output_attentions else None
|
| 896 |
+
hidden_states = inputs_embeds
|
| 897 |
+
|
| 898 |
+
for layer in self.layers:
|
| 899 |
+
if output_hidden_states:
|
| 900 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 901 |
+
if self.gradient_checkpointing and self.training:
|
| 902 |
+
def custom_forward(hs, attn, causal, pos2d, seqlens):
|
| 903 |
+
return layer(
|
| 904 |
+
hs,
|
| 905 |
+
attention_mask=attn,
|
| 906 |
+
causal_attention_mask=causal,
|
| 907 |
+
output_attentions=False,
|
| 908 |
+
positions_2d=pos2d,
|
| 909 |
+
seq_lens=seqlens,
|
| 910 |
+
)[0]
|
| 911 |
+
hidden_states = self._gradient_checkpointing_func(
|
| 912 |
+
custom_forward,
|
| 913 |
+
hidden_states,
|
| 914 |
+
attention_mask if attention_mask is not None else torch.tensor(0., device=hidden_states.device),
|
| 915 |
+
causal_attention_mask if causal_attention_mask is not None else torch.tensor(0., device=hidden_states.device),
|
| 916 |
+
positions_2d,
|
| 917 |
+
seq_lens if seq_lens is not None else torch.tensor([], device=hidden_states.device),
|
| 918 |
+
use_reentrant=False,
|
| 919 |
+
)
|
| 920 |
+
else:
|
| 921 |
+
layer_outputs = layer(
|
| 922 |
+
hidden_states,
|
| 923 |
+
attention_mask,
|
| 924 |
+
causal_attention_mask,
|
| 925 |
+
output_attentions=output_attentions,
|
| 926 |
+
positions_2d=positions_2d,
|
| 927 |
+
seq_lens=seq_lens,
|
| 928 |
+
)
|
| 929 |
+
hidden_states = layer_outputs[0]
|
| 930 |
+
if output_attentions:
|
| 931 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
| 932 |
+
|
| 933 |
+
if output_hidden_states:
|
| 934 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 935 |
+
|
| 936 |
+
if not return_dict:
|
| 937 |
+
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
| 938 |
+
|
| 939 |
+
return BaseModelOutput(
|
| 940 |
+
last_hidden_state=hidden_states,
|
| 941 |
+
hidden_states=encoder_states,
|
| 942 |
+
attentions=all_attentions,
|
| 943 |
+
)
|
| 944 |
+
|
| 945 |
+
|
| 946 |
+
class PatchUnMerger(nn.Module):
|
| 947 |
+
"""Learnable inverse of Qwen2_5_VLPatchMerger."""
|
| 948 |
+
def __init__(self, dim, context_dim, spatial_merge_size=2):
|
| 949 |
+
super().__init__()
|
| 950 |
+
self.spatial_merge_size = spatial_merge_size
|
| 951 |
+
self.context_dim = context_dim
|
| 952 |
+
hidden = context_dim * (spatial_merge_size ** 2)
|
| 953 |
+
self.ln_q = Qwen2RMSNorm(dim, eps=1e-6)
|
| 954 |
+
self.mlp = nn.Sequential(nn.Linear(dim, hidden), nn.GELU(), nn.Linear(hidden, hidden))
|
| 955 |
+
|
| 956 |
+
def forward(self, x):
|
| 957 |
+
x = self.mlp(self.ln_q(x))
|
| 958 |
+
return x.view(x.shape[0] * (self.spatial_merge_size ** 2), self.context_dim)
|
| 959 |
+
|
| 960 |
+
|
| 961 |
+
def restore_spatial_structure_and_convert_to_images(patches, grid_thw_list, patch_size,
|
| 962 |
+
channel_dim=3, temporal_patch_size=2, merge_size=2):
|
| 963 |
+
"""Convert decoder pixel features back to image tensors [3, H, W]."""
|
| 964 |
+
if isinstance(patches, tuple):
|
| 965 |
+
patches = patches[0]
|
| 966 |
+
image_tensors = []
|
| 967 |
+
ptr = 0
|
| 968 |
+
for grid in grid_thw_list:
|
| 969 |
+
gt, gh, gw = (int(x) for x in (grid if not isinstance(grid, torch.Tensor) else grid.tolist()))
|
| 970 |
+
n = gt * gh * gw
|
| 971 |
+
chunk = patches[ptr:ptr + n]
|
| 972 |
+
ptr += n
|
| 973 |
+
r = chunk.reshape(gt, gh // merge_size, gw // merge_size, merge_size, merge_size,
|
| 974 |
+
channel_dim, temporal_patch_size, patch_size, patch_size)
|
| 975 |
+
r = r.permute(0, 6, 5, 1, 3, 7, 2, 4, 8)
|
| 976 |
+
image_tensors.append(r.reshape(gt * temporal_patch_size, channel_dim, gh * patch_size, gw * patch_size)[0])
|
| 977 |
+
return image_tensors
|
| 978 |
+
|
| 979 |
+
|
| 980 |
+
class VisionTransformerDecoder(nn.Module):
|
| 981 |
+
def __init__(self, config):
|
| 982 |
+
super().__init__()
|
| 983 |
+
self.config = config
|
| 984 |
+
self.embed_dim = config.hidden_size
|
| 985 |
+
self.patch_size = config.patch_size
|
| 986 |
+
self.spatial_merge_size = config.spatial_merge_size
|
| 987 |
+
self.codebook_dim = config.codebook_dim
|
| 988 |
+
self.temporal_patch_size = config.temporal_patch_size
|
| 989 |
+
|
| 990 |
+
self.rope2d = VisionRoPE2D(theta=10000.0)
|
| 991 |
+
self.post_quant_conv = nn.Linear(self.codebook_dim, self.embed_dim)
|
| 992 |
+
self.post_quant_norm = Qwen2RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 993 |
+
self.patch_unmerger = PatchUnMerger(self.embed_dim, self.embed_dim, self.spatial_merge_size)
|
| 994 |
+
self.norm_in = Qwen2RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 995 |
+
self.encoder = VisionEncoder(config, rope=self.rope2d, rope_shift=0)
|
| 996 |
+
self.norm_out = Qwen2RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 997 |
+
self.decoder_head = nn.Sequential(
|
| 998 |
+
nn.Linear(self.embed_dim, config.intermediate_size), nn.GELU(),
|
| 999 |
+
nn.Linear(config.intermediate_size, 3 * self.patch_size * self.patch_size * self.temporal_patch_size),
|
| 1000 |
+
)
|
| 1001 |
+
|
| 1002 |
+
@classmethod
|
| 1003 |
+
def from_pretrained(cls, config, model_path: str):
|
| 1004 |
+
"""Load a pretrained model from a checkpoint."""
|
| 1005 |
+
model = cls(config)
|
| 1006 |
+
weight_dict = load_file(model_path, device="cpu")
|
| 1007 |
+
model.load_state_dict({k.removeprefix("image_decoder."): v for k, v in weight_dict.items() if k.startswith("image_decoder.")}, strict=True)
|
| 1008 |
+
model.eval()
|
| 1009 |
+
return model
|
| 1010 |
+
|
| 1011 |
+
def _build_2d_positions(self, grid_thw_list):
|
| 1012 |
+
pos_list = []
|
| 1013 |
+
for (t, gh, gw) in grid_thw_list:
|
| 1014 |
+
for _ in range(int(t)):
|
| 1015 |
+
for y in range(int(gh)):
|
| 1016 |
+
for x in range(int(gw)):
|
| 1017 |
+
pos_list.append([y, x])
|
| 1018 |
+
return torch.tensor(pos_list, dtype=torch.long)
|
| 1019 |
+
|
| 1020 |
+
def _build_attention_mask(self, grid_thw_list, device, dtype, B, num_heads):
|
| 1021 |
+
counts = [int(t) * int(h) * int(w) for (t, h, w) in grid_thw_list]
|
| 1022 |
+
L = sum(counts)
|
| 1023 |
+
mask = torch.zeros((B, num_heads, L, L), device=device, dtype=dtype)
|
| 1024 |
+
s = 0
|
| 1025 |
+
for c in counts:
|
| 1026 |
+
e = s + c
|
| 1027 |
+
if s > 0:
|
| 1028 |
+
mask[:, :, s:e, :s] = float("-inf")
|
| 1029 |
+
if e < L:
|
| 1030 |
+
mask[:, :, s:e, e:] = float("-inf")
|
| 1031 |
+
s = e
|
| 1032 |
+
return mask
|
| 1033 |
+
|
| 1034 |
+
def forward(self, embeddings, grid_thw, return_pixel_features=False, return_last_latent=False):
|
| 1035 |
+
device = embeddings.device
|
| 1036 |
+
grid_thw_list = ([(int(t), int(h), int(w)) for t, h, w in grid_thw.detach().cpu().numpy()]
|
| 1037 |
+
if isinstance(grid_thw, torch.Tensor) else list(grid_thw))
|
| 1038 |
+
|
| 1039 |
+
if embeddings.shape[-1] == self.codebook_dim:
|
| 1040 |
+
embeddings = self.post_quant_conv(embeddings)
|
| 1041 |
+
embeddings = self.post_quant_norm(embeddings)
|
| 1042 |
+
|
| 1043 |
+
unmerged = self.patch_unmerger(embeddings)
|
| 1044 |
+
if unmerged.dim() == 2:
|
| 1045 |
+
unmerged = unmerged.unsqueeze(0)
|
| 1046 |
+
B, L, D = unmerged.shape
|
| 1047 |
+
hidden_states = self.norm_in(unmerged)
|
| 1048 |
+
|
| 1049 |
+
positions_2d = self._build_2d_positions(grid_thw_list).to(device)
|
| 1050 |
+
seq_lens = torch.tensor([int(t) * int(h) * int(w) for (t, h, w) in grid_thw_list],
|
| 1051 |
+
device=device, dtype=torch.int32)
|
| 1052 |
+
assert positions_2d.shape[0] == L, f"positions_2d {positions_2d.shape[0]} != L {L}"
|
| 1053 |
+
|
| 1054 |
+
last_latent = hidden_states.detach().squeeze(0) if return_last_latent else None
|
| 1055 |
+
enc_out = self.encoder(
|
| 1056 |
+
inputs_embeds=hidden_states,
|
| 1057 |
+
attention_mask=None,
|
| 1058 |
+
causal_attention_mask=None,
|
| 1059 |
+
output_attentions=False,
|
| 1060 |
+
output_hidden_states=False,
|
| 1061 |
+
return_dict=True,
|
| 1062 |
+
positions_2d=positions_2d,
|
| 1063 |
+
seq_lens=seq_lens,
|
| 1064 |
+
)
|
| 1065 |
+
hidden_states = enc_out.last_hidden_state
|
| 1066 |
+
|
| 1067 |
+
hidden_states = self.norm_out(hidden_states)
|
| 1068 |
+
pixel_features = self.decoder_head(hidden_states).squeeze(0)
|
| 1069 |
+
|
| 1070 |
+
out_imgs = (None if return_pixel_features else
|
| 1071 |
+
restore_spatial_structure_and_convert_to_images(
|
| 1072 |
+
pixel_features, grid_thw_list, self.patch_size,
|
| 1073 |
+
temporal_patch_size=self.temporal_patch_size, merge_size=self.spatial_merge_size))
|
| 1074 |
+
ret = {"images": out_imgs, "pixel_features": pixel_features}
|
| 1075 |
+
if last_latent is not None:
|
| 1076 |
+
ret["last_latent"] = last_latent
|
| 1077 |
+
return ret
|
parse_model_response.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import json
|
| 3 |
+
import uuid
|
| 4 |
+
|
| 5 |
+
def parse_arguments(json_value):
|
| 6 |
+
"""
|
| 7 |
+
Attempt to parse a string as JSON
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
json_value: String to parse
|
| 11 |
+
|
| 12 |
+
Returns:
|
| 13 |
+
tuple: (parsed_value, is_valid_json)
|
| 14 |
+
"""
|
| 15 |
+
try:
|
| 16 |
+
parsed_value = json.loads(json_value)
|
| 17 |
+
return parsed_value, True
|
| 18 |
+
except:
|
| 19 |
+
return json_value, False
|
| 20 |
+
|
| 21 |
+
def get_argument_type(func_name: str, arg_key: str, defined_tools: list):
|
| 22 |
+
"""
|
| 23 |
+
Get the type definition of a tool parameter
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
func_name: Name of the function/tool
|
| 27 |
+
arg_key: Parameter key name
|
| 28 |
+
defined_tools: List of tool definitions
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
str or None: Type of the parameter ('string', 'object', 'array', 'integer', 'number', 'boolean')
|
| 32 |
+
"""
|
| 33 |
+
name2tool = {tool["name"]: tool for tool in defined_tools}
|
| 34 |
+
if func_name not in name2tool:
|
| 35 |
+
return None
|
| 36 |
+
tool = name2tool[func_name]
|
| 37 |
+
if "parameters" not in tool or "properties" not in tool["parameters"]:
|
| 38 |
+
return None
|
| 39 |
+
if arg_key not in tool["parameters"]["properties"]:
|
| 40 |
+
return None
|
| 41 |
+
return tool["parameters"]["properties"][arg_key].get("type")
|
| 42 |
+
|
| 43 |
+
def parse_model_response(response: str, defined_tools: list=[]):
|
| 44 |
+
"""
|
| 45 |
+
Parse model response to extract reasoning_content, content, and tool_calls
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
response: Raw response text from the model
|
| 49 |
+
defined_tools: List of tool definitions
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
dict: Message containing role, reasoning_content (optional), content (optional),
|
| 53 |
+
and tool_calls (optional)
|
| 54 |
+
"""
|
| 55 |
+
text = response
|
| 56 |
+
reasoning_content = None
|
| 57 |
+
content = None
|
| 58 |
+
tool_calls = []
|
| 59 |
+
|
| 60 |
+
formatted_tools = []
|
| 61 |
+
for tool in defined_tools:
|
| 62 |
+
if "function" in tool:
|
| 63 |
+
formatted_tools.append(tool['function'])
|
| 64 |
+
else:
|
| 65 |
+
formatted_tools.append(tool)
|
| 66 |
+
|
| 67 |
+
if '</longcat_think>' in text:
|
| 68 |
+
text = text.replace('<longcat_think>', '')
|
| 69 |
+
thinking_end = text.find('</longcat_think>')
|
| 70 |
+
reasoning_content = text[: thinking_end].strip()
|
| 71 |
+
text = text[thinking_end + len('</longcat_think>'):].lstrip()
|
| 72 |
+
|
| 73 |
+
assert '<longcat_think>' not in text, "Unclosed <longcat_think> tag found in remaining text"
|
| 74 |
+
assert '</longcat_think>' not in text, "Unexpected </longcat_think> tag found without opening tag"
|
| 75 |
+
|
| 76 |
+
if '<longcat_tool_call>' in text:
|
| 77 |
+
index = text.find('<longcat_tool_call>')
|
| 78 |
+
content = text[:index]
|
| 79 |
+
text = text[index:].strip()
|
| 80 |
+
else:
|
| 81 |
+
content = text
|
| 82 |
+
text = ""
|
| 83 |
+
|
| 84 |
+
open_tags = text.count('<longcat_tool_call>')
|
| 85 |
+
close_tags = text.count('</longcat_tool_call>')
|
| 86 |
+
assert open_tags == close_tags, \
|
| 87 |
+
f"Mismatched tool_call tags: {open_tags} opening tags, {close_tags} closing tags"
|
| 88 |
+
|
| 89 |
+
tool_call_strs = re.findall(
|
| 90 |
+
r'<longcat_tool_call>(.*?)</longcat_tool_call>',
|
| 91 |
+
text,
|
| 92 |
+
re.DOTALL
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
for call in tool_call_strs:
|
| 96 |
+
func_name_match = re.match(r'([^\n<]+)', call.strip())
|
| 97 |
+
assert func_name_match, f"Missing function name in tool call: {call[:100]}"
|
| 98 |
+
|
| 99 |
+
func_name = func_name_match.group(1).strip()
|
| 100 |
+
assert func_name, "Empty function name in tool call"
|
| 101 |
+
|
| 102 |
+
# Verify argument tags are properly paired
|
| 103 |
+
arg_key_count = call.count('<longcat_arg_key>')
|
| 104 |
+
arg_key_close_count = call.count('</longcat_arg_key>')
|
| 105 |
+
arg_value_count = call.count('<longcat_arg_value>')
|
| 106 |
+
arg_value_close_count = call.count('</longcat_arg_value>')
|
| 107 |
+
|
| 108 |
+
assert arg_key_count == arg_key_close_count, \
|
| 109 |
+
f"Mismatched arg_key tags in function {func_name}: {arg_key_count} opening, {arg_key_close_count} closing"
|
| 110 |
+
assert arg_value_count == arg_value_close_count, \
|
| 111 |
+
f"Mismatched arg_value tags in function {func_name}: {arg_value_count} opening, {arg_value_close_count} closing"
|
| 112 |
+
assert arg_key_count == arg_value_count, \
|
| 113 |
+
f"Mismatched arg_key and arg_value count in function {func_name}: {arg_key_count} keys, {arg_value_count} values"
|
| 114 |
+
|
| 115 |
+
pairs = re.findall(
|
| 116 |
+
r'<longcat_arg_key>(.*?)</longcat_arg_key>\s*<longcat_arg_value>(.*?)</longcat_arg_value>',
|
| 117 |
+
call,
|
| 118 |
+
re.DOTALL
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
assert len(pairs) == arg_key_count, \
|
| 122 |
+
f"Failed to parse all arguments in function {func_name}: expected {arg_key_count}, got {len(pairs)}"
|
| 123 |
+
|
| 124 |
+
arguments = {}
|
| 125 |
+
for arg_key, arg_value in pairs:
|
| 126 |
+
arg_key = arg_key.strip()
|
| 127 |
+
arg_value = arg_value.strip()
|
| 128 |
+
|
| 129 |
+
assert arg_key, f"Empty argument key in function {func_name}"
|
| 130 |
+
assert arg_key not in arguments, \
|
| 131 |
+
f"Duplicate argument key '{arg_key}' in function {func_name}"
|
| 132 |
+
|
| 133 |
+
arg_type = get_argument_type(func_name, arg_key, formatted_tools)
|
| 134 |
+
|
| 135 |
+
if arg_type and arg_type != 'string':
|
| 136 |
+
parsed_value, is_good_json = parse_arguments(arg_value)
|
| 137 |
+
arg_value = parsed_value
|
| 138 |
+
|
| 139 |
+
arguments[arg_key] = arg_value
|
| 140 |
+
|
| 141 |
+
tool_calls.append({
|
| 142 |
+
'id': "tool-call-" + str(uuid.uuid4()),
|
| 143 |
+
'type': "function",
|
| 144 |
+
'function': {
|
| 145 |
+
'name': func_name,
|
| 146 |
+
'arguments': arguments
|
| 147 |
+
}
|
| 148 |
+
})
|
| 149 |
+
|
| 150 |
+
message = {'role': 'assistant'}
|
| 151 |
+
|
| 152 |
+
if reasoning_content:
|
| 153 |
+
message['reasoning_content'] = reasoning_content
|
| 154 |
+
message['content'] = content
|
| 155 |
+
if tool_calls:
|
| 156 |
+
message['tool_calls'] = tool_calls
|
| 157 |
+
|
| 158 |
+
return message
|
preprocessor_config.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"processor_class": "LongcatNextProcessor",
|
| 3 |
+
"auto_map": {
|
| 4 |
+
"AutoProcessor": "processing_longcat_next.LongcatNextProcessor"
|
| 5 |
+
},
|
| 6 |
+
"spatial_merge_size": 2,
|
| 7 |
+
"max_pixels": 3211264,
|
| 8 |
+
"min_pixels": 50176,
|
| 9 |
+
|
| 10 |
+
"n_fft": 400,
|
| 11 |
+
"num_mel_bins": 128,
|
| 12 |
+
"sampling_rate": 16000,
|
| 13 |
+
"max_audio_seconds": 30,
|
| 14 |
+
"hop_length": 160,
|
| 15 |
+
"kernel_size": 3,
|
| 16 |
+
"stride_size": 2,
|
| 17 |
+
"split_overlap": 0.0,
|
| 18 |
+
"avg_pooler": 4
|
| 19 |
+
}
|
processing_longcat_next.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from typing import Union, List
|
| 3 |
+
from types import SimpleNamespace
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import librosa
|
| 7 |
+
import soundfile as sf
|
| 8 |
+
import numpy as np
|
| 9 |
+
from transformers import AutoFeatureExtractor
|
| 10 |
+
from transformers.audio_utils import mel_filter_bank
|
| 11 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 12 |
+
from transformers.feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
| 13 |
+
from transformers.processing_utils import (
|
| 14 |
+
AudioKwargs,
|
| 15 |
+
ImagesKwargs,
|
| 16 |
+
ProcessingKwargs,
|
| 17 |
+
ProcessorMixin,
|
| 18 |
+
VideosKwargs,
|
| 19 |
+
)
|
| 20 |
+
from transformers.utils import logging
|
| 21 |
+
|
| 22 |
+
logger = logging.get_logger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class LongcatNextProcessorKwargs(ProcessingKwargs, total=False):
|
| 26 |
+
images_kwargs: ImagesKwargs
|
| 27 |
+
videos_kwargs: VideosKwargs
|
| 28 |
+
audio_kwargs: AudioKwargs
|
| 29 |
+
_defaults = {
|
| 30 |
+
"text_kwargs": {
|
| 31 |
+
"padding": False,
|
| 32 |
+
"padding_side": "left",
|
| 33 |
+
"return_attention_mask": False,
|
| 34 |
+
}
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class LongcatNextAudioProcessor(FeatureExtractionMixin):
|
| 39 |
+
|
| 40 |
+
def __init__(self, **kwargs):
|
| 41 |
+
super().__init__(**kwargs)
|
| 42 |
+
self.mel_filters = mel_filter_bank(
|
| 43 |
+
num_frequency_bins=1 + self.n_fft // 2,
|
| 44 |
+
num_mel_filters=self.num_mel_bins,
|
| 45 |
+
min_frequency=0.0,
|
| 46 |
+
max_frequency=self.sampling_rate / 2.0,
|
| 47 |
+
sampling_rate=self.sampling_rate,
|
| 48 |
+
norm="slaney",
|
| 49 |
+
mel_scale="slaney",
|
| 50 |
+
)
|
| 51 |
+
self.window = torch.hann_window(self.n_fft)
|
| 52 |
+
|
| 53 |
+
@staticmethod
|
| 54 |
+
def zero_mean_unit_var_norm(x):
|
| 55 |
+
return (x - x.mean()) / torch.sqrt(x.var() + 1e-8)
|
| 56 |
+
|
| 57 |
+
def load_audio_waveform(self, uri, metadata=None, waveform_tensor=None, return_tensors=True, do_normalize=False):
|
| 58 |
+
if metadata is None or waveform_tensor is None:
|
| 59 |
+
# 使用 librosa 统一处理所有音频格式(包括 mp3, wav, flac 等)
|
| 60 |
+
# librosa.load 返回的已经是归一化的 float32 数据
|
| 61 |
+
waveform_np, sample_rate = librosa.load(uri, sr=None, mono=False)
|
| 62 |
+
|
| 63 |
+
# 转换为 tensor,确保维度为 (channels, samples)
|
| 64 |
+
if waveform_np.ndim == 1:
|
| 65 |
+
waveform_tensor = torch.from_numpy(waveform_np).unsqueeze(0)
|
| 66 |
+
else:
|
| 67 |
+
waveform_tensor = torch.from_numpy(waveform_np)
|
| 68 |
+
|
| 69 |
+
# 获取音频元信息
|
| 70 |
+
try:
|
| 71 |
+
sf_info = sf.info(uri)
|
| 72 |
+
metadata = SimpleNamespace(
|
| 73 |
+
sample_rate=sample_rate,
|
| 74 |
+
num_frames=waveform_tensor.shape[1],
|
| 75 |
+
num_channels=waveform_tensor.shape[0],
|
| 76 |
+
bits_per_sample=getattr(sf_info, 'bits_per_sample', 16),
|
| 77 |
+
encoding=getattr(sf_info, 'subtype', 'PCM_F')
|
| 78 |
+
)
|
| 79 |
+
except Exception:
|
| 80 |
+
# 如果 soundfile.info 失败,使用 librosa 提供的信息
|
| 81 |
+
metadata = SimpleNamespace(
|
| 82 |
+
sample_rate=sample_rate,
|
| 83 |
+
num_frames=waveform_tensor.shape[1],
|
| 84 |
+
num_channels=waveform_tensor.shape[0],
|
| 85 |
+
bits_per_sample=16,
|
| 86 |
+
encoding='PCM_F'
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
assert(metadata.num_channels <= 2), "acoustic file with {} channels.".format(metadata.num_channels) # whisper only accept mono channel audio
|
| 90 |
+
|
| 91 |
+
if self.sampling_rate != metadata.sample_rate:
|
| 92 |
+
# 使用 torch.functional 进行重采样
|
| 93 |
+
waveform_tensor = torch.nn.functional.interpolate(
|
| 94 |
+
waveform_tensor.unsqueeze(0),
|
| 95 |
+
size=int(waveform_tensor.shape[1] * self.sampling_rate / metadata.sample_rate),
|
| 96 |
+
mode='linear',
|
| 97 |
+
align_corners=False
|
| 98 |
+
).squeeze(0)
|
| 99 |
+
|
| 100 |
+
# downmix to mono channel https://trac.ffmpeg.org/wiki/AudioChannelManipulation
|
| 101 |
+
if metadata.num_channels > 1:
|
| 102 |
+
waveform_tensor = torch.mean(waveform_tensor, dim=0, keepdim=True)
|
| 103 |
+
|
| 104 |
+
# normalized to zero mean (Qwen Audio没有处理 但Whisper官方实现)
|
| 105 |
+
if do_normalize:
|
| 106 |
+
waveform_tensor = self.zero_mean_unit_var_norm(waveform_tensor)
|
| 107 |
+
|
| 108 |
+
if return_tensors: # (channels, samples)
|
| 109 |
+
return waveform_tensor
|
| 110 |
+
else:
|
| 111 |
+
return waveform_tensor.numpy()
|
| 112 |
+
|
| 113 |
+
def split_with_overlap(self, waveform): # 如果长度超过最大长度限制 分割为带overlap的多段
|
| 114 |
+
channels, wave_samples = waveform.shape
|
| 115 |
+
max_audio_samples = self.max_audio_seconds * self.sampling_rate
|
| 116 |
+
if wave_samples <= max_audio_samples or self.split_overlap < 0:
|
| 117 |
+
return [waveform] # 没有超出最大长度or截断逻辑 统一返回list
|
| 118 |
+
|
| 119 |
+
split_waveform, start = [], 0
|
| 120 |
+
while start < wave_samples: # 统一按秒数对齐overlap
|
| 121 |
+
if start > int(self.sampling_rate * self.split_overlap):
|
| 122 |
+
start -= int(self.sampling_rate * self.split_overlap) # 0表示没有overlap,>0 overlap对应秒数
|
| 123 |
+
end = min(start + max_audio_samples, wave_samples)
|
| 124 |
+
if end - start>= self.n_fft: # 保证至少有一帧数据
|
| 125 |
+
split_waveform.append(waveform[:, start:end]) # 注意这里可能会切割出特别短的片段 需要在预处理判断并丢弃
|
| 126 |
+
start = end
|
| 127 |
+
return split_waveform
|
| 128 |
+
|
| 129 |
+
@classmethod
|
| 130 |
+
def inference_output_length(self, input_length, kernel_size, stride_size, avg_pooler):
|
| 131 |
+
# for whisper + bridge
|
| 132 |
+
encoder_length = (input_length + 2 * (kernel_size // 2) - kernel_size) // 1 + 1 # conv layer1 with pad=1
|
| 133 |
+
encoder_length = (encoder_length + 2 * (kernel_size // 2) - kernel_size) // stride_size + 1 # conv layer2 with pad=1
|
| 134 |
+
if avg_pooler > 1:
|
| 135 |
+
bridge_length = encoder_length // avg_pooler
|
| 136 |
+
return encoder_length, bridge_length
|
| 137 |
+
|
| 138 |
+
def extract_fbank_features(self, waveform):
|
| 139 |
+
# ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py
|
| 140 |
+
channels, wave_samples = waveform.shape
|
| 141 |
+
assert(wave_samples >= self.n_fft)
|
| 142 |
+
valid_frame_nums = min(self.max_audio_seconds * self.sampling_rate // self.hop_length, wave_samples // self.hop_length + 1)
|
| 143 |
+
if wave_samples < self.max_audio_seconds * self.sampling_rate:
|
| 144 |
+
waveform = torch.nn.functional.pad(waveform, (0, self.max_audio_seconds * self.sampling_rate - wave_samples), "constant", 0)
|
| 145 |
+
else:
|
| 146 |
+
waveform = waveform[:, :self.max_audio_seconds * self.sampling_rate]
|
| 147 |
+
|
| 148 |
+
# window = torch.hann_window(self.n_fft)
|
| 149 |
+
stft = torch.stft(waveform, self.n_fft, self.hop_length, window=self.window, return_complex=True) # fft, len(wave) // n_fft // 2 + 1
|
| 150 |
+
magnitudes = stft[..., :-1].abs() ** 2
|
| 151 |
+
|
| 152 |
+
mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
|
| 153 |
+
mel_spec = mel_filters.T @ magnitudes
|
| 154 |
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
| 155 |
+
if waveform.dim() == 2:
|
| 156 |
+
max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0]
|
| 157 |
+
log_spec = torch.maximum(log_spec, max_val - 8.0)
|
| 158 |
+
else:
|
| 159 |
+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
| 160 |
+
log_spec = (log_spec + 4.0) / 4.0
|
| 161 |
+
|
| 162 |
+
log_spec = log_spec[0].numpy() # (channel, filters, samples) -> (filters, samples)
|
| 163 |
+
log_spec[:, valid_frame_nums:] = 0.0 # pad0
|
| 164 |
+
|
| 165 |
+
return log_spec, valid_frame_nums
|
| 166 |
+
|
| 167 |
+
def process(self, audio_path, **kwargs):
|
| 168 |
+
metadata, waveform_tensors = None, None
|
| 169 |
+
waveforms = self.load_audio_waveform(audio_path, metadata, waveform_tensors, True)
|
| 170 |
+
waveforms = self.split_with_overlap(waveforms)
|
| 171 |
+
|
| 172 |
+
ret_audio, ret_encoder_length, ret_bridge_length = [], [], []
|
| 173 |
+
for i, waveform in enumerate(waveforms):
|
| 174 |
+
audio, input_length = self.extract_fbank_features(waveform)
|
| 175 |
+
encoder_length, bridge_length = self.inference_output_length(input_length, self.kernel_size, self.stride_size, self.avg_pooler)
|
| 176 |
+
if bridge_length <= 0:
|
| 177 |
+
continue
|
| 178 |
+
|
| 179 |
+
ret_audio.append(audio)
|
| 180 |
+
ret_encoder_length.append(encoder_length)
|
| 181 |
+
ret_bridge_length.append(bridge_length)
|
| 182 |
+
return ret_audio, ret_encoder_length, ret_bridge_length
|
| 183 |
+
|
| 184 |
+
def __call__(self, audio: Union[str, List[str]], **kwargs):
|
| 185 |
+
if isinstance(audio, str):
|
| 186 |
+
audio = [audio]
|
| 187 |
+
results = {
|
| 188 |
+
"audio": [],
|
| 189 |
+
"encoder_length": [],
|
| 190 |
+
"bridge_length": [],
|
| 191 |
+
}
|
| 192 |
+
for audio_path in audio:
|
| 193 |
+
audio, encoder_length, bridge_length = self.process(audio_path, **kwargs)
|
| 194 |
+
results["audio"].append(audio)
|
| 195 |
+
results["encoder_length"].append(encoder_length)
|
| 196 |
+
results["bridge_length"].append(bridge_length)
|
| 197 |
+
return results
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class LongcatNextProcessor(ProcessorMixin):
|
| 201 |
+
|
| 202 |
+
attributes = ["image_processor", "video_processor", "audio_processor", "tokenizer"]
|
| 203 |
+
|
| 204 |
+
image_processor_class = "Qwen2VLImageProcessor"
|
| 205 |
+
video_processor_class = "Qwen2VLImageProcessor"
|
| 206 |
+
audio_processor_class = "LongcatNextAudioProcessor"
|
| 207 |
+
tokenizer_class = "AutoTokenizer"
|
| 208 |
+
|
| 209 |
+
def __init__(self, image_processor=None, video_processor=None, audio_processor=None, tokenizer=None, chat_template=None, **kwargs):
|
| 210 |
+
super().__init__(image_processor, video_processor, audio_processor, tokenizer, chat_template=chat_template)
|
| 211 |
+
init_token_list = [
|
| 212 |
+
"image_start_token", "image_end_token", "image_pad_token", "image_newline_token",
|
| 213 |
+
"audio_start_token", "audio_end_token", "audio_pad_token",
|
| 214 |
+
]
|
| 215 |
+
for attr in init_token_list:
|
| 216 |
+
token_str = self.tokenizer.init_kwargs.get(attr)
|
| 217 |
+
token_ids = self.tokenizer.encode(token_str, add_special_tokens=False)
|
| 218 |
+
assert len(token_ids) == 1, (f"{attr}='{token_str}' encode to get {len(token_ids)} id(s) {token_ids}, expect 1 id")
|
| 219 |
+
setattr(self, f"{attr}", token_str)
|
| 220 |
+
setattr(self, f"{attr}_id", token_ids[0])
|
| 221 |
+
|
| 222 |
+
def __call__(
|
| 223 |
+
self,
|
| 224 |
+
text: str,
|
| 225 |
+
**kwargs,
|
| 226 |
+
) -> List["LongcatNextProcessorOutput"]:
|
| 227 |
+
|
| 228 |
+
if text is None:
|
| 229 |
+
raise ValueError("You need to specify either a `text` input to process.")
|
| 230 |
+
|
| 231 |
+
output_kwargs = self._merge_kwargs(
|
| 232 |
+
LongcatNextProcessorKwargs,
|
| 233 |
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
| 234 |
+
**kwargs,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
assert isinstance(text, str)
|
| 238 |
+
|
| 239 |
+
image_path_list = re.findall(rf"{self.image_start_token}(.*?){self.image_end_token}", text)
|
| 240 |
+
audio_path_list = re.findall(rf"{self.audio_start_token}(.*?){self.audio_end_token}", text)
|
| 241 |
+
|
| 242 |
+
if len(image_path_list) > 0:
|
| 243 |
+
images_inputs = self.image_processor(images=image_path_list, **output_kwargs["images_kwargs"])
|
| 244 |
+
image_grid_thw = images_inputs["image_grid_thw"]
|
| 245 |
+
for i, image_path in enumerate(image_path_list):
|
| 246 |
+
image_token_num = image_grid_thw[i][0] * (image_grid_thw[i][1]//self.image_processor.spatial_merge_size) * (image_grid_thw[i][2]//self.image_processor.spatial_merge_size)
|
| 247 |
+
text = text.replace(f"{self.image_start_token}{image_path}{self.image_end_token}", f"{self.image_start_token}{self.image_pad_token * image_token_num}{self.image_end_token}")
|
| 248 |
+
else:
|
| 249 |
+
images_inputs = {}
|
| 250 |
+
|
| 251 |
+
if len(audio_path_list) > 0:
|
| 252 |
+
audio_inputs = self.audio_processor(audio=audio_path_list, **output_kwargs["audio_kwargs"])
|
| 253 |
+
for i, audio_path in enumerate(audio_path_list):
|
| 254 |
+
audio_token_num = np.sum(audio_inputs["bridge_length"][i])
|
| 255 |
+
text = text.replace(f"{self.audio_start_token}{audio_path}{self.audio_end_token}", f"{self.audio_start_token}{self.audio_pad_token * audio_token_num}{self.audio_end_token}")
|
| 256 |
+
for key in audio_inputs:
|
| 257 |
+
audio_inputs[key] = [val for b_val in audio_inputs[key] for val in b_val]
|
| 258 |
+
else:
|
| 259 |
+
audio_inputs = {}
|
| 260 |
+
|
| 261 |
+
texts_inputs = self.tokenizer([text], **output_kwargs["text_kwargs"])
|
| 262 |
+
|
| 263 |
+
batch_feature_func = lambda x: BatchFeature(
|
| 264 |
+
data={**x},
|
| 265 |
+
tensor_type=kwargs.get("return_tensors"),
|
| 266 |
+
)
|
| 267 |
+
return (
|
| 268 |
+
batch_feature_func(texts_inputs),
|
| 269 |
+
batch_feature_func({k.replace("image", "visual"): v for k, v in images_inputs.items()}) if len(images_inputs) > 0 else None,
|
| 270 |
+
batch_feature_func(audio_inputs) if len(audio_inputs) > 0 else None,
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
class LongcatNextAudioProcessorConfig(PretrainedConfig):
|
| 275 |
+
pass
|
| 276 |
+
AutoFeatureExtractor.register(LongcatNextAudioProcessorConfig, LongcatNextAudioProcessor)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
__all__ = ["LongcatNextAudioProcessor", "LongcatNextProcessor"]
|
refiner_modules.py
ADDED
|
@@ -0,0 +1,1330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ---------------------------------------------------------------------------
|
| 2 |
+
# Standard / third-party imports shared by all sections
|
| 3 |
+
# ---------------------------------------------------------------------------
|
| 4 |
+
|
| 5 |
+
import itertools
|
| 6 |
+
import math
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 9 |
+
|
| 10 |
+
from flash_attn import flash_attn_varlen_func # type: ignore
|
| 11 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # type: ignore
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
from torch.nn import RMSNorm
|
| 17 |
+
|
| 18 |
+
from einops import rearrange, repeat
|
| 19 |
+
|
| 20 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 21 |
+
from diffusers.loaders import PeftAdapterMixin
|
| 22 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
| 23 |
+
from diffusers.models.activations import get_activation
|
| 24 |
+
from diffusers.models.attention_processor import Attention
|
| 25 |
+
from diffusers.models.embeddings import Timesteps, get_1d_rotary_pos_embed
|
| 26 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 27 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 28 |
+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
| 29 |
+
from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
|
| 30 |
+
|
| 31 |
+
logger = logging.get_logger(__name__)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def swiglu(x, y):
|
| 35 |
+
return F.silu(x.float(), inplace=False).to(x.dtype) * y
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class TimestepEmbedding(nn.Module):
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
in_channels: int,
|
| 42 |
+
time_embed_dim: int,
|
| 43 |
+
act_fn: str = "silu",
|
| 44 |
+
out_dim: int = None,
|
| 45 |
+
post_act_fn: Optional[str] = None,
|
| 46 |
+
cond_proj_dim=None,
|
| 47 |
+
sample_proj_bias=True,
|
| 48 |
+
):
|
| 49 |
+
super().__init__()
|
| 50 |
+
|
| 51 |
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
|
| 52 |
+
|
| 53 |
+
if cond_proj_dim is not None:
|
| 54 |
+
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
| 55 |
+
else:
|
| 56 |
+
self.cond_proj = None
|
| 57 |
+
|
| 58 |
+
self.act = get_activation(act_fn)
|
| 59 |
+
|
| 60 |
+
if out_dim is not None:
|
| 61 |
+
time_embed_dim_out = out_dim
|
| 62 |
+
else:
|
| 63 |
+
time_embed_dim_out = time_embed_dim
|
| 64 |
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
|
| 65 |
+
|
| 66 |
+
if post_act_fn is None:
|
| 67 |
+
self.post_act = None
|
| 68 |
+
else:
|
| 69 |
+
self.post_act = get_activation(post_act_fn)
|
| 70 |
+
|
| 71 |
+
self.initialize_weights()
|
| 72 |
+
|
| 73 |
+
def initialize_weights(self):
|
| 74 |
+
nn.init.normal_(self.linear_1.weight, std=0.02)
|
| 75 |
+
nn.init.zeros_(self.linear_1.bias)
|
| 76 |
+
nn.init.normal_(self.linear_2.weight, std=0.02)
|
| 77 |
+
nn.init.zeros_(self.linear_2.bias)
|
| 78 |
+
|
| 79 |
+
def forward(self, sample, condition=None):
|
| 80 |
+
if condition is not None:
|
| 81 |
+
sample = sample + self.cond_proj(condition)
|
| 82 |
+
sample = self.linear_1(sample)
|
| 83 |
+
if self.act is not None:
|
| 84 |
+
sample = self.act(sample)
|
| 85 |
+
sample = self.linear_2(sample)
|
| 86 |
+
if self.post_act is not None:
|
| 87 |
+
sample = self.post_act(sample)
|
| 88 |
+
return sample
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def apply_rotary_emb(
|
| 92 |
+
x: torch.Tensor,
|
| 93 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
| 94 |
+
use_real: bool = True,
|
| 95 |
+
use_real_unbind_dim: int = -1,
|
| 96 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 97 |
+
"""
|
| 98 |
+
Apply rotary embeddings to input tensors using the given frequency tensor.
|
| 99 |
+
"""
|
| 100 |
+
if use_real:
|
| 101 |
+
cos, sin = freqs_cis # [S, D]
|
| 102 |
+
cos = cos[None, None]
|
| 103 |
+
sin = sin[None, None]
|
| 104 |
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
| 105 |
+
|
| 106 |
+
if use_real_unbind_dim == -1:
|
| 107 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
| 108 |
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
| 109 |
+
elif use_real_unbind_dim == -2:
|
| 110 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2)
|
| 111 |
+
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
| 112 |
+
else:
|
| 113 |
+
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
| 114 |
+
|
| 115 |
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
| 116 |
+
return out
|
| 117 |
+
else:
|
| 118 |
+
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], x.shape[-1] // 2, 2))
|
| 119 |
+
freqs_cis = freqs_cis.unsqueeze(2)
|
| 120 |
+
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
| 121 |
+
return x_out.type_as(x)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
@dataclass
|
| 125 |
+
class TeaCacheParams:
|
| 126 |
+
"""
|
| 127 |
+
TeaCache parameters for Transformer2DModel.
|
| 128 |
+
See https://github.com/ali-vilab/TeaCache/ for a more comprehensive understanding.
|
| 129 |
+
"""
|
| 130 |
+
previous_residual: Optional[torch.Tensor] = None
|
| 131 |
+
previous_modulated_inp: Optional[torch.Tensor] = None
|
| 132 |
+
accumulated_rel_l1_distance: float = 0
|
| 133 |
+
is_first_or_last_step: bool = False
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def derivative_approximation(*args, **kwargs):
|
| 137 |
+
pass
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def taylor_formula(*args, **kwargs):
|
| 141 |
+
pass
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def taylor_cache_init(*args, **kwargs):
|
| 145 |
+
pass
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def cache_init(*args, **kwargs):
|
| 149 |
+
pass
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def cal_type(*args, **kwargs):
|
| 153 |
+
pass
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class LuminaRMSNormZero(nn.Module):
|
| 157 |
+
"""
|
| 158 |
+
Norm layer adaptive RMS normalization zero.
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
def __init__(
|
| 162 |
+
self,
|
| 163 |
+
embedding_dim: int,
|
| 164 |
+
norm_eps: float,
|
| 165 |
+
norm_elementwise_affine: bool,
|
| 166 |
+
):
|
| 167 |
+
super().__init__()
|
| 168 |
+
self.silu = nn.SiLU()
|
| 169 |
+
self.linear = nn.Linear(
|
| 170 |
+
min(embedding_dim, 1024),
|
| 171 |
+
4 * embedding_dim,
|
| 172 |
+
bias=True,
|
| 173 |
+
)
|
| 174 |
+
self.norm = RMSNorm(embedding_dim, eps=norm_eps)
|
| 175 |
+
|
| 176 |
+
def forward(
|
| 177 |
+
self,
|
| 178 |
+
x: torch.Tensor,
|
| 179 |
+
emb: Optional[torch.Tensor] = None,
|
| 180 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 181 |
+
emb = self.linear(self.silu(emb))
|
| 182 |
+
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
|
| 183 |
+
x = self.norm(x) * (1 + scale_msa[:, None])
|
| 184 |
+
return x, gate_msa, scale_mlp, gate_mlp
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class LuminaLayerNormContinuous(nn.Module):
|
| 188 |
+
def __init__(
|
| 189 |
+
self,
|
| 190 |
+
embedding_dim: int,
|
| 191 |
+
conditioning_embedding_dim: int,
|
| 192 |
+
elementwise_affine=True,
|
| 193 |
+
eps=1e-5,
|
| 194 |
+
bias=True,
|
| 195 |
+
norm_type="layer_norm",
|
| 196 |
+
out_dim: Optional[int] = None,
|
| 197 |
+
):
|
| 198 |
+
super().__init__()
|
| 199 |
+
|
| 200 |
+
self.silu = nn.SiLU()
|
| 201 |
+
self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
|
| 202 |
+
|
| 203 |
+
if norm_type == "layer_norm":
|
| 204 |
+
self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
| 205 |
+
elif norm_type == "rms_norm":
|
| 206 |
+
self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
|
| 207 |
+
else:
|
| 208 |
+
raise ValueError(f"unknown norm_type {norm_type}")
|
| 209 |
+
|
| 210 |
+
self.linear_2 = None
|
| 211 |
+
if out_dim is not None:
|
| 212 |
+
self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias)
|
| 213 |
+
|
| 214 |
+
def forward(
|
| 215 |
+
self,
|
| 216 |
+
x: torch.Tensor,
|
| 217 |
+
conditioning_embedding: torch.Tensor,
|
| 218 |
+
) -> torch.Tensor:
|
| 219 |
+
emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
|
| 220 |
+
scale = emb
|
| 221 |
+
x = self.norm(x) * (1 + scale)[:, None, :]
|
| 222 |
+
if self.linear_2 is not None:
|
| 223 |
+
x = self.linear_2(x)
|
| 224 |
+
return x
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class LuminaFeedForward(nn.Module):
|
| 228 |
+
def __init__(
|
| 229 |
+
self,
|
| 230 |
+
dim: int,
|
| 231 |
+
inner_dim: int,
|
| 232 |
+
multiple_of: Optional[int] = 256,
|
| 233 |
+
ffn_dim_multiplier: Optional[float] = None,
|
| 234 |
+
):
|
| 235 |
+
super().__init__()
|
| 236 |
+
|
| 237 |
+
if ffn_dim_multiplier is not None:
|
| 238 |
+
inner_dim = int(ffn_dim_multiplier * inner_dim)
|
| 239 |
+
inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
|
| 240 |
+
|
| 241 |
+
self.linear_1 = nn.Linear(dim, inner_dim, bias=False)
|
| 242 |
+
self.linear_2 = nn.Linear(inner_dim, dim, bias=False)
|
| 243 |
+
self.linear_3 = nn.Linear(dim, inner_dim, bias=False)
|
| 244 |
+
|
| 245 |
+
def forward(self, x):
|
| 246 |
+
h1, h2 = self.linear_1(x), self.linear_3(x)
|
| 247 |
+
return self.linear_2(swiglu(h1, h2))
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
|
| 251 |
+
def __init__(
|
| 252 |
+
self,
|
| 253 |
+
hidden_size: int = 4096,
|
| 254 |
+
text_feat_dim: int = 2048,
|
| 255 |
+
frequency_embedding_size: int = 256,
|
| 256 |
+
norm_eps: float = 1e-5,
|
| 257 |
+
timestep_scale: float = 1.0,
|
| 258 |
+
) -> None:
|
| 259 |
+
super().__init__()
|
| 260 |
+
|
| 261 |
+
self.time_proj = Timesteps(
|
| 262 |
+
num_channels=frequency_embedding_size,
|
| 263 |
+
flip_sin_to_cos=True,
|
| 264 |
+
downscale_freq_shift=0.0,
|
| 265 |
+
scale=timestep_scale,
|
| 266 |
+
)
|
| 267 |
+
self.timestep_embedder = TimestepEmbedding(
|
| 268 |
+
in_channels=frequency_embedding_size,
|
| 269 |
+
time_embed_dim=min(hidden_size, 1024),
|
| 270 |
+
)
|
| 271 |
+
self.caption_embedder = nn.Sequential(
|
| 272 |
+
RMSNorm(text_feat_dim, eps=norm_eps),
|
| 273 |
+
nn.Linear(text_feat_dim, hidden_size, bias=True),
|
| 274 |
+
)
|
| 275 |
+
self._initialize_weights()
|
| 276 |
+
|
| 277 |
+
def _initialize_weights(self):
|
| 278 |
+
nn.init.trunc_normal_(self.caption_embedder[1].weight, std=0.02)
|
| 279 |
+
nn.init.zeros_(self.caption_embedder[1].bias)
|
| 280 |
+
|
| 281 |
+
def forward(
|
| 282 |
+
self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype
|
| 283 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 284 |
+
timestep_proj = self.time_proj(timestep).to(dtype=dtype)
|
| 285 |
+
time_embed = self.timestep_embedder(timestep_proj)
|
| 286 |
+
caption_embed = self.caption_embedder(text_hidden_states)
|
| 287 |
+
return time_embed, caption_embed
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
class AttnProcessorFlash2Varlen:
|
| 291 |
+
"""
|
| 292 |
+
Processor for implementing scaled dot-product attention with flash attention
|
| 293 |
+
and variable length sequences.
|
| 294 |
+
"""
|
| 295 |
+
|
| 296 |
+
def __init__(self) -> None:
|
| 297 |
+
pass
|
| 298 |
+
# if not is_flash_attn_available():
|
| 299 |
+
# raise ImportError(
|
| 300 |
+
# "AttnProcessorFlash2Varlen requires flash_attn. "
|
| 301 |
+
# "Please install flash_attn."
|
| 302 |
+
# )
|
| 303 |
+
|
| 304 |
+
def _upad_input(
|
| 305 |
+
self,
|
| 306 |
+
query_layer: torch.Tensor,
|
| 307 |
+
key_layer: torch.Tensor,
|
| 308 |
+
value_layer: torch.Tensor,
|
| 309 |
+
attention_mask: torch.Tensor,
|
| 310 |
+
query_length: int,
|
| 311 |
+
num_heads: int,
|
| 312 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
|
| 313 |
+
def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
| 314 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
| 315 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
| 316 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
| 317 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
| 318 |
+
return indices, cu_seqlens, max_seqlen_in_batch
|
| 319 |
+
|
| 320 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
| 321 |
+
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
| 322 |
+
|
| 323 |
+
key_layer = index_first_axis(
|
| 324 |
+
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k,
|
| 325 |
+
)
|
| 326 |
+
value_layer = index_first_axis(
|
| 327 |
+
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
if query_length == kv_seq_len:
|
| 331 |
+
query_layer = index_first_axis(
|
| 332 |
+
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k,
|
| 333 |
+
)
|
| 334 |
+
cu_seqlens_q = cu_seqlens_k
|
| 335 |
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
| 336 |
+
indices_q = indices_k
|
| 337 |
+
elif query_length == 1:
|
| 338 |
+
max_seqlen_in_batch_q = 1
|
| 339 |
+
cu_seqlens_q = torch.arange(
|
| 340 |
+
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
| 341 |
+
)
|
| 342 |
+
indices_q = cu_seqlens_q[:-1]
|
| 343 |
+
query_layer = query_layer.squeeze(1)
|
| 344 |
+
else:
|
| 345 |
+
attention_mask = attention_mask[:, -query_length:]
|
| 346 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
|
| 347 |
+
query_layer, attention_mask
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
return (
|
| 351 |
+
query_layer, key_layer, value_layer, indices_q,
|
| 352 |
+
(cu_seqlens_q, cu_seqlens_k),
|
| 353 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
def __call__(
|
| 357 |
+
self,
|
| 358 |
+
attn: Attention,
|
| 359 |
+
hidden_states: torch.Tensor,
|
| 360 |
+
encoder_hidden_states: torch.Tensor,
|
| 361 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 362 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 363 |
+
base_sequence_length: Optional[int] = None,
|
| 364 |
+
) -> torch.Tensor:
|
| 365 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
| 366 |
+
|
| 367 |
+
query = attn.to_q(hidden_states)
|
| 368 |
+
key = attn.to_k(encoder_hidden_states)
|
| 369 |
+
value = attn.to_v(encoder_hidden_states)
|
| 370 |
+
|
| 371 |
+
query_dim = query.shape[-1]
|
| 372 |
+
inner_dim = key.shape[-1]
|
| 373 |
+
head_dim = query_dim // attn.heads
|
| 374 |
+
dtype = query.dtype
|
| 375 |
+
kv_heads = inner_dim // head_dim
|
| 376 |
+
|
| 377 |
+
query = query.view(batch_size, -1, attn.heads, head_dim)
|
| 378 |
+
key = key.view(batch_size, -1, kv_heads, head_dim)
|
| 379 |
+
value = value.view(batch_size, -1, kv_heads, head_dim)
|
| 380 |
+
|
| 381 |
+
if attn.norm_q is not None:
|
| 382 |
+
query = attn.norm_q(query)
|
| 383 |
+
if attn.norm_k is not None:
|
| 384 |
+
key = attn.norm_k(key)
|
| 385 |
+
|
| 386 |
+
if image_rotary_emb is not None:
|
| 387 |
+
query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
|
| 388 |
+
key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
|
| 389 |
+
|
| 390 |
+
query, key = query.to(dtype), key.to(dtype)
|
| 391 |
+
|
| 392 |
+
if base_sequence_length is not None:
|
| 393 |
+
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
|
| 394 |
+
else:
|
| 395 |
+
softmax_scale = attn.scale
|
| 396 |
+
|
| 397 |
+
(
|
| 398 |
+
query_states, key_states, value_states, indices_q,
|
| 399 |
+
cu_seq_lens, max_seq_lens,
|
| 400 |
+
) = self._upad_input(query, key, value, attention_mask, sequence_length, attn.heads)
|
| 401 |
+
|
| 402 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
| 403 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
| 404 |
+
|
| 405 |
+
if kv_heads < attn.heads:
|
| 406 |
+
key_states = repeat(key_states, "l h c -> l (h k) c", k=attn.heads // kv_heads)
|
| 407 |
+
value_states = repeat(value_states, "l h c -> l (h k) c", k=attn.heads // kv_heads)
|
| 408 |
+
|
| 409 |
+
attn_output_unpad = flash_attn_varlen_func(
|
| 410 |
+
query_states, key_states, value_states,
|
| 411 |
+
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
|
| 412 |
+
max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
|
| 413 |
+
dropout_p=0.0, causal=False, softmax_scale=softmax_scale,
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
hidden_states = pad_input(attn_output_unpad, indices_q, batch_size, sequence_length)
|
| 417 |
+
hidden_states = hidden_states.flatten(-2)
|
| 418 |
+
hidden_states = hidden_states.type_as(query)
|
| 419 |
+
|
| 420 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 421 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 422 |
+
return hidden_states
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
class AttnProcessor:
|
| 426 |
+
"""
|
| 427 |
+
Processor for implementing scaled dot-product attention (PyTorch 2.0+).
|
| 428 |
+
"""
|
| 429 |
+
|
| 430 |
+
def __init__(self) -> None:
|
| 431 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 432 |
+
raise ImportError(
|
| 433 |
+
"AttnProcessor requires PyTorch 2.0. "
|
| 434 |
+
"Please upgrade PyTorch to version 2.0 or later."
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
def __call__(
|
| 438 |
+
self,
|
| 439 |
+
attn: Attention,
|
| 440 |
+
hidden_states: torch.Tensor,
|
| 441 |
+
encoder_hidden_states: torch.Tensor,
|
| 442 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 443 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 444 |
+
base_sequence_length: Optional[int] = None,
|
| 445 |
+
) -> torch.Tensor:
|
| 446 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
| 447 |
+
|
| 448 |
+
query = attn.to_q(hidden_states)
|
| 449 |
+
key = attn.to_k(encoder_hidden_states)
|
| 450 |
+
value = attn.to_v(encoder_hidden_states)
|
| 451 |
+
|
| 452 |
+
query_dim = query.shape[-1]
|
| 453 |
+
inner_dim = key.shape[-1]
|
| 454 |
+
head_dim = query_dim // attn.heads
|
| 455 |
+
dtype = query.dtype
|
| 456 |
+
kv_heads = inner_dim // head_dim
|
| 457 |
+
|
| 458 |
+
query = query.view(batch_size, -1, attn.heads, head_dim)
|
| 459 |
+
key = key.view(batch_size, -1, kv_heads, head_dim)
|
| 460 |
+
value = value.view(batch_size, -1, kv_heads, head_dim)
|
| 461 |
+
|
| 462 |
+
if attn.norm_q is not None:
|
| 463 |
+
query = attn.norm_q(query)
|
| 464 |
+
if attn.norm_k is not None:
|
| 465 |
+
key = attn.norm_k(key)
|
| 466 |
+
|
| 467 |
+
if image_rotary_emb is not None:
|
| 468 |
+
query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
|
| 469 |
+
key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
|
| 470 |
+
|
| 471 |
+
query, key = query.to(dtype), key.to(dtype)
|
| 472 |
+
|
| 473 |
+
if base_sequence_length is not None:
|
| 474 |
+
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
|
| 475 |
+
else:
|
| 476 |
+
softmax_scale = attn.scale
|
| 477 |
+
|
| 478 |
+
if attention_mask is not None:
|
| 479 |
+
attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
|
| 480 |
+
|
| 481 |
+
query = query.transpose(1, 2)
|
| 482 |
+
key = key.transpose(1, 2)
|
| 483 |
+
value = value.transpose(1, 2)
|
| 484 |
+
|
| 485 |
+
key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
|
| 486 |
+
value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)
|
| 487 |
+
|
| 488 |
+
hidden_states = F.scaled_dot_product_attention(
|
| 489 |
+
query, key, value, attn_mask=attention_mask, scale=softmax_scale
|
| 490 |
+
)
|
| 491 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 492 |
+
hidden_states = hidden_states.type_as(query)
|
| 493 |
+
|
| 494 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 495 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 496 |
+
return hidden_states
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
class RotaryPosEmbed(nn.Module):
|
| 501 |
+
def __init__(
|
| 502 |
+
self,
|
| 503 |
+
theta: int,
|
| 504 |
+
axes_dim: Tuple[int, int, int],
|
| 505 |
+
axes_lens: Tuple[int, int, int] = (300, 512, 512),
|
| 506 |
+
patch_size: int = 2,
|
| 507 |
+
):
|
| 508 |
+
super().__init__()
|
| 509 |
+
self.theta = theta
|
| 510 |
+
self.axes_dim = axes_dim
|
| 511 |
+
self.axes_lens = axes_lens
|
| 512 |
+
self.patch_size = patch_size
|
| 513 |
+
|
| 514 |
+
@staticmethod
|
| 515 |
+
def get_freqs_cis(
|
| 516 |
+
axes_dim: Tuple[int, int, int],
|
| 517 |
+
axes_lens: Tuple[int, int, int],
|
| 518 |
+
theta: int,
|
| 519 |
+
) -> List[torch.Tensor]:
|
| 520 |
+
freqs_cis = []
|
| 521 |
+
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
| 522 |
+
for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
|
| 523 |
+
emb = get_1d_rotary_pos_embed(d, e, theta=theta, freqs_dtype=freqs_dtype)
|
| 524 |
+
freqs_cis.append(emb)
|
| 525 |
+
return freqs_cis
|
| 526 |
+
|
| 527 |
+
def _get_freqs_cis(self, freqs_cis, ids: torch.Tensor) -> torch.Tensor:
|
| 528 |
+
device = ids.device
|
| 529 |
+
if ids.device.type == "mps":
|
| 530 |
+
ids = ids.to("cpu")
|
| 531 |
+
|
| 532 |
+
result = []
|
| 533 |
+
for i in range(len(self.axes_dim)):
|
| 534 |
+
freqs = freqs_cis[i].to(ids.device)
|
| 535 |
+
index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
|
| 536 |
+
result.append(
|
| 537 |
+
torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index)
|
| 538 |
+
)
|
| 539 |
+
return torch.cat(result, dim=-1).to(device)
|
| 540 |
+
|
| 541 |
+
def forward(
|
| 542 |
+
self,
|
| 543 |
+
freqs_cis,
|
| 544 |
+
attention_mask,
|
| 545 |
+
l_effective_ref_img_len,
|
| 546 |
+
l_effective_img_len,
|
| 547 |
+
ref_img_sizes,
|
| 548 |
+
img_sizes,
|
| 549 |
+
device,
|
| 550 |
+
):
|
| 551 |
+
batch_size = len(attention_mask)
|
| 552 |
+
p = self.patch_size
|
| 553 |
+
|
| 554 |
+
encoder_seq_len = attention_mask.shape[1]
|
| 555 |
+
l_effective_cap_len = attention_mask.sum(dim=1).tolist()
|
| 556 |
+
|
| 557 |
+
seq_lengths = [
|
| 558 |
+
cap_len + sum(ref_img_len) + img_len
|
| 559 |
+
for cap_len, ref_img_len, img_len in zip(
|
| 560 |
+
l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len
|
| 561 |
+
)
|
| 562 |
+
]
|
| 563 |
+
|
| 564 |
+
max_seq_len = max(seq_lengths)
|
| 565 |
+
max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
|
| 566 |
+
max_img_len = max(l_effective_img_len)
|
| 567 |
+
|
| 568 |
+
position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
|
| 569 |
+
|
| 570 |
+
for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
|
| 571 |
+
position_ids[i, :cap_seq_len] = repeat(
|
| 572 |
+
torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3"
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
pe_shift = cap_seq_len
|
| 576 |
+
pe_shift_len = cap_seq_len
|
| 577 |
+
|
| 578 |
+
if ref_img_sizes[i] is not None:
|
| 579 |
+
for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]):
|
| 580 |
+
H, W = ref_img_size
|
| 581 |
+
ref_H_tokens, ref_W_tokens = H // p, W // p
|
| 582 |
+
assert ref_H_tokens * ref_W_tokens == ref_img_len
|
| 583 |
+
|
| 584 |
+
row_ids = repeat(
|
| 585 |
+
torch.arange(ref_H_tokens, dtype=torch.int32, device=device),
|
| 586 |
+
"h -> h w", w=ref_W_tokens,
|
| 587 |
+
).flatten()
|
| 588 |
+
col_ids = repeat(
|
| 589 |
+
torch.arange(ref_W_tokens, dtype=torch.int32, device=device),
|
| 590 |
+
"w -> h w", h=ref_H_tokens,
|
| 591 |
+
).flatten()
|
| 592 |
+
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift
|
| 593 |
+
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids
|
| 594 |
+
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids
|
| 595 |
+
|
| 596 |
+
pe_shift += max(ref_H_tokens, ref_W_tokens)
|
| 597 |
+
pe_shift_len += ref_img_len
|
| 598 |
+
|
| 599 |
+
H, W = img_sizes[i]
|
| 600 |
+
H_tokens, W_tokens = H // p, W // p
|
| 601 |
+
assert H_tokens * W_tokens == l_effective_img_len[i]
|
| 602 |
+
|
| 603 |
+
row_ids = repeat(
|
| 604 |
+
torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens
|
| 605 |
+
).flatten()
|
| 606 |
+
col_ids = repeat(
|
| 607 |
+
torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens
|
| 608 |
+
).flatten()
|
| 609 |
+
|
| 610 |
+
assert pe_shift_len + l_effective_img_len[i] == seq_len
|
| 611 |
+
position_ids[i, pe_shift_len: seq_len, 0] = pe_shift
|
| 612 |
+
position_ids[i, pe_shift_len: seq_len, 1] = row_ids
|
| 613 |
+
position_ids[i, pe_shift_len: seq_len, 2] = col_ids
|
| 614 |
+
|
| 615 |
+
freqs_cis = self._get_freqs_cis(freqs_cis, position_ids)
|
| 616 |
+
|
| 617 |
+
cap_freqs_cis = torch.zeros(
|
| 618 |
+
batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
|
| 619 |
+
)
|
| 620 |
+
ref_img_freqs_cis = torch.zeros(
|
| 621 |
+
batch_size, max_ref_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
|
| 622 |
+
)
|
| 623 |
+
img_freqs_cis = torch.zeros(
|
| 624 |
+
batch_size, max_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(
|
| 628 |
+
zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)
|
| 629 |
+
):
|
| 630 |
+
cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
|
| 631 |
+
ref_img_freqs_cis[i, :sum(ref_img_len)] = freqs_cis[
|
| 632 |
+
i, cap_seq_len:cap_seq_len + sum(ref_img_len)
|
| 633 |
+
]
|
| 634 |
+
img_freqs_cis[i, :img_len] = freqs_cis[
|
| 635 |
+
i,
|
| 636 |
+
cap_seq_len + sum(ref_img_len):cap_seq_len + sum(ref_img_len) + img_len,
|
| 637 |
+
]
|
| 638 |
+
|
| 639 |
+
return (
|
| 640 |
+
cap_freqs_cis,
|
| 641 |
+
ref_img_freqs_cis,
|
| 642 |
+
img_freqs_cis,
|
| 643 |
+
freqs_cis,
|
| 644 |
+
l_effective_cap_len,
|
| 645 |
+
seq_lengths,
|
| 646 |
+
)
|
| 647 |
+
|
| 648 |
+
|
| 649 |
+
class TransformerBlock(nn.Module):
|
| 650 |
+
"""
|
| 651 |
+
Transformer block for refiner model.
|
| 652 |
+
"""
|
| 653 |
+
|
| 654 |
+
def __init__(
|
| 655 |
+
self,
|
| 656 |
+
dim: int,
|
| 657 |
+
num_attention_heads: int,
|
| 658 |
+
num_kv_heads: int,
|
| 659 |
+
multiple_of: int,
|
| 660 |
+
ffn_dim_multiplier: float,
|
| 661 |
+
norm_eps: float,
|
| 662 |
+
modulation: bool = True,
|
| 663 |
+
) -> None:
|
| 664 |
+
super().__init__()
|
| 665 |
+
self.head_dim = dim // num_attention_heads
|
| 666 |
+
self.modulation = modulation
|
| 667 |
+
|
| 668 |
+
try:
|
| 669 |
+
processor = AttnProcessorFlash2Varlen()
|
| 670 |
+
except ImportError:
|
| 671 |
+
processor = AttnProcessor()
|
| 672 |
+
|
| 673 |
+
self.attn = Attention(
|
| 674 |
+
query_dim=dim,
|
| 675 |
+
cross_attention_dim=None,
|
| 676 |
+
dim_head=dim // num_attention_heads,
|
| 677 |
+
qk_norm="rms_norm",
|
| 678 |
+
heads=num_attention_heads,
|
| 679 |
+
kv_heads=num_kv_heads,
|
| 680 |
+
eps=1e-5,
|
| 681 |
+
bias=False,
|
| 682 |
+
out_bias=False,
|
| 683 |
+
processor=processor,
|
| 684 |
+
)
|
| 685 |
+
|
| 686 |
+
self.feed_forward = LuminaFeedForward(
|
| 687 |
+
dim=dim,
|
| 688 |
+
inner_dim=4 * dim,
|
| 689 |
+
multiple_of=multiple_of,
|
| 690 |
+
ffn_dim_multiplier=ffn_dim_multiplier,
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
if modulation:
|
| 694 |
+
self.norm1 = LuminaRMSNormZero(
|
| 695 |
+
embedding_dim=dim,
|
| 696 |
+
norm_eps=norm_eps,
|
| 697 |
+
norm_elementwise_affine=True,
|
| 698 |
+
)
|
| 699 |
+
else:
|
| 700 |
+
self.norm1 = RMSNorm(dim, eps=norm_eps)
|
| 701 |
+
|
| 702 |
+
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
|
| 703 |
+
self.norm2 = RMSNorm(dim, eps=norm_eps)
|
| 704 |
+
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
|
| 705 |
+
|
| 706 |
+
self.initialize_weights()
|
| 707 |
+
|
| 708 |
+
def initialize_weights(self) -> None:
|
| 709 |
+
nn.init.xavier_uniform_(self.attn.to_q.weight)
|
| 710 |
+
nn.init.xavier_uniform_(self.attn.to_k.weight)
|
| 711 |
+
nn.init.xavier_uniform_(self.attn.to_v.weight)
|
| 712 |
+
nn.init.xavier_uniform_(self.attn.to_out[0].weight)
|
| 713 |
+
|
| 714 |
+
nn.init.xavier_uniform_(self.feed_forward.linear_1.weight)
|
| 715 |
+
nn.init.xavier_uniform_(self.feed_forward.linear_2.weight)
|
| 716 |
+
nn.init.xavier_uniform_(self.feed_forward.linear_3.weight)
|
| 717 |
+
|
| 718 |
+
if self.modulation:
|
| 719 |
+
nn.init.zeros_(self.norm1.linear.weight)
|
| 720 |
+
nn.init.zeros_(self.norm1.linear.bias)
|
| 721 |
+
|
| 722 |
+
def forward(
|
| 723 |
+
self,
|
| 724 |
+
hidden_states: torch.Tensor,
|
| 725 |
+
attention_mask: torch.Tensor,
|
| 726 |
+
image_rotary_emb: torch.Tensor,
|
| 727 |
+
temb: Optional[torch.Tensor] = None,
|
| 728 |
+
) -> torch.Tensor:
|
| 729 |
+
enable_taylorseer = getattr(self, 'enable_taylorseer', False)
|
| 730 |
+
if enable_taylorseer:
|
| 731 |
+
if self.modulation:
|
| 732 |
+
if temb is None:
|
| 733 |
+
raise ValueError("temb must be provided when modulation is enabled")
|
| 734 |
+
|
| 735 |
+
if self.current['type'] == 'full':
|
| 736 |
+
self.current['module'] = 'total'
|
| 737 |
+
taylor_cache_init(cache_dic=self.cache_dic, current=self.current)
|
| 738 |
+
|
| 739 |
+
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
| 740 |
+
attn_output = self.attn(
|
| 741 |
+
hidden_states=norm_hidden_states,
|
| 742 |
+
encoder_hidden_states=norm_hidden_states,
|
| 743 |
+
attention_mask=attention_mask,
|
| 744 |
+
image_rotary_emb=image_rotary_emb,
|
| 745 |
+
)
|
| 746 |
+
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
|
| 747 |
+
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
|
| 748 |
+
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
|
| 749 |
+
|
| 750 |
+
derivative_approximation(cache_dic=self.cache_dic, current=self.current, feature=hidden_states)
|
| 751 |
+
|
| 752 |
+
elif self.current['type'] == 'Taylor':
|
| 753 |
+
self.current['module'] = 'total'
|
| 754 |
+
hidden_states = taylor_formula(cache_dic=self.cache_dic, current=self.current)
|
| 755 |
+
else:
|
| 756 |
+
norm_hidden_states = self.norm1(hidden_states)
|
| 757 |
+
attn_output = self.attn(
|
| 758 |
+
hidden_states=norm_hidden_states,
|
| 759 |
+
encoder_hidden_states=norm_hidden_states,
|
| 760 |
+
attention_mask=attention_mask,
|
| 761 |
+
image_rotary_emb=image_rotary_emb,
|
| 762 |
+
)
|
| 763 |
+
hidden_states = hidden_states + self.norm2(attn_output)
|
| 764 |
+
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
|
| 765 |
+
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
|
| 766 |
+
else:
|
| 767 |
+
if self.modulation:
|
| 768 |
+
if temb is None:
|
| 769 |
+
raise ValueError("temb must be provided when modulation is enabled")
|
| 770 |
+
|
| 771 |
+
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
| 772 |
+
attn_output = self.attn(
|
| 773 |
+
hidden_states=norm_hidden_states,
|
| 774 |
+
encoder_hidden_states=norm_hidden_states,
|
| 775 |
+
attention_mask=attention_mask,
|
| 776 |
+
image_rotary_emb=image_rotary_emb,
|
| 777 |
+
)
|
| 778 |
+
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
|
| 779 |
+
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
|
| 780 |
+
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
|
| 781 |
+
else:
|
| 782 |
+
norm_hidden_states = self.norm1(hidden_states)
|
| 783 |
+
attn_output = self.attn(
|
| 784 |
+
hidden_states=norm_hidden_states,
|
| 785 |
+
encoder_hidden_states=norm_hidden_states,
|
| 786 |
+
attention_mask=attention_mask,
|
| 787 |
+
image_rotary_emb=image_rotary_emb,
|
| 788 |
+
)
|
| 789 |
+
hidden_states = hidden_states + self.norm2(attn_output)
|
| 790 |
+
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
|
| 791 |
+
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
|
| 792 |
+
|
| 793 |
+
return hidden_states
|
| 794 |
+
|
| 795 |
+
|
| 796 |
+
class Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
| 797 |
+
"""
|
| 798 |
+
Transformer 2D Model.
|
| 799 |
+
"""
|
| 800 |
+
|
| 801 |
+
_supports_gradient_checkpointing = True
|
| 802 |
+
_no_split_modules = ["TransformerBlock"]
|
| 803 |
+
_skip_layerwise_casting_patterns = ["x_embedder", "norm"]
|
| 804 |
+
|
| 805 |
+
@register_to_config
|
| 806 |
+
def __init__(
|
| 807 |
+
self,
|
| 808 |
+
patch_size: int = 2,
|
| 809 |
+
in_channels: int = 16,
|
| 810 |
+
out_channels: Optional[int] = None,
|
| 811 |
+
hidden_size: int = 2304,
|
| 812 |
+
num_layers: int = 26,
|
| 813 |
+
num_refiner_layers: int = 2,
|
| 814 |
+
num_attention_heads: int = 24,
|
| 815 |
+
num_kv_heads: int = 8,
|
| 816 |
+
multiple_of: int = 256,
|
| 817 |
+
ffn_dim_multiplier: Optional[float] = None,
|
| 818 |
+
norm_eps: float = 1e-5,
|
| 819 |
+
axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
|
| 820 |
+
axes_lens: Tuple[int, int, int] = (300, 512, 512),
|
| 821 |
+
text_feat_dim: int = 1024,
|
| 822 |
+
timestep_scale: float = 1.0,
|
| 823 |
+
) -> None:
|
| 824 |
+
super().__init__()
|
| 825 |
+
|
| 826 |
+
if (hidden_size // num_attention_heads) != sum(axes_dim_rope):
|
| 827 |
+
raise ValueError(
|
| 828 |
+
f"hidden_size // num_attention_heads ({hidden_size // num_attention_heads}) "
|
| 829 |
+
f"must equal sum(axes_dim_rope) ({sum(axes_dim_rope)})"
|
| 830 |
+
)
|
| 831 |
+
|
| 832 |
+
self.out_channels = out_channels or in_channels
|
| 833 |
+
|
| 834 |
+
self.rope_embedder = RotaryPosEmbed(
|
| 835 |
+
theta=10000,
|
| 836 |
+
axes_dim=axes_dim_rope,
|
| 837 |
+
axes_lens=axes_lens,
|
| 838 |
+
patch_size=patch_size,
|
| 839 |
+
)
|
| 840 |
+
|
| 841 |
+
self.x_embedder = nn.Linear(
|
| 842 |
+
in_features=patch_size * patch_size * in_channels,
|
| 843 |
+
out_features=hidden_size,
|
| 844 |
+
)
|
| 845 |
+
|
| 846 |
+
self.ref_image_patch_embedder = nn.Linear(
|
| 847 |
+
in_features=patch_size * patch_size * in_channels,
|
| 848 |
+
out_features=hidden_size,
|
| 849 |
+
)
|
| 850 |
+
|
| 851 |
+
self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
|
| 852 |
+
hidden_size=hidden_size,
|
| 853 |
+
text_feat_dim=text_feat_dim,
|
| 854 |
+
norm_eps=norm_eps,
|
| 855 |
+
timestep_scale=timestep_scale,
|
| 856 |
+
)
|
| 857 |
+
|
| 858 |
+
self.noise_refiner = nn.ModuleList([
|
| 859 |
+
TransformerBlock(
|
| 860 |
+
hidden_size, num_attention_heads, num_kv_heads,
|
| 861 |
+
multiple_of, ffn_dim_multiplier, norm_eps, modulation=True,
|
| 862 |
+
)
|
| 863 |
+
for _ in range(num_refiner_layers)
|
| 864 |
+
])
|
| 865 |
+
|
| 866 |
+
self.ref_image_refiner = nn.ModuleList([
|
| 867 |
+
TransformerBlock(
|
| 868 |
+
hidden_size, num_attention_heads, num_kv_heads,
|
| 869 |
+
multiple_of, ffn_dim_multiplier, norm_eps, modulation=True,
|
| 870 |
+
)
|
| 871 |
+
for _ in range(num_refiner_layers)
|
| 872 |
+
])
|
| 873 |
+
|
| 874 |
+
self.context_refiner = nn.ModuleList([
|
| 875 |
+
TransformerBlock(
|
| 876 |
+
hidden_size, num_attention_heads, num_kv_heads,
|
| 877 |
+
multiple_of, ffn_dim_multiplier, norm_eps, modulation=False,
|
| 878 |
+
)
|
| 879 |
+
for _ in range(num_refiner_layers)
|
| 880 |
+
])
|
| 881 |
+
|
| 882 |
+
self.layers = nn.ModuleList([
|
| 883 |
+
TransformerBlock(
|
| 884 |
+
hidden_size, num_attention_heads, num_kv_heads,
|
| 885 |
+
multiple_of, ffn_dim_multiplier, norm_eps, modulation=True,
|
| 886 |
+
)
|
| 887 |
+
for _ in range(num_layers)
|
| 888 |
+
])
|
| 889 |
+
|
| 890 |
+
self.norm_out = LuminaLayerNormContinuous(
|
| 891 |
+
embedding_dim=hidden_size,
|
| 892 |
+
conditioning_embedding_dim=min(hidden_size, 1024),
|
| 893 |
+
elementwise_affine=False,
|
| 894 |
+
eps=1e-6,
|
| 895 |
+
bias=True,
|
| 896 |
+
out_dim=patch_size * patch_size * self.out_channels,
|
| 897 |
+
)
|
| 898 |
+
|
| 899 |
+
self.image_index_embedding = nn.Parameter(torch.randn(5, hidden_size))
|
| 900 |
+
|
| 901 |
+
self.gradient_checkpointing = False
|
| 902 |
+
|
| 903 |
+
self.initialize_weights()
|
| 904 |
+
|
| 905 |
+
self.enable_teacache = False
|
| 906 |
+
self.teacache_rel_l1_thresh = 0.05
|
| 907 |
+
self.teacache_params = TeaCacheParams()
|
| 908 |
+
|
| 909 |
+
coefficients = [-5.48259225, 11.48772289, -4.47407401, 2.47730926, -0.03316487]
|
| 910 |
+
self.rescale_func = np.poly1d(coefficients)
|
| 911 |
+
|
| 912 |
+
def initialize_weights(self) -> None:
|
| 913 |
+
nn.init.xavier_uniform_(self.x_embedder.weight)
|
| 914 |
+
nn.init.constant_(self.x_embedder.bias, 0.0)
|
| 915 |
+
|
| 916 |
+
nn.init.xavier_uniform_(self.ref_image_patch_embedder.weight)
|
| 917 |
+
nn.init.constant_(self.ref_image_patch_embedder.bias, 0.0)
|
| 918 |
+
|
| 919 |
+
nn.init.zeros_(self.norm_out.linear_1.weight)
|
| 920 |
+
nn.init.zeros_(self.norm_out.linear_1.bias)
|
| 921 |
+
nn.init.zeros_(self.norm_out.linear_2.weight)
|
| 922 |
+
nn.init.zeros_(self.norm_out.linear_2.bias)
|
| 923 |
+
|
| 924 |
+
nn.init.normal_(self.image_index_embedding, std=0.02)
|
| 925 |
+
|
| 926 |
+
def img_patch_embed_and_refine(
|
| 927 |
+
self,
|
| 928 |
+
hidden_states,
|
| 929 |
+
ref_image_hidden_states,
|
| 930 |
+
padded_img_mask,
|
| 931 |
+
padded_ref_img_mask,
|
| 932 |
+
noise_rotary_emb,
|
| 933 |
+
ref_img_rotary_emb,
|
| 934 |
+
l_effective_ref_img_len,
|
| 935 |
+
l_effective_img_len,
|
| 936 |
+
temb,
|
| 937 |
+
):
|
| 938 |
+
batch_size = len(hidden_states)
|
| 939 |
+
max_combined_img_len = max([
|
| 940 |
+
img_len + sum(ref_img_len)
|
| 941 |
+
for img_len, ref_img_len in zip(l_effective_img_len, l_effective_ref_img_len)
|
| 942 |
+
])
|
| 943 |
+
|
| 944 |
+
hidden_states = self.x_embedder(hidden_states)
|
| 945 |
+
ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states)
|
| 946 |
+
|
| 947 |
+
for i in range(batch_size):
|
| 948 |
+
shift = 0
|
| 949 |
+
for j, ref_img_len in enumerate(l_effective_ref_img_len[i]):
|
| 950 |
+
ref_image_hidden_states[i, shift:shift + ref_img_len, :] = (
|
| 951 |
+
ref_image_hidden_states[i, shift:shift + ref_img_len, :]
|
| 952 |
+
+ self.image_index_embedding[j]
|
| 953 |
+
)
|
| 954 |
+
shift += ref_img_len
|
| 955 |
+
|
| 956 |
+
for layer in self.noise_refiner:
|
| 957 |
+
hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
|
| 958 |
+
|
| 959 |
+
flat_l_effective_ref_img_len = list(itertools.chain(*l_effective_ref_img_len))
|
| 960 |
+
num_ref_images = len(flat_l_effective_ref_img_len)
|
| 961 |
+
max_ref_img_len = max(flat_l_effective_ref_img_len)
|
| 962 |
+
|
| 963 |
+
batch_ref_img_mask = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, dtype=torch.bool)
|
| 964 |
+
batch_ref_image_hidden_states = ref_image_hidden_states.new_zeros(
|
| 965 |
+
num_ref_images, max_ref_img_len, self.config.hidden_size
|
| 966 |
+
)
|
| 967 |
+
batch_ref_img_rotary_emb = hidden_states.new_zeros(
|
| 968 |
+
num_ref_images, max_ref_img_len, ref_img_rotary_emb.shape[-1], dtype=ref_img_rotary_emb.dtype
|
| 969 |
+
)
|
| 970 |
+
batch_temb = temb.new_zeros(num_ref_images, *temb.shape[1:], dtype=temb.dtype)
|
| 971 |
+
|
| 972 |
+
idx = 0
|
| 973 |
+
for i in range(batch_size):
|
| 974 |
+
shift = 0
|
| 975 |
+
for ref_img_len in l_effective_ref_img_len[i]:
|
| 976 |
+
batch_ref_img_mask[idx, :ref_img_len] = True
|
| 977 |
+
batch_ref_image_hidden_states[idx, :ref_img_len] = ref_image_hidden_states[i, shift:shift + ref_img_len]
|
| 978 |
+
batch_ref_img_rotary_emb[idx, :ref_img_len] = ref_img_rotary_emb[i, shift:shift + ref_img_len]
|
| 979 |
+
batch_temb[idx] = temb[i]
|
| 980 |
+
shift += ref_img_len
|
| 981 |
+
idx += 1
|
| 982 |
+
|
| 983 |
+
for layer in self.ref_image_refiner:
|
| 984 |
+
batch_ref_image_hidden_states = layer(
|
| 985 |
+
batch_ref_image_hidden_states, batch_ref_img_mask, batch_ref_img_rotary_emb, batch_temb
|
| 986 |
+
)
|
| 987 |
+
|
| 988 |
+
idx = 0
|
| 989 |
+
for i in range(batch_size):
|
| 990 |
+
shift = 0
|
| 991 |
+
for ref_img_len in l_effective_ref_img_len[i]:
|
| 992 |
+
ref_image_hidden_states[i, shift:shift + ref_img_len] = batch_ref_image_hidden_states[idx, :ref_img_len]
|
| 993 |
+
shift += ref_img_len
|
| 994 |
+
idx += 1
|
| 995 |
+
|
| 996 |
+
combined_img_hidden_states = hidden_states.new_zeros(
|
| 997 |
+
batch_size, max_combined_img_len, self.config.hidden_size
|
| 998 |
+
)
|
| 999 |
+
for i, (ref_img_len, img_len) in enumerate(zip(l_effective_ref_img_len, l_effective_img_len)):
|
| 1000 |
+
combined_img_hidden_states[i, :sum(ref_img_len)] = ref_image_hidden_states[i, :sum(ref_img_len)]
|
| 1001 |
+
combined_img_hidden_states[i, sum(ref_img_len):sum(ref_img_len) + img_len] = hidden_states[i, :img_len]
|
| 1002 |
+
|
| 1003 |
+
return combined_img_hidden_states
|
| 1004 |
+
|
| 1005 |
+
def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states):
|
| 1006 |
+
batch_size = len(hidden_states)
|
| 1007 |
+
p = self.config.patch_size
|
| 1008 |
+
device = hidden_states[0].device
|
| 1009 |
+
|
| 1010 |
+
img_sizes = [(img.size(1), img.size(2)) for img in hidden_states]
|
| 1011 |
+
l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes]
|
| 1012 |
+
|
| 1013 |
+
if ref_image_hidden_states is not None and len(ref_image_hidden_states) > 0:
|
| 1014 |
+
ref_img_sizes = [
|
| 1015 |
+
[(img.size(1), img.size(2)) for img in imgs] if imgs is not None else None
|
| 1016 |
+
for imgs in ref_image_hidden_states
|
| 1017 |
+
]
|
| 1018 |
+
l_effective_ref_img_len = [
|
| 1019 |
+
[(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes]
|
| 1020 |
+
if _ref_img_sizes is not None else [0]
|
| 1021 |
+
for _ref_img_sizes in ref_img_sizes
|
| 1022 |
+
]
|
| 1023 |
+
else:
|
| 1024 |
+
ref_img_sizes = [None for _ in range(batch_size)]
|
| 1025 |
+
l_effective_ref_img_len = [[0] for _ in range(batch_size)]
|
| 1026 |
+
|
| 1027 |
+
max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
|
| 1028 |
+
max_img_len = max(l_effective_img_len)
|
| 1029 |
+
|
| 1030 |
+
flat_ref_img_hidden_states = []
|
| 1031 |
+
for i in range(batch_size):
|
| 1032 |
+
if ref_img_sizes[i] is not None:
|
| 1033 |
+
imgs = []
|
| 1034 |
+
for ref_img in ref_image_hidden_states[i]:
|
| 1035 |
+
C, H, W = ref_img.size()
|
| 1036 |
+
ref_img = rearrange(ref_img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p)
|
| 1037 |
+
imgs.append(ref_img)
|
| 1038 |
+
flat_ref_img_hidden_states.append(torch.cat(imgs, dim=0))
|
| 1039 |
+
else:
|
| 1040 |
+
flat_ref_img_hidden_states.append(None)
|
| 1041 |
+
|
| 1042 |
+
flat_hidden_states = []
|
| 1043 |
+
for i in range(batch_size):
|
| 1044 |
+
img = hidden_states[i]
|
| 1045 |
+
C, H, W = img.size()
|
| 1046 |
+
img = rearrange(img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p)
|
| 1047 |
+
flat_hidden_states.append(img)
|
| 1048 |
+
|
| 1049 |
+
padded_ref_img_hidden_states = torch.zeros(
|
| 1050 |
+
batch_size, max_ref_img_len, flat_hidden_states[0].shape[-1],
|
| 1051 |
+
device=device, dtype=flat_hidden_states[0].dtype,
|
| 1052 |
+
)
|
| 1053 |
+
padded_ref_img_mask = torch.zeros(batch_size, max_ref_img_len, dtype=torch.bool, device=device)
|
| 1054 |
+
for i in range(batch_size):
|
| 1055 |
+
if ref_img_sizes[i] is not None:
|
| 1056 |
+
padded_ref_img_hidden_states[i, :sum(l_effective_ref_img_len[i])] = flat_ref_img_hidden_states[i]
|
| 1057 |
+
padded_ref_img_mask[i, :sum(l_effective_ref_img_len[i])] = True
|
| 1058 |
+
|
| 1059 |
+
padded_hidden_states = torch.zeros(
|
| 1060 |
+
batch_size, max_img_len, flat_hidden_states[0].shape[-1],
|
| 1061 |
+
device=device, dtype=flat_hidden_states[0].dtype,
|
| 1062 |
+
)
|
| 1063 |
+
padded_img_mask = torch.zeros(batch_size, max_img_len, dtype=torch.bool, device=device)
|
| 1064 |
+
for i in range(batch_size):
|
| 1065 |
+
padded_hidden_states[i, :l_effective_img_len[i]] = flat_hidden_states[i]
|
| 1066 |
+
padded_img_mask[i, :l_effective_img_len[i]] = True
|
| 1067 |
+
|
| 1068 |
+
return (
|
| 1069 |
+
padded_hidden_states,
|
| 1070 |
+
padded_ref_img_hidden_states,
|
| 1071 |
+
padded_img_mask,
|
| 1072 |
+
padded_ref_img_mask,
|
| 1073 |
+
l_effective_ref_img_len,
|
| 1074 |
+
l_effective_img_len,
|
| 1075 |
+
ref_img_sizes,
|
| 1076 |
+
img_sizes,
|
| 1077 |
+
)
|
| 1078 |
+
|
| 1079 |
+
def forward(
|
| 1080 |
+
self,
|
| 1081 |
+
hidden_states: Union[torch.Tensor, List[torch.Tensor]],
|
| 1082 |
+
timestep: torch.Tensor,
|
| 1083 |
+
text_hidden_states: torch.Tensor,
|
| 1084 |
+
freqs_cis: torch.Tensor,
|
| 1085 |
+
text_attention_mask: torch.Tensor,
|
| 1086 |
+
ref_image_hidden_states: Optional[List[List[torch.Tensor]]] = None,
|
| 1087 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 1088 |
+
return_dict: bool = False,
|
| 1089 |
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
| 1090 |
+
enable_taylorseer = getattr(self, 'enable_taylorseer', False)
|
| 1091 |
+
if enable_taylorseer:
|
| 1092 |
+
cal_type(self.cache_dic, self.current)
|
| 1093 |
+
|
| 1094 |
+
if attention_kwargs is not None:
|
| 1095 |
+
attention_kwargs = attention_kwargs.copy()
|
| 1096 |
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
| 1097 |
+
else:
|
| 1098 |
+
lora_scale = 1.0
|
| 1099 |
+
|
| 1100 |
+
if USE_PEFT_BACKEND:
|
| 1101 |
+
scale_lora_layers(self, lora_scale)
|
| 1102 |
+
else:
|
| 1103 |
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
| 1104 |
+
logger.warning(
|
| 1105 |
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
| 1106 |
+
)
|
| 1107 |
+
|
| 1108 |
+
batch_size = len(hidden_states)
|
| 1109 |
+
is_hidden_states_tensor = isinstance(hidden_states, torch.Tensor)
|
| 1110 |
+
|
| 1111 |
+
if is_hidden_states_tensor:
|
| 1112 |
+
assert hidden_states.ndim == 4
|
| 1113 |
+
hidden_states = [_hidden_states for _hidden_states in hidden_states]
|
| 1114 |
+
|
| 1115 |
+
device = hidden_states[0].device
|
| 1116 |
+
|
| 1117 |
+
assert isinstance(text_hidden_states, torch.Tensor), \
|
| 1118 |
+
f"text_hidden_states must be Tensor, got {type(text_hidden_states)}. " \
|
| 1119 |
+
f"Check if freqs_cis and text_hidden_states are swapped in the caller."
|
| 1120 |
+
|
| 1121 |
+
temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype)
|
| 1122 |
+
|
| 1123 |
+
(
|
| 1124 |
+
hidden_states,
|
| 1125 |
+
ref_image_hidden_states,
|
| 1126 |
+
img_mask,
|
| 1127 |
+
ref_img_mask,
|
| 1128 |
+
l_effective_ref_img_len,
|
| 1129 |
+
l_effective_img_len,
|
| 1130 |
+
ref_img_sizes,
|
| 1131 |
+
img_sizes,
|
| 1132 |
+
) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states)
|
| 1133 |
+
|
| 1134 |
+
(
|
| 1135 |
+
context_rotary_emb,
|
| 1136 |
+
ref_img_rotary_emb,
|
| 1137 |
+
noise_rotary_emb,
|
| 1138 |
+
rotary_emb,
|
| 1139 |
+
encoder_seq_lengths,
|
| 1140 |
+
seq_lengths,
|
| 1141 |
+
) = self.rope_embedder(
|
| 1142 |
+
freqs_cis,
|
| 1143 |
+
text_attention_mask,
|
| 1144 |
+
l_effective_ref_img_len,
|
| 1145 |
+
l_effective_img_len,
|
| 1146 |
+
ref_img_sizes,
|
| 1147 |
+
img_sizes,
|
| 1148 |
+
device,
|
| 1149 |
+
)
|
| 1150 |
+
|
| 1151 |
+
# 2. Context refinement
|
| 1152 |
+
for layer in self.context_refiner:
|
| 1153 |
+
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb)
|
| 1154 |
+
|
| 1155 |
+
combined_img_hidden_states = self.img_patch_embed_and_refine(
|
| 1156 |
+
hidden_states,
|
| 1157 |
+
ref_image_hidden_states,
|
| 1158 |
+
img_mask,
|
| 1159 |
+
ref_img_mask,
|
| 1160 |
+
noise_rotary_emb,
|
| 1161 |
+
ref_img_rotary_emb,
|
| 1162 |
+
l_effective_ref_img_len,
|
| 1163 |
+
l_effective_img_len,
|
| 1164 |
+
temb,
|
| 1165 |
+
)
|
| 1166 |
+
|
| 1167 |
+
# 3. Joint Transformer blocks
|
| 1168 |
+
max_seq_len = max(seq_lengths)
|
| 1169 |
+
|
| 1170 |
+
attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
|
| 1171 |
+
joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
|
| 1172 |
+
for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
|
| 1173 |
+
attention_mask[i, :seq_len] = True
|
| 1174 |
+
joint_hidden_states[i, :encoder_seq_len] = text_hidden_states[i, :encoder_seq_len]
|
| 1175 |
+
joint_hidden_states[i, encoder_seq_len:seq_len] = combined_img_hidden_states[i, :seq_len - encoder_seq_len]
|
| 1176 |
+
|
| 1177 |
+
hidden_states = joint_hidden_states
|
| 1178 |
+
|
| 1179 |
+
if self.enable_teacache:
|
| 1180 |
+
teacache_hidden_states = hidden_states.clone()
|
| 1181 |
+
teacache_temb = temb.clone()
|
| 1182 |
+
modulated_inp, _, _, _ = self.layers[0].norm1(teacache_hidden_states, teacache_temb)
|
| 1183 |
+
if self.teacache_params.is_first_or_last_step:
|
| 1184 |
+
should_calc = True
|
| 1185 |
+
self.teacache_params.accumulated_rel_l1_distance = 0
|
| 1186 |
+
else:
|
| 1187 |
+
self.teacache_params.accumulated_rel_l1_distance += self.rescale_func(
|
| 1188 |
+
((modulated_inp - self.teacache_params.previous_modulated_inp).abs().mean()
|
| 1189 |
+
/ self.teacache_params.previous_modulated_inp.abs().mean()).cpu().item()
|
| 1190 |
+
)
|
| 1191 |
+
if self.teacache_params.accumulated_rel_l1_distance < self.teacache_rel_l1_thresh:
|
| 1192 |
+
should_calc = False
|
| 1193 |
+
else:
|
| 1194 |
+
should_calc = True
|
| 1195 |
+
self.teacache_params.accumulated_rel_l1_distance = 0
|
| 1196 |
+
self.teacache_params.previous_modulated_inp = modulated_inp
|
| 1197 |
+
|
| 1198 |
+
if self.enable_teacache:
|
| 1199 |
+
if not should_calc:
|
| 1200 |
+
hidden_states += self.teacache_params.previous_residual
|
| 1201 |
+
else:
|
| 1202 |
+
ori_hidden_states = hidden_states.clone()
|
| 1203 |
+
for layer_idx, layer in enumerate(self.layers):
|
| 1204 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 1205 |
+
hidden_states = self._gradient_checkpointing_func(
|
| 1206 |
+
layer, hidden_states, attention_mask, rotary_emb, temb
|
| 1207 |
+
)
|
| 1208 |
+
else:
|
| 1209 |
+
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
|
| 1210 |
+
self.teacache_params.previous_residual = hidden_states - ori_hidden_states
|
| 1211 |
+
else:
|
| 1212 |
+
if enable_taylorseer:
|
| 1213 |
+
self.current['stream'] = 'layers_stream'
|
| 1214 |
+
|
| 1215 |
+
for layer_idx, layer in enumerate(self.layers):
|
| 1216 |
+
if enable_taylorseer:
|
| 1217 |
+
layer.current = self.current
|
| 1218 |
+
layer.cache_dic = self.cache_dic
|
| 1219 |
+
layer.enable_taylorseer = True
|
| 1220 |
+
self.current['layer'] = layer_idx
|
| 1221 |
+
|
| 1222 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 1223 |
+
hidden_states = self._gradient_checkpointing_func(
|
| 1224 |
+
layer, hidden_states, attention_mask, rotary_emb, temb
|
| 1225 |
+
)
|
| 1226 |
+
else:
|
| 1227 |
+
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
|
| 1228 |
+
|
| 1229 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 1230 |
+
|
| 1231 |
+
p = self.config.patch_size
|
| 1232 |
+
output = []
|
| 1233 |
+
for i, (img_size, img_len, seq_len) in enumerate(zip(img_sizes, l_effective_img_len, seq_lengths)):
|
| 1234 |
+
height, width = img_size
|
| 1235 |
+
output.append(rearrange(
|
| 1236 |
+
hidden_states[i][seq_len - img_len:seq_len],
|
| 1237 |
+
'(h w) (p1 p2 c) -> c (h p1) (w p2)',
|
| 1238 |
+
h=height // p, w=width // p, p1=p, p2=p,
|
| 1239 |
+
))
|
| 1240 |
+
if is_hidden_states_tensor:
|
| 1241 |
+
output = torch.stack(output, dim=0)
|
| 1242 |
+
|
| 1243 |
+
if USE_PEFT_BACKEND:
|
| 1244 |
+
unscale_lora_layers(self, lora_scale)
|
| 1245 |
+
|
| 1246 |
+
if enable_taylorseer:
|
| 1247 |
+
self.current['step'] += 1
|
| 1248 |
+
|
| 1249 |
+
if not return_dict:
|
| 1250 |
+
return output
|
| 1251 |
+
return Transformer2DModelOutput(sample=output)
|
| 1252 |
+
|
| 1253 |
+
|
| 1254 |
+
# ---------------------------------------------------------------------------
|
| 1255 |
+
# FlowMatch Euler Discrete Scheduler (merged from scheduling_flow_match_euler_discrete.py)
|
| 1256 |
+
# ---------------------------------------------------------------------------
|
| 1257 |
+
|
| 1258 |
+
@dataclass
|
| 1259 |
+
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
|
| 1260 |
+
prev_sample: torch.FloatTensor
|
| 1261 |
+
|
| 1262 |
+
|
| 1263 |
+
class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
| 1264 |
+
_compatibles = []
|
| 1265 |
+
order = 1
|
| 1266 |
+
|
| 1267 |
+
@register_to_config
|
| 1268 |
+
def __init__(self, num_train_timesteps: int = 1000, dynamic_time_shift: bool = False):
|
| 1269 |
+
timesteps = torch.linspace(0, 1, num_train_timesteps + 1, dtype=torch.float32)[:-1]
|
| 1270 |
+
self.timesteps = timesteps
|
| 1271 |
+
self._step_index = None
|
| 1272 |
+
self._begin_index = None
|
| 1273 |
+
|
| 1274 |
+
@property
|
| 1275 |
+
def step_index(self):
|
| 1276 |
+
return self._step_index
|
| 1277 |
+
|
| 1278 |
+
@property
|
| 1279 |
+
def begin_index(self):
|
| 1280 |
+
return self._begin_index
|
| 1281 |
+
|
| 1282 |
+
def set_begin_index(self, begin_index: int = 0):
|
| 1283 |
+
self._begin_index = begin_index
|
| 1284 |
+
|
| 1285 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
| 1286 |
+
if schedule_timesteps is None:
|
| 1287 |
+
schedule_timesteps = self._timesteps
|
| 1288 |
+
indices = (schedule_timesteps == timestep).nonzero()
|
| 1289 |
+
pos = 1 if len(indices) > 1 else 0
|
| 1290 |
+
return indices[pos].item()
|
| 1291 |
+
|
| 1292 |
+
def set_timesteps(self, num_inference_steps=None, device=None, timesteps=None, num_tokens=None):
|
| 1293 |
+
if timesteps is None:
|
| 1294 |
+
self.num_inference_steps = num_inference_steps
|
| 1295 |
+
timesteps = np.linspace(0, 1, num_inference_steps + 1, dtype=np.float32)[:-1]
|
| 1296 |
+
if self.config.dynamic_time_shift and num_tokens is not None:
|
| 1297 |
+
m = np.sqrt(num_tokens) / 40
|
| 1298 |
+
timesteps = timesteps / (m - m * timesteps + timesteps)
|
| 1299 |
+
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device)
|
| 1300 |
+
_timesteps = torch.cat([timesteps, torch.ones(1, device=timesteps.device)])
|
| 1301 |
+
self.timesteps = timesteps
|
| 1302 |
+
self._timesteps = _timesteps
|
| 1303 |
+
self._step_index = None
|
| 1304 |
+
self._begin_index = None
|
| 1305 |
+
|
| 1306 |
+
def _init_step_index(self, timestep):
|
| 1307 |
+
if self.begin_index is None:
|
| 1308 |
+
if isinstance(timestep, torch.Tensor):
|
| 1309 |
+
timestep = timestep.to(self.timesteps.device)
|
| 1310 |
+
self._step_index = self.index_for_timestep(timestep)
|
| 1311 |
+
else:
|
| 1312 |
+
self._step_index = self._begin_index
|
| 1313 |
+
|
| 1314 |
+
def step(self, model_output, timestep, sample, generator=None, return_dict=True):
|
| 1315 |
+
if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
|
| 1316 |
+
raise ValueError("Pass scheduler.timesteps values, not integer indices.")
|
| 1317 |
+
if self.step_index is None:
|
| 1318 |
+
self._init_step_index(timestep)
|
| 1319 |
+
sample = sample.to(torch.float32)
|
| 1320 |
+
t = self._timesteps[self.step_index]
|
| 1321 |
+
t_next = self._timesteps[self.step_index + 1]
|
| 1322 |
+
prev_sample = sample + (t_next - t) * model_output
|
| 1323 |
+
prev_sample = prev_sample.to(model_output.dtype)
|
| 1324 |
+
self._step_index += 1
|
| 1325 |
+
if not return_dict:
|
| 1326 |
+
return (prev_sample,)
|
| 1327 |
+
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
|
| 1328 |
+
|
| 1329 |
+
def __len__(self):
|
| 1330 |
+
return self.config.num_train_timesteps
|
requirements-post.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
flash-attn==2.7.4.post1
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.6.0
|
| 2 |
+
torchvision==0.21.0
|
| 3 |
+
torchaudio==2.6.0
|
| 4 |
+
accelerate==1.10.0
|
| 5 |
+
transformers==4.57.6
|
| 6 |
+
librosa==0.11.0
|
| 7 |
+
diffusers==0.34.0
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,2294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": false,
|
| 3 |
+
"add_eos_token": true,
|
| 4 |
+
"add_prefix_space": false,
|
| 5 |
+
"added_tokens_decoder": {
|
| 6 |
+
"0": {
|
| 7 |
+
"content": "<longcat_unk>",
|
| 8 |
+
"lstrip": false,
|
| 9 |
+
"normalized": false,
|
| 10 |
+
"rstrip": false,
|
| 11 |
+
"single_word": false,
|
| 12 |
+
"special": true
|
| 13 |
+
},
|
| 14 |
+
"1": {
|
| 15 |
+
"content": "<longcat_s>",
|
| 16 |
+
"lstrip": false,
|
| 17 |
+
"normalized": false,
|
| 18 |
+
"rstrip": false,
|
| 19 |
+
"single_word": false,
|
| 20 |
+
"special": true
|
| 21 |
+
},
|
| 22 |
+
"2": {
|
| 23 |
+
"content": "</longcat_s>",
|
| 24 |
+
"lstrip": false,
|
| 25 |
+
"normalized": false,
|
| 26 |
+
"rstrip": false,
|
| 27 |
+
"single_word": false,
|
| 28 |
+
"special": true
|
| 29 |
+
},
|
| 30 |
+
"3": {
|
| 31 |
+
"content": "<longcat_pad>",
|
| 32 |
+
"lstrip": false,
|
| 33 |
+
"normalized": false,
|
| 34 |
+
"rstrip": false,
|
| 35 |
+
"single_word": false,
|
| 36 |
+
"special": true
|
| 37 |
+
},
|
| 38 |
+
"4": {
|
| 39 |
+
"content": "<shift_unk>",
|
| 40 |
+
"lstrip": false,
|
| 41 |
+
"normalized": false,
|
| 42 |
+
"rstrip": false,
|
| 43 |
+
"single_word": false,
|
| 44 |
+
"special": true
|
| 45 |
+
},
|
| 46 |
+
"5": {
|
| 47 |
+
"content": "<shift_s>",
|
| 48 |
+
"lstrip": false,
|
| 49 |
+
"normalized": false,
|
| 50 |
+
"rstrip": false,
|
| 51 |
+
"single_word": false,
|
| 52 |
+
"special": true
|
| 53 |
+
},
|
| 54 |
+
"6": {
|
| 55 |
+
"content": "</shift_s>",
|
| 56 |
+
"lstrip": false,
|
| 57 |
+
"normalized": false,
|
| 58 |
+
"rstrip": false,
|
| 59 |
+
"single_word": false,
|
| 60 |
+
"special": true
|
| 61 |
+
},
|
| 62 |
+
"7": {
|
| 63 |
+
"content": "<shift_pad>",
|
| 64 |
+
"lstrip": false,
|
| 65 |
+
"normalized": false,
|
| 66 |
+
"rstrip": false,
|
| 67 |
+
"single_word": false,
|
| 68 |
+
"special": true
|
| 69 |
+
},
|
| 70 |
+
"8": {
|
| 71 |
+
"content": "<mask_0>",
|
| 72 |
+
"lstrip": false,
|
| 73 |
+
"normalized": false,
|
| 74 |
+
"rstrip": false,
|
| 75 |
+
"single_word": false,
|
| 76 |
+
"special": true
|
| 77 |
+
},
|
| 78 |
+
"9": {
|
| 79 |
+
"content": "<reponame>",
|
| 80 |
+
"lstrip": false,
|
| 81 |
+
"normalized": false,
|
| 82 |
+
"rstrip": false,
|
| 83 |
+
"single_word": false,
|
| 84 |
+
"special": true
|
| 85 |
+
},
|
| 86 |
+
"10": {
|
| 87 |
+
"content": "<filename>",
|
| 88 |
+
"lstrip": false,
|
| 89 |
+
"normalized": false,
|
| 90 |
+
"rstrip": false,
|
| 91 |
+
"single_word": false,
|
| 92 |
+
"special": true
|
| 93 |
+
},
|
| 94 |
+
"11": {
|
| 95 |
+
"content": "<gh_stars>",
|
| 96 |
+
"lstrip": false,
|
| 97 |
+
"normalized": false,
|
| 98 |
+
"rstrip": false,
|
| 99 |
+
"single_word": false,
|
| 100 |
+
"special": true
|
| 101 |
+
},
|
| 102 |
+
"12": {
|
| 103 |
+
"content": "<issue_start>",
|
| 104 |
+
"lstrip": false,
|
| 105 |
+
"normalized": false,
|
| 106 |
+
"rstrip": false,
|
| 107 |
+
"single_word": false,
|
| 108 |
+
"special": true
|
| 109 |
+
},
|
| 110 |
+
"13": {
|
| 111 |
+
"content": "<issue_comment>",
|
| 112 |
+
"lstrip": false,
|
| 113 |
+
"normalized": false,
|
| 114 |
+
"rstrip": false,
|
| 115 |
+
"single_word": false,
|
| 116 |
+
"special": true
|
| 117 |
+
},
|
| 118 |
+
"14": {
|
| 119 |
+
"content": "<issue_closed>",
|
| 120 |
+
"lstrip": false,
|
| 121 |
+
"normalized": false,
|
| 122 |
+
"rstrip": false,
|
| 123 |
+
"single_word": false,
|
| 124 |
+
"special": true
|
| 125 |
+
},
|
| 126 |
+
"15": {
|
| 127 |
+
"content": "<jupyter_start>",
|
| 128 |
+
"lstrip": false,
|
| 129 |
+
"normalized": false,
|
| 130 |
+
"rstrip": false,
|
| 131 |
+
"single_word": false,
|
| 132 |
+
"special": true
|
| 133 |
+
},
|
| 134 |
+
"16": {
|
| 135 |
+
"content": "<jupyter_text>",
|
| 136 |
+
"lstrip": false,
|
| 137 |
+
"normalized": false,
|
| 138 |
+
"rstrip": false,
|
| 139 |
+
"single_word": false,
|
| 140 |
+
"special": true
|
| 141 |
+
},
|
| 142 |
+
"17": {
|
| 143 |
+
"content": "<jupyter_code>",
|
| 144 |
+
"lstrip": false,
|
| 145 |
+
"normalized": false,
|
| 146 |
+
"rstrip": false,
|
| 147 |
+
"single_word": false,
|
| 148 |
+
"special": true
|
| 149 |
+
},
|
| 150 |
+
"18": {
|
| 151 |
+
"content": "<jupyter_output>",
|
| 152 |
+
"lstrip": false,
|
| 153 |
+
"normalized": false,
|
| 154 |
+
"rstrip": false,
|
| 155 |
+
"single_word": false,
|
| 156 |
+
"special": true
|
| 157 |
+
},
|
| 158 |
+
"19": {
|
| 159 |
+
"content": "<empty_output>",
|
| 160 |
+
"lstrip": false,
|
| 161 |
+
"normalized": false,
|
| 162 |
+
"rstrip": false,
|
| 163 |
+
"single_word": false,
|
| 164 |
+
"special": true
|
| 165 |
+
},
|
| 166 |
+
"20": {
|
| 167 |
+
"content": "<commit_before>",
|
| 168 |
+
"lstrip": false,
|
| 169 |
+
"normalized": false,
|
| 170 |
+
"rstrip": false,
|
| 171 |
+
"single_word": false,
|
| 172 |
+
"special": true
|
| 173 |
+
},
|
| 174 |
+
"21": {
|
| 175 |
+
"content": "<commit_msg>",
|
| 176 |
+
"lstrip": false,
|
| 177 |
+
"normalized": false,
|
| 178 |
+
"rstrip": false,
|
| 179 |
+
"single_word": false,
|
| 180 |
+
"special": true
|
| 181 |
+
},
|
| 182 |
+
"22": {
|
| 183 |
+
"content": "<commit_after>",
|
| 184 |
+
"lstrip": false,
|
| 185 |
+
"normalized": false,
|
| 186 |
+
"rstrip": false,
|
| 187 |
+
"single_word": false,
|
| 188 |
+
"special": true
|
| 189 |
+
},
|
| 190 |
+
"23": {
|
| 191 |
+
"content": "<program_lang>",
|
| 192 |
+
"lstrip": false,
|
| 193 |
+
"normalized": false,
|
| 194 |
+
"rstrip": false,
|
| 195 |
+
"single_word": false,
|
| 196 |
+
"special": true
|
| 197 |
+
},
|
| 198 |
+
"24": {
|
| 199 |
+
"content": "<|image_placeholder|>",
|
| 200 |
+
"lstrip": false,
|
| 201 |
+
"normalized": false,
|
| 202 |
+
"rstrip": false,
|
| 203 |
+
"single_word": false,
|
| 204 |
+
"special": true
|
| 205 |
+
},
|
| 206 |
+
"25": {
|
| 207 |
+
"content": "<|url_placeholder|>",
|
| 208 |
+
"lstrip": false,
|
| 209 |
+
"normalized": false,
|
| 210 |
+
"rstrip": false,
|
| 211 |
+
"single_word": false,
|
| 212 |
+
"special": true
|
| 213 |
+
},
|
| 214 |
+
"26": {
|
| 215 |
+
"content": "<|hyperlink_placeholder|>",
|
| 216 |
+
"lstrip": false,
|
| 217 |
+
"normalized": false,
|
| 218 |
+
"rstrip": false,
|
| 219 |
+
"single_word": false,
|
| 220 |
+
"special": true
|
| 221 |
+
},
|
| 222 |
+
"27": {
|
| 223 |
+
"content": "<|table_placeholder|>",
|
| 224 |
+
"lstrip": false,
|
| 225 |
+
"normalized": false,
|
| 226 |
+
"rstrip": false,
|
| 227 |
+
"single_word": false,
|
| 228 |
+
"special": true
|
| 229 |
+
},
|
| 230 |
+
"28": {
|
| 231 |
+
"content": "<|equation_placeholder|>",
|
| 232 |
+
"lstrip": false,
|
| 233 |
+
"normalized": false,
|
| 234 |
+
"rstrip": false,
|
| 235 |
+
"single_word": false,
|
| 236 |
+
"special": true
|
| 237 |
+
},
|
| 238 |
+
"29": {
|
| 239 |
+
"content": "<|code_placeholder|>",
|
| 240 |
+
"lstrip": false,
|
| 241 |
+
"normalized": false,
|
| 242 |
+
"rstrip": false,
|
| 243 |
+
"single_word": false,
|
| 244 |
+
"special": true
|
| 245 |
+
},
|
| 246 |
+
"30": {
|
| 247 |
+
"content": "<|reference_placeholder|>",
|
| 248 |
+
"lstrip": false,
|
| 249 |
+
"normalized": false,
|
| 250 |
+
"rstrip": false,
|
| 251 |
+
"single_word": false,
|
| 252 |
+
"special": true
|
| 253 |
+
},
|
| 254 |
+
"31": {
|
| 255 |
+
"content": "<|endoftext|>",
|
| 256 |
+
"lstrip": false,
|
| 257 |
+
"normalized": false,
|
| 258 |
+
"rstrip": false,
|
| 259 |
+
"single_word": false,
|
| 260 |
+
"special": true
|
| 261 |
+
},
|
| 262 |
+
"32": {
|
| 263 |
+
"content": "<fim_prefix>",
|
| 264 |
+
"lstrip": false,
|
| 265 |
+
"normalized": false,
|
| 266 |
+
"rstrip": false,
|
| 267 |
+
"single_word": false,
|
| 268 |
+
"special": true
|
| 269 |
+
},
|
| 270 |
+
"33": {
|
| 271 |
+
"content": "<fim_middle>",
|
| 272 |
+
"lstrip": false,
|
| 273 |
+
"normalized": false,
|
| 274 |
+
"rstrip": false,
|
| 275 |
+
"single_word": false,
|
| 276 |
+
"special": true
|
| 277 |
+
},
|
| 278 |
+
"34": {
|
| 279 |
+
"content": "<fim_suffix>",
|
| 280 |
+
"lstrip": false,
|
| 281 |
+
"normalized": false,
|
| 282 |
+
"rstrip": false,
|
| 283 |
+
"single_word": false,
|
| 284 |
+
"special": true
|
| 285 |
+
},
|
| 286 |
+
"35": {
|
| 287 |
+
"content": "<fim_pad>",
|
| 288 |
+
"lstrip": false,
|
| 289 |
+
"normalized": false,
|
| 290 |
+
"rstrip": false,
|
| 291 |
+
"single_word": false,
|
| 292 |
+
"special": true
|
| 293 |
+
},
|
| 294 |
+
"36": {
|
| 295 |
+
"content": "<longcat_think>",
|
| 296 |
+
"lstrip": false,
|
| 297 |
+
"normalized": false,
|
| 298 |
+
"rstrip": false,
|
| 299 |
+
"single_word": false,
|
| 300 |
+
"special": false
|
| 301 |
+
},
|
| 302 |
+
"37": {
|
| 303 |
+
"content": "</longcat_think>",
|
| 304 |
+
"lstrip": false,
|
| 305 |
+
"normalized": false,
|
| 306 |
+
"rstrip": false,
|
| 307 |
+
"single_word": false,
|
| 308 |
+
"special": false
|
| 309 |
+
},
|
| 310 |
+
"38": {
|
| 311 |
+
"content": "<longcat_answer>",
|
| 312 |
+
"lstrip": false,
|
| 313 |
+
"normalized": false,
|
| 314 |
+
"rstrip": false,
|
| 315 |
+
"single_word": false,
|
| 316 |
+
"special": false
|
| 317 |
+
},
|
| 318 |
+
"39": {
|
| 319 |
+
"content": "</longcat_answer>",
|
| 320 |
+
"lstrip": false,
|
| 321 |
+
"normalized": false,
|
| 322 |
+
"rstrip": false,
|
| 323 |
+
"single_word": false,
|
| 324 |
+
"special": false
|
| 325 |
+
},
|
| 326 |
+
"40": {
|
| 327 |
+
"content": "<longcat_files>",
|
| 328 |
+
"lstrip": false,
|
| 329 |
+
"normalized": false,
|
| 330 |
+
"rstrip": false,
|
| 331 |
+
"single_word": false,
|
| 332 |
+
"special": false
|
| 333 |
+
},
|
| 334 |
+
"41": {
|
| 335 |
+
"content": "</longcat_files>",
|
| 336 |
+
"lstrip": false,
|
| 337 |
+
"normalized": false,
|
| 338 |
+
"rstrip": false,
|
| 339 |
+
"single_word": false,
|
| 340 |
+
"special": false
|
| 341 |
+
},
|
| 342 |
+
"42": {
|
| 343 |
+
"content": "<longcat_tool_call>",
|
| 344 |
+
"lstrip": false,
|
| 345 |
+
"normalized": false,
|
| 346 |
+
"rstrip": false,
|
| 347 |
+
"single_word": false,
|
| 348 |
+
"special": false
|
| 349 |
+
},
|
| 350 |
+
"43": {
|
| 351 |
+
"content": "</longcat_tool_call>",
|
| 352 |
+
"lstrip": false,
|
| 353 |
+
"normalized": false,
|
| 354 |
+
"rstrip": false,
|
| 355 |
+
"single_word": false,
|
| 356 |
+
"special": false
|
| 357 |
+
},
|
| 358 |
+
"44": {
|
| 359 |
+
"content": "<longcat_tool_declare>",
|
| 360 |
+
"lstrip": false,
|
| 361 |
+
"normalized": false,
|
| 362 |
+
"rstrip": false,
|
| 363 |
+
"single_word": false,
|
| 364 |
+
"special": true
|
| 365 |
+
},
|
| 366 |
+
"45": {
|
| 367 |
+
"content": "</longcat_tool_declare>",
|
| 368 |
+
"lstrip": false,
|
| 369 |
+
"normalized": false,
|
| 370 |
+
"rstrip": false,
|
| 371 |
+
"single_word": false,
|
| 372 |
+
"special": true
|
| 373 |
+
},
|
| 374 |
+
"46": {
|
| 375 |
+
"content": "<longcat_system>",
|
| 376 |
+
"lstrip": false,
|
| 377 |
+
"normalized": false,
|
| 378 |
+
"rstrip": false,
|
| 379 |
+
"single_word": false,
|
| 380 |
+
"special": true
|
| 381 |
+
},
|
| 382 |
+
"47": {
|
| 383 |
+
"content": "<longcat_user>",
|
| 384 |
+
"lstrip": false,
|
| 385 |
+
"normalized": false,
|
| 386 |
+
"rstrip": false,
|
| 387 |
+
"single_word": false,
|
| 388 |
+
"special": true
|
| 389 |
+
},
|
| 390 |
+
"48": {
|
| 391 |
+
"content": "<longcat_assistant>",
|
| 392 |
+
"lstrip": false,
|
| 393 |
+
"normalized": false,
|
| 394 |
+
"rstrip": false,
|
| 395 |
+
"single_word": false,
|
| 396 |
+
"special": true
|
| 397 |
+
},
|
| 398 |
+
"49": {
|
| 399 |
+
"content": "<longcat_tool_response>",
|
| 400 |
+
"lstrip": false,
|
| 401 |
+
"normalized": false,
|
| 402 |
+
"rstrip": false,
|
| 403 |
+
"single_word": false,
|
| 404 |
+
"special": false
|
| 405 |
+
},
|
| 406 |
+
"50": {
|
| 407 |
+
"content": "</longcat_tool_response>",
|
| 408 |
+
"lstrip": false,
|
| 409 |
+
"normalized": false,
|
| 410 |
+
"rstrip": false,
|
| 411 |
+
"single_word": false,
|
| 412 |
+
"special": false
|
| 413 |
+
},
|
| 414 |
+
"51": {
|
| 415 |
+
"content": "<longcat_arg_key>",
|
| 416 |
+
"lstrip": false,
|
| 417 |
+
"normalized": false,
|
| 418 |
+
"rstrip": false,
|
| 419 |
+
"single_word": false,
|
| 420 |
+
"special": false
|
| 421 |
+
},
|
| 422 |
+
"52": {
|
| 423 |
+
"content": "</longcat_arg_key>",
|
| 424 |
+
"lstrip": false,
|
| 425 |
+
"normalized": false,
|
| 426 |
+
"rstrip": false,
|
| 427 |
+
"single_word": false,
|
| 428 |
+
"special": false
|
| 429 |
+
},
|
| 430 |
+
"53": {
|
| 431 |
+
"content": "<longcat_arg_value>",
|
| 432 |
+
"lstrip": false,
|
| 433 |
+
"normalized": false,
|
| 434 |
+
"rstrip": false,
|
| 435 |
+
"single_word": false,
|
| 436 |
+
"special": false
|
| 437 |
+
},
|
| 438 |
+
"54": {
|
| 439 |
+
"content": "</longcat_arg_value>",
|
| 440 |
+
"lstrip": false,
|
| 441 |
+
"normalized": false,
|
| 442 |
+
"rstrip": false,
|
| 443 |
+
"single_word": false,
|
| 444 |
+
"special": false
|
| 445 |
+
},
|
| 446 |
+
"55": {
|
| 447 |
+
"content": "<mask_31>",
|
| 448 |
+
"lstrip": false,
|
| 449 |
+
"normalized": false,
|
| 450 |
+
"rstrip": false,
|
| 451 |
+
"single_word": false,
|
| 452 |
+
"special": true
|
| 453 |
+
},
|
| 454 |
+
"56": {
|
| 455 |
+
"content": "<mask_32>",
|
| 456 |
+
"lstrip": false,
|
| 457 |
+
"normalized": false,
|
| 458 |
+
"rstrip": false,
|
| 459 |
+
"single_word": false,
|
| 460 |
+
"special": true
|
| 461 |
+
},
|
| 462 |
+
"57": {
|
| 463 |
+
"content": "<mask_33>",
|
| 464 |
+
"lstrip": false,
|
| 465 |
+
"normalized": false,
|
| 466 |
+
"rstrip": false,
|
| 467 |
+
"single_word": false,
|
| 468 |
+
"special": true
|
| 469 |
+
},
|
| 470 |
+
"58": {
|
| 471 |
+
"content": "<mask_34>",
|
| 472 |
+
"lstrip": false,
|
| 473 |
+
"normalized": false,
|
| 474 |
+
"rstrip": false,
|
| 475 |
+
"single_word": false,
|
| 476 |
+
"special": true
|
| 477 |
+
},
|
| 478 |
+
"59": {
|
| 479 |
+
"content": "<mask_35>",
|
| 480 |
+
"lstrip": false,
|
| 481 |
+
"normalized": false,
|
| 482 |
+
"rstrip": false,
|
| 483 |
+
"single_word": false,
|
| 484 |
+
"special": true
|
| 485 |
+
},
|
| 486 |
+
"60": {
|
| 487 |
+
"content": "<mask_36>",
|
| 488 |
+
"lstrip": false,
|
| 489 |
+
"normalized": false,
|
| 490 |
+
"rstrip": false,
|
| 491 |
+
"single_word": false,
|
| 492 |
+
"special": true
|
| 493 |
+
},
|
| 494 |
+
"61": {
|
| 495 |
+
"content": "<mask_37>",
|
| 496 |
+
"lstrip": false,
|
| 497 |
+
"normalized": false,
|
| 498 |
+
"rstrip": false,
|
| 499 |
+
"single_word": false,
|
| 500 |
+
"special": true
|
| 501 |
+
},
|
| 502 |
+
"62": {
|
| 503 |
+
"content": "<mask_38>",
|
| 504 |
+
"lstrip": false,
|
| 505 |
+
"normalized": false,
|
| 506 |
+
"rstrip": false,
|
| 507 |
+
"single_word": false,
|
| 508 |
+
"special": true
|
| 509 |
+
},
|
| 510 |
+
"63": {
|
| 511 |
+
"content": "<mask_39>",
|
| 512 |
+
"lstrip": false,
|
| 513 |
+
"normalized": false,
|
| 514 |
+
"rstrip": false,
|
| 515 |
+
"single_word": false,
|
| 516 |
+
"special": true
|
| 517 |
+
},
|
| 518 |
+
"64": {
|
| 519 |
+
"content": "<mask_40>",
|
| 520 |
+
"lstrip": false,
|
| 521 |
+
"normalized": false,
|
| 522 |
+
"rstrip": false,
|
| 523 |
+
"single_word": false,
|
| 524 |
+
"special": true
|
| 525 |
+
},
|
| 526 |
+
"65": {
|
| 527 |
+
"content": "<mask_41>",
|
| 528 |
+
"lstrip": false,
|
| 529 |
+
"normalized": false,
|
| 530 |
+
"rstrip": false,
|
| 531 |
+
"single_word": false,
|
| 532 |
+
"special": true
|
| 533 |
+
},
|
| 534 |
+
"66": {
|
| 535 |
+
"content": "<mask_42>",
|
| 536 |
+
"lstrip": false,
|
| 537 |
+
"normalized": false,
|
| 538 |
+
"rstrip": false,
|
| 539 |
+
"single_word": false,
|
| 540 |
+
"special": true
|
| 541 |
+
},
|
| 542 |
+
"67": {
|
| 543 |
+
"content": "<mask_43>",
|
| 544 |
+
"lstrip": false,
|
| 545 |
+
"normalized": false,
|
| 546 |
+
"rstrip": false,
|
| 547 |
+
"single_word": false,
|
| 548 |
+
"special": true
|
| 549 |
+
},
|
| 550 |
+
"68": {
|
| 551 |
+
"content": "<mask_44>",
|
| 552 |
+
"lstrip": false,
|
| 553 |
+
"normalized": false,
|
| 554 |
+
"rstrip": false,
|
| 555 |
+
"single_word": false,
|
| 556 |
+
"special": true
|
| 557 |
+
},
|
| 558 |
+
"69": {
|
| 559 |
+
"content": "<mask_45>",
|
| 560 |
+
"lstrip": false,
|
| 561 |
+
"normalized": false,
|
| 562 |
+
"rstrip": false,
|
| 563 |
+
"single_word": false,
|
| 564 |
+
"special": true
|
| 565 |
+
},
|
| 566 |
+
"70": {
|
| 567 |
+
"content": "<mask_46>",
|
| 568 |
+
"lstrip": false,
|
| 569 |
+
"normalized": false,
|
| 570 |
+
"rstrip": false,
|
| 571 |
+
"single_word": false,
|
| 572 |
+
"special": true
|
| 573 |
+
},
|
| 574 |
+
"71": {
|
| 575 |
+
"content": "<mask_47>",
|
| 576 |
+
"lstrip": false,
|
| 577 |
+
"normalized": false,
|
| 578 |
+
"rstrip": false,
|
| 579 |
+
"single_word": false,
|
| 580 |
+
"special": true
|
| 581 |
+
},
|
| 582 |
+
"72": {
|
| 583 |
+
"content": "<mask_48>",
|
| 584 |
+
"lstrip": false,
|
| 585 |
+
"normalized": false,
|
| 586 |
+
"rstrip": false,
|
| 587 |
+
"single_word": false,
|
| 588 |
+
"special": true
|
| 589 |
+
},
|
| 590 |
+
"73": {
|
| 591 |
+
"content": "<mask_49>",
|
| 592 |
+
"lstrip": false,
|
| 593 |
+
"normalized": false,
|
| 594 |
+
"rstrip": false,
|
| 595 |
+
"single_word": false,
|
| 596 |
+
"special": true
|
| 597 |
+
},
|
| 598 |
+
"74": {
|
| 599 |
+
"content": "<mask_50>",
|
| 600 |
+
"lstrip": false,
|
| 601 |
+
"normalized": false,
|
| 602 |
+
"rstrip": false,
|
| 603 |
+
"single_word": false,
|
| 604 |
+
"special": true
|
| 605 |
+
},
|
| 606 |
+
"75": {
|
| 607 |
+
"content": "<mask_51>",
|
| 608 |
+
"lstrip": false,
|
| 609 |
+
"normalized": false,
|
| 610 |
+
"rstrip": false,
|
| 611 |
+
"single_word": false,
|
| 612 |
+
"special": true
|
| 613 |
+
},
|
| 614 |
+
"76": {
|
| 615 |
+
"content": "<mask_52>",
|
| 616 |
+
"lstrip": false,
|
| 617 |
+
"normalized": false,
|
| 618 |
+
"rstrip": false,
|
| 619 |
+
"single_word": false,
|
| 620 |
+
"special": true
|
| 621 |
+
},
|
| 622 |
+
"77": {
|
| 623 |
+
"content": "<mask_53>",
|
| 624 |
+
"lstrip": false,
|
| 625 |
+
"normalized": false,
|
| 626 |
+
"rstrip": false,
|
| 627 |
+
"single_word": false,
|
| 628 |
+
"special": true
|
| 629 |
+
},
|
| 630 |
+
"78": {
|
| 631 |
+
"content": "<mask_54>",
|
| 632 |
+
"lstrip": false,
|
| 633 |
+
"normalized": false,
|
| 634 |
+
"rstrip": false,
|
| 635 |
+
"single_word": false,
|
| 636 |
+
"special": true
|
| 637 |
+
},
|
| 638 |
+
"79": {
|
| 639 |
+
"content": "<mask_55>",
|
| 640 |
+
"lstrip": false,
|
| 641 |
+
"normalized": false,
|
| 642 |
+
"rstrip": false,
|
| 643 |
+
"single_word": false,
|
| 644 |
+
"special": true
|
| 645 |
+
},
|
| 646 |
+
"80": {
|
| 647 |
+
"content": "<mask_56>",
|
| 648 |
+
"lstrip": false,
|
| 649 |
+
"normalized": false,
|
| 650 |
+
"rstrip": false,
|
| 651 |
+
"single_word": false,
|
| 652 |
+
"special": true
|
| 653 |
+
},
|
| 654 |
+
"81": {
|
| 655 |
+
"content": "<mask_57>",
|
| 656 |
+
"lstrip": false,
|
| 657 |
+
"normalized": false,
|
| 658 |
+
"rstrip": false,
|
| 659 |
+
"single_word": false,
|
| 660 |
+
"special": true
|
| 661 |
+
},
|
| 662 |
+
"82": {
|
| 663 |
+
"content": "<mask_58>",
|
| 664 |
+
"lstrip": false,
|
| 665 |
+
"normalized": false,
|
| 666 |
+
"rstrip": false,
|
| 667 |
+
"single_word": false,
|
| 668 |
+
"special": true
|
| 669 |
+
},
|
| 670 |
+
"83": {
|
| 671 |
+
"content": "<mask_59>",
|
| 672 |
+
"lstrip": false,
|
| 673 |
+
"normalized": false,
|
| 674 |
+
"rstrip": false,
|
| 675 |
+
"single_word": false,
|
| 676 |
+
"special": true
|
| 677 |
+
},
|
| 678 |
+
"84": {
|
| 679 |
+
"content": "<mask_60>",
|
| 680 |
+
"lstrip": false,
|
| 681 |
+
"normalized": false,
|
| 682 |
+
"rstrip": false,
|
| 683 |
+
"single_word": false,
|
| 684 |
+
"special": true
|
| 685 |
+
},
|
| 686 |
+
"85": {
|
| 687 |
+
"content": "<mask_61>",
|
| 688 |
+
"lstrip": false,
|
| 689 |
+
"normalized": false,
|
| 690 |
+
"rstrip": false,
|
| 691 |
+
"single_word": false,
|
| 692 |
+
"special": true
|
| 693 |
+
},
|
| 694 |
+
"86": {
|
| 695 |
+
"content": "<mask_62>",
|
| 696 |
+
"lstrip": false,
|
| 697 |
+
"normalized": false,
|
| 698 |
+
"rstrip": false,
|
| 699 |
+
"single_word": false,
|
| 700 |
+
"special": true
|
| 701 |
+
},
|
| 702 |
+
"87": {
|
| 703 |
+
"content": "<mask_63>",
|
| 704 |
+
"lstrip": false,
|
| 705 |
+
"normalized": false,
|
| 706 |
+
"rstrip": false,
|
| 707 |
+
"single_word": false,
|
| 708 |
+
"special": true
|
| 709 |
+
},
|
| 710 |
+
"88": {
|
| 711 |
+
"content": "<mask_64>",
|
| 712 |
+
"lstrip": false,
|
| 713 |
+
"normalized": false,
|
| 714 |
+
"rstrip": false,
|
| 715 |
+
"single_word": false,
|
| 716 |
+
"special": true
|
| 717 |
+
},
|
| 718 |
+
"89": {
|
| 719 |
+
"content": "<mask_65>",
|
| 720 |
+
"lstrip": false,
|
| 721 |
+
"normalized": false,
|
| 722 |
+
"rstrip": false,
|
| 723 |
+
"single_word": false,
|
| 724 |
+
"special": true
|
| 725 |
+
},
|
| 726 |
+
"90": {
|
| 727 |
+
"content": "<mask_66>",
|
| 728 |
+
"lstrip": false,
|
| 729 |
+
"normalized": false,
|
| 730 |
+
"rstrip": false,
|
| 731 |
+
"single_word": false,
|
| 732 |
+
"special": true
|
| 733 |
+
},
|
| 734 |
+
"91": {
|
| 735 |
+
"content": "<mask_67>",
|
| 736 |
+
"lstrip": false,
|
| 737 |
+
"normalized": false,
|
| 738 |
+
"rstrip": false,
|
| 739 |
+
"single_word": false,
|
| 740 |
+
"special": true
|
| 741 |
+
},
|
| 742 |
+
"92": {
|
| 743 |
+
"content": "<mask_68>",
|
| 744 |
+
"lstrip": false,
|
| 745 |
+
"normalized": false,
|
| 746 |
+
"rstrip": false,
|
| 747 |
+
"single_word": false,
|
| 748 |
+
"special": true
|
| 749 |
+
},
|
| 750 |
+
"93": {
|
| 751 |
+
"content": "<mask_69>",
|
| 752 |
+
"lstrip": false,
|
| 753 |
+
"normalized": false,
|
| 754 |
+
"rstrip": false,
|
| 755 |
+
"single_word": false,
|
| 756 |
+
"special": true
|
| 757 |
+
},
|
| 758 |
+
"94": {
|
| 759 |
+
"content": "<mask_70>",
|
| 760 |
+
"lstrip": false,
|
| 761 |
+
"normalized": false,
|
| 762 |
+
"rstrip": false,
|
| 763 |
+
"single_word": false,
|
| 764 |
+
"special": true
|
| 765 |
+
},
|
| 766 |
+
"95": {
|
| 767 |
+
"content": "<mask_71>",
|
| 768 |
+
"lstrip": false,
|
| 769 |
+
"normalized": false,
|
| 770 |
+
"rstrip": false,
|
| 771 |
+
"single_word": false,
|
| 772 |
+
"special": true
|
| 773 |
+
},
|
| 774 |
+
"96": {
|
| 775 |
+
"content": "<mask_72>",
|
| 776 |
+
"lstrip": false,
|
| 777 |
+
"normalized": false,
|
| 778 |
+
"rstrip": false,
|
| 779 |
+
"single_word": false,
|
| 780 |
+
"special": true
|
| 781 |
+
},
|
| 782 |
+
"97": {
|
| 783 |
+
"content": "<mask_73>",
|
| 784 |
+
"lstrip": false,
|
| 785 |
+
"normalized": false,
|
| 786 |
+
"rstrip": false,
|
| 787 |
+
"single_word": false,
|
| 788 |
+
"special": true
|
| 789 |
+
},
|
| 790 |
+
"98": {
|
| 791 |
+
"content": "<mask_74>",
|
| 792 |
+
"lstrip": false,
|
| 793 |
+
"normalized": false,
|
| 794 |
+
"rstrip": false,
|
| 795 |
+
"single_word": false,
|
| 796 |
+
"special": true
|
| 797 |
+
},
|
| 798 |
+
"99": {
|
| 799 |
+
"content": "<mask_75>",
|
| 800 |
+
"lstrip": false,
|
| 801 |
+
"normalized": false,
|
| 802 |
+
"rstrip": false,
|
| 803 |
+
"single_word": false,
|
| 804 |
+
"special": true
|
| 805 |
+
},
|
| 806 |
+
"100": {
|
| 807 |
+
"content": "<mask_76>",
|
| 808 |
+
"lstrip": false,
|
| 809 |
+
"normalized": false,
|
| 810 |
+
"rstrip": false,
|
| 811 |
+
"single_word": false,
|
| 812 |
+
"special": true
|
| 813 |
+
},
|
| 814 |
+
"101": {
|
| 815 |
+
"content": "<mask_77>",
|
| 816 |
+
"lstrip": false,
|
| 817 |
+
"normalized": false,
|
| 818 |
+
"rstrip": false,
|
| 819 |
+
"single_word": false,
|
| 820 |
+
"special": true
|
| 821 |
+
},
|
| 822 |
+
"102": {
|
| 823 |
+
"content": "<mask_78>",
|
| 824 |
+
"lstrip": false,
|
| 825 |
+
"normalized": false,
|
| 826 |
+
"rstrip": false,
|
| 827 |
+
"single_word": false,
|
| 828 |
+
"special": true
|
| 829 |
+
},
|
| 830 |
+
"103": {
|
| 831 |
+
"content": "<mask_79>",
|
| 832 |
+
"lstrip": false,
|
| 833 |
+
"normalized": false,
|
| 834 |
+
"rstrip": false,
|
| 835 |
+
"single_word": false,
|
| 836 |
+
"special": true
|
| 837 |
+
},
|
| 838 |
+
"104": {
|
| 839 |
+
"content": "<mask_80>",
|
| 840 |
+
"lstrip": false,
|
| 841 |
+
"normalized": false,
|
| 842 |
+
"rstrip": false,
|
| 843 |
+
"single_word": false,
|
| 844 |
+
"special": true
|
| 845 |
+
},
|
| 846 |
+
"105": {
|
| 847 |
+
"content": "<mask_81>",
|
| 848 |
+
"lstrip": false,
|
| 849 |
+
"normalized": false,
|
| 850 |
+
"rstrip": false,
|
| 851 |
+
"single_word": false,
|
| 852 |
+
"special": true
|
| 853 |
+
},
|
| 854 |
+
"106": {
|
| 855 |
+
"content": "<mask_82>",
|
| 856 |
+
"lstrip": false,
|
| 857 |
+
"normalized": false,
|
| 858 |
+
"rstrip": false,
|
| 859 |
+
"single_word": false,
|
| 860 |
+
"special": true
|
| 861 |
+
},
|
| 862 |
+
"107": {
|
| 863 |
+
"content": "<mask_83>",
|
| 864 |
+
"lstrip": false,
|
| 865 |
+
"normalized": false,
|
| 866 |
+
"rstrip": false,
|
| 867 |
+
"single_word": false,
|
| 868 |
+
"special": true
|
| 869 |
+
},
|
| 870 |
+
"108": {
|
| 871 |
+
"content": "<mask_84>",
|
| 872 |
+
"lstrip": false,
|
| 873 |
+
"normalized": false,
|
| 874 |
+
"rstrip": false,
|
| 875 |
+
"single_word": false,
|
| 876 |
+
"special": true
|
| 877 |
+
},
|
| 878 |
+
"109": {
|
| 879 |
+
"content": "<mask_85>",
|
| 880 |
+
"lstrip": false,
|
| 881 |
+
"normalized": false,
|
| 882 |
+
"rstrip": false,
|
| 883 |
+
"single_word": false,
|
| 884 |
+
"special": true
|
| 885 |
+
},
|
| 886 |
+
"110": {
|
| 887 |
+
"content": "<mask_86>",
|
| 888 |
+
"lstrip": false,
|
| 889 |
+
"normalized": false,
|
| 890 |
+
"rstrip": false,
|
| 891 |
+
"single_word": false,
|
| 892 |
+
"special": true
|
| 893 |
+
},
|
| 894 |
+
"111": {
|
| 895 |
+
"content": "<mask_87>",
|
| 896 |
+
"lstrip": false,
|
| 897 |
+
"normalized": false,
|
| 898 |
+
"rstrip": false,
|
| 899 |
+
"single_word": false,
|
| 900 |
+
"special": true
|
| 901 |
+
},
|
| 902 |
+
"112": {
|
| 903 |
+
"content": "<mask_88>",
|
| 904 |
+
"lstrip": false,
|
| 905 |
+
"normalized": false,
|
| 906 |
+
"rstrip": false,
|
| 907 |
+
"single_word": false,
|
| 908 |
+
"special": true
|
| 909 |
+
},
|
| 910 |
+
"113": {
|
| 911 |
+
"content": "<mask_89>",
|
| 912 |
+
"lstrip": false,
|
| 913 |
+
"normalized": false,
|
| 914 |
+
"rstrip": false,
|
| 915 |
+
"single_word": false,
|
| 916 |
+
"special": true
|
| 917 |
+
},
|
| 918 |
+
"114": {
|
| 919 |
+
"content": "<mask_90>",
|
| 920 |
+
"lstrip": false,
|
| 921 |
+
"normalized": false,
|
| 922 |
+
"rstrip": false,
|
| 923 |
+
"single_word": false,
|
| 924 |
+
"special": true
|
| 925 |
+
},
|
| 926 |
+
"115": {
|
| 927 |
+
"content": "<mask_91>",
|
| 928 |
+
"lstrip": false,
|
| 929 |
+
"normalized": false,
|
| 930 |
+
"rstrip": false,
|
| 931 |
+
"single_word": false,
|
| 932 |
+
"special": true
|
| 933 |
+
},
|
| 934 |
+
"116": {
|
| 935 |
+
"content": "<mask_92>",
|
| 936 |
+
"lstrip": false,
|
| 937 |
+
"normalized": false,
|
| 938 |
+
"rstrip": false,
|
| 939 |
+
"single_word": false,
|
| 940 |
+
"special": true
|
| 941 |
+
},
|
| 942 |
+
"117": {
|
| 943 |
+
"content": "<mask_93>",
|
| 944 |
+
"lstrip": false,
|
| 945 |
+
"normalized": false,
|
| 946 |
+
"rstrip": false,
|
| 947 |
+
"single_word": false,
|
| 948 |
+
"special": true
|
| 949 |
+
},
|
| 950 |
+
"118": {
|
| 951 |
+
"content": "<mask_94>",
|
| 952 |
+
"lstrip": false,
|
| 953 |
+
"normalized": false,
|
| 954 |
+
"rstrip": false,
|
| 955 |
+
"single_word": false,
|
| 956 |
+
"special": true
|
| 957 |
+
},
|
| 958 |
+
"119": {
|
| 959 |
+
"content": "<mask_95>",
|
| 960 |
+
"lstrip": false,
|
| 961 |
+
"normalized": false,
|
| 962 |
+
"rstrip": false,
|
| 963 |
+
"single_word": false,
|
| 964 |
+
"special": true
|
| 965 |
+
},
|
| 966 |
+
"120": {
|
| 967 |
+
"content": "<mask_96>",
|
| 968 |
+
"lstrip": false,
|
| 969 |
+
"normalized": false,
|
| 970 |
+
"rstrip": false,
|
| 971 |
+
"single_word": false,
|
| 972 |
+
"special": true
|
| 973 |
+
},
|
| 974 |
+
"121": {
|
| 975 |
+
"content": "<mask_97>",
|
| 976 |
+
"lstrip": false,
|
| 977 |
+
"normalized": false,
|
| 978 |
+
"rstrip": false,
|
| 979 |
+
"single_word": false,
|
| 980 |
+
"special": true
|
| 981 |
+
},
|
| 982 |
+
"122": {
|
| 983 |
+
"content": "<mask_98>",
|
| 984 |
+
"lstrip": false,
|
| 985 |
+
"normalized": false,
|
| 986 |
+
"rstrip": false,
|
| 987 |
+
"single_word": false,
|
| 988 |
+
"special": true
|
| 989 |
+
},
|
| 990 |
+
"123": {
|
| 991 |
+
"content": "<mask_99>",
|
| 992 |
+
"lstrip": false,
|
| 993 |
+
"normalized": false,
|
| 994 |
+
"rstrip": false,
|
| 995 |
+
"single_word": false,
|
| 996 |
+
"special": true
|
| 997 |
+
},
|
| 998 |
+
"124": {
|
| 999 |
+
"content": "<mask_100>",
|
| 1000 |
+
"lstrip": false,
|
| 1001 |
+
"normalized": false,
|
| 1002 |
+
"rstrip": false,
|
| 1003 |
+
"single_word": false,
|
| 1004 |
+
"special": true
|
| 1005 |
+
},
|
| 1006 |
+
"125": {
|
| 1007 |
+
"content": "<mask_101>",
|
| 1008 |
+
"lstrip": false,
|
| 1009 |
+
"normalized": false,
|
| 1010 |
+
"rstrip": false,
|
| 1011 |
+
"single_word": false,
|
| 1012 |
+
"special": true
|
| 1013 |
+
},
|
| 1014 |
+
"126": {
|
| 1015 |
+
"content": "<mask_102>",
|
| 1016 |
+
"lstrip": false,
|
| 1017 |
+
"normalized": false,
|
| 1018 |
+
"rstrip": false,
|
| 1019 |
+
"single_word": false,
|
| 1020 |
+
"special": true
|
| 1021 |
+
},
|
| 1022 |
+
"127": {
|
| 1023 |
+
"content": "<mask_103>",
|
| 1024 |
+
"lstrip": false,
|
| 1025 |
+
"normalized": false,
|
| 1026 |
+
"rstrip": false,
|
| 1027 |
+
"single_word": false,
|
| 1028 |
+
"special": true
|
| 1029 |
+
},
|
| 1030 |
+
"128": {
|
| 1031 |
+
"content": "<mask_104>",
|
| 1032 |
+
"lstrip": false,
|
| 1033 |
+
"normalized": false,
|
| 1034 |
+
"rstrip": false,
|
| 1035 |
+
"single_word": false,
|
| 1036 |
+
"special": true
|
| 1037 |
+
},
|
| 1038 |
+
"129": {
|
| 1039 |
+
"content": "<mask_105>",
|
| 1040 |
+
"lstrip": false,
|
| 1041 |
+
"normalized": false,
|
| 1042 |
+
"rstrip": false,
|
| 1043 |
+
"single_word": false,
|
| 1044 |
+
"special": true
|
| 1045 |
+
},
|
| 1046 |
+
"130": {
|
| 1047 |
+
"content": "<mask_106>",
|
| 1048 |
+
"lstrip": false,
|
| 1049 |
+
"normalized": false,
|
| 1050 |
+
"rstrip": false,
|
| 1051 |
+
"single_word": false,
|
| 1052 |
+
"special": true
|
| 1053 |
+
},
|
| 1054 |
+
"131": {
|
| 1055 |
+
"content": "<mask_107>",
|
| 1056 |
+
"lstrip": false,
|
| 1057 |
+
"normalized": false,
|
| 1058 |
+
"rstrip": false,
|
| 1059 |
+
"single_word": false,
|
| 1060 |
+
"special": true
|
| 1061 |
+
},
|
| 1062 |
+
"132": {
|
| 1063 |
+
"content": "<mask_108>",
|
| 1064 |
+
"lstrip": false,
|
| 1065 |
+
"normalized": false,
|
| 1066 |
+
"rstrip": false,
|
| 1067 |
+
"single_word": false,
|
| 1068 |
+
"special": true
|
| 1069 |
+
},
|
| 1070 |
+
"133": {
|
| 1071 |
+
"content": "<mask_109>",
|
| 1072 |
+
"lstrip": false,
|
| 1073 |
+
"normalized": false,
|
| 1074 |
+
"rstrip": false,
|
| 1075 |
+
"single_word": false,
|
| 1076 |
+
"special": true
|
| 1077 |
+
},
|
| 1078 |
+
"134": {
|
| 1079 |
+
"content": "<mask_110>",
|
| 1080 |
+
"lstrip": false,
|
| 1081 |
+
"normalized": false,
|
| 1082 |
+
"rstrip": false,
|
| 1083 |
+
"single_word": false,
|
| 1084 |
+
"special": true
|
| 1085 |
+
},
|
| 1086 |
+
"135": {
|
| 1087 |
+
"content": "<mask_111>",
|
| 1088 |
+
"lstrip": false,
|
| 1089 |
+
"normalized": false,
|
| 1090 |
+
"rstrip": false,
|
| 1091 |
+
"single_word": false,
|
| 1092 |
+
"special": true
|
| 1093 |
+
},
|
| 1094 |
+
"136": {
|
| 1095 |
+
"content": "<mask_112>",
|
| 1096 |
+
"lstrip": false,
|
| 1097 |
+
"normalized": false,
|
| 1098 |
+
"rstrip": false,
|
| 1099 |
+
"single_word": false,
|
| 1100 |
+
"special": true
|
| 1101 |
+
},
|
| 1102 |
+
"137": {
|
| 1103 |
+
"content": "<mask_113>",
|
| 1104 |
+
"lstrip": false,
|
| 1105 |
+
"normalized": false,
|
| 1106 |
+
"rstrip": false,
|
| 1107 |
+
"single_word": false,
|
| 1108 |
+
"special": true
|
| 1109 |
+
},
|
| 1110 |
+
"138": {
|
| 1111 |
+
"content": "<mask_114>",
|
| 1112 |
+
"lstrip": false,
|
| 1113 |
+
"normalized": false,
|
| 1114 |
+
"rstrip": false,
|
| 1115 |
+
"single_word": false,
|
| 1116 |
+
"special": true
|
| 1117 |
+
},
|
| 1118 |
+
"139": {
|
| 1119 |
+
"content": "<mask_115>",
|
| 1120 |
+
"lstrip": false,
|
| 1121 |
+
"normalized": false,
|
| 1122 |
+
"rstrip": false,
|
| 1123 |
+
"single_word": false,
|
| 1124 |
+
"special": true
|
| 1125 |
+
},
|
| 1126 |
+
"140": {
|
| 1127 |
+
"content": "<mask_116>",
|
| 1128 |
+
"lstrip": false,
|
| 1129 |
+
"normalized": false,
|
| 1130 |
+
"rstrip": false,
|
| 1131 |
+
"single_word": false,
|
| 1132 |
+
"special": true
|
| 1133 |
+
},
|
| 1134 |
+
"141": {
|
| 1135 |
+
"content": "<mask_117>",
|
| 1136 |
+
"lstrip": false,
|
| 1137 |
+
"normalized": false,
|
| 1138 |
+
"rstrip": false,
|
| 1139 |
+
"single_word": false,
|
| 1140 |
+
"special": true
|
| 1141 |
+
},
|
| 1142 |
+
"142": {
|
| 1143 |
+
"content": "<mask_118>",
|
| 1144 |
+
"lstrip": false,
|
| 1145 |
+
"normalized": false,
|
| 1146 |
+
"rstrip": false,
|
| 1147 |
+
"single_word": false,
|
| 1148 |
+
"special": true
|
| 1149 |
+
},
|
| 1150 |
+
"143": {
|
| 1151 |
+
"content": "<mask_119>",
|
| 1152 |
+
"lstrip": false,
|
| 1153 |
+
"normalized": false,
|
| 1154 |
+
"rstrip": false,
|
| 1155 |
+
"single_word": false,
|
| 1156 |
+
"special": true
|
| 1157 |
+
},
|
| 1158 |
+
"144": {
|
| 1159 |
+
"content": "<mask_120>",
|
| 1160 |
+
"lstrip": false,
|
| 1161 |
+
"normalized": false,
|
| 1162 |
+
"rstrip": false,
|
| 1163 |
+
"single_word": false,
|
| 1164 |
+
"special": true
|
| 1165 |
+
},
|
| 1166 |
+
"145": {
|
| 1167 |
+
"content": "<mask_121>",
|
| 1168 |
+
"lstrip": false,
|
| 1169 |
+
"normalized": false,
|
| 1170 |
+
"rstrip": false,
|
| 1171 |
+
"single_word": false,
|
| 1172 |
+
"special": true
|
| 1173 |
+
},
|
| 1174 |
+
"146": {
|
| 1175 |
+
"content": "<mask_122>",
|
| 1176 |
+
"lstrip": false,
|
| 1177 |
+
"normalized": false,
|
| 1178 |
+
"rstrip": false,
|
| 1179 |
+
"single_word": false,
|
| 1180 |
+
"special": true
|
| 1181 |
+
},
|
| 1182 |
+
"147": {
|
| 1183 |
+
"content": "<mask_123>",
|
| 1184 |
+
"lstrip": false,
|
| 1185 |
+
"normalized": false,
|
| 1186 |
+
"rstrip": false,
|
| 1187 |
+
"single_word": false,
|
| 1188 |
+
"special": true
|
| 1189 |
+
},
|
| 1190 |
+
"148": {
|
| 1191 |
+
"content": "<mask_124>",
|
| 1192 |
+
"lstrip": false,
|
| 1193 |
+
"normalized": false,
|
| 1194 |
+
"rstrip": false,
|
| 1195 |
+
"single_word": false,
|
| 1196 |
+
"special": true
|
| 1197 |
+
},
|
| 1198 |
+
"149": {
|
| 1199 |
+
"content": "<mask_125>",
|
| 1200 |
+
"lstrip": false,
|
| 1201 |
+
"normalized": false,
|
| 1202 |
+
"rstrip": false,
|
| 1203 |
+
"single_word": false,
|
| 1204 |
+
"special": true
|
| 1205 |
+
},
|
| 1206 |
+
"150": {
|
| 1207 |
+
"content": "<mask_126>",
|
| 1208 |
+
"lstrip": false,
|
| 1209 |
+
"normalized": false,
|
| 1210 |
+
"rstrip": false,
|
| 1211 |
+
"single_word": false,
|
| 1212 |
+
"special": true
|
| 1213 |
+
},
|
| 1214 |
+
"151": {
|
| 1215 |
+
"content": "<mask_127>",
|
| 1216 |
+
"lstrip": false,
|
| 1217 |
+
"normalized": false,
|
| 1218 |
+
"rstrip": false,
|
| 1219 |
+
"single_word": false,
|
| 1220 |
+
"special": true
|
| 1221 |
+
},
|
| 1222 |
+
"152": {
|
| 1223 |
+
"content": "<mask_128>",
|
| 1224 |
+
"lstrip": false,
|
| 1225 |
+
"normalized": false,
|
| 1226 |
+
"rstrip": false,
|
| 1227 |
+
"single_word": false,
|
| 1228 |
+
"special": true
|
| 1229 |
+
},
|
| 1230 |
+
"153": {
|
| 1231 |
+
"content": "<mask_129>",
|
| 1232 |
+
"lstrip": false,
|
| 1233 |
+
"normalized": false,
|
| 1234 |
+
"rstrip": false,
|
| 1235 |
+
"single_word": false,
|
| 1236 |
+
"special": true
|
| 1237 |
+
},
|
| 1238 |
+
"154": {
|
| 1239 |
+
"content": "<mask_130>",
|
| 1240 |
+
"lstrip": false,
|
| 1241 |
+
"normalized": false,
|
| 1242 |
+
"rstrip": false,
|
| 1243 |
+
"single_word": false,
|
| 1244 |
+
"special": true
|
| 1245 |
+
},
|
| 1246 |
+
"155": {
|
| 1247 |
+
"content": "<mask_131>",
|
| 1248 |
+
"lstrip": false,
|
| 1249 |
+
"normalized": false,
|
| 1250 |
+
"rstrip": false,
|
| 1251 |
+
"single_word": false,
|
| 1252 |
+
"special": true
|
| 1253 |
+
},
|
| 1254 |
+
"156": {
|
| 1255 |
+
"content": "<mask_132>",
|
| 1256 |
+
"lstrip": false,
|
| 1257 |
+
"normalized": false,
|
| 1258 |
+
"rstrip": false,
|
| 1259 |
+
"single_word": false,
|
| 1260 |
+
"special": true
|
| 1261 |
+
},
|
| 1262 |
+
"157": {
|
| 1263 |
+
"content": "<mask_133>",
|
| 1264 |
+
"lstrip": false,
|
| 1265 |
+
"normalized": false,
|
| 1266 |
+
"rstrip": false,
|
| 1267 |
+
"single_word": false,
|
| 1268 |
+
"special": true
|
| 1269 |
+
},
|
| 1270 |
+
"158": {
|
| 1271 |
+
"content": "<mask_134>",
|
| 1272 |
+
"lstrip": false,
|
| 1273 |
+
"normalized": false,
|
| 1274 |
+
"rstrip": false,
|
| 1275 |
+
"single_word": false,
|
| 1276 |
+
"special": true
|
| 1277 |
+
},
|
| 1278 |
+
"159": {
|
| 1279 |
+
"content": "<mask_135>",
|
| 1280 |
+
"lstrip": false,
|
| 1281 |
+
"normalized": false,
|
| 1282 |
+
"rstrip": false,
|
| 1283 |
+
"single_word": false,
|
| 1284 |
+
"special": true
|
| 1285 |
+
},
|
| 1286 |
+
"160": {
|
| 1287 |
+
"content": "<mask_136>",
|
| 1288 |
+
"lstrip": false,
|
| 1289 |
+
"normalized": false,
|
| 1290 |
+
"rstrip": false,
|
| 1291 |
+
"single_word": false,
|
| 1292 |
+
"special": true
|
| 1293 |
+
},
|
| 1294 |
+
"161": {
|
| 1295 |
+
"content": "<mask_137>",
|
| 1296 |
+
"lstrip": false,
|
| 1297 |
+
"normalized": false,
|
| 1298 |
+
"rstrip": false,
|
| 1299 |
+
"single_word": false,
|
| 1300 |
+
"special": true
|
| 1301 |
+
},
|
| 1302 |
+
"162": {
|
| 1303 |
+
"content": "<mask_138>",
|
| 1304 |
+
"lstrip": false,
|
| 1305 |
+
"normalized": false,
|
| 1306 |
+
"rstrip": false,
|
| 1307 |
+
"single_word": false,
|
| 1308 |
+
"special": true
|
| 1309 |
+
},
|
| 1310 |
+
"163": {
|
| 1311 |
+
"content": "<mask_139>",
|
| 1312 |
+
"lstrip": false,
|
| 1313 |
+
"normalized": false,
|
| 1314 |
+
"rstrip": false,
|
| 1315 |
+
"single_word": false,
|
| 1316 |
+
"special": true
|
| 1317 |
+
},
|
| 1318 |
+
"164": {
|
| 1319 |
+
"content": "<mask_140>",
|
| 1320 |
+
"lstrip": false,
|
| 1321 |
+
"normalized": false,
|
| 1322 |
+
"rstrip": false,
|
| 1323 |
+
"single_word": false,
|
| 1324 |
+
"special": true
|
| 1325 |
+
},
|
| 1326 |
+
"165": {
|
| 1327 |
+
"content": "<mask_141>",
|
| 1328 |
+
"lstrip": false,
|
| 1329 |
+
"normalized": false,
|
| 1330 |
+
"rstrip": false,
|
| 1331 |
+
"single_word": false,
|
| 1332 |
+
"special": true
|
| 1333 |
+
},
|
| 1334 |
+
"166": {
|
| 1335 |
+
"content": "<mask_142>",
|
| 1336 |
+
"lstrip": false,
|
| 1337 |
+
"normalized": false,
|
| 1338 |
+
"rstrip": false,
|
| 1339 |
+
"single_word": false,
|
| 1340 |
+
"special": true
|
| 1341 |
+
},
|
| 1342 |
+
"167": {
|
| 1343 |
+
"content": "<mask_143>",
|
| 1344 |
+
"lstrip": false,
|
| 1345 |
+
"normalized": false,
|
| 1346 |
+
"rstrip": false,
|
| 1347 |
+
"single_word": false,
|
| 1348 |
+
"special": true
|
| 1349 |
+
},
|
| 1350 |
+
"168": {
|
| 1351 |
+
"content": "<mask_144>",
|
| 1352 |
+
"lstrip": false,
|
| 1353 |
+
"normalized": false,
|
| 1354 |
+
"rstrip": false,
|
| 1355 |
+
"single_word": false,
|
| 1356 |
+
"special": true
|
| 1357 |
+
},
|
| 1358 |
+
"169": {
|
| 1359 |
+
"content": "<mask_145>",
|
| 1360 |
+
"lstrip": false,
|
| 1361 |
+
"normalized": false,
|
| 1362 |
+
"rstrip": false,
|
| 1363 |
+
"single_word": false,
|
| 1364 |
+
"special": true
|
| 1365 |
+
},
|
| 1366 |
+
"170": {
|
| 1367 |
+
"content": "<mask_146>",
|
| 1368 |
+
"lstrip": false,
|
| 1369 |
+
"normalized": false,
|
| 1370 |
+
"rstrip": false,
|
| 1371 |
+
"single_word": false,
|
| 1372 |
+
"special": true
|
| 1373 |
+
},
|
| 1374 |
+
"171": {
|
| 1375 |
+
"content": "<mask_147>",
|
| 1376 |
+
"lstrip": false,
|
| 1377 |
+
"normalized": false,
|
| 1378 |
+
"rstrip": false,
|
| 1379 |
+
"single_word": false,
|
| 1380 |
+
"special": true
|
| 1381 |
+
},
|
| 1382 |
+
"172": {
|
| 1383 |
+
"content": "<mask_148>",
|
| 1384 |
+
"lstrip": false,
|
| 1385 |
+
"normalized": false,
|
| 1386 |
+
"rstrip": false,
|
| 1387 |
+
"single_word": false,
|
| 1388 |
+
"special": true
|
| 1389 |
+
},
|
| 1390 |
+
"173": {
|
| 1391 |
+
"content": "<mask_149>",
|
| 1392 |
+
"lstrip": false,
|
| 1393 |
+
"normalized": false,
|
| 1394 |
+
"rstrip": false,
|
| 1395 |
+
"single_word": false,
|
| 1396 |
+
"special": true
|
| 1397 |
+
},
|
| 1398 |
+
"174": {
|
| 1399 |
+
"content": "<mask_150>",
|
| 1400 |
+
"lstrip": false,
|
| 1401 |
+
"normalized": false,
|
| 1402 |
+
"rstrip": false,
|
| 1403 |
+
"single_word": false,
|
| 1404 |
+
"special": true
|
| 1405 |
+
},
|
| 1406 |
+
"175": {
|
| 1407 |
+
"content": "<mask_151>",
|
| 1408 |
+
"lstrip": false,
|
| 1409 |
+
"normalized": false,
|
| 1410 |
+
"rstrip": false,
|
| 1411 |
+
"single_word": false,
|
| 1412 |
+
"special": true
|
| 1413 |
+
},
|
| 1414 |
+
"176": {
|
| 1415 |
+
"content": "<mask_152>",
|
| 1416 |
+
"lstrip": false,
|
| 1417 |
+
"normalized": false,
|
| 1418 |
+
"rstrip": false,
|
| 1419 |
+
"single_word": false,
|
| 1420 |
+
"special": true
|
| 1421 |
+
},
|
| 1422 |
+
"177": {
|
| 1423 |
+
"content": "<mask_153>",
|
| 1424 |
+
"lstrip": false,
|
| 1425 |
+
"normalized": false,
|
| 1426 |
+
"rstrip": false,
|
| 1427 |
+
"single_word": false,
|
| 1428 |
+
"special": true
|
| 1429 |
+
},
|
| 1430 |
+
"178": {
|
| 1431 |
+
"content": "<mask_154>",
|
| 1432 |
+
"lstrip": false,
|
| 1433 |
+
"normalized": false,
|
| 1434 |
+
"rstrip": false,
|
| 1435 |
+
"single_word": false,
|
| 1436 |
+
"special": true
|
| 1437 |
+
},
|
| 1438 |
+
"179": {
|
| 1439 |
+
"content": "<mask_155>",
|
| 1440 |
+
"lstrip": false,
|
| 1441 |
+
"normalized": false,
|
| 1442 |
+
"rstrip": false,
|
| 1443 |
+
"single_word": false,
|
| 1444 |
+
"special": true
|
| 1445 |
+
},
|
| 1446 |
+
"180": {
|
| 1447 |
+
"content": "<mask_156>",
|
| 1448 |
+
"lstrip": false,
|
| 1449 |
+
"normalized": false,
|
| 1450 |
+
"rstrip": false,
|
| 1451 |
+
"single_word": false,
|
| 1452 |
+
"special": true
|
| 1453 |
+
},
|
| 1454 |
+
"181": {
|
| 1455 |
+
"content": "<mask_157>",
|
| 1456 |
+
"lstrip": false,
|
| 1457 |
+
"normalized": false,
|
| 1458 |
+
"rstrip": false,
|
| 1459 |
+
"single_word": false,
|
| 1460 |
+
"special": true
|
| 1461 |
+
},
|
| 1462 |
+
"182": {
|
| 1463 |
+
"content": "<mask_158>",
|
| 1464 |
+
"lstrip": false,
|
| 1465 |
+
"normalized": false,
|
| 1466 |
+
"rstrip": false,
|
| 1467 |
+
"single_word": false,
|
| 1468 |
+
"special": true
|
| 1469 |
+
},
|
| 1470 |
+
"183": {
|
| 1471 |
+
"content": "<mask_159>",
|
| 1472 |
+
"lstrip": false,
|
| 1473 |
+
"normalized": false,
|
| 1474 |
+
"rstrip": false,
|
| 1475 |
+
"single_word": false,
|
| 1476 |
+
"special": true
|
| 1477 |
+
},
|
| 1478 |
+
"184": {
|
| 1479 |
+
"content": "<mask_160>",
|
| 1480 |
+
"lstrip": false,
|
| 1481 |
+
"normalized": false,
|
| 1482 |
+
"rstrip": false,
|
| 1483 |
+
"single_word": false,
|
| 1484 |
+
"special": true
|
| 1485 |
+
},
|
| 1486 |
+
"185": {
|
| 1487 |
+
"content": "<mask_161>",
|
| 1488 |
+
"lstrip": false,
|
| 1489 |
+
"normalized": false,
|
| 1490 |
+
"rstrip": false,
|
| 1491 |
+
"single_word": false,
|
| 1492 |
+
"special": true
|
| 1493 |
+
},
|
| 1494 |
+
"186": {
|
| 1495 |
+
"content": "<mask_162>",
|
| 1496 |
+
"lstrip": false,
|
| 1497 |
+
"normalized": false,
|
| 1498 |
+
"rstrip": false,
|
| 1499 |
+
"single_word": false,
|
| 1500 |
+
"special": true
|
| 1501 |
+
},
|
| 1502 |
+
"187": {
|
| 1503 |
+
"content": "<mask_163>",
|
| 1504 |
+
"lstrip": false,
|
| 1505 |
+
"normalized": false,
|
| 1506 |
+
"rstrip": false,
|
| 1507 |
+
"single_word": false,
|
| 1508 |
+
"special": true
|
| 1509 |
+
},
|
| 1510 |
+
"188": {
|
| 1511 |
+
"content": "<mask_164>",
|
| 1512 |
+
"lstrip": false,
|
| 1513 |
+
"normalized": false,
|
| 1514 |
+
"rstrip": false,
|
| 1515 |
+
"single_word": false,
|
| 1516 |
+
"special": true
|
| 1517 |
+
},
|
| 1518 |
+
"189": {
|
| 1519 |
+
"content": "<mask_165>",
|
| 1520 |
+
"lstrip": false,
|
| 1521 |
+
"normalized": false,
|
| 1522 |
+
"rstrip": false,
|
| 1523 |
+
"single_word": false,
|
| 1524 |
+
"special": true
|
| 1525 |
+
},
|
| 1526 |
+
"190": {
|
| 1527 |
+
"content": "<mask_166>",
|
| 1528 |
+
"lstrip": false,
|
| 1529 |
+
"normalized": false,
|
| 1530 |
+
"rstrip": false,
|
| 1531 |
+
"single_word": false,
|
| 1532 |
+
"special": true
|
| 1533 |
+
},
|
| 1534 |
+
"191": {
|
| 1535 |
+
"content": "<mask_167>",
|
| 1536 |
+
"lstrip": false,
|
| 1537 |
+
"normalized": false,
|
| 1538 |
+
"rstrip": false,
|
| 1539 |
+
"single_word": false,
|
| 1540 |
+
"special": true
|
| 1541 |
+
},
|
| 1542 |
+
"192": {
|
| 1543 |
+
"content": "<mask_168>",
|
| 1544 |
+
"lstrip": false,
|
| 1545 |
+
"normalized": false,
|
| 1546 |
+
"rstrip": false,
|
| 1547 |
+
"single_word": false,
|
| 1548 |
+
"special": true
|
| 1549 |
+
},
|
| 1550 |
+
"193": {
|
| 1551 |
+
"content": "<mask_169>",
|
| 1552 |
+
"lstrip": false,
|
| 1553 |
+
"normalized": false,
|
| 1554 |
+
"rstrip": false,
|
| 1555 |
+
"single_word": false,
|
| 1556 |
+
"special": true
|
| 1557 |
+
},
|
| 1558 |
+
"194": {
|
| 1559 |
+
"content": "<mask_170>",
|
| 1560 |
+
"lstrip": false,
|
| 1561 |
+
"normalized": false,
|
| 1562 |
+
"rstrip": false,
|
| 1563 |
+
"single_word": false,
|
| 1564 |
+
"special": true
|
| 1565 |
+
},
|
| 1566 |
+
"195": {
|
| 1567 |
+
"content": "<mask_171>",
|
| 1568 |
+
"lstrip": false,
|
| 1569 |
+
"normalized": false,
|
| 1570 |
+
"rstrip": false,
|
| 1571 |
+
"single_word": false,
|
| 1572 |
+
"special": true
|
| 1573 |
+
},
|
| 1574 |
+
"196": {
|
| 1575 |
+
"content": "<mask_172>",
|
| 1576 |
+
"lstrip": false,
|
| 1577 |
+
"normalized": false,
|
| 1578 |
+
"rstrip": false,
|
| 1579 |
+
"single_word": false,
|
| 1580 |
+
"special": true
|
| 1581 |
+
},
|
| 1582 |
+
"197": {
|
| 1583 |
+
"content": "<mask_173>",
|
| 1584 |
+
"lstrip": false,
|
| 1585 |
+
"normalized": false,
|
| 1586 |
+
"rstrip": false,
|
| 1587 |
+
"single_word": false,
|
| 1588 |
+
"special": true
|
| 1589 |
+
},
|
| 1590 |
+
"198": {
|
| 1591 |
+
"content": "<mask_174>",
|
| 1592 |
+
"lstrip": false,
|
| 1593 |
+
"normalized": false,
|
| 1594 |
+
"rstrip": false,
|
| 1595 |
+
"single_word": false,
|
| 1596 |
+
"special": true
|
| 1597 |
+
},
|
| 1598 |
+
"199": {
|
| 1599 |
+
"content": "<mask_175>",
|
| 1600 |
+
"lstrip": false,
|
| 1601 |
+
"normalized": false,
|
| 1602 |
+
"rstrip": false,
|
| 1603 |
+
"single_word": false,
|
| 1604 |
+
"special": true
|
| 1605 |
+
},
|
| 1606 |
+
"200": {
|
| 1607 |
+
"content": "<mask_176>",
|
| 1608 |
+
"lstrip": false,
|
| 1609 |
+
"normalized": false,
|
| 1610 |
+
"rstrip": false,
|
| 1611 |
+
"single_word": false,
|
| 1612 |
+
"special": true
|
| 1613 |
+
},
|
| 1614 |
+
"201": {
|
| 1615 |
+
"content": "<mask_177>",
|
| 1616 |
+
"lstrip": false,
|
| 1617 |
+
"normalized": false,
|
| 1618 |
+
"rstrip": false,
|
| 1619 |
+
"single_word": false,
|
| 1620 |
+
"special": true
|
| 1621 |
+
},
|
| 1622 |
+
"202": {
|
| 1623 |
+
"content": "<mask_178>",
|
| 1624 |
+
"lstrip": false,
|
| 1625 |
+
"normalized": false,
|
| 1626 |
+
"rstrip": false,
|
| 1627 |
+
"single_word": false,
|
| 1628 |
+
"special": true
|
| 1629 |
+
},
|
| 1630 |
+
"203": {
|
| 1631 |
+
"content": "<mask_179>",
|
| 1632 |
+
"lstrip": false,
|
| 1633 |
+
"normalized": false,
|
| 1634 |
+
"rstrip": false,
|
| 1635 |
+
"single_word": false,
|
| 1636 |
+
"special": true
|
| 1637 |
+
},
|
| 1638 |
+
"204": {
|
| 1639 |
+
"content": "<mask_180>",
|
| 1640 |
+
"lstrip": false,
|
| 1641 |
+
"normalized": false,
|
| 1642 |
+
"rstrip": false,
|
| 1643 |
+
"single_word": false,
|
| 1644 |
+
"special": true
|
| 1645 |
+
},
|
| 1646 |
+
"205": {
|
| 1647 |
+
"content": "<mask_181>",
|
| 1648 |
+
"lstrip": false,
|
| 1649 |
+
"normalized": false,
|
| 1650 |
+
"rstrip": false,
|
| 1651 |
+
"single_word": false,
|
| 1652 |
+
"special": true
|
| 1653 |
+
},
|
| 1654 |
+
"206": {
|
| 1655 |
+
"content": "<mask_182>",
|
| 1656 |
+
"lstrip": false,
|
| 1657 |
+
"normalized": false,
|
| 1658 |
+
"rstrip": false,
|
| 1659 |
+
"single_word": false,
|
| 1660 |
+
"special": true
|
| 1661 |
+
},
|
| 1662 |
+
"207": {
|
| 1663 |
+
"content": "<mask_183>",
|
| 1664 |
+
"lstrip": false,
|
| 1665 |
+
"normalized": false,
|
| 1666 |
+
"rstrip": false,
|
| 1667 |
+
"single_word": false,
|
| 1668 |
+
"special": true
|
| 1669 |
+
},
|
| 1670 |
+
"208": {
|
| 1671 |
+
"content": "<mask_184>",
|
| 1672 |
+
"lstrip": false,
|
| 1673 |
+
"normalized": false,
|
| 1674 |
+
"rstrip": false,
|
| 1675 |
+
"single_word": false,
|
| 1676 |
+
"special": true
|
| 1677 |
+
},
|
| 1678 |
+
"209": {
|
| 1679 |
+
"content": "<mask_185>",
|
| 1680 |
+
"lstrip": false,
|
| 1681 |
+
"normalized": false,
|
| 1682 |
+
"rstrip": false,
|
| 1683 |
+
"single_word": false,
|
| 1684 |
+
"special": true
|
| 1685 |
+
},
|
| 1686 |
+
"210": {
|
| 1687 |
+
"content": "<mask_186>",
|
| 1688 |
+
"lstrip": false,
|
| 1689 |
+
"normalized": false,
|
| 1690 |
+
"rstrip": false,
|
| 1691 |
+
"single_word": false,
|
| 1692 |
+
"special": true
|
| 1693 |
+
},
|
| 1694 |
+
"211": {
|
| 1695 |
+
"content": "<mask_187>",
|
| 1696 |
+
"lstrip": false,
|
| 1697 |
+
"normalized": false,
|
| 1698 |
+
"rstrip": false,
|
| 1699 |
+
"single_word": false,
|
| 1700 |
+
"special": true
|
| 1701 |
+
},
|
| 1702 |
+
"212": {
|
| 1703 |
+
"content": "<mask_188>",
|
| 1704 |
+
"lstrip": false,
|
| 1705 |
+
"normalized": false,
|
| 1706 |
+
"rstrip": false,
|
| 1707 |
+
"single_word": false,
|
| 1708 |
+
"special": true
|
| 1709 |
+
},
|
| 1710 |
+
"213": {
|
| 1711 |
+
"content": "<mask_189>",
|
| 1712 |
+
"lstrip": false,
|
| 1713 |
+
"normalized": false,
|
| 1714 |
+
"rstrip": false,
|
| 1715 |
+
"single_word": false,
|
| 1716 |
+
"special": true
|
| 1717 |
+
},
|
| 1718 |
+
"214": {
|
| 1719 |
+
"content": "<mask_190>",
|
| 1720 |
+
"lstrip": false,
|
| 1721 |
+
"normalized": false,
|
| 1722 |
+
"rstrip": false,
|
| 1723 |
+
"single_word": false,
|
| 1724 |
+
"special": true
|
| 1725 |
+
},
|
| 1726 |
+
"215": {
|
| 1727 |
+
"content": "<mask_191>",
|
| 1728 |
+
"lstrip": false,
|
| 1729 |
+
"normalized": false,
|
| 1730 |
+
"rstrip": false,
|
| 1731 |
+
"single_word": false,
|
| 1732 |
+
"special": true
|
| 1733 |
+
},
|
| 1734 |
+
"216": {
|
| 1735 |
+
"content": "<mask_192>",
|
| 1736 |
+
"lstrip": false,
|
| 1737 |
+
"normalized": false,
|
| 1738 |
+
"rstrip": false,
|
| 1739 |
+
"single_word": false,
|
| 1740 |
+
"special": true
|
| 1741 |
+
},
|
| 1742 |
+
"217": {
|
| 1743 |
+
"content": "<mask_193>",
|
| 1744 |
+
"lstrip": false,
|
| 1745 |
+
"normalized": false,
|
| 1746 |
+
"rstrip": false,
|
| 1747 |
+
"single_word": false,
|
| 1748 |
+
"special": true
|
| 1749 |
+
},
|
| 1750 |
+
"218": {
|
| 1751 |
+
"content": "<mask_194>",
|
| 1752 |
+
"lstrip": false,
|
| 1753 |
+
"normalized": false,
|
| 1754 |
+
"rstrip": false,
|
| 1755 |
+
"single_word": false,
|
| 1756 |
+
"special": true
|
| 1757 |
+
},
|
| 1758 |
+
"219": {
|
| 1759 |
+
"content": "<mask_195>",
|
| 1760 |
+
"lstrip": false,
|
| 1761 |
+
"normalized": false,
|
| 1762 |
+
"rstrip": false,
|
| 1763 |
+
"single_word": false,
|
| 1764 |
+
"special": true
|
| 1765 |
+
},
|
| 1766 |
+
"220": {
|
| 1767 |
+
"content": "<mask_196>",
|
| 1768 |
+
"lstrip": false,
|
| 1769 |
+
"normalized": false,
|
| 1770 |
+
"rstrip": false,
|
| 1771 |
+
"single_word": false,
|
| 1772 |
+
"special": true
|
| 1773 |
+
},
|
| 1774 |
+
"221": {
|
| 1775 |
+
"content": "<mask_197>",
|
| 1776 |
+
"lstrip": false,
|
| 1777 |
+
"normalized": false,
|
| 1778 |
+
"rstrip": false,
|
| 1779 |
+
"single_word": false,
|
| 1780 |
+
"special": true
|
| 1781 |
+
},
|
| 1782 |
+
"222": {
|
| 1783 |
+
"content": "<mask_198>",
|
| 1784 |
+
"lstrip": false,
|
| 1785 |
+
"normalized": false,
|
| 1786 |
+
"rstrip": false,
|
| 1787 |
+
"single_word": false,
|
| 1788 |
+
"special": true
|
| 1789 |
+
},
|
| 1790 |
+
"223": {
|
| 1791 |
+
"content": "<mask_199>",
|
| 1792 |
+
"lstrip": false,
|
| 1793 |
+
"normalized": false,
|
| 1794 |
+
"rstrip": false,
|
| 1795 |
+
"single_word": false,
|
| 1796 |
+
"special": true
|
| 1797 |
+
},
|
| 1798 |
+
"131072": {
|
| 1799 |
+
"content": "<mask_131048>",
|
| 1800 |
+
"lstrip": false,
|
| 1801 |
+
"normalized": false,
|
| 1802 |
+
"rstrip": false,
|
| 1803 |
+
"single_word": false,
|
| 1804 |
+
"special": true
|
| 1805 |
+
},
|
| 1806 |
+
"131073": {
|
| 1807 |
+
"content": "<mask_131049>",
|
| 1808 |
+
"lstrip": false,
|
| 1809 |
+
"normalized": false,
|
| 1810 |
+
"rstrip": false,
|
| 1811 |
+
"single_word": false,
|
| 1812 |
+
"special": true
|
| 1813 |
+
},
|
| 1814 |
+
"131074": {
|
| 1815 |
+
"content": "<mask_131050>",
|
| 1816 |
+
"lstrip": false,
|
| 1817 |
+
"normalized": false,
|
| 1818 |
+
"rstrip": false,
|
| 1819 |
+
"single_word": false,
|
| 1820 |
+
"special": true
|
| 1821 |
+
},
|
| 1822 |
+
"131075": {
|
| 1823 |
+
"content": "<mask_131051>",
|
| 1824 |
+
"lstrip": false,
|
| 1825 |
+
"normalized": false,
|
| 1826 |
+
"rstrip": false,
|
| 1827 |
+
"single_word": false,
|
| 1828 |
+
"special": true
|
| 1829 |
+
},
|
| 1830 |
+
"131076": {
|
| 1831 |
+
"content": "<mask_131052>",
|
| 1832 |
+
"lstrip": false,
|
| 1833 |
+
"normalized": false,
|
| 1834 |
+
"rstrip": false,
|
| 1835 |
+
"single_word": false,
|
| 1836 |
+
"special": true
|
| 1837 |
+
},
|
| 1838 |
+
"131077": {
|
| 1839 |
+
"content": "<mask_131053>",
|
| 1840 |
+
"lstrip": false,
|
| 1841 |
+
"normalized": false,
|
| 1842 |
+
"rstrip": false,
|
| 1843 |
+
"single_word": false,
|
| 1844 |
+
"special": true
|
| 1845 |
+
},
|
| 1846 |
+
"131078": {
|
| 1847 |
+
"content": "<mask_131054>",
|
| 1848 |
+
"lstrip": false,
|
| 1849 |
+
"normalized": false,
|
| 1850 |
+
"rstrip": false,
|
| 1851 |
+
"single_word": false,
|
| 1852 |
+
"special": true
|
| 1853 |
+
},
|
| 1854 |
+
"131079": {
|
| 1855 |
+
"content": "<mask_131055>",
|
| 1856 |
+
"lstrip": false,
|
| 1857 |
+
"normalized": false,
|
| 1858 |
+
"rstrip": false,
|
| 1859 |
+
"single_word": false,
|
| 1860 |
+
"special": true
|
| 1861 |
+
},
|
| 1862 |
+
"131080": {
|
| 1863 |
+
"content": "<mask_131056>",
|
| 1864 |
+
"lstrip": false,
|
| 1865 |
+
"normalized": false,
|
| 1866 |
+
"rstrip": false,
|
| 1867 |
+
"single_word": false,
|
| 1868 |
+
"special": true
|
| 1869 |
+
},
|
| 1870 |
+
"131081": {
|
| 1871 |
+
"content": "<mask_131057>",
|
| 1872 |
+
"lstrip": false,
|
| 1873 |
+
"normalized": false,
|
| 1874 |
+
"rstrip": false,
|
| 1875 |
+
"single_word": false,
|
| 1876 |
+
"special": true
|
| 1877 |
+
},
|
| 1878 |
+
"131082": {
|
| 1879 |
+
"content": "<mask_131058>",
|
| 1880 |
+
"lstrip": false,
|
| 1881 |
+
"normalized": false,
|
| 1882 |
+
"rstrip": false,
|
| 1883 |
+
"single_word": false,
|
| 1884 |
+
"special": true
|
| 1885 |
+
},
|
| 1886 |
+
"131083": {
|
| 1887 |
+
"content": "<mask_131059>",
|
| 1888 |
+
"lstrip": false,
|
| 1889 |
+
"normalized": false,
|
| 1890 |
+
"rstrip": false,
|
| 1891 |
+
"single_word": false,
|
| 1892 |
+
"special": true
|
| 1893 |
+
},
|
| 1894 |
+
"131084": {
|
| 1895 |
+
"content": "<mask_131060>",
|
| 1896 |
+
"lstrip": false,
|
| 1897 |
+
"normalized": false,
|
| 1898 |
+
"rstrip": false,
|
| 1899 |
+
"single_word": false,
|
| 1900 |
+
"special": true
|
| 1901 |
+
},
|
| 1902 |
+
"131085": {
|
| 1903 |
+
"content": "<mask_131061>",
|
| 1904 |
+
"lstrip": false,
|
| 1905 |
+
"normalized": false,
|
| 1906 |
+
"rstrip": false,
|
| 1907 |
+
"single_word": false,
|
| 1908 |
+
"special": true
|
| 1909 |
+
},
|
| 1910 |
+
"131086": {
|
| 1911 |
+
"content": "<mask_131062>",
|
| 1912 |
+
"lstrip": false,
|
| 1913 |
+
"normalized": false,
|
| 1914 |
+
"rstrip": false,
|
| 1915 |
+
"single_word": false,
|
| 1916 |
+
"special": true
|
| 1917 |
+
},
|
| 1918 |
+
"131087": {
|
| 1919 |
+
"content": "<mask_131063>",
|
| 1920 |
+
"lstrip": false,
|
| 1921 |
+
"normalized": false,
|
| 1922 |
+
"rstrip": false,
|
| 1923 |
+
"single_word": false,
|
| 1924 |
+
"special": true
|
| 1925 |
+
},
|
| 1926 |
+
"131088": {
|
| 1927 |
+
"content": "<mask_131064>",
|
| 1928 |
+
"lstrip": false,
|
| 1929 |
+
"normalized": false,
|
| 1930 |
+
"rstrip": false,
|
| 1931 |
+
"single_word": false,
|
| 1932 |
+
"special": true
|
| 1933 |
+
},
|
| 1934 |
+
"131089": {
|
| 1935 |
+
"content": "<mask_131065>",
|
| 1936 |
+
"lstrip": false,
|
| 1937 |
+
"normalized": false,
|
| 1938 |
+
"rstrip": false,
|
| 1939 |
+
"single_word": false,
|
| 1940 |
+
"special": true
|
| 1941 |
+
},
|
| 1942 |
+
"131090": {
|
| 1943 |
+
"content": "<longcat_img_token_size>",
|
| 1944 |
+
"lstrip": false,
|
| 1945 |
+
"normalized": false,
|
| 1946 |
+
"rstrip": false,
|
| 1947 |
+
"single_word": false,
|
| 1948 |
+
"special": true
|
| 1949 |
+
},
|
| 1950 |
+
"131091": {
|
| 1951 |
+
"content": "</longcat_img_token_size>",
|
| 1952 |
+
"lstrip": false,
|
| 1953 |
+
"normalized": false,
|
| 1954 |
+
"rstrip": false,
|
| 1955 |
+
"single_word": false,
|
| 1956 |
+
"special": true
|
| 1957 |
+
},
|
| 1958 |
+
"131092": {
|
| 1959 |
+
"content": "<mask_131068>",
|
| 1960 |
+
"lstrip": false,
|
| 1961 |
+
"normalized": false,
|
| 1962 |
+
"rstrip": false,
|
| 1963 |
+
"single_word": false,
|
| 1964 |
+
"special": true
|
| 1965 |
+
},
|
| 1966 |
+
"131093": {
|
| 1967 |
+
"content": "<mask_131069>",
|
| 1968 |
+
"lstrip": false,
|
| 1969 |
+
"normalized": false,
|
| 1970 |
+
"rstrip": false,
|
| 1971 |
+
"single_word": false,
|
| 1972 |
+
"special": true
|
| 1973 |
+
},
|
| 1974 |
+
"131094": {
|
| 1975 |
+
"content": "<mask_131070>",
|
| 1976 |
+
"lstrip": false,
|
| 1977 |
+
"normalized": false,
|
| 1978 |
+
"rstrip": false,
|
| 1979 |
+
"single_word": false,
|
| 1980 |
+
"special": true
|
| 1981 |
+
},
|
| 1982 |
+
"131095": {
|
| 1983 |
+
"content": "<mask_131071>",
|
| 1984 |
+
"lstrip": false,
|
| 1985 |
+
"normalized": false,
|
| 1986 |
+
"rstrip": false,
|
| 1987 |
+
"single_word": false,
|
| 1988 |
+
"special": true
|
| 1989 |
+
},
|
| 1990 |
+
"131096": {
|
| 1991 |
+
"content": "<longcat_point_start>",
|
| 1992 |
+
"lstrip": false,
|
| 1993 |
+
"normalized": false,
|
| 1994 |
+
"rstrip": false,
|
| 1995 |
+
"single_word": false,
|
| 1996 |
+
"special": true
|
| 1997 |
+
},
|
| 1998 |
+
"131097": {
|
| 1999 |
+
"content": "<longcat_point_end>",
|
| 2000 |
+
"lstrip": false,
|
| 2001 |
+
"normalized": false,
|
| 2002 |
+
"rstrip": false,
|
| 2003 |
+
"single_word": false,
|
| 2004 |
+
"special": true
|
| 2005 |
+
},
|
| 2006 |
+
"131098": {
|
| 2007 |
+
"content": "<longcat_point_delim>",
|
| 2008 |
+
"lstrip": false,
|
| 2009 |
+
"normalized": false,
|
| 2010 |
+
"rstrip": false,
|
| 2011 |
+
"single_word": false,
|
| 2012 |
+
"special": true
|
| 2013 |
+
},
|
| 2014 |
+
"131099": {
|
| 2015 |
+
"content": "<longcat_polygon_start>",
|
| 2016 |
+
"lstrip": false,
|
| 2017 |
+
"normalized": false,
|
| 2018 |
+
"rstrip": false,
|
| 2019 |
+
"single_word": false,
|
| 2020 |
+
"special": true
|
| 2021 |
+
},
|
| 2022 |
+
"131100": {
|
| 2023 |
+
"content": "<longcat_polygon_end>",
|
| 2024 |
+
"lstrip": false,
|
| 2025 |
+
"normalized": false,
|
| 2026 |
+
"rstrip": false,
|
| 2027 |
+
"single_word": false,
|
| 2028 |
+
"special": true
|
| 2029 |
+
},
|
| 2030 |
+
"131101": {
|
| 2031 |
+
"content": "<mask_131077>",
|
| 2032 |
+
"lstrip": false,
|
| 2033 |
+
"normalized": false,
|
| 2034 |
+
"rstrip": false,
|
| 2035 |
+
"single_word": false,
|
| 2036 |
+
"special": true
|
| 2037 |
+
},
|
| 2038 |
+
"131102": {
|
| 2039 |
+
"content": "<mask_131078>",
|
| 2040 |
+
"lstrip": false,
|
| 2041 |
+
"normalized": false,
|
| 2042 |
+
"rstrip": false,
|
| 2043 |
+
"single_word": false,
|
| 2044 |
+
"special": true
|
| 2045 |
+
},
|
| 2046 |
+
"131103": {
|
| 2047 |
+
"content": "<longcat_audio_start>",
|
| 2048 |
+
"lstrip": false,
|
| 2049 |
+
"normalized": false,
|
| 2050 |
+
"rstrip": false,
|
| 2051 |
+
"single_word": false,
|
| 2052 |
+
"special": true
|
| 2053 |
+
},
|
| 2054 |
+
"131104": {
|
| 2055 |
+
"content": "<longcat_audio_end>",
|
| 2056 |
+
"lstrip": false,
|
| 2057 |
+
"normalized": false,
|
| 2058 |
+
"rstrip": false,
|
| 2059 |
+
"single_word": false,
|
| 2060 |
+
"special": true
|
| 2061 |
+
},
|
| 2062 |
+
"131105": {
|
| 2063 |
+
"content": "<longcat_audio_pad>",
|
| 2064 |
+
"lstrip": false,
|
| 2065 |
+
"normalized": false,
|
| 2066 |
+
"rstrip": false,
|
| 2067 |
+
"single_word": false,
|
| 2068 |
+
"special": true
|
| 2069 |
+
},
|
| 2070 |
+
"131106": {
|
| 2071 |
+
"content": "<longcat_img_start>",
|
| 2072 |
+
"lstrip": false,
|
| 2073 |
+
"normalized": false,
|
| 2074 |
+
"rstrip": false,
|
| 2075 |
+
"single_word": false,
|
| 2076 |
+
"special": true
|
| 2077 |
+
},
|
| 2078 |
+
"131107": {
|
| 2079 |
+
"content": "<longcat_img_end>",
|
| 2080 |
+
"lstrip": false,
|
| 2081 |
+
"normalized": false,
|
| 2082 |
+
"rstrip": false,
|
| 2083 |
+
"single_word": false,
|
| 2084 |
+
"special": true
|
| 2085 |
+
},
|
| 2086 |
+
"131108": {
|
| 2087 |
+
"content": "<longcat_img_pad>",
|
| 2088 |
+
"lstrip": false,
|
| 2089 |
+
"normalized": false,
|
| 2090 |
+
"rstrip": false,
|
| 2091 |
+
"single_word": false,
|
| 2092 |
+
"special": true
|
| 2093 |
+
},
|
| 2094 |
+
"131109": {
|
| 2095 |
+
"content": "<longcat_img_newline>",
|
| 2096 |
+
"lstrip": false,
|
| 2097 |
+
"normalized": false,
|
| 2098 |
+
"rstrip": false,
|
| 2099 |
+
"single_word": false,
|
| 2100 |
+
"special": true
|
| 2101 |
+
},
|
| 2102 |
+
"131110": {
|
| 2103 |
+
"content": "<longcat_box_start>",
|
| 2104 |
+
"lstrip": false,
|
| 2105 |
+
"normalized": false,
|
| 2106 |
+
"rstrip": false,
|
| 2107 |
+
"single_word": false,
|
| 2108 |
+
"special": true
|
| 2109 |
+
},
|
| 2110 |
+
"131111": {
|
| 2111 |
+
"content": "<longcat_box_end>",
|
| 2112 |
+
"lstrip": false,
|
| 2113 |
+
"normalized": false,
|
| 2114 |
+
"rstrip": false,
|
| 2115 |
+
"single_word": false,
|
| 2116 |
+
"special": true
|
| 2117 |
+
},
|
| 2118 |
+
"131112": {
|
| 2119 |
+
"content": "<longcat_box_delim>",
|
| 2120 |
+
"lstrip": false,
|
| 2121 |
+
"normalized": false,
|
| 2122 |
+
"rstrip": false,
|
| 2123 |
+
"single_word": false,
|
| 2124 |
+
"special": true
|
| 2125 |
+
},
|
| 2126 |
+
"131113": {
|
| 2127 |
+
"content": "<longcat_ref_start>",
|
| 2128 |
+
"lstrip": false,
|
| 2129 |
+
"normalized": false,
|
| 2130 |
+
"rstrip": false,
|
| 2131 |
+
"single_word": false,
|
| 2132 |
+
"special": true
|
| 2133 |
+
},
|
| 2134 |
+
"131114": {
|
| 2135 |
+
"content": "<longcat_ref_end>",
|
| 2136 |
+
"lstrip": false,
|
| 2137 |
+
"normalized": false,
|
| 2138 |
+
"rstrip": false,
|
| 2139 |
+
"single_word": false,
|
| 2140 |
+
"special": true
|
| 2141 |
+
},
|
| 2142 |
+
"131115": {
|
| 2143 |
+
"content": "<longcat_img_delim>",
|
| 2144 |
+
"lstrip": false,
|
| 2145 |
+
"normalized": false,
|
| 2146 |
+
"rstrip": false,
|
| 2147 |
+
"single_word": false,
|
| 2148 |
+
"special": true
|
| 2149 |
+
},
|
| 2150 |
+
"131116": {
|
| 2151 |
+
"content": "<longcat_audio_delim>",
|
| 2152 |
+
"lstrip": false,
|
| 2153 |
+
"normalized": false,
|
| 2154 |
+
"rstrip": false,
|
| 2155 |
+
"single_word": false,
|
| 2156 |
+
"special": true
|
| 2157 |
+
},
|
| 2158 |
+
"131117": {
|
| 2159 |
+
"content": "<longcat_video_palce>",
|
| 2160 |
+
"lstrip": false,
|
| 2161 |
+
"normalized": false,
|
| 2162 |
+
"rstrip": false,
|
| 2163 |
+
"single_word": false,
|
| 2164 |
+
"special": true
|
| 2165 |
+
},
|
| 2166 |
+
"131118": {
|
| 2167 |
+
"content": "<longcat_video_start>",
|
| 2168 |
+
"lstrip": false,
|
| 2169 |
+
"normalized": false,
|
| 2170 |
+
"rstrip": false,
|
| 2171 |
+
"single_word": false,
|
| 2172 |
+
"special": true
|
| 2173 |
+
},
|
| 2174 |
+
"131119": {
|
| 2175 |
+
"content": "<longcat_video_end>",
|
| 2176 |
+
"lstrip": false,
|
| 2177 |
+
"normalized": false,
|
| 2178 |
+
"rstrip": false,
|
| 2179 |
+
"single_word": false,
|
| 2180 |
+
"special": true
|
| 2181 |
+
},
|
| 2182 |
+
"131120": {
|
| 2183 |
+
"content": "<longcat_audiotext_start>",
|
| 2184 |
+
"lstrip": false,
|
| 2185 |
+
"normalized": false,
|
| 2186 |
+
"rstrip": false,
|
| 2187 |
+
"single_word": false,
|
| 2188 |
+
"special": true
|
| 2189 |
+
},
|
| 2190 |
+
"131121": {
|
| 2191 |
+
"content": "<longcat_audiotext_end>",
|
| 2192 |
+
"lstrip": false,
|
| 2193 |
+
"normalized": false,
|
| 2194 |
+
"rstrip": false,
|
| 2195 |
+
"single_word": false,
|
| 2196 |
+
"special": true
|
| 2197 |
+
},
|
| 2198 |
+
"131122": {
|
| 2199 |
+
"content": "<longcat_audiotext_pad>",
|
| 2200 |
+
"lstrip": false,
|
| 2201 |
+
"normalized": false,
|
| 2202 |
+
"rstrip": false,
|
| 2203 |
+
"single_word": false,
|
| 2204 |
+
"special": true
|
| 2205 |
+
},
|
| 2206 |
+
"131123": {
|
| 2207 |
+
"content": "<longcat_audiogen_start>",
|
| 2208 |
+
"lstrip": false,
|
| 2209 |
+
"normalized": false,
|
| 2210 |
+
"rstrip": false,
|
| 2211 |
+
"single_word": false,
|
| 2212 |
+
"special": true
|
| 2213 |
+
},
|
| 2214 |
+
"131124": {
|
| 2215 |
+
"content": "<longcat_audiogen_end>",
|
| 2216 |
+
"lstrip": false,
|
| 2217 |
+
"normalized": false,
|
| 2218 |
+
"rstrip": false,
|
| 2219 |
+
"single_word": false,
|
| 2220 |
+
"special": true
|
| 2221 |
+
}
|
| 2222 |
+
},
|
| 2223 |
+
"additional_special_tokens": [
|
| 2224 |
+
"<mask_131048>",
|
| 2225 |
+
"<mask_131049>",
|
| 2226 |
+
"<mask_131050>",
|
| 2227 |
+
"<mask_131051>",
|
| 2228 |
+
"<mask_131052>",
|
| 2229 |
+
"<mask_131053>",
|
| 2230 |
+
"<mask_131054>",
|
| 2231 |
+
"<mask_131055>",
|
| 2232 |
+
"<mask_131056>",
|
| 2233 |
+
"<mask_131057>",
|
| 2234 |
+
"<mask_131058>",
|
| 2235 |
+
"<mask_131059>",
|
| 2236 |
+
"<mask_131060>",
|
| 2237 |
+
"<mask_131061>",
|
| 2238 |
+
"<mask_131062>",
|
| 2239 |
+
"<mask_131063>",
|
| 2240 |
+
"<mask_131064>",
|
| 2241 |
+
"<mask_131065>",
|
| 2242 |
+
"<longcat_img_token_size>",
|
| 2243 |
+
"</longcat_img_token_size>",
|
| 2244 |
+
"<mask_131068>",
|
| 2245 |
+
"<mask_131069>",
|
| 2246 |
+
"<mask_131070>",
|
| 2247 |
+
"<mask_131071>",
|
| 2248 |
+
"<longcat_point_start>",
|
| 2249 |
+
"<longcat_point_end>",
|
| 2250 |
+
"<longcat_point_delim>",
|
| 2251 |
+
"<longcat_polygon_start>",
|
| 2252 |
+
"<longcat_polygon_end>",
|
| 2253 |
+
"<mask_131077>",
|
| 2254 |
+
"<mask_131078>",
|
| 2255 |
+
"<longcat_audio_start>",
|
| 2256 |
+
"<longcat_audio_end>",
|
| 2257 |
+
"<longcat_audio_pad>",
|
| 2258 |
+
"<longcat_img_start>",
|
| 2259 |
+
"<longcat_img_end>",
|
| 2260 |
+
"<longcat_img_pad>",
|
| 2261 |
+
"<longcat_img_newline>",
|
| 2262 |
+
"<longcat_box_start>",
|
| 2263 |
+
"<longcat_box_end>",
|
| 2264 |
+
"<longcat_box_delim>",
|
| 2265 |
+
"<longcat_ref_start>",
|
| 2266 |
+
"<longcat_ref_end>",
|
| 2267 |
+
"<longcat_img_delim>",
|
| 2268 |
+
"<longcat_audio_delim>",
|
| 2269 |
+
"<longcat_video_palce>",
|
| 2270 |
+
"<longcat_video_start>",
|
| 2271 |
+
"<longcat_video_end>",
|
| 2272 |
+
"<longcat_audiotext_start>",
|
| 2273 |
+
"<longcat_audiotext_end>",
|
| 2274 |
+
"<longcat_audiotext_pad>",
|
| 2275 |
+
"<longcat_audiogen_start>",
|
| 2276 |
+
"<longcat_audiogen_end>"
|
| 2277 |
+
],
|
| 2278 |
+
"bos_token": "<longcat_s>",
|
| 2279 |
+
"chat_template": "{%- set tool_choice = tool_choice | default('auto') %}\n{%- set ns = namespace(tool_types = [], last_query_index = -1, suffix_to_move = '') %}\n\n{%- if tools and tool_choice != 'none' %}\n {{- \"<longcat_tool_declare>\\n\"-}}\n {{- \"# Tools\\n\" }}\n {{- \"You have access to the following tools:\\n\\n\" }}\n {%- for tool in tools %}\n {%- if tool.type not in ns.tool_types %}\n {%- set ns.tool_types = ns.tool_types + [tool.type] %}\n {{- \"## Tool namespace: \" ~ tool.type ~ \"\\n\\n\" }}\n {%- endif %}\n {%- if tool.type == 'code_interpreter' %}\n {%- set tool = {\"type\":\"code_interpreter\",\"function\":{\"name\":\"code_interpreter_preview\",\"description\":\"The code will be executed in a stateful Jupyter notebook sandbox environment, only supports local computation, data processing, and file operations.\\nCode sandbox environment (network isolated) Any external network requests or online API calls are prohibited.\\nIf online functionality is needed, please use other permitted tools.\\nCode will respond with the output of the execution or time out after 60.0 seconds. \",\"parameters\":{\"type\":\"object\",\"properties\":{\"language\":{\"type\":\"string\",\"description\":\"The programming language of the code to be executed. Available values: python (Default), java, go, js, ts, c, c++.\"},\"code\":{\"type\":\"string\",\"description\":\"Python code to be executed must not include the following:\\n- Importing network libraries such as requests, httplib, etc.\\n- Any form of HTTP requests.\\n- External API calls.\\n- Network port operations. Example: ```python\\nimport pandas as pd\\npd.DataFrame({'A':[1,2]})\\n```\"},\"timeout\":{\"type\":\"number\",\"description\":\"The maximum execution time of the code, in seconds. Default is 60.0.\"}}},\"required\":[\"code\"]}} %}\n {%- endif %}\n {{- \"### Tool name: \" + tool.function.name + \"\\n\" }}\n {{- \"Description: \" + tool.function.description + \"\\n\\n\" }}\n {{- \"InputSchema: \" + tool.function.parameters | tojson(ensure_ascii=False) + \"\\n\\n\" }}\n {%- endfor %}\n {{- '**Note**: For each function call, output the function name and arguments within the following XML format:\\n<longcat_tool_call>{function-name}\\n<longcat_arg_key>{arg-key-1}</longcat_arg_key>\\n<longcat_arg_value>{arg-value-1}</longcat_arg_value>\\n<longcat_arg_key>{arg-key-2}</longcat_arg_key>\\n<longcat_arg_value>{arg-value-2}</longcat_arg_value>\\n...\\n</longcat_tool_call>\\n' }}\n {{- \"</longcat_tool_declare>\"-}}\n {%- for idx in range(messages|length - 1) %}\n {%- set msg = messages[idx] %}\n {%- if msg.role == 'assistant' and not msg.tool_calls %}\n {%- set ns.last_query_index = idx %}\n {%- endif %}\n {%- endfor%}\n{%- endif %}\n\n{%- for msg in messages %}\n {%- if msg.role == \"system\" %}\n {{- \"<longcat_system>\" + msg.content }}\n {%- elif msg.role == \"user\" %}\n {{- \"<longcat_user>\" }}\n {%- if msg[\"files\"] %}\n {{- '<longcat_files>\\n' ~ msg.files | tojson(indent=2) ~ '\\n</longcat_files>' }}\n {%- endif %}\n\n {%- if add_generation_prompt and loop.last and msg.content is string and msg.content.endswith(\"<longcat_img_start>\") %}\n {%- set ns.suffix_to_move = \"<longcat_img_start>\" %}\n {{- msg.content[:-19] }}\n {%- elif add_generation_prompt and loop.last and msg.content is string and msg.content.endswith(\"<longcat_audiogen_start>\") %}\n {%- set ns.suffix_to_move = \"<longcat_audiogen_start>\" %}\n {{- msg.content[:-24] }}\n {%- else %}\n {{- msg.content }}\n {%- endif %}\n\n {%- elif msg.role == \"assistant\" %}\n {{- \"<longcat_assistant>\" }}\n {%- if enable_thinking == true and msg.reasoning_content and ns.tool_types != [] and loop.index0 > ns.last_query_index %}\n {{- \"\\n<longcat_think>\\n\" ~ msg.reasoning_content ~ \"\\n</longcat_think>\\n\" }}\n {%- endif %}\n {%- if msg.content%}\n {{- msg.content }}\n {%- endif %}\n {%- if msg.tool_calls %}\n {%- for tool_call in msg.tool_calls -%}\n {{- \"<longcat_tool_call>\" ~ tool_call.function.name ~ \"\\n\" -}}\n {% set _args = tool_call.function.arguments %}\n {% for k, v in _args.items() %}\n {{- \"<longcat_arg_key>\" ~ k ~ \"</longcat_arg_key>\\n\" -}}\n {{- \"<longcat_arg_value>\" ~ (v if v is string else v | tojson(ensure_ascii=False)) ~ \"</longcat_arg_value>\\n\" -}}\n {% endfor %}\n {{- \"</longcat_tool_call>\\n\" }}\n {%- endfor %}\n {%- endif %}\n {{- \"</longcat_s>\" -}}\n {%- elif msg.role == \"tool\" %}\n {%- if messages[loop.index0 - 1].role != \"tool\"%}\n {{- \"<longcat_user>\" -}}\n {%- endif %}\n {{- \"<longcat_tool_response>\" ~ msg.content ~ \"</longcat_tool_response>\"-}}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {%- if enable_thinking == true %}\n {{- \" /think_on\" }}\n {%- if thinking_budget %}\n {%- if thinking_budget < 1024 %}\n {%- set thinking_budget = 1024 %}\n {%- endif%}\n {{- \"\\nthinking_budget: < \" ~ thinking_budget ~ \".\"}}\n {%- endif %}\n {{- \" <longcat_assistant><longcat_think>\\n\"}}\n {%- elif enable_thinking == false %}\n {{- \" /think_off <longcat_assistant><longcat_think>\\n\\n</longcat_think>\\n\" }}\n {%- else %}\n {{- \"<longcat_assistant>\" ~ ns.suffix_to_move }}\n {%- endif %}\n{%- endif %}",
|
| 2280 |
+
"clean_up_tokenization_spaces": false,
|
| 2281 |
+
"eos_token": "</longcat_s>",
|
| 2282 |
+
"model_max_length": 131072,
|
| 2283 |
+
"pad_token": "<longcat_pad>",
|
| 2284 |
+
"sp_model_kwargs": {},
|
| 2285 |
+
"tokenizer_class": "BloomTokenizer",
|
| 2286 |
+
"unk_token": "<longcat_unk>",
|
| 2287 |
+
"image_start_token": "<longcat_img_start>",
|
| 2288 |
+
"image_end_token": "<longcat_img_end>",
|
| 2289 |
+
"image_pad_token": "<longcat_img_pad>",
|
| 2290 |
+
"image_newline_token": "<longcat_img_newline>",
|
| 2291 |
+
"audio_start_token": "<longcat_audio_start>",
|
| 2292 |
+
"audio_end_token": "<longcat_audio_end>",
|
| 2293 |
+
"audio_pad_token": "<longcat_audio_pad>"
|
| 2294 |
+
}
|