Upload MiniCPM-o 4.5 MLX 4-bit quantized model
Browse files- .gitattributes +1 -0
- README.md +137 -0
- added_tokens.json +107 -0
- chat_template.jinja +88 -0
- config.json +297 -0
- configuration_minicpmo.py +260 -0
- generation_config.json +12 -0
- model-00001-of-00002.safetensors +3 -0
- model-00002-of-00002.safetensors +3 -0
- model.safetensors.index.json +0 -0
- modeling_minicpmo.py +0 -0
- modeling_navit_siglip.py +981 -0
- preprocessor_config.json +35 -0
- processing_minicpmo.py +1665 -0
- processor_config.json +102 -0
- special_tokens_map.json +88 -0
- tokenization_minicpmo_fast.py +120 -0
- tokenizer.json +3 -0
- tokenizer_config.json +22 -0
- utils.py +2417 -0
- vocab.json +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
license_link: https://github.com/OpenBMB/MiniCPM-V/blob/main/LICENSE
|
| 4 |
+
base_model: openbmb/MiniCPM-o-4_5
|
| 5 |
+
tags:
|
| 6 |
+
- mlx
|
| 7 |
+
- vision
|
| 8 |
+
- multimodal
|
| 9 |
+
- vlm
|
| 10 |
+
- minicpm
|
| 11 |
+
- apple-silicon
|
| 12 |
+
- quantized
|
| 13 |
+
language:
|
| 14 |
+
- en
|
| 15 |
+
- zh
|
| 16 |
+
- id
|
| 17 |
+
- fr
|
| 18 |
+
- de
|
| 19 |
+
library_name: mlx
|
| 20 |
+
pipeline_tag: image-text-to-text
|
| 21 |
+
---
|
| 22 |
+
|
| 23 |
+
# MiniCPM-o 4.5 — MLX 4-bit Quantized
|
| 24 |
+
|
| 25 |
+
4-bit quantized [MLX](https://github.com/ml-explore/mlx) conversion of [openbmb/MiniCPM-o-4_5](https://huggingface.co/openbmb/MiniCPM-o-4_5) for fast inference on Apple Silicon (M1/M2/M3/M4).
|
| 26 |
+
|
| 27 |
+
## Model Details
|
| 28 |
+
|
| 29 |
+
| | |
|
| 30 |
+
|---|---|
|
| 31 |
+
| **Base model** | [openbmb/MiniCPM-o-4_5](https://huggingface.co/openbmb/MiniCPM-o-4_5) |
|
| 32 |
+
| **Architecture** | SigLIP2 (27L) + Perceiver Resampler + Qwen3 LLM (36L) |
|
| 33 |
+
| **Parameters** | ~8B |
|
| 34 |
+
| **Quantization** | 4-bit (5.255 effective bits) — LLM quantized, vision encoder & resampler full precision |
|
| 35 |
+
| **Size on disk** | ~5.3 GB |
|
| 36 |
+
| **Framework** | [MLX](https://github.com/ml-explore/mlx) via [mlx-vlm](https://github.com/Blaizzy/mlx-vlm) |
|
| 37 |
+
|
| 38 |
+
## Performance (M4 Pro, 24 GB RAM)
|
| 39 |
+
|
| 40 |
+
| Mode | Prompt Processing | Generation | Peak Memory |
|
| 41 |
+
|------|-------------------|------------|-------------|
|
| 42 |
+
| Text-only | ~100 tok/s | ~55 tok/s | ~5.8 GB |
|
| 43 |
+
| Image + Text | ~150 tok/s | ~51 tok/s | ~6.5 GB |
|
| 44 |
+
|
| 45 |
+
## Capabilities
|
| 46 |
+
|
| 47 |
+
- Image understanding & description
|
| 48 |
+
- OCR / text extraction from images
|
| 49 |
+
- Chart & diagram analysis
|
| 50 |
+
- Math equation solving from images
|
| 51 |
+
- Visual reasoning & counting
|
| 52 |
+
- Code generation
|
| 53 |
+
- Multilingual (English, Chinese, Indonesian, French, German, etc.)
|
| 54 |
+
|
| 55 |
+
## Requirements
|
| 56 |
+
|
| 57 |
+
- Apple Silicon Mac (M1 or later)
|
| 58 |
+
- Python 3.10+
|
| 59 |
+
- ~8 GB free RAM
|
| 60 |
+
|
| 61 |
+
```bash
|
| 62 |
+
pip install mlx-vlm torch transformers Pillow
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
## Quick Start
|
| 66 |
+
|
| 67 |
+
### Python API
|
| 68 |
+
|
| 69 |
+
```python
|
| 70 |
+
from mlx_vlm import load
|
| 71 |
+
from mlx_vlm.generate import generate_step
|
| 72 |
+
import mlx.core as mx
|
| 73 |
+
|
| 74 |
+
model, processor = load("andrevp/MiniCPM-o-4_5-MLX", trust_remote_code=True)
|
| 75 |
+
|
| 76 |
+
# Text-only
|
| 77 |
+
text = "<|im_start|>user\nWhat is machine learning?<|im_end|>\n<|im_start|>assistant\n"
|
| 78 |
+
input_ids = mx.array(processor.tokenizer(text, return_tensors="np")["input_ids"])
|
| 79 |
+
|
| 80 |
+
tokens = []
|
| 81 |
+
for token, _ in generate_step(input_ids, model, None, None, temp=0.0):
|
| 82 |
+
tok_val = token.item()
|
| 83 |
+
tokens.append(tok_val)
|
| 84 |
+
if processor.tokenizer.decode([tok_val]) in ["<|im_end|>", "<|endoftext|>"]:
|
| 85 |
+
break
|
| 86 |
+
|
| 87 |
+
print(processor.tokenizer.decode(tokens, skip_special_tokens=True))
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
### Chat Script
|
| 91 |
+
|
| 92 |
+
A standalone `chat_minicpmo.py` script is available in the [conversion repository](https://github.com/andrevp):
|
| 93 |
+
|
| 94 |
+
```bash
|
| 95 |
+
# Single-shot with image
|
| 96 |
+
python chat_minicpmo.py photo.jpg -p "What's in this image?"
|
| 97 |
+
|
| 98 |
+
# Single-shot text-only
|
| 99 |
+
python chat_minicpmo.py -p "Explain quantum computing briefly."
|
| 100 |
+
|
| 101 |
+
# Interactive mode
|
| 102 |
+
python chat_minicpmo.py
|
| 103 |
+
|
| 104 |
+
# Interactive with pre-loaded image
|
| 105 |
+
python chat_minicpmo.py photo.jpg
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
Interactive commands: `/image <path>` | `/clear` | `/quit`
|
| 109 |
+
|
| 110 |
+
## Quantization Details
|
| 111 |
+
|
| 112 |
+
- **LLM layers**: 4-bit quantized (group_size=64, affine mode)
|
| 113 |
+
- **Vision encoder (SigLIP2)**: Full precision (not quantized)
|
| 114 |
+
- **Perceiver Resampler**: Full precision (not quantized)
|
| 115 |
+
- **Weight breakdown**: 907 LLM keys (quantized) + 437 vision keys + 17 resampler keys (full precision)
|
| 116 |
+
|
| 117 |
+
## Limitations
|
| 118 |
+
|
| 119 |
+
- **Vision-language only**: Audio input (Whisper encoder) and TTS output (CosyVoice2) from the original model are not included in this conversion.
|
| 120 |
+
- **Single image per turn**: Processes one image at a time.
|
| 121 |
+
- Quantization may slightly reduce output quality compared to the full-precision model.
|
| 122 |
+
|
| 123 |
+
## License
|
| 124 |
+
|
| 125 |
+
This model is released under the **Apache-2.0** license, following the original [openbmb/MiniCPM-o-4_5](https://huggingface.co/openbmb/MiniCPM-o-4_5) license.
|
| 126 |
+
|
| 127 |
+
See the [original license](https://github.com/OpenBMB/MiniCPM-V/blob/main/LICENSE) for full terms.
|
| 128 |
+
|
| 129 |
+
## Disclaimer
|
| 130 |
+
|
| 131 |
+
> As an LMM, MiniCPM-o 4.5 generates content by learning from a large amount of multimodal corpora, but it cannot comprehend, express personal opinions or make value judgments. Anything generated by MiniCPM-o 4.5 does not represent the views and positions of the model developers. We will not be liable for any problems arising from the use of the MiniCPM-o models, including but not limited to data security issues, risk of public opinion, or any risks and problems arising from the misdirection, misuse, dissemination or misuse of the model.
|
| 132 |
+
|
| 133 |
+
## Credits
|
| 134 |
+
|
| 135 |
+
- **Original model**: [OpenBMB](https://github.com/OpenBMB) — [MiniCPM-o 4.5](https://huggingface.co/openbmb/MiniCPM-o-4_5)
|
| 136 |
+
- **MLX framework**: [Apple ML Explore](https://github.com/ml-explore/mlx)
|
| 137 |
+
- **mlx-vlm**: [Prince Canuma](https://github.com/Blaizzy/mlx-vlm)
|
added_tokens.json
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"</answer>": 151686,
|
| 3 |
+
"</box>": 151674,
|
| 4 |
+
"</focus>": 151688,
|
| 5 |
+
"</image>": 151670,
|
| 6 |
+
"</image_id>": 151682,
|
| 7 |
+
"</image_save_to>": 151696,
|
| 8 |
+
"</line>": 151690,
|
| 9 |
+
"</perception>": 151692,
|
| 10 |
+
"</point>": 151678,
|
| 11 |
+
"</quad>": 151676,
|
| 12 |
+
"</ref>": 151672,
|
| 13 |
+
"</slice>": 151680,
|
| 14 |
+
"</source_image>": 151694,
|
| 15 |
+
"</think>": 151668,
|
| 16 |
+
"</tool_call>": 151658,
|
| 17 |
+
"</tool_response>": 151666,
|
| 18 |
+
"</unit>": 151684,
|
| 19 |
+
"<answer>": 151685,
|
| 20 |
+
"<box>": 151673,
|
| 21 |
+
"<focus>": 151687,
|
| 22 |
+
"<image>": 151669,
|
| 23 |
+
"<image_id>": 151681,
|
| 24 |
+
"<image_save_to>": 151695,
|
| 25 |
+
"<line>": 151689,
|
| 26 |
+
"<perception>": 151691,
|
| 27 |
+
"<point>": 151677,
|
| 28 |
+
"<quad>": 151675,
|
| 29 |
+
"<ref>": 151671,
|
| 30 |
+
"<slice>": 151679,
|
| 31 |
+
"<source_image>": 151693,
|
| 32 |
+
"<think>": 151667,
|
| 33 |
+
"<tool_call>": 151657,
|
| 34 |
+
"<tool_response>": 151665,
|
| 35 |
+
"<unit>": 151683,
|
| 36 |
+
"<|audio_end|>": 151699,
|
| 37 |
+
"<|audio_start|>": 151697,
|
| 38 |
+
"<|audio|>": 151698,
|
| 39 |
+
"<|box_end|>": 151649,
|
| 40 |
+
"<|box_start|>": 151648,
|
| 41 |
+
"<|emotion_end|>": 151711,
|
| 42 |
+
"<|emotion_start|>": 151710,
|
| 43 |
+
"<|endoftext|>": 151643,
|
| 44 |
+
"<|file_sep|>": 151664,
|
| 45 |
+
"<|fim_middle|>": 151660,
|
| 46 |
+
"<|fim_pad|>": 151662,
|
| 47 |
+
"<|fim_prefix|>": 151659,
|
| 48 |
+
"<|fim_suffix|>": 151661,
|
| 49 |
+
"<|im_end|>": 151645,
|
| 50 |
+
"<|im_start|>": 151644,
|
| 51 |
+
"<|image_pad|>": 151655,
|
| 52 |
+
"<|interrupt|>": 151707,
|
| 53 |
+
"<|listen|>": 151705,
|
| 54 |
+
"<|object_ref_end|>": 151647,
|
| 55 |
+
"<|object_ref_start|>": 151646,
|
| 56 |
+
"<|pitch_end|>": 151715,
|
| 57 |
+
"<|pitch_start|>": 151714,
|
| 58 |
+
"<|quad_end|>": 151651,
|
| 59 |
+
"<|quad_start|>": 151650,
|
| 60 |
+
"<|repo_name|>": 151663,
|
| 61 |
+
"<|speak|>": 151706,
|
| 62 |
+
"<|speed_end|>": 151713,
|
| 63 |
+
"<|speed_start|>": 151712,
|
| 64 |
+
"<|spk_bos|>": 151700,
|
| 65 |
+
"<|spk_eos|>": 151702,
|
| 66 |
+
"<|spk|>": 151701,
|
| 67 |
+
"<|turn_bos|>": 151716,
|
| 68 |
+
"<|timbre_10|>": 151726,
|
| 69 |
+
"<|timbre_11|>": 151727,
|
| 70 |
+
"<|timbre_12|>": 151728,
|
| 71 |
+
"<|timbre_13|>": 151729,
|
| 72 |
+
"<|timbre_14|>": 151730,
|
| 73 |
+
"<|timbre_15|>": 151731,
|
| 74 |
+
"<|timbre_16|>": 151732,
|
| 75 |
+
"<|timbre_17|>": 151733,
|
| 76 |
+
"<|timbre_18|>": 151734,
|
| 77 |
+
"<|timbre_19|>": 151735,
|
| 78 |
+
"<|turn_eos|>": 151717,
|
| 79 |
+
"<|timbre_20|>": 151736,
|
| 80 |
+
"<|timbre_21|>": 151737,
|
| 81 |
+
"<|timbre_22|>": 151738,
|
| 82 |
+
"<|timbre_23|>": 151739,
|
| 83 |
+
"<|timbre_24|>": 151740,
|
| 84 |
+
"<|timbre_25|>": 151741,
|
| 85 |
+
"<|timbre_26|>": 151742,
|
| 86 |
+
"<|timbre_27|>": 151743,
|
| 87 |
+
"<|timbre_28|>": 151744,
|
| 88 |
+
"<|timbre_29|>": 151745,
|
| 89 |
+
"<|chunk_eos|>": 151718,
|
| 90 |
+
"<|timbre_30|>": 151746,
|
| 91 |
+
"<|timbre_31|>": 151747,
|
| 92 |
+
"<|chunk_bos|>": 151719,
|
| 93 |
+
"<|chunk_tts_bos|>": 151720,
|
| 94 |
+
"<|chunk_tts_eos|>": 151721,
|
| 95 |
+
"<|tts_pad|>": 151722,
|
| 96 |
+
"<|timbre_7|>": 151723,
|
| 97 |
+
"<|timbre_8|>": 151724,
|
| 98 |
+
"<|timbre_9|>": 151725,
|
| 99 |
+
"<|tts_bos|>": 151703,
|
| 100 |
+
"<|tts_eos|>": 151704,
|
| 101 |
+
"<|vad_end|>": 151709,
|
| 102 |
+
"<|vad_start|>": 151708,
|
| 103 |
+
"<|video_pad|>": 151656,
|
| 104 |
+
"<|vision_end|>": 151653,
|
| 105 |
+
"<|vision_pad|>": 151654,
|
| 106 |
+
"<|vision_start|>": 151652
|
| 107 |
+
}
|
chat_template.jinja
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{%- if tools %}
|
| 2 |
+
{{- '<|im_start|>system\n' }}
|
| 3 |
+
{%- if messages[0].role == 'system' %}
|
| 4 |
+
{{- messages[0].content + '\n\n' }}
|
| 5 |
+
{%- endif %}
|
| 6 |
+
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
|
| 7 |
+
{%- for tool in tools %}
|
| 8 |
+
{{- "\n" }}
|
| 9 |
+
{{- tool | tojson }}
|
| 10 |
+
{%- endfor %}
|
| 11 |
+
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
|
| 12 |
+
{%- else %}
|
| 13 |
+
{%- if messages[0].role == 'system' %}
|
| 14 |
+
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
|
| 15 |
+
{%- endif %}
|
| 16 |
+
{%- endif %}
|
| 17 |
+
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
|
| 18 |
+
{%- for message in messages[::-1] %}
|
| 19 |
+
{%- set index = (messages|length - 1) - loop.index0 %}
|
| 20 |
+
{%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
|
| 21 |
+
{%- set ns.multi_step_tool = false %}
|
| 22 |
+
{%- set ns.last_query_index = index %}
|
| 23 |
+
{%- endif %}
|
| 24 |
+
{%- endfor %}
|
| 25 |
+
{%- for message in messages %}
|
| 26 |
+
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
|
| 27 |
+
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
|
| 28 |
+
{%- elif message.role == "assistant" %}
|
| 29 |
+
{%- set content = message.content %}
|
| 30 |
+
{%- set reasoning_content = '' %}
|
| 31 |
+
{%- if message.reasoning_content is defined and message.reasoning_content is not none %}
|
| 32 |
+
{%- set reasoning_content = message.reasoning_content %}
|
| 33 |
+
{%- else %}
|
| 34 |
+
{%- if '</think>' in message.content %}
|
| 35 |
+
{%- set content = message.content.split('</think>')[-1].lstrip('\n') %}
|
| 36 |
+
{%- set reasoning_content = message.content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
|
| 37 |
+
{%- endif %}
|
| 38 |
+
{%- endif %}
|
| 39 |
+
{%- if loop.index0 > ns.last_query_index %}
|
| 40 |
+
{%- if loop.last or (not loop.last and reasoning_content) %}
|
| 41 |
+
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
|
| 42 |
+
{%- else %}
|
| 43 |
+
{{- '<|im_start|>' + message.role + '\n' + content }}
|
| 44 |
+
{%- endif %}
|
| 45 |
+
{%- else %}
|
| 46 |
+
{{- '<|im_start|>' + message.role + '\n' + content }}
|
| 47 |
+
{%- endif %}
|
| 48 |
+
{%- if message.tool_calls %}
|
| 49 |
+
{%- for tool_call in message.tool_calls %}
|
| 50 |
+
{%- if (loop.first and content) or (not loop.first) %}
|
| 51 |
+
{{- '\n' }}
|
| 52 |
+
{%- endif %}
|
| 53 |
+
{%- if tool_call.function %}
|
| 54 |
+
{%- set tool_call = tool_call.function %}
|
| 55 |
+
{%- endif %}
|
| 56 |
+
{{- '<tool_call>\n{"name": "' }}
|
| 57 |
+
{{- tool_call.name }}
|
| 58 |
+
{{- '", "arguments": ' }}
|
| 59 |
+
{%- if tool_call.arguments is string %}
|
| 60 |
+
{{- tool_call.arguments }}
|
| 61 |
+
{%- else %}
|
| 62 |
+
{{- tool_call.arguments | tojson }}
|
| 63 |
+
{%- endif %}
|
| 64 |
+
{{- '}\n</tool_call>' }}
|
| 65 |
+
{%- endfor %}
|
| 66 |
+
{%- endif %}
|
| 67 |
+
{{- '<|im_end|>\n' }}
|
| 68 |
+
{%- elif message.role == "tool" %}
|
| 69 |
+
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
|
| 70 |
+
{{- '<|im_start|>user' }}
|
| 71 |
+
{%- endif %}
|
| 72 |
+
{{- '\n<tool_response>\n' }}
|
| 73 |
+
{{- message.content }}
|
| 74 |
+
{{- '\n</tool_response>' }}
|
| 75 |
+
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
| 76 |
+
{{- '<|im_end|>\n' }}
|
| 77 |
+
{%- endif %}
|
| 78 |
+
{%- endif %}
|
| 79 |
+
{%- endfor %}
|
| 80 |
+
{%- if add_generation_prompt %}
|
| 81 |
+
{{- '<|im_start|>assistant\n' }}
|
| 82 |
+
{%- if enable_thinking is defined and enable_thinking is false %}
|
| 83 |
+
{{- '<think>\n\n</think>\n\n' }}
|
| 84 |
+
{%- endif %}
|
| 85 |
+
{%- if use_tts_template is defined and use_tts_template is true %}
|
| 86 |
+
{{- '<|tts_bos|>' }}
|
| 87 |
+
{%- endif %}
|
| 88 |
+
{%- endif %}
|
config.json
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"MiniCPMO"
|
| 4 |
+
],
|
| 5 |
+
"attention_bias": false,
|
| 6 |
+
"attention_dropout": 0.0,
|
| 7 |
+
"audio_chunk_length": 1.0,
|
| 8 |
+
"audio_config": {
|
| 9 |
+
"_attn_implementation_autoset": true,
|
| 10 |
+
"_name_or_path": "openai/whisper-medium",
|
| 11 |
+
"activation_dropout": 0.0,
|
| 12 |
+
"activation_function": "gelu",
|
| 13 |
+
"apply_spec_augment": false,
|
| 14 |
+
"architectures": [
|
| 15 |
+
"MiniCPMWhisperEncoder"
|
| 16 |
+
],
|
| 17 |
+
"attention_dropout": 0.0,
|
| 18 |
+
"begin_suppress_tokens": [
|
| 19 |
+
220,
|
| 20 |
+
50257
|
| 21 |
+
],
|
| 22 |
+
"bos_token_id": 50257,
|
| 23 |
+
"classifier_proj_size": 256,
|
| 24 |
+
"d_model": 1024,
|
| 25 |
+
"decoder_attention_heads": 16,
|
| 26 |
+
"decoder_ffn_dim": 4096,
|
| 27 |
+
"decoder_layerdrop": 0.0,
|
| 28 |
+
"decoder_layers": 24,
|
| 29 |
+
"decoder_start_token_id": 50258,
|
| 30 |
+
"dropout": 0.0,
|
| 31 |
+
"encoder_attention_heads": 16,
|
| 32 |
+
"encoder_ffn_dim": 4096,
|
| 33 |
+
"encoder_layerdrop": 0.0,
|
| 34 |
+
"encoder_layers": 24,
|
| 35 |
+
"eos_token_id": 50257,
|
| 36 |
+
"forced_decoder_ids": [
|
| 37 |
+
[
|
| 38 |
+
1,
|
| 39 |
+
50259
|
| 40 |
+
],
|
| 41 |
+
[
|
| 42 |
+
2,
|
| 43 |
+
50359
|
| 44 |
+
],
|
| 45 |
+
[
|
| 46 |
+
3,
|
| 47 |
+
50363
|
| 48 |
+
]
|
| 49 |
+
],
|
| 50 |
+
"init_std": 0.02,
|
| 51 |
+
"mask_feature_length": 10,
|
| 52 |
+
"mask_feature_min_masks": 0,
|
| 53 |
+
"mask_feature_prob": 0.0,
|
| 54 |
+
"mask_time_length": 10,
|
| 55 |
+
"mask_time_min_masks": 2,
|
| 56 |
+
"mask_time_prob": 0.05,
|
| 57 |
+
"max_length": 448,
|
| 58 |
+
"max_source_positions": 1500,
|
| 59 |
+
"max_target_positions": 448,
|
| 60 |
+
"median_filter_width": 7,
|
| 61 |
+
"model_type": "whisper",
|
| 62 |
+
"num_hidden_layers": 24,
|
| 63 |
+
"num_mel_bins": 80,
|
| 64 |
+
"pad_token_id": 50257,
|
| 65 |
+
"scale_embedding": false,
|
| 66 |
+
"suppress_tokens": [
|
| 67 |
+
1,
|
| 68 |
+
2,
|
| 69 |
+
7,
|
| 70 |
+
8,
|
| 71 |
+
9,
|
| 72 |
+
10,
|
| 73 |
+
14,
|
| 74 |
+
25,
|
| 75 |
+
26,
|
| 76 |
+
27,
|
| 77 |
+
28,
|
| 78 |
+
29,
|
| 79 |
+
31,
|
| 80 |
+
58,
|
| 81 |
+
59,
|
| 82 |
+
60,
|
| 83 |
+
61,
|
| 84 |
+
62,
|
| 85 |
+
63,
|
| 86 |
+
90,
|
| 87 |
+
91,
|
| 88 |
+
92,
|
| 89 |
+
93,
|
| 90 |
+
359,
|
| 91 |
+
503,
|
| 92 |
+
522,
|
| 93 |
+
542,
|
| 94 |
+
873,
|
| 95 |
+
893,
|
| 96 |
+
902,
|
| 97 |
+
918,
|
| 98 |
+
922,
|
| 99 |
+
931,
|
| 100 |
+
1350,
|
| 101 |
+
1853,
|
| 102 |
+
1982,
|
| 103 |
+
2460,
|
| 104 |
+
2627,
|
| 105 |
+
3246,
|
| 106 |
+
3253,
|
| 107 |
+
3268,
|
| 108 |
+
3536,
|
| 109 |
+
3846,
|
| 110 |
+
3961,
|
| 111 |
+
4183,
|
| 112 |
+
4667,
|
| 113 |
+
6585,
|
| 114 |
+
6647,
|
| 115 |
+
7273,
|
| 116 |
+
9061,
|
| 117 |
+
9383,
|
| 118 |
+
10428,
|
| 119 |
+
10929,
|
| 120 |
+
11938,
|
| 121 |
+
12033,
|
| 122 |
+
12331,
|
| 123 |
+
12562,
|
| 124 |
+
13793,
|
| 125 |
+
14157,
|
| 126 |
+
14635,
|
| 127 |
+
15265,
|
| 128 |
+
15618,
|
| 129 |
+
16553,
|
| 130 |
+
16604,
|
| 131 |
+
18362,
|
| 132 |
+
18956,
|
| 133 |
+
20075,
|
| 134 |
+
21675,
|
| 135 |
+
22520,
|
| 136 |
+
26130,
|
| 137 |
+
26161,
|
| 138 |
+
26435,
|
| 139 |
+
28279,
|
| 140 |
+
29464,
|
| 141 |
+
31650,
|
| 142 |
+
32302,
|
| 143 |
+
32470,
|
| 144 |
+
36865,
|
| 145 |
+
42863,
|
| 146 |
+
47425,
|
| 147 |
+
49870,
|
| 148 |
+
50254,
|
| 149 |
+
50258,
|
| 150 |
+
50358,
|
| 151 |
+
50359,
|
| 152 |
+
50360,
|
| 153 |
+
50361,
|
| 154 |
+
50362
|
| 155 |
+
],
|
| 156 |
+
"torch_dtype": "float32",
|
| 157 |
+
"use_cache": true,
|
| 158 |
+
"use_weighted_layer_sum": false,
|
| 159 |
+
"vocab_size": 51865
|
| 160 |
+
},
|
| 161 |
+
"audio_pool_step": 5,
|
| 162 |
+
"auto_map": {
|
| 163 |
+
"AutoConfig": "configuration_minicpmo.MiniCPMOConfig",
|
| 164 |
+
"AutoModel": "modeling_minicpmo.MiniCPMO",
|
| 165 |
+
"AutoModelForCausalLM": "modeling_minicpmo.MiniCPMO"
|
| 166 |
+
},
|
| 167 |
+
"batch_vision_input": true,
|
| 168 |
+
"bos_token_id": 151643,
|
| 169 |
+
"drop_vision_last_layer": false,
|
| 170 |
+
"eos_token_id": [
|
| 171 |
+
151645,
|
| 172 |
+
151643
|
| 173 |
+
],
|
| 174 |
+
"head_dim": 128,
|
| 175 |
+
"hidden_act": "silu",
|
| 176 |
+
"hidden_size": 4096,
|
| 177 |
+
"image_size": 448,
|
| 178 |
+
"init_audio": true,
|
| 179 |
+
"init_tts": true,
|
| 180 |
+
"init_vision": true,
|
| 181 |
+
"initializer_range": 0.02,
|
| 182 |
+
"intermediate_size": 12288,
|
| 183 |
+
"listen_speak_type": "asr",
|
| 184 |
+
"max_position_embeddings": 40960,
|
| 185 |
+
"max_window_layers": 36,
|
| 186 |
+
"model_type": "minicpmo",
|
| 187 |
+
"num_attention_heads": 32,
|
| 188 |
+
"num_hidden_layers": 36,
|
| 189 |
+
"num_key_value_heads": 8,
|
| 190 |
+
"patch_size": 14,
|
| 191 |
+
"quantization": {
|
| 192 |
+
"group_size": 64,
|
| 193 |
+
"bits": 4,
|
| 194 |
+
"mode": "affine"
|
| 195 |
+
},
|
| 196 |
+
"quantization_config": {
|
| 197 |
+
"group_size": 64,
|
| 198 |
+
"bits": 4,
|
| 199 |
+
"mode": "affine"
|
| 200 |
+
},
|
| 201 |
+
"query_num": 64,
|
| 202 |
+
"rms_norm_eps": 1e-06,
|
| 203 |
+
"rope_scaling": null,
|
| 204 |
+
"rope_theta": 1000000,
|
| 205 |
+
"slice_config": {
|
| 206 |
+
"max_slice_nums": 1,
|
| 207 |
+
"model_type": "minicpmv",
|
| 208 |
+
"patch_size": 14,
|
| 209 |
+
"scale_resolution": 448
|
| 210 |
+
},
|
| 211 |
+
"slice_mode": true,
|
| 212 |
+
"sliding_window": null,
|
| 213 |
+
"stream_input": true,
|
| 214 |
+
"tie_word_embeddings": false,
|
| 215 |
+
"transformers_version": "4.51.0",
|
| 216 |
+
"tts_config": {
|
| 217 |
+
"_attn_implementation_autoset": true,
|
| 218 |
+
"attention_type": "full_attention",
|
| 219 |
+
"attn_implementation": "sdpa",
|
| 220 |
+
"audio_bos_token_id": 151687,
|
| 221 |
+
"audio_tokenizer_sample_rate": 16000,
|
| 222 |
+
"audio_tokenizer_type": "s3tokenizer",
|
| 223 |
+
"aug_layer_loss_weight": false,
|
| 224 |
+
"aug_loss_weight": false,
|
| 225 |
+
"backbone_model": "llama",
|
| 226 |
+
"condition_type": "hidden_text_merge",
|
| 227 |
+
"cosyvoice_config_path": null,
|
| 228 |
+
"cosyvoice_model_dir": null,
|
| 229 |
+
"filter_tts_loss": false,
|
| 230 |
+
"hidden_act": "silu",
|
| 231 |
+
"hidden_size": 768,
|
| 232 |
+
"interleaved": false,
|
| 233 |
+
"intermediate_size": 3072,
|
| 234 |
+
"llm_dim": 4096,
|
| 235 |
+
"llm_dim_model_base": 256,
|
| 236 |
+
"llm_down_scale": false,
|
| 237 |
+
"llm_hidden_size": 4096,
|
| 238 |
+
"llm_intermediate_size": 768,
|
| 239 |
+
"long_weight": 0.1,
|
| 240 |
+
"max_position_embeddings": 4096,
|
| 241 |
+
"model_type": "minicpmtts",
|
| 242 |
+
"normalize_projected_hidden": true,
|
| 243 |
+
"num_attention_heads": 12,
|
| 244 |
+
"num_audio_tokens": 6562,
|
| 245 |
+
"num_hidden_layers": 20,
|
| 246 |
+
"num_key_value_heads": 12,
|
| 247 |
+
"num_mel_bins": 100,
|
| 248 |
+
"num_text_tokens": 152064,
|
| 249 |
+
"num_vq": 1,
|
| 250 |
+
"projector_type": "mlp",
|
| 251 |
+
"recomputed_chunks": 1,
|
| 252 |
+
"s3_stream_chunk_size": 25,
|
| 253 |
+
"s3_stream_generate": false,
|
| 254 |
+
"s3_stream_n_timesteps": 10,
|
| 255 |
+
"s3_stream_prelook_size": 3,
|
| 256 |
+
"short_weight": 0.1,
|
| 257 |
+
"streaming": false,
|
| 258 |
+
"streaming_audio_chunk_size": 50,
|
| 259 |
+
"streaming_sliding_window": false,
|
| 260 |
+
"streaming_sliding_window_audio_frame_rate": 50,
|
| 261 |
+
"streaming_sliding_window_audio_init_text_length": 10,
|
| 262 |
+
"streaming_sliding_window_audio_window_size": 300,
|
| 263 |
+
"streaming_sliding_window_average_speed": 5,
|
| 264 |
+
"streaming_sliding_window_fast_speed": 7,
|
| 265 |
+
"streaming_sliding_window_max_text_len": 500,
|
| 266 |
+
"streaming_sliding_window_slow_speed": 3,
|
| 267 |
+
"streaming_sliding_window_text_window_size": 50,
|
| 268 |
+
"streaming_text_chunk_max": 7,
|
| 269 |
+
"streaming_text_chunk_min": 3,
|
| 270 |
+
"streaming_text_reserved_len": 300,
|
| 271 |
+
"text_eos_token_id": 151692,
|
| 272 |
+
"tts_filter_loss_fix": false,
|
| 273 |
+
"use_llm_hidden_state": false,
|
| 274 |
+
"use_text": true,
|
| 275 |
+
"window_size": 2
|
| 276 |
+
},
|
| 277 |
+
"use_cache": true,
|
| 278 |
+
"use_image_id": true,
|
| 279 |
+
"use_sliding_window": false,
|
| 280 |
+
"version": "4.5",
|
| 281 |
+
"vision_batch_size": 16,
|
| 282 |
+
"vision_config": {
|
| 283 |
+
"_attn_implementation_autoset": true,
|
| 284 |
+
"attention_dropout": 0.0,
|
| 285 |
+
"hidden_act": "gelu_pytorch_tanh",
|
| 286 |
+
"hidden_size": 1152,
|
| 287 |
+
"image_size": 980,
|
| 288 |
+
"intermediate_size": 4304,
|
| 289 |
+
"layer_norm_eps": 1e-06,
|
| 290 |
+
"model_type": "siglip_vision_model",
|
| 291 |
+
"num_attention_heads": 16,
|
| 292 |
+
"num_channels": 3,
|
| 293 |
+
"num_hidden_layers": 27,
|
| 294 |
+
"patch_size": 14
|
| 295 |
+
},
|
| 296 |
+
"vocab_size": 151748
|
| 297 |
+
}
|
configuration_minicpmo.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# Copyright 2026 The OpenBMB Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
from typing import Union
|
| 19 |
+
|
| 20 |
+
from transformers import PretrainedConfig
|
| 21 |
+
from transformers import Qwen3Config
|
| 22 |
+
from transformers import WhisperConfig
|
| 23 |
+
from transformers.utils import logging
|
| 24 |
+
|
| 25 |
+
from .modeling_navit_siglip import SiglipVisionConfig
|
| 26 |
+
|
| 27 |
+
logger = logging.get_logger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class MiniCPMVSliceConfig(PretrainedConfig):
|
| 31 |
+
model_type = "minicpmv"
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
patch_size=14,
|
| 36 |
+
max_slice_nums=9,
|
| 37 |
+
scale_resolution=448,
|
| 38 |
+
**kwargs,
|
| 39 |
+
):
|
| 40 |
+
super().__init__(**kwargs)
|
| 41 |
+
self.patch_size = patch_size
|
| 42 |
+
self.max_slice_nums = max_slice_nums
|
| 43 |
+
self.scale_resolution = scale_resolution
|
| 44 |
+
|
| 45 |
+
@classmethod
|
| 46 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
| 47 |
+
cls._set_token_in_kwargs(kwargs)
|
| 48 |
+
|
| 49 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
| 50 |
+
|
| 51 |
+
if config_dict.get("model_type") == "minicpmv":
|
| 52 |
+
config_dict = config_dict["slice_config"]
|
| 53 |
+
|
| 54 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
| 55 |
+
logger.warning(
|
| 56 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
| 57 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
return cls.from_dict(config_dict, **kwargs)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class MiniCPMTTSConfig(PretrainedConfig):
|
| 64 |
+
model_type = "minicpmtts"
|
| 65 |
+
|
| 66 |
+
def __init__(
|
| 67 |
+
self,
|
| 68 |
+
llm_dim: int = 2560,
|
| 69 |
+
llm_intermediate_size: int = 768,
|
| 70 |
+
llm_down_scale: bool = False,
|
| 71 |
+
llm_dim_model_base: int = 256,
|
| 72 |
+
projector_type: str = "mlp",
|
| 73 |
+
hidden_act: str = "silu",
|
| 74 |
+
aug_loss_weight: bool = False,
|
| 75 |
+
aug_layer_loss_weight: bool = False,
|
| 76 |
+
filter_tts_loss: bool = False,
|
| 77 |
+
tts_filter_loss_fix: bool = False,
|
| 78 |
+
long_weight: float = 0.1,
|
| 79 |
+
short_weight: float = 0.1,
|
| 80 |
+
hidden_size: int = 768,
|
| 81 |
+
intermediate_size: int = 3072,
|
| 82 |
+
num_attention_heads: int = 12,
|
| 83 |
+
num_hidden_layers: int = 20,
|
| 84 |
+
num_key_value_heads: int = 12,
|
| 85 |
+
max_position_embeddings: int = 4096,
|
| 86 |
+
num_audio_tokens: int = 4097,
|
| 87 |
+
num_text_tokens: int = 21178,
|
| 88 |
+
num_mel_bins: int = 100,
|
| 89 |
+
num_vq: int = 1,
|
| 90 |
+
use_llm_hidden_state: bool = False,
|
| 91 |
+
audio_bos_token_id: int = 21132,
|
| 92 |
+
text_eos_token_id: int = 21133,
|
| 93 |
+
use_text: bool = True,
|
| 94 |
+
streaming: bool = False,
|
| 95 |
+
streaming_text_chunk_min: int = 3,
|
| 96 |
+
streaming_text_chunk_max: int = 7,
|
| 97 |
+
streaming_text_reserved_len: int = 300,
|
| 98 |
+
streaming_audio_chunk_size: int = 50,
|
| 99 |
+
attn_implementation: str = "sdpa",
|
| 100 |
+
condition_type: str = "llm_hidden",
|
| 101 |
+
backbone_model: str = "llama",
|
| 102 |
+
audio_tokenizer_type: str = "wavtokenizer",
|
| 103 |
+
audio_tokenizer_sample_rate: int = 24000,
|
| 104 |
+
streaming_sliding_window: bool = False,
|
| 105 |
+
streaming_sliding_window_max_text_len: int = 500,
|
| 106 |
+
streaming_sliding_window_average_speed: int = 5,
|
| 107 |
+
streaming_sliding_window_fast_speed: int = 7,
|
| 108 |
+
streaming_sliding_window_slow_speed: int = 3,
|
| 109 |
+
streaming_sliding_window_audio_frame_rate: int = 50,
|
| 110 |
+
streaming_sliding_window_audio_init_text_length: int = 10,
|
| 111 |
+
streaming_sliding_window_audio_window_size: int = 300,
|
| 112 |
+
normalize_projected_hidden: bool = False,
|
| 113 |
+
interleaved: bool = False,
|
| 114 |
+
attention_type: str = "sliding_recompute",
|
| 115 |
+
recomputed_chunks: int = 1,
|
| 116 |
+
window_size: int = 2,
|
| 117 |
+
**kwargs,
|
| 118 |
+
):
|
| 119 |
+
super().__init__(**kwargs)
|
| 120 |
+
|
| 121 |
+
self.llm_dim = llm_dim
|
| 122 |
+
self.llm_hidden_size = llm_dim
|
| 123 |
+
self.llm_intermediate_size = llm_intermediate_size
|
| 124 |
+
self.llm_down_scale = llm_down_scale
|
| 125 |
+
self.llm_dim_model_base = llm_dim_model_base
|
| 126 |
+
self.projector_type = projector_type
|
| 127 |
+
self.aug_loss_weight = aug_loss_weight
|
| 128 |
+
self.aug_layer_loss_weight = aug_layer_loss_weight
|
| 129 |
+
self.tts_filter_loss_fix = tts_filter_loss_fix
|
| 130 |
+
self.filter_tts_loss = filter_tts_loss
|
| 131 |
+
self.long_weight = long_weight
|
| 132 |
+
self.short_weight = short_weight
|
| 133 |
+
self.hidden_act = hidden_act
|
| 134 |
+
|
| 135 |
+
self.hidden_size = hidden_size
|
| 136 |
+
self.intermediate_size = intermediate_size
|
| 137 |
+
self.num_attention_heads = num_attention_heads
|
| 138 |
+
self.num_hidden_layers = num_hidden_layers
|
| 139 |
+
self.num_key_value_heads = num_key_value_heads
|
| 140 |
+
self.max_position_embeddings = max_position_embeddings
|
| 141 |
+
self.num_audio_tokens = num_audio_tokens
|
| 142 |
+
self.num_text_tokens = num_text_tokens
|
| 143 |
+
self.num_mel_bins = num_mel_bins
|
| 144 |
+
self.num_vq = num_vq
|
| 145 |
+
self.use_llm_hidden_state = use_llm_hidden_state
|
| 146 |
+
self.audio_bos_token_id = audio_bos_token_id
|
| 147 |
+
self.text_eos_token_id = text_eos_token_id
|
| 148 |
+
self.use_text = use_text
|
| 149 |
+
self.streaming = streaming
|
| 150 |
+
self.streaming_text_chunk_min = streaming_text_chunk_min
|
| 151 |
+
self.streaming_text_chunk_max = streaming_text_chunk_max
|
| 152 |
+
self.streaming_text_reserved_len = streaming_text_reserved_len
|
| 153 |
+
self.streaming_audio_chunk_size = streaming_audio_chunk_size
|
| 154 |
+
self.attn_implementation = attn_implementation
|
| 155 |
+
self.condition_type = condition_type
|
| 156 |
+
self.backbone_model = backbone_model
|
| 157 |
+
self.audio_tokenizer_type = audio_tokenizer_type
|
| 158 |
+
self.audio_tokenizer_sample_rate = audio_tokenizer_sample_rate
|
| 159 |
+
|
| 160 |
+
self.streaming_sliding_window = streaming_sliding_window
|
| 161 |
+
self.streaming_sliding_window_max_text_len = streaming_sliding_window_max_text_len
|
| 162 |
+
self.streaming_sliding_window_average_speed = streaming_sliding_window_average_speed
|
| 163 |
+
self.streaming_sliding_window_fast_speed = streaming_sliding_window_fast_speed
|
| 164 |
+
self.streaming_sliding_window_slow_speed = streaming_sliding_window_slow_speed
|
| 165 |
+
self.streaming_sliding_window_audio_frame_rate = streaming_sliding_window_audio_frame_rate
|
| 166 |
+
self.streaming_sliding_window_audio_init_text_length = streaming_sliding_window_audio_init_text_length
|
| 167 |
+
self.streaming_sliding_window_audio_window_size = streaming_sliding_window_audio_window_size
|
| 168 |
+
|
| 169 |
+
self.normalize_projected_hidden = normalize_projected_hidden
|
| 170 |
+
|
| 171 |
+
self.interleaved = interleaved
|
| 172 |
+
self.attention_type = attention_type
|
| 173 |
+
self.recomputed_chunks = recomputed_chunks
|
| 174 |
+
self.window_size = window_size
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class MiniCPMOConfig(Qwen3Config):
|
| 178 |
+
model_type = "minicpmo"
|
| 179 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 180 |
+
|
| 181 |
+
default_vision_config = {
|
| 182 |
+
"hidden_size": 1152,
|
| 183 |
+
"image_size": 980,
|
| 184 |
+
"intermediate_size": 4304,
|
| 185 |
+
"model_type": "siglip",
|
| 186 |
+
"num_attention_heads": 16,
|
| 187 |
+
"num_hidden_layers": 27,
|
| 188 |
+
"patch_size": 14,
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
def __init__(
|
| 192 |
+
self,
|
| 193 |
+
use_cache=True,
|
| 194 |
+
query_num=64,
|
| 195 |
+
image_size=448,
|
| 196 |
+
drop_vision_last_layer=True,
|
| 197 |
+
batch_vision_input=True,
|
| 198 |
+
slice_config=None,
|
| 199 |
+
vision_config=None,
|
| 200 |
+
audio_config=None,
|
| 201 |
+
tts_config=None,
|
| 202 |
+
use_image_id=True,
|
| 203 |
+
vision_batch_size=16,
|
| 204 |
+
audio_pool_step=5,
|
| 205 |
+
audio_chunk_length=1.0,
|
| 206 |
+
stream_input=False,
|
| 207 |
+
listen_speak_type="asr",
|
| 208 |
+
init_vision=True,
|
| 209 |
+
init_audio=True,
|
| 210 |
+
init_tts=True,
|
| 211 |
+
**kwargs,
|
| 212 |
+
):
|
| 213 |
+
self.use_cache = use_cache
|
| 214 |
+
self.query_num = query_num
|
| 215 |
+
self.image_size = image_size
|
| 216 |
+
self.drop_vision_last_layer = drop_vision_last_layer
|
| 217 |
+
self.batch_vision_input = batch_vision_input
|
| 218 |
+
self.use_image_id = use_image_id
|
| 219 |
+
self.vision_batch_size = vision_batch_size
|
| 220 |
+
self.audio_pool_step = audio_pool_step
|
| 221 |
+
self.audio_chunk_length = audio_chunk_length
|
| 222 |
+
self.stream_input = stream_input
|
| 223 |
+
self.listen_speak_type = listen_speak_type
|
| 224 |
+
|
| 225 |
+
self.init_vision = init_vision
|
| 226 |
+
self.init_audio = init_audio
|
| 227 |
+
self.init_tts = init_tts
|
| 228 |
+
|
| 229 |
+
if slice_config is None:
|
| 230 |
+
self.slice_config = MiniCPMVSliceConfig(max_slice_nums=1)
|
| 231 |
+
else:
|
| 232 |
+
self.slice_config = MiniCPMVSliceConfig(**slice_config)
|
| 233 |
+
self.slice_mode = True
|
| 234 |
+
|
| 235 |
+
# same as HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit add tgt_sizes
|
| 236 |
+
if vision_config is None:
|
| 237 |
+
self.vision_config = SiglipVisionConfig(**self.default_vision_config)
|
| 238 |
+
logger.info("vision_config is None, using default vision config")
|
| 239 |
+
elif isinstance(vision_config, dict):
|
| 240 |
+
self.vision_config = SiglipVisionConfig(**vision_config)
|
| 241 |
+
elif isinstance(vision_config, SiglipVisionConfig):
|
| 242 |
+
self.vision_config = vision_config
|
| 243 |
+
|
| 244 |
+
if audio_config is None:
|
| 245 |
+
self.audio_config = WhisperConfig()
|
| 246 |
+
elif isinstance(audio_config, dict):
|
| 247 |
+
self.audio_config = WhisperConfig(**audio_config)
|
| 248 |
+
elif isinstance(audio_config, WhisperConfig):
|
| 249 |
+
self.audio_config = audio_config
|
| 250 |
+
|
| 251 |
+
if tts_config is None:
|
| 252 |
+
self.tts_config = MiniCPMTTSConfig()
|
| 253 |
+
elif isinstance(tts_config, dict):
|
| 254 |
+
self.tts_config = MiniCPMTTSConfig(**tts_config)
|
| 255 |
+
elif isinstance(tts_config, MiniCPMTTSConfig):
|
| 256 |
+
self.tts_config = tts_config
|
| 257 |
+
|
| 258 |
+
self.patch_size = self.vision_config.patch_size
|
| 259 |
+
|
| 260 |
+
super().__init__(**kwargs)
|
generation_config.json
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token_id": 151643,
|
| 3 |
+
"do_sample": true,
|
| 4 |
+
"eos_token_id": [
|
| 5 |
+
151645,
|
| 6 |
+
151643
|
| 7 |
+
],
|
| 8 |
+
"pad_token_id": 151643,
|
| 9 |
+
"temperature": 0.6,
|
| 10 |
+
"top_k": 20,
|
| 11 |
+
"top_p": 0.95
|
| 12 |
+
}
|
model-00001-of-00002.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5e0d2dbc6eacf34177bcf3badf36ead8719f48e1b4c8adb2a2754be5a73218e6
|
| 3 |
+
size 5092993723
|
model-00002-of-00002.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2b35835d4c370c057a6c11d87cb91565f54e973c09a9c6b8ee5564140ce34d2a
|
| 3 |
+
size 527444905
|
model.safetensors.index.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
modeling_minicpmo.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
modeling_navit_siglip.py
ADDED
|
@@ -0,0 +1,981 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 Google AI and The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch Siglip model."""
|
| 16 |
+
# Copied from HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit and add tgt_sizes
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
import math
|
| 20 |
+
import os
|
| 21 |
+
import warnings
|
| 22 |
+
from dataclasses import dataclass
|
| 23 |
+
from typing import Optional
|
| 24 |
+
from typing import Tuple
|
| 25 |
+
from typing import Union
|
| 26 |
+
|
| 27 |
+
import numpy as np
|
| 28 |
+
import torch
|
| 29 |
+
import torch.nn.functional as F
|
| 30 |
+
import torch.utils.checkpoint
|
| 31 |
+
from torch import nn
|
| 32 |
+
from torch.nn.init import _calculate_fan_in_and_fan_out
|
| 33 |
+
from transformers.activations import ACT2FN
|
| 34 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 35 |
+
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
| 36 |
+
from transformers.modeling_outputs import BaseModelOutput
|
| 37 |
+
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
| 38 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 39 |
+
from transformers.utils import add_start_docstrings
|
| 40 |
+
from transformers.utils import add_start_docstrings_to_model_forward
|
| 41 |
+
from transformers.utils import is_flash_attn_2_available
|
| 42 |
+
from transformers.utils import logging
|
| 43 |
+
from transformers.utils import ModelOutput
|
| 44 |
+
from transformers.utils import replace_return_docstrings
|
| 45 |
+
|
| 46 |
+
logger = logging.get_logger(__name__)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class SiglipVisionConfig(PretrainedConfig):
|
| 50 |
+
r"""
|
| 51 |
+
This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
|
| 52 |
+
Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
|
| 53 |
+
configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
|
| 54 |
+
[google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
|
| 55 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 56 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 57 |
+
Args:
|
| 58 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 59 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 60 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
| 61 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 62 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 63 |
+
Number of hidden layers in the Transformer encoder.
|
| 64 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 65 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 66 |
+
num_channels (`int`, *optional*, defaults to 3):
|
| 67 |
+
Number of channels in the input images.
|
| 68 |
+
image_size (`int`, *optional*, defaults to 224):
|
| 69 |
+
The size (resolution) of each image.
|
| 70 |
+
patch_size (`int`, *optional*, defaults to 16):
|
| 71 |
+
The size (resolution) of each patch.
|
| 72 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
|
| 73 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 74 |
+
`"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
|
| 75 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
| 76 |
+
The epsilon used by the layer normalization layers.
|
| 77 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 78 |
+
The dropout ratio for the attention probabilities.
|
| 79 |
+
Example:
|
| 80 |
+
```python
|
| 81 |
+
>>> from transformers import SiglipVisionConfig, SiglipVisionModel
|
| 82 |
+
>>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
|
| 83 |
+
>>> configuration = SiglipVisionConfig()
|
| 84 |
+
>>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
|
| 85 |
+
>>> model = SiglipVisionModel(configuration)
|
| 86 |
+
>>> # Accessing the model configuration
|
| 87 |
+
>>> configuration = model.config
|
| 88 |
+
```"""
|
| 89 |
+
|
| 90 |
+
model_type = "siglip_vision_model"
|
| 91 |
+
|
| 92 |
+
def __init__(
|
| 93 |
+
self,
|
| 94 |
+
hidden_size=768,
|
| 95 |
+
intermediate_size=3072,
|
| 96 |
+
num_hidden_layers=12,
|
| 97 |
+
num_attention_heads=12,
|
| 98 |
+
num_channels=3,
|
| 99 |
+
image_size=224,
|
| 100 |
+
patch_size=16,
|
| 101 |
+
hidden_act="gelu_pytorch_tanh",
|
| 102 |
+
layer_norm_eps=1e-6,
|
| 103 |
+
attention_dropout=0.0,
|
| 104 |
+
**kwargs,
|
| 105 |
+
):
|
| 106 |
+
super().__init__(**kwargs)
|
| 107 |
+
|
| 108 |
+
self.hidden_size = hidden_size
|
| 109 |
+
self.intermediate_size = intermediate_size
|
| 110 |
+
self.num_hidden_layers = num_hidden_layers
|
| 111 |
+
self.num_attention_heads = num_attention_heads
|
| 112 |
+
self.num_channels = num_channels
|
| 113 |
+
self.patch_size = patch_size
|
| 114 |
+
self.image_size = image_size
|
| 115 |
+
self.attention_dropout = attention_dropout
|
| 116 |
+
self.layer_norm_eps = layer_norm_eps
|
| 117 |
+
self.hidden_act = hidden_act
|
| 118 |
+
|
| 119 |
+
@classmethod
|
| 120 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
| 121 |
+
cls._set_token_in_kwargs(kwargs)
|
| 122 |
+
|
| 123 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
| 124 |
+
|
| 125 |
+
# get the vision config dict if we are loading from SiglipConfig
|
| 126 |
+
if config_dict.get("model_type") == "siglip":
|
| 127 |
+
config_dict = config_dict["vision_config"]
|
| 128 |
+
|
| 129 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
| 130 |
+
logger.warning(
|
| 131 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
| 132 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
return cls.from_dict(config_dict, **kwargs)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
|
| 139 |
+
|
| 140 |
+
SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
| 141 |
+
"google/siglip-base-patch16-224",
|
| 142 |
+
# See all SigLIP models at https://huggingface.co/models?filter=siglip
|
| 143 |
+
]
|
| 144 |
+
|
| 145 |
+
if is_flash_attn_2_available():
|
| 146 |
+
from flash_attn import flash_attn_func
|
| 147 |
+
from flash_attn import flash_attn_varlen_func
|
| 148 |
+
from flash_attn.bert_padding import index_first_axis # noqa
|
| 149 |
+
from flash_attn.bert_padding import pad_input
|
| 150 |
+
from flash_attn.bert_padding import unpad_input
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
| 154 |
+
def _get_unpad_data(attention_mask):
|
| 155 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
| 156 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
| 157 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
| 158 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
| 159 |
+
return (
|
| 160 |
+
indices,
|
| 161 |
+
cu_seqlens,
|
| 162 |
+
max_seqlen_in_batch,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def _trunc_normal_(tensor, mean, std, a, b):
|
| 167 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
| 168 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
| 169 |
+
def norm_cdf(x):
|
| 170 |
+
# Computes standard normal cumulative distribution function
|
| 171 |
+
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
| 172 |
+
|
| 173 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
| 174 |
+
warnings.warn(
|
| 175 |
+
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
| 176 |
+
"The distribution of values may be incorrect.",
|
| 177 |
+
stacklevel=2,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# Values are generated by using a truncated uniform distribution and
|
| 181 |
+
# then using the inverse CDF for the normal distribution.
|
| 182 |
+
# Get upper and lower cdf values
|
| 183 |
+
l = norm_cdf((a - mean) / std)
|
| 184 |
+
u = norm_cdf((b - mean) / std)
|
| 185 |
+
|
| 186 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
| 187 |
+
# [2l-1, 2u-1].
|
| 188 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
| 189 |
+
|
| 190 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
| 191 |
+
# standard normal
|
| 192 |
+
if tensor.dtype in [torch.float16, torch.bfloat16]:
|
| 193 |
+
# The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu
|
| 194 |
+
og_dtype = tensor.dtype
|
| 195 |
+
tensor = tensor.to(torch.float32)
|
| 196 |
+
tensor.erfinv_()
|
| 197 |
+
tensor = tensor.to(og_dtype)
|
| 198 |
+
else:
|
| 199 |
+
tensor.erfinv_()
|
| 200 |
+
|
| 201 |
+
# Transform to proper mean, std
|
| 202 |
+
tensor.mul_(std * math.sqrt(2.0))
|
| 203 |
+
tensor.add_(mean)
|
| 204 |
+
|
| 205 |
+
# Clamp to ensure it's in the proper range
|
| 206 |
+
if tensor.dtype == torch.float16:
|
| 207 |
+
# The `clamp_` op is not (yet?) defined in float16+cpu
|
| 208 |
+
tensor = tensor.to(torch.float32)
|
| 209 |
+
tensor.clamp_(min=a, max=b)
|
| 210 |
+
tensor = tensor.to(torch.float16)
|
| 211 |
+
else:
|
| 212 |
+
tensor.clamp_(min=a, max=b)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def trunc_normal_tf_(
|
| 216 |
+
tensor: torch.Tensor,
|
| 217 |
+
mean: float = 0.0,
|
| 218 |
+
std: float = 1.0,
|
| 219 |
+
a: float = -2.0,
|
| 220 |
+
b: float = 2.0,
|
| 221 |
+
) -> torch.Tensor:
|
| 222 |
+
"""Fills the input Tensor with values drawn from a truncated
|
| 223 |
+
normal distribution. The values are effectively drawn from the
|
| 224 |
+
normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
|
| 225 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
| 226 |
+
the bounds. The method used for generating the random values works
|
| 227 |
+
best when :math:`a \\leq \text{mean} \\leq b`.
|
| 228 |
+
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
|
| 229 |
+
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
|
| 230 |
+
and the result is subsquently scaled and shifted by the mean and std args.
|
| 231 |
+
Args:
|
| 232 |
+
tensor: an n-dimensional `torch.Tensor`
|
| 233 |
+
mean: the mean of the normal distribution
|
| 234 |
+
std: the standard deviation of the normal distribution
|
| 235 |
+
a: the minimum cutoff value
|
| 236 |
+
b: the maximum cutoff value
|
| 237 |
+
"""
|
| 238 |
+
with torch.no_grad():
|
| 239 |
+
_trunc_normal_(tensor, 0, 1.0, a, b)
|
| 240 |
+
tensor.mul_(std).add_(mean)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
|
| 244 |
+
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
| 245 |
+
if mode == "fan_in":
|
| 246 |
+
denom = fan_in
|
| 247 |
+
elif mode == "fan_out":
|
| 248 |
+
denom = fan_out
|
| 249 |
+
elif mode == "fan_avg":
|
| 250 |
+
denom = (fan_in + fan_out) / 2
|
| 251 |
+
|
| 252 |
+
variance = scale / denom
|
| 253 |
+
|
| 254 |
+
if distribution == "truncated_normal":
|
| 255 |
+
# constant is stddev of standard normal truncated to (-2, 2)
|
| 256 |
+
trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
|
| 257 |
+
elif distribution == "normal":
|
| 258 |
+
with torch.no_grad():
|
| 259 |
+
tensor.normal_(std=math.sqrt(variance))
|
| 260 |
+
elif distribution == "uniform":
|
| 261 |
+
bound = math.sqrt(3 * variance)
|
| 262 |
+
with torch.no_grad():
|
| 263 |
+
tensor.uniform_(-bound, bound)
|
| 264 |
+
else:
|
| 265 |
+
raise ValueError(f"invalid distribution {distribution}")
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def lecun_normal_(tensor):
|
| 269 |
+
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def default_flax_embed_init(tensor):
|
| 273 |
+
variance_scaling_(tensor, mode="fan_in", distribution="normal")
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
@dataclass
|
| 277 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip
|
| 278 |
+
class SiglipVisionModelOutput(ModelOutput):
|
| 279 |
+
"""
|
| 280 |
+
Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
|
| 281 |
+
Args:
|
| 282 |
+
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
|
| 283 |
+
The image embeddings obtained by applying the projection layer to the pooler_output.
|
| 284 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 285 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 286 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 287 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| 288 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
| 289 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
| 290 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 291 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 292 |
+
sequence_length)`.
|
| 293 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 294 |
+
heads.
|
| 295 |
+
"""
|
| 296 |
+
|
| 297 |
+
image_embeds: Optional[torch.FloatTensor] = None
|
| 298 |
+
last_hidden_state: torch.FloatTensor = None
|
| 299 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 300 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
class SiglipVisionEmbeddings(nn.Module):
|
| 304 |
+
def __init__(self, config: SiglipVisionConfig):
|
| 305 |
+
super().__init__()
|
| 306 |
+
self.config = config
|
| 307 |
+
self.embed_dim = config.hidden_size
|
| 308 |
+
self.image_size = config.image_size
|
| 309 |
+
self.patch_size = config.patch_size
|
| 310 |
+
|
| 311 |
+
self.patch_embedding = nn.Conv2d(
|
| 312 |
+
in_channels=config.num_channels,
|
| 313 |
+
out_channels=self.embed_dim,
|
| 314 |
+
kernel_size=self.patch_size,
|
| 315 |
+
stride=self.patch_size,
|
| 316 |
+
padding="valid",
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
self.num_patches_per_side = self.image_size // self.patch_size
|
| 320 |
+
self.num_patches = self.num_patches_per_side**2
|
| 321 |
+
self.num_positions = self.num_patches
|
| 322 |
+
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
| 323 |
+
|
| 324 |
+
def forward(
|
| 325 |
+
self,
|
| 326 |
+
pixel_values: torch.FloatTensor,
|
| 327 |
+
patch_attention_mask: torch.BoolTensor,
|
| 328 |
+
tgt_sizes: Optional[torch.IntTensor] = None,
|
| 329 |
+
) -> torch.Tensor:
|
| 330 |
+
batch_size = pixel_values.size(0)
|
| 331 |
+
|
| 332 |
+
patch_embeds = self.patch_embedding(pixel_values)
|
| 333 |
+
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
| 334 |
+
|
| 335 |
+
max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
|
| 336 |
+
max_nb_patches_h, max_nb_patches_w = (
|
| 337 |
+
max_im_h // self.patch_size,
|
| 338 |
+
max_im_w // self.patch_size,
|
| 339 |
+
)
|
| 340 |
+
boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side)
|
| 341 |
+
position_ids = torch.full(
|
| 342 |
+
size=(
|
| 343 |
+
batch_size,
|
| 344 |
+
max_nb_patches_h * max_nb_patches_w,
|
| 345 |
+
),
|
| 346 |
+
fill_value=0,
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
|
| 350 |
+
if tgt_sizes is not None:
|
| 351 |
+
nb_patches_h = tgt_sizes[batch_idx][0]
|
| 352 |
+
nb_patches_w = tgt_sizes[batch_idx][1]
|
| 353 |
+
else:
|
| 354 |
+
nb_patches_h = p_attn_mask[:, 0].sum()
|
| 355 |
+
nb_patches_w = p_attn_mask[0].sum()
|
| 356 |
+
|
| 357 |
+
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
|
| 358 |
+
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
|
| 359 |
+
|
| 360 |
+
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
|
| 361 |
+
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
|
| 362 |
+
|
| 363 |
+
pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten()
|
| 364 |
+
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
|
| 365 |
+
|
| 366 |
+
position_ids = position_ids.to(self.position_embedding.weight.device)
|
| 367 |
+
|
| 368 |
+
embeddings = embeddings + self.position_embedding(position_ids)
|
| 369 |
+
return embeddings
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
class SiglipAttention(nn.Module):
|
| 373 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 374 |
+
|
| 375 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
|
| 376 |
+
def __init__(self, config):
|
| 377 |
+
super().__init__()
|
| 378 |
+
self.config = config
|
| 379 |
+
self.embed_dim = config.hidden_size
|
| 380 |
+
self.num_heads = config.num_attention_heads
|
| 381 |
+
self.head_dim = self.embed_dim // self.num_heads
|
| 382 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
| 383 |
+
raise ValueError(
|
| 384 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
| 385 |
+
f" {self.num_heads})."
|
| 386 |
+
)
|
| 387 |
+
self.scale = self.head_dim**-0.5
|
| 388 |
+
self.dropout = config.attention_dropout
|
| 389 |
+
|
| 390 |
+
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 391 |
+
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 392 |
+
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 393 |
+
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 394 |
+
|
| 395 |
+
def forward(
|
| 396 |
+
self,
|
| 397 |
+
hidden_states: torch.Tensor,
|
| 398 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 399 |
+
output_attentions: Optional[bool] = False,
|
| 400 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 401 |
+
"""Input shape: Batch x Time x Channel"""
|
| 402 |
+
|
| 403 |
+
batch_size, q_len, _ = hidden_states.size()
|
| 404 |
+
|
| 405 |
+
query_states = self.q_proj(hidden_states)
|
| 406 |
+
key_states = self.k_proj(hidden_states)
|
| 407 |
+
value_states = self.v_proj(hidden_states)
|
| 408 |
+
|
| 409 |
+
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 410 |
+
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 411 |
+
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 412 |
+
|
| 413 |
+
k_v_seq_len = key_states.shape[-2]
|
| 414 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
|
| 415 |
+
|
| 416 |
+
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
|
| 417 |
+
raise ValueError(
|
| 418 |
+
f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
|
| 419 |
+
f" {attn_weights.size()}"
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
if attention_mask is not None:
|
| 423 |
+
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
|
| 424 |
+
raise ValueError(
|
| 425 |
+
f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
|
| 426 |
+
)
|
| 427 |
+
attn_weights = attn_weights + attention_mask
|
| 428 |
+
|
| 429 |
+
# upcast attention to fp32
|
| 430 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
| 431 |
+
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
| 432 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 433 |
+
|
| 434 |
+
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
|
| 435 |
+
raise ValueError(
|
| 436 |
+
f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
|
| 437 |
+
f" {attn_output.size()}"
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 441 |
+
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
|
| 442 |
+
|
| 443 |
+
attn_output = self.out_proj(attn_output)
|
| 444 |
+
|
| 445 |
+
return attn_output, attn_weights
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
class SiglipFlashAttention2(SiglipAttention):
|
| 449 |
+
"""
|
| 450 |
+
Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
|
| 451 |
+
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
| 452 |
+
flash attention and deal with padding tokens in case the input contains any of them.
|
| 453 |
+
"""
|
| 454 |
+
|
| 455 |
+
def __init__(self, *args, **kwargs):
|
| 456 |
+
super().__init__(*args, **kwargs)
|
| 457 |
+
self.is_causal = False # Hack to make sure we don't use a causal mask
|
| 458 |
+
|
| 459 |
+
def forward(
|
| 460 |
+
self,
|
| 461 |
+
hidden_states: torch.Tensor,
|
| 462 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 463 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 464 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 465 |
+
output_attentions: bool = False,
|
| 466 |
+
use_cache: bool = False,
|
| 467 |
+
**kwargs,
|
| 468 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 469 |
+
output_attentions = False
|
| 470 |
+
|
| 471 |
+
bsz, q_len, _ = hidden_states.size()
|
| 472 |
+
|
| 473 |
+
query_states = self.q_proj(hidden_states)
|
| 474 |
+
key_states = self.k_proj(hidden_states)
|
| 475 |
+
value_states = self.v_proj(hidden_states)
|
| 476 |
+
|
| 477 |
+
# Flash attention requires the input to have the shape
|
| 478 |
+
# batch_size x seq_length x head_dim x hidden_dim
|
| 479 |
+
# therefore we just need to keep the original shape
|
| 480 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 481 |
+
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 482 |
+
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 483 |
+
|
| 484 |
+
kv_seq_len = key_states.shape[-2]
|
| 485 |
+
if past_key_value is not None:
|
| 486 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 487 |
+
# cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 488 |
+
# query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 489 |
+
|
| 490 |
+
# if past_key_value is not None:
|
| 491 |
+
# cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
| 492 |
+
# key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 493 |
+
|
| 494 |
+
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
| 495 |
+
# to be able to avoid many of these transpose/reshape/view.
|
| 496 |
+
query_states = query_states.transpose(1, 2)
|
| 497 |
+
key_states = key_states.transpose(1, 2)
|
| 498 |
+
value_states = value_states.transpose(1, 2)
|
| 499 |
+
|
| 500 |
+
dropout_rate = self.dropout if self.training else 0.0
|
| 501 |
+
|
| 502 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
| 503 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
| 504 |
+
# cast them back in the correct dtype just to be sure everything works as expected.
|
| 505 |
+
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
| 506 |
+
# in fp32. (LlamaRMSNorm handles it correctly)
|
| 507 |
+
|
| 508 |
+
input_dtype = query_states.dtype
|
| 509 |
+
if input_dtype == torch.float32:
|
| 510 |
+
if torch.is_autocast_enabled():
|
| 511 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
| 512 |
+
# Handle the case where the model is quantized
|
| 513 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
| 514 |
+
target_dtype = self.config._pre_quantization_dtype
|
| 515 |
+
else:
|
| 516 |
+
target_dtype = self.q_proj.weight.dtype
|
| 517 |
+
|
| 518 |
+
logger.warning_once(
|
| 519 |
+
"The input hidden states seems to be silently casted in float32, this might be related to the fact"
|
| 520 |
+
" you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
| 521 |
+
f" {target_dtype}."
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
query_states = query_states.to(target_dtype)
|
| 525 |
+
key_states = key_states.to(target_dtype)
|
| 526 |
+
value_states = value_states.to(target_dtype)
|
| 527 |
+
|
| 528 |
+
attn_output = self._flash_attention_forward(
|
| 529 |
+
query_states,
|
| 530 |
+
key_states,
|
| 531 |
+
value_states,
|
| 532 |
+
attention_mask,
|
| 533 |
+
q_len,
|
| 534 |
+
dropout=dropout_rate,
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous()
|
| 538 |
+
attn_output = self.out_proj(attn_output)
|
| 539 |
+
|
| 540 |
+
if not output_attentions:
|
| 541 |
+
attn_weights = None
|
| 542 |
+
|
| 543 |
+
return attn_output, attn_weights
|
| 544 |
+
|
| 545 |
+
def _flash_attention_forward(
|
| 546 |
+
self,
|
| 547 |
+
query_states,
|
| 548 |
+
key_states,
|
| 549 |
+
value_states,
|
| 550 |
+
attention_mask,
|
| 551 |
+
query_length,
|
| 552 |
+
dropout=0.0,
|
| 553 |
+
softmax_scale=None,
|
| 554 |
+
):
|
| 555 |
+
"""
|
| 556 |
+
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
| 557 |
+
first unpad the input, then computes the attention scores and pad the final attention scores.
|
| 558 |
+
Args:
|
| 559 |
+
query_states (`torch.Tensor`):
|
| 560 |
+
Input query states to be passed to Flash Attention API
|
| 561 |
+
key_states (`torch.Tensor`):
|
| 562 |
+
Input key states to be passed to Flash Attention API
|
| 563 |
+
value_states (`torch.Tensor`):
|
| 564 |
+
Input value states to be passed to Flash Attention API
|
| 565 |
+
attention_mask (`torch.Tensor`):
|
| 566 |
+
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
| 567 |
+
position of padding tokens and 1 for the position of non-padding tokens.
|
| 568 |
+
dropout (`int`, *optional*):
|
| 569 |
+
Attention dropout
|
| 570 |
+
softmax_scale (`float`, *optional*):
|
| 571 |
+
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
| 572 |
+
"""
|
| 573 |
+
|
| 574 |
+
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
|
| 575 |
+
causal = self.is_causal and query_length != 1
|
| 576 |
+
|
| 577 |
+
# Contains at least one padding token in the sequence
|
| 578 |
+
if attention_mask is not None:
|
| 579 |
+
batch_size = query_states.shape[0]
|
| 580 |
+
(
|
| 581 |
+
query_states,
|
| 582 |
+
key_states,
|
| 583 |
+
value_states,
|
| 584 |
+
indices_q,
|
| 585 |
+
cu_seq_lens,
|
| 586 |
+
max_seq_lens,
|
| 587 |
+
) = self._upad_input(query_states, key_states, value_states, attention_mask, query_length)
|
| 588 |
+
|
| 589 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
| 590 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
| 591 |
+
|
| 592 |
+
attn_output_unpad = flash_attn_varlen_func(
|
| 593 |
+
query_states,
|
| 594 |
+
key_states,
|
| 595 |
+
value_states,
|
| 596 |
+
cu_seqlens_q=cu_seqlens_q,
|
| 597 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 598 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
| 599 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
| 600 |
+
dropout_p=dropout,
|
| 601 |
+
softmax_scale=softmax_scale,
|
| 602 |
+
causal=causal,
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
| 606 |
+
else:
|
| 607 |
+
attn_output = flash_attn_func(
|
| 608 |
+
query_states,
|
| 609 |
+
key_states,
|
| 610 |
+
value_states,
|
| 611 |
+
dropout,
|
| 612 |
+
softmax_scale=softmax_scale,
|
| 613 |
+
causal=causal,
|
| 614 |
+
)
|
| 615 |
+
|
| 616 |
+
return attn_output
|
| 617 |
+
|
| 618 |
+
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
| 619 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
| 620 |
+
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
| 621 |
+
|
| 622 |
+
key_layer = index_first_axis(
|
| 623 |
+
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
| 624 |
+
indices_k,
|
| 625 |
+
)
|
| 626 |
+
value_layer = index_first_axis(
|
| 627 |
+
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
| 628 |
+
indices_k,
|
| 629 |
+
)
|
| 630 |
+
if query_length == kv_seq_len:
|
| 631 |
+
query_layer = index_first_axis(
|
| 632 |
+
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
|
| 633 |
+
indices_k,
|
| 634 |
+
)
|
| 635 |
+
cu_seqlens_q = cu_seqlens_k
|
| 636 |
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
| 637 |
+
indices_q = indices_k
|
| 638 |
+
elif query_length == 1:
|
| 639 |
+
max_seqlen_in_batch_q = 1
|
| 640 |
+
cu_seqlens_q = torch.arange(
|
| 641 |
+
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
| 642 |
+
) # There is a memcpy here, that is very bad.
|
| 643 |
+
indices_q = cu_seqlens_q[:-1]
|
| 644 |
+
query_layer = query_layer.squeeze(1)
|
| 645 |
+
else:
|
| 646 |
+
# The -q_len: slice assumes left padding.
|
| 647 |
+
attention_mask = attention_mask[:, -query_length:]
|
| 648 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
| 649 |
+
|
| 650 |
+
return (
|
| 651 |
+
query_layer,
|
| 652 |
+
key_layer,
|
| 653 |
+
value_layer,
|
| 654 |
+
indices_q,
|
| 655 |
+
(cu_seqlens_q, cu_seqlens_k),
|
| 656 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
|
| 660 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
|
| 661 |
+
class SiglipMLP(nn.Module):
|
| 662 |
+
def __init__(self, config):
|
| 663 |
+
super().__init__()
|
| 664 |
+
self.config = config
|
| 665 |
+
self.activation_fn = ACT2FN[config.hidden_act]
|
| 666 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 667 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 668 |
+
|
| 669 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 670 |
+
hidden_states = self.fc1(hidden_states)
|
| 671 |
+
hidden_states = self.activation_fn(hidden_states)
|
| 672 |
+
hidden_states = self.fc2(hidden_states)
|
| 673 |
+
return hidden_states
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip
|
| 677 |
+
class SiglipEncoderLayer(nn.Module):
|
| 678 |
+
def __init__(self, config: SiglipVisionConfig):
|
| 679 |
+
super().__init__()
|
| 680 |
+
self.embed_dim = config.hidden_size
|
| 681 |
+
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
| 682 |
+
self.self_attn = SiglipAttention(config) if not self._use_flash_attention_2 else SiglipFlashAttention2(config)
|
| 683 |
+
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 684 |
+
self.mlp = SiglipMLP(config)
|
| 685 |
+
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 686 |
+
|
| 687 |
+
def forward(
|
| 688 |
+
self,
|
| 689 |
+
hidden_states: torch.Tensor,
|
| 690 |
+
attention_mask: torch.Tensor,
|
| 691 |
+
output_attentions: Optional[bool] = False,
|
| 692 |
+
) -> Tuple[torch.FloatTensor]:
|
| 693 |
+
"""
|
| 694 |
+
Args:
|
| 695 |
+
hidden_states (`torch.FloatTensor`):
|
| 696 |
+
Input to the layer of shape `(batch, seq_len, embed_dim)`.
|
| 697 |
+
attention_mask (`torch.FloatTensor`):
|
| 698 |
+
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
|
| 699 |
+
output_attentions (`bool`, *optional*, defaults to `False`):
|
| 700 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 701 |
+
returned tensors for more detail.
|
| 702 |
+
"""
|
| 703 |
+
residual = hidden_states
|
| 704 |
+
|
| 705 |
+
hidden_states = self.layer_norm1(hidden_states)
|
| 706 |
+
hidden_states, attn_weights = self.self_attn(
|
| 707 |
+
hidden_states=hidden_states,
|
| 708 |
+
attention_mask=attention_mask,
|
| 709 |
+
output_attentions=output_attentions,
|
| 710 |
+
)
|
| 711 |
+
hidden_states = residual + hidden_states
|
| 712 |
+
|
| 713 |
+
residual = hidden_states
|
| 714 |
+
hidden_states = self.layer_norm2(hidden_states)
|
| 715 |
+
hidden_states = self.mlp(hidden_states)
|
| 716 |
+
hidden_states = residual + hidden_states
|
| 717 |
+
|
| 718 |
+
outputs = (hidden_states,)
|
| 719 |
+
|
| 720 |
+
if output_attentions:
|
| 721 |
+
outputs += (attn_weights,)
|
| 722 |
+
|
| 723 |
+
return outputs
|
| 724 |
+
|
| 725 |
+
|
| 726 |
+
class SiglipPreTrainedModel(PreTrainedModel):
|
| 727 |
+
"""
|
| 728 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 729 |
+
models.
|
| 730 |
+
"""
|
| 731 |
+
|
| 732 |
+
config_class = SiglipVisionConfig
|
| 733 |
+
base_model_prefix = "siglip"
|
| 734 |
+
supports_gradient_checkpointing = True
|
| 735 |
+
|
| 736 |
+
def _init_weights(self, module):
|
| 737 |
+
"""Initialize the weights"""
|
| 738 |
+
|
| 739 |
+
if isinstance(module, SiglipVisionEmbeddings):
|
| 740 |
+
width = self.config.hidden_size
|
| 741 |
+
nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
|
| 742 |
+
elif isinstance(module, nn.Embedding):
|
| 743 |
+
default_flax_embed_init(module.weight)
|
| 744 |
+
elif isinstance(module, SiglipAttention):
|
| 745 |
+
nn.init.normal_(module.q_proj.weight)
|
| 746 |
+
nn.init.normal_(module.k_proj.weight)
|
| 747 |
+
nn.init.normal_(module.v_proj.weight)
|
| 748 |
+
nn.init.normal_(module.out_proj.weight)
|
| 749 |
+
nn.init.zeros_(module.q_proj.bias)
|
| 750 |
+
nn.init.zeros_(module.k_proj.bias)
|
| 751 |
+
nn.init.zeros_(module.v_proj.bias)
|
| 752 |
+
nn.init.zeros_(module.out_proj.bias)
|
| 753 |
+
elif isinstance(module, SiglipMLP):
|
| 754 |
+
nn.init.normal_(module.fc1.weight)
|
| 755 |
+
nn.init.normal_(module.fc2.weight)
|
| 756 |
+
nn.init.normal_(module.fc1.bias, std=1e-6)
|
| 757 |
+
nn.init.normal_(module.fc2.bias, std=1e-6)
|
| 758 |
+
elif isinstance(module, (nn.Linear, nn.Conv2d)):
|
| 759 |
+
lecun_normal_(module.weight)
|
| 760 |
+
if module.bias is not None:
|
| 761 |
+
nn.init.zeros_(module.bias)
|
| 762 |
+
elif isinstance(module, nn.LayerNorm):
|
| 763 |
+
module.bias.data.zero_()
|
| 764 |
+
module.weight.data.fill_(1.0)
|
| 765 |
+
|
| 766 |
+
|
| 767 |
+
SIGLIP_START_DOCSTRING = r"""
|
| 768 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 769 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 770 |
+
etc.)
|
| 771 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 772 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 773 |
+
and behavior.
|
| 774 |
+
Parameters:
|
| 775 |
+
config ([`SiglipVisionConfig`]): Model configuration class with all the parameters of the model.
|
| 776 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 777 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 778 |
+
"""
|
| 779 |
+
|
| 780 |
+
|
| 781 |
+
SIGLIP_VISION_INPUTS_DOCSTRING = r"""
|
| 782 |
+
Args:
|
| 783 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 784 |
+
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
| 785 |
+
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
| 786 |
+
output_attentions (`bool`, *optional*):
|
| 787 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 788 |
+
tensors for more detail.
|
| 789 |
+
output_hidden_states (`bool`, *optional*):
|
| 790 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 791 |
+
more detail.
|
| 792 |
+
return_dict (`bool`, *optional*):
|
| 793 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 794 |
+
"""
|
| 795 |
+
|
| 796 |
+
|
| 797 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip
|
| 798 |
+
class SiglipEncoder(nn.Module):
|
| 799 |
+
"""
|
| 800 |
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
| 801 |
+
[`SiglipEncoderLayer`].
|
| 802 |
+
Args:
|
| 803 |
+
config: SiglipConfig
|
| 804 |
+
"""
|
| 805 |
+
|
| 806 |
+
def __init__(self, config: SiglipVisionConfig):
|
| 807 |
+
super().__init__()
|
| 808 |
+
self.config = config
|
| 809 |
+
self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 810 |
+
self.gradient_checkpointing = False
|
| 811 |
+
|
| 812 |
+
# Ignore copy
|
| 813 |
+
def forward(
|
| 814 |
+
self,
|
| 815 |
+
inputs_embeds,
|
| 816 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 817 |
+
output_attentions: Optional[bool] = None,
|
| 818 |
+
output_hidden_states: Optional[bool] = None,
|
| 819 |
+
return_dict: Optional[bool] = None,
|
| 820 |
+
) -> Union[Tuple, BaseModelOutput]:
|
| 821 |
+
r"""
|
| 822 |
+
Args:
|
| 823 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 824 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
| 825 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
| 826 |
+
than the model's internal embedding lookup matrix.
|
| 827 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 828 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 829 |
+
- 1 for tokens that are **not masked**,
|
| 830 |
+
- 0 for tokens that are **masked**.
|
| 831 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 832 |
+
output_attentions (`bool`, *optional*):
|
| 833 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 834 |
+
returned tensors for more detail.
|
| 835 |
+
output_hidden_states (`bool`, *optional*):
|
| 836 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
| 837 |
+
for more detail.
|
| 838 |
+
return_dict (`bool`, *optional*):
|
| 839 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 840 |
+
"""
|
| 841 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 842 |
+
output_hidden_states = (
|
| 843 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 844 |
+
)
|
| 845 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 846 |
+
|
| 847 |
+
encoder_states = () if output_hidden_states else None
|
| 848 |
+
all_attentions = () if output_attentions else None
|
| 849 |
+
|
| 850 |
+
hidden_states = inputs_embeds
|
| 851 |
+
for encoder_layer in self.layers:
|
| 852 |
+
if output_hidden_states:
|
| 853 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 854 |
+
if self.gradient_checkpointing and self.training:
|
| 855 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 856 |
+
encoder_layer.__call__,
|
| 857 |
+
hidden_states,
|
| 858 |
+
attention_mask,
|
| 859 |
+
output_attentions,
|
| 860 |
+
)
|
| 861 |
+
else:
|
| 862 |
+
layer_outputs = encoder_layer(
|
| 863 |
+
hidden_states,
|
| 864 |
+
attention_mask,
|
| 865 |
+
output_attentions=output_attentions,
|
| 866 |
+
)
|
| 867 |
+
|
| 868 |
+
hidden_states = layer_outputs[0]
|
| 869 |
+
|
| 870 |
+
if output_attentions:
|
| 871 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
| 872 |
+
|
| 873 |
+
if output_hidden_states:
|
| 874 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 875 |
+
|
| 876 |
+
if not return_dict:
|
| 877 |
+
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
| 878 |
+
return BaseModelOutput(
|
| 879 |
+
last_hidden_state=hidden_states,
|
| 880 |
+
hidden_states=encoder_states,
|
| 881 |
+
attentions=all_attentions,
|
| 882 |
+
)
|
| 883 |
+
|
| 884 |
+
|
| 885 |
+
@add_start_docstrings(
|
| 886 |
+
"""The vision model from SigLIP without any head or projection on top.""",
|
| 887 |
+
SIGLIP_START_DOCSTRING,
|
| 888 |
+
)
|
| 889 |
+
class SiglipVisionTransformer(SiglipPreTrainedModel):
|
| 890 |
+
config_class = SiglipVisionConfig
|
| 891 |
+
main_input_name = "pixel_values"
|
| 892 |
+
_supports_flash_attn_2 = True
|
| 893 |
+
_no_split_modules = []
|
| 894 |
+
|
| 895 |
+
def __init__(self, config: SiglipVisionConfig):
|
| 896 |
+
super().__init__(config)
|
| 897 |
+
self.config = config
|
| 898 |
+
embed_dim = config.hidden_size
|
| 899 |
+
|
| 900 |
+
self.embeddings = SiglipVisionEmbeddings(config)
|
| 901 |
+
self.encoder = SiglipEncoder(config)
|
| 902 |
+
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
| 903 |
+
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
| 904 |
+
|
| 905 |
+
# Initialize weights and apply final processing
|
| 906 |
+
self.post_init()
|
| 907 |
+
|
| 908 |
+
def get_input_embeddings(self) -> nn.Module:
|
| 909 |
+
return self.embeddings.patch_embedding
|
| 910 |
+
|
| 911 |
+
@add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
|
| 912 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig)
|
| 913 |
+
def forward(
|
| 914 |
+
self,
|
| 915 |
+
pixel_values,
|
| 916 |
+
patch_attention_mask: Optional[torch.BoolTensor] = None,
|
| 917 |
+
tgt_sizes: Optional[torch.IntTensor] = None,
|
| 918 |
+
output_attentions: Optional[bool] = None,
|
| 919 |
+
output_hidden_states: Optional[bool] = None,
|
| 920 |
+
return_dict: Optional[bool] = None,
|
| 921 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
| 922 |
+
r"""
|
| 923 |
+
Returns:
|
| 924 |
+
"""
|
| 925 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 926 |
+
output_hidden_states = (
|
| 927 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 928 |
+
)
|
| 929 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 930 |
+
|
| 931 |
+
batch_size = pixel_values.size(0)
|
| 932 |
+
if patch_attention_mask is None:
|
| 933 |
+
patch_attention_mask = torch.ones(
|
| 934 |
+
size=(
|
| 935 |
+
batch_size,
|
| 936 |
+
pixel_values.size(2) // self.config.patch_size,
|
| 937 |
+
pixel_values.size(3) // self.config.patch_size,
|
| 938 |
+
),
|
| 939 |
+
dtype=torch.bool,
|
| 940 |
+
device=pixel_values.device,
|
| 941 |
+
)
|
| 942 |
+
|
| 943 |
+
hidden_states = self.embeddings(
|
| 944 |
+
pixel_values=pixel_values,
|
| 945 |
+
patch_attention_mask=patch_attention_mask,
|
| 946 |
+
tgt_sizes=tgt_sizes,
|
| 947 |
+
)
|
| 948 |
+
|
| 949 |
+
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
|
| 950 |
+
# The call to `_upad_input` in `_flash_attention_forward` is expensive
|
| 951 |
+
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
|
| 952 |
+
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
|
| 953 |
+
if not torch.any(~patch_attention_mask):
|
| 954 |
+
attention_mask = None
|
| 955 |
+
else:
|
| 956 |
+
attention_mask = (
|
| 957 |
+
_prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
|
| 958 |
+
if not self._use_flash_attention_2
|
| 959 |
+
else patch_attention_mask
|
| 960 |
+
)
|
| 961 |
+
|
| 962 |
+
encoder_outputs = self.encoder(
|
| 963 |
+
inputs_embeds=hidden_states,
|
| 964 |
+
attention_mask=attention_mask,
|
| 965 |
+
output_attentions=output_attentions,
|
| 966 |
+
output_hidden_states=output_hidden_states,
|
| 967 |
+
return_dict=return_dict,
|
| 968 |
+
)
|
| 969 |
+
|
| 970 |
+
last_hidden_state = encoder_outputs[0]
|
| 971 |
+
last_hidden_state = self.post_layernorm(last_hidden_state)
|
| 972 |
+
|
| 973 |
+
if not return_dict:
|
| 974 |
+
return (last_hidden_state, None) + encoder_outputs[1:]
|
| 975 |
+
|
| 976 |
+
return BaseModelOutputWithPooling(
|
| 977 |
+
last_hidden_state=last_hidden_state,
|
| 978 |
+
pooler_output=None,
|
| 979 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 980 |
+
attentions=encoder_outputs.attentions,
|
| 981 |
+
)
|
preprocessor_config.json
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"image_processor_type": "MiniCPMVImageProcessor",
|
| 3 |
+
"feature_extractor_type": "MiniCPMAAudioProcessor",
|
| 4 |
+
"auto_map": {
|
| 5 |
+
"AutoProcessor": "processing_minicpmo.MiniCPMOProcessor",
|
| 6 |
+
"AutoImageProcessor": "processing_minicpmo.MiniCPMVImageProcessor",
|
| 7 |
+
"AutoFeatureExtractor": "processing_minicpmo.MiniCPMAAudioProcessor"
|
| 8 |
+
},
|
| 9 |
+
"processor_class": "MiniCPMOProcessor",
|
| 10 |
+
"max_slice_nums": 9,
|
| 11 |
+
"scale_resolution": 448,
|
| 12 |
+
"patch_size": 14,
|
| 13 |
+
"use_image_id": true,
|
| 14 |
+
"image_feature_size": 64,
|
| 15 |
+
"im_start": "<image>",
|
| 16 |
+
"im_end": "</image>",
|
| 17 |
+
"slice_start": "<slice>",
|
| 18 |
+
"slice_end": "</slice>",
|
| 19 |
+
"unk": "<unk>",
|
| 20 |
+
"im_id_start": "<image_id>",
|
| 21 |
+
"im_id_end": "</image_id>",
|
| 22 |
+
"slice_mode": true,
|
| 23 |
+
"audio_pool_step": 5,
|
| 24 |
+
"norm_mean": [
|
| 25 |
+
0.5,
|
| 26 |
+
0.5,
|
| 27 |
+
0.5
|
| 28 |
+
],
|
| 29 |
+
"norm_std": [
|
| 30 |
+
0.5,
|
| 31 |
+
0.5,
|
| 32 |
+
0.5
|
| 33 |
+
],
|
| 34 |
+
"version": 4.5
|
| 35 |
+
}
|
processing_minicpmo.py
ADDED
|
@@ -0,0 +1,1665 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# Copyright 2026 The OpenBMB Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import copy
|
| 18 |
+
import math
|
| 19 |
+
import re
|
| 20 |
+
from typing import Any
|
| 21 |
+
from typing import Dict
|
| 22 |
+
from typing import List
|
| 23 |
+
from typing import Optional
|
| 24 |
+
from typing import Tuple
|
| 25 |
+
from typing import Union
|
| 26 |
+
|
| 27 |
+
import numpy as np
|
| 28 |
+
import torch
|
| 29 |
+
from PIL import Image
|
| 30 |
+
from transformers import AutoImageProcessor
|
| 31 |
+
from transformers.audio_utils import spectrogram
|
| 32 |
+
from transformers.audio_utils import window_function
|
| 33 |
+
from transformers.image_processing_utils import BaseImageProcessor
|
| 34 |
+
from transformers.image_processing_utils import BatchFeature
|
| 35 |
+
from transformers.image_transforms import to_channel_dimension_format
|
| 36 |
+
from transformers.image_utils import ChannelDimension
|
| 37 |
+
from transformers.image_utils import ImageInput
|
| 38 |
+
from transformers.image_utils import infer_channel_dimension_format
|
| 39 |
+
from transformers.image_utils import is_torch_tensor
|
| 40 |
+
from transformers.image_utils import to_numpy_array
|
| 41 |
+
from transformers.image_utils import valid_images
|
| 42 |
+
from transformers.models.whisper.feature_extraction_whisper import WhisperFeatureExtractor
|
| 43 |
+
from transformers.processing_utils import ProcessorMixin
|
| 44 |
+
from transformers.tokenization_utils_base import PreTokenizedInput
|
| 45 |
+
from transformers.tokenization_utils_base import TextInput
|
| 46 |
+
from transformers.utils import is_torch_device
|
| 47 |
+
from transformers.utils import is_torch_dtype
|
| 48 |
+
from transformers.utils import requires_backends
|
| 49 |
+
from transformers.utils import TensorType
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def recursive_converter(converter, value):
|
| 53 |
+
if isinstance(value, list):
|
| 54 |
+
new_value = []
|
| 55 |
+
for v in value:
|
| 56 |
+
new_value += [recursive_converter(converter, v)]
|
| 57 |
+
return new_value
|
| 58 |
+
else:
|
| 59 |
+
return converter(value)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class MiniCPMOBatchFeature(BatchFeature):
|
| 63 |
+
"""Extend from BatchFeature for supporting various image size"""
|
| 64 |
+
|
| 65 |
+
def __init__(self, data: Optional[Dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None):
|
| 66 |
+
super().__init__(data)
|
| 67 |
+
self.convert_to_tensors(tensor_type=tensor_type)
|
| 68 |
+
|
| 69 |
+
def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None, **kwargs):
|
| 70 |
+
if tensor_type is None:
|
| 71 |
+
return self
|
| 72 |
+
|
| 73 |
+
is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type)
|
| 74 |
+
|
| 75 |
+
def converter(value):
|
| 76 |
+
try:
|
| 77 |
+
if not is_tensor(value):
|
| 78 |
+
tensor = as_tensor(value)
|
| 79 |
+
return tensor
|
| 80 |
+
except: # noqa E722
|
| 81 |
+
if key == "overflowing_values":
|
| 82 |
+
raise ValueError("Unable to create tensor returning overflowing values of different lengths. ")
|
| 83 |
+
raise ValueError(
|
| 84 |
+
"Unable to create tensor, you should probably activate padding "
|
| 85 |
+
"with 'padding=True' to have batched tensors with the same length."
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
for key, value in self.items():
|
| 89 |
+
self[key] = recursive_converter(converter, value)
|
| 90 |
+
return self
|
| 91 |
+
|
| 92 |
+
def to(self, *args, **kwargs) -> "MiniCPMOBatchFeature":
|
| 93 |
+
requires_backends(self, ["torch"])
|
| 94 |
+
import torch
|
| 95 |
+
|
| 96 |
+
def cast_tensor(v):
|
| 97 |
+
if not torch.is_tensor(v):
|
| 98 |
+
return v
|
| 99 |
+
|
| 100 |
+
if torch.is_floating_point(v):
|
| 101 |
+
return v.to(*args, **kwargs)
|
| 102 |
+
elif device is not None:
|
| 103 |
+
return v.to(device=device)
|
| 104 |
+
else:
|
| 105 |
+
return v
|
| 106 |
+
|
| 107 |
+
new_data = {}
|
| 108 |
+
device = kwargs.get("device")
|
| 109 |
+
if device is None and len(args) > 0:
|
| 110 |
+
arg = args[0]
|
| 111 |
+
if is_torch_dtype(arg):
|
| 112 |
+
pass
|
| 113 |
+
elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int):
|
| 114 |
+
device = arg
|
| 115 |
+
else:
|
| 116 |
+
raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
|
| 117 |
+
|
| 118 |
+
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
|
| 119 |
+
for k, v in self.items():
|
| 120 |
+
new_data[k] = recursive_converter(cast_tensor, v)
|
| 121 |
+
self.data = new_data
|
| 122 |
+
return self
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class MiniCPMVImageProcessor(BaseImageProcessor):
|
| 126 |
+
model_input_names = ["pixel_values"]
|
| 127 |
+
|
| 128 |
+
def __init__(self, max_slice_nums=9, scale_resolution=448, patch_size=14, **kwargs):
|
| 129 |
+
super().__init__(**kwargs)
|
| 130 |
+
self.max_slice_nums = max_slice_nums
|
| 131 |
+
self.scale_resolution = scale_resolution
|
| 132 |
+
self.patch_size = patch_size
|
| 133 |
+
self.use_image_id = kwargs.pop("use_image_id", False)
|
| 134 |
+
self.image_feature_size = kwargs.pop("image_feature_size", 64)
|
| 135 |
+
self.im_start_token = kwargs.pop("im_start", "<image>")
|
| 136 |
+
self.im_end_token = kwargs.pop("im_end", "</image>")
|
| 137 |
+
self.slice_start_token = kwargs.pop("slice_start", "<slice>")
|
| 138 |
+
self.slice_end_token = kwargs.pop("slice_end", "</slice>")
|
| 139 |
+
self.unk_token = kwargs.pop("unk", "<unk>")
|
| 140 |
+
self.im_id_start = kwargs.pop("im_id_start", "<image_id>")
|
| 141 |
+
self.im_id_end = kwargs.pop("im_id_end", "</image_id>")
|
| 142 |
+
self.slice_mode = kwargs.pop("slice_mode", True)
|
| 143 |
+
|
| 144 |
+
self.mean = np.array(kwargs.pop("norm_mean", [0.5, 0.5, 0.5]))
|
| 145 |
+
self.std = np.array(kwargs.pop("norm_std", [0.5, 0.5, 0.5]))
|
| 146 |
+
self.version = kwargs.pop("version", 2.0)
|
| 147 |
+
|
| 148 |
+
@staticmethod
|
| 149 |
+
def ensure_divide(length, patch_size):
|
| 150 |
+
return max(round(length / patch_size) * patch_size, patch_size)
|
| 151 |
+
|
| 152 |
+
def find_best_resize(self, original_size, scale_resolution, patch_size, allow_upscale=False):
|
| 153 |
+
width, height = original_size
|
| 154 |
+
if (width * height > scale_resolution * scale_resolution) or allow_upscale:
|
| 155 |
+
r = width / height
|
| 156 |
+
height = int(scale_resolution / math.sqrt(r))
|
| 157 |
+
width = int(height * r)
|
| 158 |
+
best_width = self.ensure_divide(width, patch_size)
|
| 159 |
+
best_height = self.ensure_divide(height, patch_size)
|
| 160 |
+
return best_width, best_height
|
| 161 |
+
|
| 162 |
+
def get_refine_size(self, original_size, grid, scale_resolution, patch_size, allow_upscale=False):
|
| 163 |
+
width, height = original_size
|
| 164 |
+
grid_x, grid_y = grid
|
| 165 |
+
|
| 166 |
+
refine_width = self.ensure_divide(width, grid_x)
|
| 167 |
+
refine_height = self.ensure_divide(height, grid_y)
|
| 168 |
+
|
| 169 |
+
grid_width = refine_width / grid_x
|
| 170 |
+
grid_height = refine_height / grid_y
|
| 171 |
+
|
| 172 |
+
best_grid_size = self.find_best_resize(
|
| 173 |
+
(grid_width, grid_height), scale_resolution, patch_size, allow_upscale=allow_upscale
|
| 174 |
+
)
|
| 175 |
+
refine_size = (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y)
|
| 176 |
+
return refine_size
|
| 177 |
+
|
| 178 |
+
@staticmethod
|
| 179 |
+
def split_to_patches(image, grid):
|
| 180 |
+
patches = []
|
| 181 |
+
width, height = image.size
|
| 182 |
+
grid_x = int(width / grid[0])
|
| 183 |
+
grid_y = int(height / grid[1])
|
| 184 |
+
for i in range(0, height, grid_y):
|
| 185 |
+
images = []
|
| 186 |
+
for j in range(0, width, grid_x):
|
| 187 |
+
box = (j, i, j + grid_x, i + grid_y)
|
| 188 |
+
patch = image.crop(box)
|
| 189 |
+
images.append(patch)
|
| 190 |
+
patches.append(images)
|
| 191 |
+
return patches
|
| 192 |
+
|
| 193 |
+
def slice_image(self, image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False):
|
| 194 |
+
original_size = image.size
|
| 195 |
+
source_image = None
|
| 196 |
+
best_grid = self.get_sliced_grid(original_size, max_slice_nums, never_split)
|
| 197 |
+
patches = []
|
| 198 |
+
|
| 199 |
+
if best_grid is None:
|
| 200 |
+
# dont need to slice, upsample
|
| 201 |
+
best_size = self.find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=True)
|
| 202 |
+
source_image = image.resize(best_size, resample=Image.Resampling.BICUBIC)
|
| 203 |
+
else:
|
| 204 |
+
# source image, down-sampling and ensure divided by patch_size
|
| 205 |
+
best_resize = self.find_best_resize(original_size, scale_resolution, patch_size)
|
| 206 |
+
source_image = image.copy().resize(best_resize, resample=Image.Resampling.BICUBIC)
|
| 207 |
+
refine_size = self.get_refine_size(
|
| 208 |
+
original_size, best_grid, scale_resolution, patch_size, allow_upscale=True
|
| 209 |
+
)
|
| 210 |
+
refine_image = image.resize(refine_size, resample=Image.Resampling.BICUBIC)
|
| 211 |
+
patches = self.split_to_patches(refine_image, best_grid)
|
| 212 |
+
|
| 213 |
+
return source_image, patches, best_grid
|
| 214 |
+
|
| 215 |
+
def get_grid_placeholder(self, grid):
|
| 216 |
+
if grid is None:
|
| 217 |
+
return ""
|
| 218 |
+
slice_image_placeholder = (
|
| 219 |
+
self.slice_start_token + self.unk_token * self.image_feature_size + self.slice_end_token
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
cols = grid[0]
|
| 223 |
+
rows = grid[1]
|
| 224 |
+
slices = []
|
| 225 |
+
for i in range(rows):
|
| 226 |
+
lines = []
|
| 227 |
+
for j in range(cols):
|
| 228 |
+
lines.append(slice_image_placeholder)
|
| 229 |
+
slices.append("".join(lines))
|
| 230 |
+
|
| 231 |
+
slice_placeholder = "\n".join(slices)
|
| 232 |
+
return slice_placeholder
|
| 233 |
+
|
| 234 |
+
def get_image_id_placeholder(self, idx=0):
|
| 235 |
+
return f"{self.im_id_start}{idx}{self.im_id_end}"
|
| 236 |
+
|
| 237 |
+
def get_sliced_images(self, image, max_slice_nums=None):
|
| 238 |
+
slice_images = []
|
| 239 |
+
|
| 240 |
+
if not self.slice_mode:
|
| 241 |
+
return [image]
|
| 242 |
+
|
| 243 |
+
max_slice_nums = self.max_slice_nums if max_slice_nums is None else int(max_slice_nums)
|
| 244 |
+
assert max_slice_nums > 0
|
| 245 |
+
source_image, patches, sliced_grid = self.slice_image(
|
| 246 |
+
image, max_slice_nums, self.scale_resolution, self.patch_size # default: 9 # default: 448 # default: 14
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
slice_images.append(source_image)
|
| 250 |
+
if len(patches) > 0:
|
| 251 |
+
for i in range(len(patches)):
|
| 252 |
+
for j in range(len(patches[0])):
|
| 253 |
+
slice_images.append(patches[i][j])
|
| 254 |
+
return slice_images
|
| 255 |
+
|
| 256 |
+
def get_sliced_grid(self, image_size, max_slice_nums, nerver_split=False):
|
| 257 |
+
original_width, original_height = image_size
|
| 258 |
+
log_ratio = math.log(original_width / original_height)
|
| 259 |
+
ratio = original_width * original_height / (self.scale_resolution * self.scale_resolution)
|
| 260 |
+
multiple = min(math.ceil(ratio), max_slice_nums)
|
| 261 |
+
if multiple <= 1 or nerver_split:
|
| 262 |
+
return None
|
| 263 |
+
candidate_split_grids_nums = []
|
| 264 |
+
for i in [multiple - 1, multiple, multiple + 1]:
|
| 265 |
+
if i == 1 or i > max_slice_nums:
|
| 266 |
+
continue
|
| 267 |
+
candidate_split_grids_nums.append(i)
|
| 268 |
+
|
| 269 |
+
candidate_grids = []
|
| 270 |
+
for split_grids_nums in candidate_split_grids_nums:
|
| 271 |
+
m = 1
|
| 272 |
+
while m <= split_grids_nums:
|
| 273 |
+
if split_grids_nums % m == 0:
|
| 274 |
+
candidate_grids.append([m, split_grids_nums // m])
|
| 275 |
+
m += 1
|
| 276 |
+
|
| 277 |
+
best_grid = [1, 1]
|
| 278 |
+
min_error = float("inf")
|
| 279 |
+
for grid in candidate_grids:
|
| 280 |
+
error = abs(log_ratio - math.log(grid[0] / grid[1]))
|
| 281 |
+
if error < min_error:
|
| 282 |
+
best_grid = grid
|
| 283 |
+
min_error = error
|
| 284 |
+
|
| 285 |
+
return best_grid
|
| 286 |
+
|
| 287 |
+
def get_slice_image_placeholder(self, image_size, image_idx=0, max_slice_nums=None, use_image_id=None):
|
| 288 |
+
max_slice_nums = self.max_slice_nums if max_slice_nums is None else int(max_slice_nums)
|
| 289 |
+
assert max_slice_nums > 0
|
| 290 |
+
grid = self.get_sliced_grid(image_size=image_size, max_slice_nums=max_slice_nums)
|
| 291 |
+
|
| 292 |
+
image_placeholder = self.im_start_token + self.unk_token * self.image_feature_size + self.im_end_token
|
| 293 |
+
use_image_id = self.use_image_id if use_image_id is None else bool(use_image_id)
|
| 294 |
+
if use_image_id:
|
| 295 |
+
final_placeholder = self.get_image_id_placeholder(image_idx) + image_placeholder
|
| 296 |
+
else:
|
| 297 |
+
final_placeholder = image_placeholder
|
| 298 |
+
|
| 299 |
+
if self.slice_mode:
|
| 300 |
+
final_placeholder = final_placeholder + self.get_grid_placeholder(grid=grid)
|
| 301 |
+
return final_placeholder
|
| 302 |
+
|
| 303 |
+
@staticmethod
|
| 304 |
+
def to_pil_image(image, rescale=None) -> Image.Image:
|
| 305 |
+
"""Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back
|
| 306 |
+
as the last axis if needed.
|
| 307 |
+
|
| 308 |
+
Args:
|
| 309 |
+
image (`Image.Image` or `numpy.ndarray` or `torch.Tensor`):
|
| 310 |
+
The image to convert to the PIL Image format.
|
| 311 |
+
rescale (`bool`, *optional*):
|
| 312 |
+
whether to apply the scaling factor (to make pixel values integers between 0 and 255). Will
|
| 313 |
+
default to `True` if the image type is a floating type, `False` otherwise.
|
| 314 |
+
"""
|
| 315 |
+
if isinstance(image, Image.Image):
|
| 316 |
+
return image
|
| 317 |
+
if is_torch_tensor(image):
|
| 318 |
+
image = image.numpy()
|
| 319 |
+
|
| 320 |
+
if isinstance(image, np.ndarray):
|
| 321 |
+
if rescale is None:
|
| 322 |
+
# rescale default to the array being of floating type.
|
| 323 |
+
rescale = isinstance(image.flat[0], np.floating)
|
| 324 |
+
# If the channel as been moved to first dim, we put it back at the end.
|
| 325 |
+
if image.ndim == 3 and image.shape[0] in [1, 3]:
|
| 326 |
+
image = image.transpose(1, 2, 0)
|
| 327 |
+
if rescale:
|
| 328 |
+
image = image * 255
|
| 329 |
+
image = image.astype(np.uint8)
|
| 330 |
+
return Image.fromarray(image)
|
| 331 |
+
return image
|
| 332 |
+
|
| 333 |
+
def reshape_by_patch(self, image):
|
| 334 |
+
image = torch.from_numpy(image)
|
| 335 |
+
patch_size = self.patch_size
|
| 336 |
+
patches = torch.nn.functional.unfold(image, (patch_size, patch_size), stride=(patch_size, patch_size))
|
| 337 |
+
|
| 338 |
+
patches = patches.reshape(image.size(0), patch_size, patch_size, -1)
|
| 339 |
+
patches = patches.permute(0, 1, 3, 2).reshape(image.size(0), patch_size, -1)
|
| 340 |
+
return patches.numpy()
|
| 341 |
+
|
| 342 |
+
def preprocess(
|
| 343 |
+
self,
|
| 344 |
+
images: Union[Image.Image, List[Image.Image], List[List[Image.Image]]],
|
| 345 |
+
do_pad: Optional[bool] = True,
|
| 346 |
+
max_slice_nums: int = None,
|
| 347 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 348 |
+
**kwargs,
|
| 349 |
+
) -> MiniCPMOBatchFeature:
|
| 350 |
+
if isinstance(images, Image.Image):
|
| 351 |
+
images_list = [[images]]
|
| 352 |
+
elif isinstance(images[0], Image.Image):
|
| 353 |
+
images_list = [images]
|
| 354 |
+
else:
|
| 355 |
+
images_list = images
|
| 356 |
+
|
| 357 |
+
new_images_list = []
|
| 358 |
+
image_sizes_list = []
|
| 359 |
+
tgt_sizes_list = []
|
| 360 |
+
|
| 361 |
+
for _images in images_list:
|
| 362 |
+
if _images is None or len(_images) == 0:
|
| 363 |
+
new_images_list.append([])
|
| 364 |
+
image_sizes_list.append([])
|
| 365 |
+
tgt_sizes_list.append([])
|
| 366 |
+
continue
|
| 367 |
+
if not valid_images(_images):
|
| 368 |
+
raise ValueError(
|
| 369 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 370 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
_images = [self.to_pil_image(image).convert("RGB") for image in _images]
|
| 374 |
+
input_data_format = infer_channel_dimension_format(np.array(_images[0]))
|
| 375 |
+
|
| 376 |
+
new_images = []
|
| 377 |
+
image_sizes = [image.size for image in _images]
|
| 378 |
+
tgt_sizes = []
|
| 379 |
+
for image in _images:
|
| 380 |
+
image_patches = self.get_sliced_images(image, max_slice_nums)
|
| 381 |
+
image_patches = [to_numpy_array(image).astype(np.float32) / 255 for image in image_patches]
|
| 382 |
+
image_patches = [
|
| 383 |
+
self.normalize(image=image, mean=self.mean, std=self.std, input_data_format=input_data_format)
|
| 384 |
+
for image in image_patches
|
| 385 |
+
]
|
| 386 |
+
image_patches = [
|
| 387 |
+
to_channel_dimension_format(image, ChannelDimension.FIRST, input_channel_dim=input_data_format)
|
| 388 |
+
for image in image_patches
|
| 389 |
+
]
|
| 390 |
+
for slice_image in image_patches:
|
| 391 |
+
new_images.append(self.reshape_by_patch(slice_image))
|
| 392 |
+
tgt_sizes.append(
|
| 393 |
+
np.array((slice_image.shape[1] // self.patch_size, slice_image.shape[2] // self.patch_size))
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
if tgt_sizes:
|
| 397 |
+
tgt_sizes = np.vstack(tgt_sizes)
|
| 398 |
+
|
| 399 |
+
new_images_list.append(new_images)
|
| 400 |
+
image_sizes_list.append(image_sizes)
|
| 401 |
+
tgt_sizes_list.append(tgt_sizes)
|
| 402 |
+
return MiniCPMOBatchFeature(
|
| 403 |
+
data={"pixel_values": new_images_list, "image_sizes": image_sizes_list, "tgt_sizes": tgt_sizes_list},
|
| 404 |
+
tensor_type=return_tensors,
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
AutoImageProcessor.register("MiniCPMVImageProcessor", MiniCPMVImageProcessor)
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def chunk_audio(audio: np.ndarray, max_duration_seconds: int = 30, sample_rate: int = 16000) -> List[np.ndarray]:
|
| 412 |
+
"""split long audio into chunks
|
| 413 |
+
|
| 414 |
+
Args:
|
| 415 |
+
audio:
|
| 416 |
+
max_duration_seconds:
|
| 417 |
+
sample_rate:
|
| 418 |
+
|
| 419 |
+
Returns:
|
| 420 |
+
chunks
|
| 421 |
+
"""
|
| 422 |
+
max_len = int(max_duration_seconds * sample_rate)
|
| 423 |
+
|
| 424 |
+
if len(audio) <= max_len:
|
| 425 |
+
return [audio]
|
| 426 |
+
|
| 427 |
+
chunks = []
|
| 428 |
+
for i in range(0, len(audio), max_len):
|
| 429 |
+
chunk = audio[i : i + max_len]
|
| 430 |
+
chunks.append(chunk)
|
| 431 |
+
|
| 432 |
+
return chunks
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def process_audio_batch(
|
| 436 |
+
audios: Union[np.ndarray, List[np.ndarray], List[List[np.ndarray]]],
|
| 437 |
+
feature_extractor,
|
| 438 |
+
sampling_rate: int = 16000,
|
| 439 |
+
max_duration_seconds: int = 30,
|
| 440 |
+
return_attention_mask: bool = True,
|
| 441 |
+
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
| 442 |
+
"""extract audio mel features
|
| 443 |
+
|
| 444 |
+
Args:
|
| 445 |
+
audios:
|
| 446 |
+
feature_extractor: WhisperFeatureExtractor
|
| 447 |
+
sampling_rate:
|
| 448 |
+
max_duration_seconds:
|
| 449 |
+
return_attention_mask:
|
| 450 |
+
|
| 451 |
+
Returns:
|
| 452 |
+
(audio_features, audio_feature_lens)
|
| 453 |
+
audio_features: [batch_size, n_mels, max_frames]
|
| 454 |
+
audio_feature_lens:
|
| 455 |
+
"""
|
| 456 |
+
if isinstance(audios, np.ndarray):
|
| 457 |
+
audios_list = [[audios]]
|
| 458 |
+
elif len(audios) > 0 and isinstance(audios[0], np.ndarray):
|
| 459 |
+
audios_list = [audios]
|
| 460 |
+
else:
|
| 461 |
+
audios_list = audios
|
| 462 |
+
|
| 463 |
+
audio_features_all = []
|
| 464 |
+
audio_feature_lens_list = []
|
| 465 |
+
|
| 466 |
+
for batch_audios in audios_list:
|
| 467 |
+
batch_lens = []
|
| 468 |
+
|
| 469 |
+
for audio in batch_audios:
|
| 470 |
+
chunks = chunk_audio(audio, max_duration_seconds, sampling_rate)
|
| 471 |
+
|
| 472 |
+
for chunk in chunks:
|
| 473 |
+
audio_input = feature_extractor(
|
| 474 |
+
chunk,
|
| 475 |
+
sampling_rate=sampling_rate,
|
| 476 |
+
return_tensors="pt",
|
| 477 |
+
padding="max_length",
|
| 478 |
+
return_attention_mask=return_attention_mask,
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
audio_feature = audio_input["input_features"] # [1, 80, frames]
|
| 482 |
+
|
| 483 |
+
if return_attention_mask:
|
| 484 |
+
actual_len = audio_input["attention_mask"].sum(dim=1) # Tensor([frames])
|
| 485 |
+
audio_feature = audio_feature[:, :, : actual_len[0]]
|
| 486 |
+
batch_lens.append(actual_len[0])
|
| 487 |
+
else:
|
| 488 |
+
batch_lens.append(torch.tensor(audio_feature.shape[2]))
|
| 489 |
+
|
| 490 |
+
audio_features_all.append(audio_feature.squeeze(0)) # [80, frames]
|
| 491 |
+
|
| 492 |
+
if len(batch_lens) > 0:
|
| 493 |
+
audio_feature_lens_list.append(torch.hstack(batch_lens))
|
| 494 |
+
else:
|
| 495 |
+
audio_feature_lens_list.append(torch.tensor([]))
|
| 496 |
+
|
| 497 |
+
# pad to same length
|
| 498 |
+
if audio_features_all:
|
| 499 |
+
audio_features = torch.nn.utils.rnn.pad_sequence(
|
| 500 |
+
[feat.transpose(0, 1) for feat in audio_features_all], batch_first=True, padding_value=0.0
|
| 501 |
+
).transpose(
|
| 502 |
+
1, 2
|
| 503 |
+
) # [batch, 80, max_frames]
|
| 504 |
+
else:
|
| 505 |
+
audio_features = torch.tensor([])
|
| 506 |
+
|
| 507 |
+
return audio_features, audio_feature_lens_list
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
def regroup_audio_features(
|
| 511 |
+
audio_features: torch.Tensor, audio_feature_lens: List[torch.Tensor], regroup_seconds: int, fps: int = 100
|
| 512 |
+
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
| 513 |
+
"""regroup audio features to fixed duration
|
| 514 |
+
|
| 515 |
+
Args:
|
| 516 |
+
audio_features: [batch, n_mels, frames]
|
| 517 |
+
audio_feature_lens: each batch's actual length
|
| 518 |
+
regroup_seconds: regroup duration (seconds)
|
| 519 |
+
fps: frames per second
|
| 520 |
+
|
| 521 |
+
Returns:
|
| 522 |
+
(regrouped_features, regrouped_lens)
|
| 523 |
+
"""
|
| 524 |
+
# flatten to continuous frames sequence
|
| 525 |
+
all_lens = []
|
| 526 |
+
for lens in audio_feature_lens:
|
| 527 |
+
if isinstance(lens, torch.Tensor):
|
| 528 |
+
all_lens.extend(lens.tolist())
|
| 529 |
+
elif isinstance(lens, list):
|
| 530 |
+
all_lens.extend([int(x) for x in lens])
|
| 531 |
+
|
| 532 |
+
if len(all_lens) == 0:
|
| 533 |
+
return torch.tensor([]), []
|
| 534 |
+
|
| 535 |
+
# concatenate all valid features
|
| 536 |
+
flat_slices = [audio_features[i, :, :L] for i, L in enumerate(all_lens)] # [n_mels, L]
|
| 537 |
+
|
| 538 |
+
if len(flat_slices) == 1:
|
| 539 |
+
full_feat = flat_slices[0]
|
| 540 |
+
else:
|
| 541 |
+
full_feat = torch.cat(flat_slices, dim=1) # [n_mels, total_frames]
|
| 542 |
+
|
| 543 |
+
# split to fixed frames
|
| 544 |
+
frames_per_seg = int(regroup_seconds * fps)
|
| 545 |
+
segments = []
|
| 546 |
+
|
| 547 |
+
for start in range(0, full_feat.size(1), frames_per_seg):
|
| 548 |
+
seg = full_feat[:, start : start + frames_per_seg]
|
| 549 |
+
if seg.size(1) > 0:
|
| 550 |
+
segments.append(seg)
|
| 551 |
+
|
| 552 |
+
if len(segments) == 0:
|
| 553 |
+
return torch.tensor([]), []
|
| 554 |
+
|
| 555 |
+
# pad and convert to batch
|
| 556 |
+
seg_lens = [s.size(1) for s in segments]
|
| 557 |
+
segs_transposed = [s.transpose(0, 1) for s in segments]
|
| 558 |
+
|
| 559 |
+
padded = torch.nn.utils.rnn.pad_sequence(segs_transposed, batch_first=True, padding_value=0.0) # [N, max_T, n_mels]
|
| 560 |
+
|
| 561 |
+
padded = padded.transpose(1, 2) # [N, n_mels, max_T]
|
| 562 |
+
lens_tensor = torch.tensor(seg_lens, dtype=torch.int32, device=padded.device)
|
| 563 |
+
|
| 564 |
+
return padded, [lens_tensor]
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
class MiniCPMAAudioProcessor(WhisperFeatureExtractor):
|
| 568 |
+
"""
|
| 569 |
+
On top of WhisperFeatureExtractor:
|
| 570 |
+
- support dynamic_log_norm (original max-8dB, adjustable dynamic_range_db)
|
| 571 |
+
- or fixed log_floor_db (e.g. -10dB)
|
| 572 |
+
- this is because we need to do streaming scheme, in which we can't do dynamic setting
|
| 573 |
+
- this can be modified in the middle, through set_dynamic_log_norm
|
| 574 |
+
Two paths (torch / numpy) keep consistent clipping and scaling order:
|
| 575 |
+
log10 -> (dynamic/fixed lower limit clipping) -> (+4)/4
|
| 576 |
+
"""
|
| 577 |
+
|
| 578 |
+
def __init__(
|
| 579 |
+
self,
|
| 580 |
+
*args,
|
| 581 |
+
dynamic_log_norm: bool = True,
|
| 582 |
+
dynamic_range_db: float = 8.0,
|
| 583 |
+
log_floor_db: float = -10.0,
|
| 584 |
+
**kwargs,
|
| 585 |
+
):
|
| 586 |
+
super().__init__(*args, **kwargs)
|
| 587 |
+
self.dynamic_log_norm = bool(dynamic_log_norm)
|
| 588 |
+
self.dynamic_range_db = float(dynamic_range_db)
|
| 589 |
+
self.log_floor_db = float(log_floor_db)
|
| 590 |
+
|
| 591 |
+
def set_spac_log_norm(
|
| 592 |
+
self,
|
| 593 |
+
dynamic_range_db: Optional[float] = None,
|
| 594 |
+
log_floor_db: Optional[float] = None,
|
| 595 |
+
*,
|
| 596 |
+
inplace: bool = True,
|
| 597 |
+
) -> "MiniCPMAAudioProcessor":
|
| 598 |
+
"""Hot update dynamic/fixed lower limit strategy.
|
| 599 |
+
|
| 600 |
+
Args:
|
| 601 |
+
enabled: True=use dynamic threshold (max - dynamic_range_db), False=use fixed lower limit log_floor_db.
|
| 602 |
+
None means keep unchanged.
|
| 603 |
+
dynamic_range_db: dynamic range (dB), only effective when enabled=True. None means keep unchanged.
|
| 604 |
+
log_floor_db: fixed log floor (dB, usually <= 0), only effective when enabled=False. None means keep unchanged.
|
| 605 |
+
inplace: True directly modify current instance; False return a shallow copy and modify on it.
|
| 606 |
+
|
| 607 |
+
Returns:
|
| 608 |
+
self or new instance (when inplace=False).
|
| 609 |
+
"""
|
| 610 |
+
|
| 611 |
+
target = self if inplace else copy.copy(self)
|
| 612 |
+
|
| 613 |
+
if dynamic_range_db is not None:
|
| 614 |
+
val = float(dynamic_range_db)
|
| 615 |
+
if val < 0:
|
| 616 |
+
raise ValueError("dynamic_range_db must be >= 0.")
|
| 617 |
+
target.dynamic_log_norm = True # explicitly set the value to dynamic mode
|
| 618 |
+
target.dynamic_range_db = val
|
| 619 |
+
|
| 620 |
+
if log_floor_db is not None:
|
| 621 |
+
val = float(log_floor_db)
|
| 622 |
+
# usually log10(mel) maximum is not more than ~0dB, floor should be <= 0; here do loose validation
|
| 623 |
+
if val > 0:
|
| 624 |
+
raise ValueError("log_floor_db should be <= 0 (log10 scale).")
|
| 625 |
+
target.dynamic_log_norm = False # explicitly set the value to fixed lower limit mode
|
| 626 |
+
target.log_floor_db = val
|
| 627 |
+
|
| 628 |
+
return target
|
| 629 |
+
|
| 630 |
+
def _np_extract_fbank_features(self, waveform_batch: np.ndarray, device: str) -> np.ndarray:
|
| 631 |
+
"""NumPy version consistent with upstream, but replace max-8dB with configurable dynamic/fixed lower limit clipping."""
|
| 632 |
+
if device != "cpu":
|
| 633 |
+
raise ValueError(
|
| 634 |
+
f"Got device `{device}` for feature extraction, but feature extraction on CUDA accelerator "
|
| 635 |
+
"devices requires torch. Set device='cpu' or install torch."
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
log_spec_batch: List[np.ndarray] = []
|
| 639 |
+
for waveform in waveform_batch:
|
| 640 |
+
# generate log10 Mel
|
| 641 |
+
log_spec = spectrogram(
|
| 642 |
+
waveform,
|
| 643 |
+
window_function(self.n_fft, "hann"),
|
| 644 |
+
frame_length=self.n_fft,
|
| 645 |
+
hop_length=self.hop_length,
|
| 646 |
+
power=2.0,
|
| 647 |
+
dither=self.dither,
|
| 648 |
+
mel_filters=self.mel_filters,
|
| 649 |
+
log_mel="log10",
|
| 650 |
+
)
|
| 651 |
+
# consistent with upstream: remove the last frame
|
| 652 |
+
log_spec = log_spec[:, :-1]
|
| 653 |
+
|
| 654 |
+
# dynamic/fixed clipping
|
| 655 |
+
if self.dynamic_log_norm:
|
| 656 |
+
threshold = log_spec.max() - self.dynamic_range_db
|
| 657 |
+
log_spec = np.maximum(log_spec, threshold)
|
| 658 |
+
else:
|
| 659 |
+
log_spec = np.maximum(log_spec, self.log_floor_db)
|
| 660 |
+
|
| 661 |
+
# consistent with Whisper linear scaling
|
| 662 |
+
log_spec = (log_spec + 4.0) / 4.0
|
| 663 |
+
|
| 664 |
+
log_spec_batch.append(log_spec)
|
| 665 |
+
|
| 666 |
+
return np.array(log_spec_batch)
|
| 667 |
+
|
| 668 |
+
def _torch_extract_fbank_features(self, waveform: np.ndarray, device: str = "cpu") -> np.ndarray:
|
| 669 |
+
if torch is None:
|
| 670 |
+
raise RuntimeError("PyTorch is not installed, cannot compute STFT on GPU.")
|
| 671 |
+
|
| 672 |
+
waveform = torch.from_numpy(waveform).to(device, torch.float32)
|
| 673 |
+
window = torch.hann_window(self.n_fft, device=device)
|
| 674 |
+
|
| 675 |
+
if self.dither != 0.0:
|
| 676 |
+
waveform = waveform + self.dither * torch.randn_like(waveform)
|
| 677 |
+
|
| 678 |
+
stft = torch.stft(waveform, n_fft=self.n_fft, hop_length=self.hop_length, window=window, return_complex=True)
|
| 679 |
+
magnitudes = stft[..., :-1].abs() ** 2
|
| 680 |
+
|
| 681 |
+
mel_filters = torch.from_numpy(self.mel_filters).to(device, torch.float32) # [n_mels, 1+n_fft//2]
|
| 682 |
+
mel_spec = mel_filters.T @ magnitudes # [..., n_mels, T]
|
| 683 |
+
|
| 684 |
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10() # <= 0
|
| 685 |
+
|
| 686 |
+
if self.dynamic_log_norm:
|
| 687 |
+
if waveform.dim() == 2:
|
| 688 |
+
max_val_t = log_spec.max(dim=2, keepdim=True)[0] # over T
|
| 689 |
+
max_val_bt = max_val_t.max(dim=1, keepdim=True)[0] # over mel
|
| 690 |
+
threshold = max_val_bt - self.dynamic_range_db
|
| 691 |
+
log_spec = torch.maximum(log_spec, threshold)
|
| 692 |
+
else:
|
| 693 |
+
threshold = log_spec.max() - self.dynamic_range_db
|
| 694 |
+
log_spec = torch.maximum(log_spec, threshold)
|
| 695 |
+
else:
|
| 696 |
+
floor_tensor = torch.tensor(self.log_floor_db, dtype=log_spec.dtype, device=log_spec.device)
|
| 697 |
+
log_spec = torch.maximum(log_spec, floor_tensor)
|
| 698 |
+
|
| 699 |
+
log_spec = (log_spec + 4.0) / 4.0
|
| 700 |
+
|
| 701 |
+
if device != "cpu":
|
| 702 |
+
log_spec = log_spec.detach().cpu()
|
| 703 |
+
return log_spec.numpy()
|
| 704 |
+
|
| 705 |
+
def process(self, *args, **kwargs):
|
| 706 |
+
"""Alias of __call__ for convenience."""
|
| 707 |
+
return self.__call__(*args, **kwargs)
|
| 708 |
+
|
| 709 |
+
|
| 710 |
+
class StreamingMelProcessorExact:
|
| 711 |
+
"""Strictly offline equivalent streaming Mel processor.
|
| 712 |
+
|
| 713 |
+
- accumulate all historical audio into buffer; use the same feature_extractor to calculate the entire mel after each addition.
|
| 714 |
+
- only output "stable" frames: the frame center does not depend on future (right) context, i.e. center + n_fft//2 <= current buffer length.
|
| 715 |
+
- output the last batch of frames at the end (flush), ensuring complete consistency with offline full-calculation.
|
| 716 |
+
|
| 717 |
+
Cost: Each call performs feature extraction on the accumulated buffer (can be optimized to incremental if needed).
|
| 718 |
+
"""
|
| 719 |
+
|
| 720 |
+
def __init__(
|
| 721 |
+
self,
|
| 722 |
+
feature_extractor: MiniCPMAAudioProcessor,
|
| 723 |
+
chunk_ms: int = 100,
|
| 724 |
+
first_chunk_ms: Optional[int] = None,
|
| 725 |
+
sample_rate: int = 16000,
|
| 726 |
+
n_fft: int = 400,
|
| 727 |
+
hop_length: int = 160,
|
| 728 |
+
n_mels: int = 80,
|
| 729 |
+
cnn_redundancy_ms: int = 10, # (given in ms, usually 10ms=1 frame)
|
| 730 |
+
# sliding window parameters
|
| 731 |
+
enable_sliding_window: bool = False, # whether to enable sliding window
|
| 732 |
+
slide_trigger_seconds: float = 30.0, # trigger threshold for sliding window in seconds
|
| 733 |
+
slide_stride_seconds: float = 10.0, # stride for sliding window in seconds
|
| 734 |
+
):
|
| 735 |
+
self.feature_extractor = feature_extractor
|
| 736 |
+
self.chunk_ms = chunk_ms
|
| 737 |
+
self.first_chunk_ms = first_chunk_ms if first_chunk_ms is not None else chunk_ms
|
| 738 |
+
self.sample_rate = sample_rate
|
| 739 |
+
self.n_fft = n_fft
|
| 740 |
+
self.hop_length = hop_length
|
| 741 |
+
self.n_mels = n_mels
|
| 742 |
+
|
| 743 |
+
self.chunk_samples = int(round(chunk_ms * sample_rate / 1000))
|
| 744 |
+
self.chunk_frames = self.chunk_samples // hop_length
|
| 745 |
+
# align to hop_length to avoid frame boundary issues
|
| 746 |
+
hop = self.hop_length
|
| 747 |
+
raw_first_samples = int(round(self.first_chunk_ms * sample_rate / 1000))
|
| 748 |
+
aligned_first = max(hop, (raw_first_samples // hop) * hop)
|
| 749 |
+
self.first_chunk_samples = aligned_first
|
| 750 |
+
self.half_window = n_fft // 2 # required right context
|
| 751 |
+
|
| 752 |
+
# redundancy frames (in frames), <=1 frame: 10ms → 1 frame
|
| 753 |
+
self.cnn_redundancy_ms = cnn_redundancy_ms
|
| 754 |
+
self.cnn_redundancy_samples = int(cnn_redundancy_ms * sample_rate / 1000)
|
| 755 |
+
self.cnn_redundancy_frames = max(0, self.cnn_redundancy_samples // hop_length)
|
| 756 |
+
|
| 757 |
+
# sliding window configuration (Trigger mode)
|
| 758 |
+
self.enable_sliding_window = enable_sliding_window
|
| 759 |
+
self.trigger_seconds = slide_trigger_seconds
|
| 760 |
+
self.slide_seconds = slide_stride_seconds
|
| 761 |
+
|
| 762 |
+
# shift/base (global frame coordinates)
|
| 763 |
+
self.left_samples_dropped = 0 # samples dropped from the left
|
| 764 |
+
self.base_T = 0 # index of the "global frame" corresponding to mel_full[:, :, 0]
|
| 765 |
+
|
| 766 |
+
self.reset()
|
| 767 |
+
|
| 768 |
+
def reset(self):
|
| 769 |
+
self.buffer = np.zeros(0, dtype=np.float32)
|
| 770 |
+
self.last_emitted_T = 0
|
| 771 |
+
self.total_samples_processed = 0
|
| 772 |
+
self.chunk_count = 0
|
| 773 |
+
self.is_first = True
|
| 774 |
+
self.left_samples_dropped = 0
|
| 775 |
+
self.base_T = 0
|
| 776 |
+
|
| 777 |
+
def get_chunk_size(self) -> int:
|
| 778 |
+
return self.first_chunk_samples if self.is_first else self.chunk_samples
|
| 779 |
+
|
| 780 |
+
def get_expected_output_frames(self) -> int:
|
| 781 |
+
raise NotImplementedError("get_expected_output_frames is not implemented")
|
| 782 |
+
|
| 783 |
+
def _extract_full(self) -> torch.Tensor:
|
| 784 |
+
# when buffer length is less than n_fft, Whisper's internal STFT will raise an error in center=True and pad mode
|
| 785 |
+
# (pad is greater than input length). At this time, there is no stable frame to output, so return empty features directly.
|
| 786 |
+
if len(self.buffer) < self.n_fft:
|
| 787 |
+
raise ValueError(f"buffer length is shorter than n_fft {len(self.buffer)} < {self.n_fft}")
|
| 788 |
+
# if buffer length is less than 5s, use set_spac_log_norm(log_floor_db=-10) or the last cached result
|
| 789 |
+
if len(self.buffer) < 5 * self.sample_rate:
|
| 790 |
+
# TODO: here the best is to do some experiments to choose the best one, now this is selected through experience, can see MiniCPMAAudioProcessor's main implementation
|
| 791 |
+
self.feature_extractor.set_spac_log_norm(log_floor_db=-10)
|
| 792 |
+
# if buffer length is greater than 5s, use set_spac_log_norm(dynamic_range_db=8)
|
| 793 |
+
else:
|
| 794 |
+
self.feature_extractor.set_spac_log_norm(dynamic_range_db=8)
|
| 795 |
+
feats = self.feature_extractor(
|
| 796 |
+
self.buffer,
|
| 797 |
+
sampling_rate=self.sample_rate,
|
| 798 |
+
return_tensors="pt",
|
| 799 |
+
padding=False,
|
| 800 |
+
)
|
| 801 |
+
return feats.input_features # [1, 80, T]
|
| 802 |
+
|
| 803 |
+
def _stable_frames_count(self) -> int:
|
| 804 |
+
# number of stable frames = floor((len(buffer) - half_window) / hop) + 1, minimum is 0
|
| 805 |
+
L = int(self.buffer.shape[0])
|
| 806 |
+
if L <= 0:
|
| 807 |
+
return 0
|
| 808 |
+
if L < self.half_window:
|
| 809 |
+
return 0
|
| 810 |
+
return max(0, (L - self.half_window) // self.hop_length + 1)
|
| 811 |
+
|
| 812 |
+
def _maybe_slide_buffer(self):
|
| 813 |
+
"""Trigger mode sliding window: when the buffer reaches the trigger threshold, slide a fixed length window."""
|
| 814 |
+
if not self.enable_sliding_window:
|
| 815 |
+
return
|
| 816 |
+
|
| 817 |
+
sr = self.sample_rate
|
| 818 |
+
hop = self.hop_length
|
| 819 |
+
L = len(self.buffer)
|
| 820 |
+
|
| 821 |
+
# convert seconds to samples
|
| 822 |
+
trigger_samples = int(self.trigger_seconds * sr)
|
| 823 |
+
stride_samples = int(self.slide_seconds * sr)
|
| 824 |
+
|
| 825 |
+
# check if the trigger threshold is reached
|
| 826 |
+
if L < trigger_samples:
|
| 827 |
+
return
|
| 828 |
+
|
| 829 |
+
# calculate the number of samples to drop (fixed sliding stride_samples)
|
| 830 |
+
drop = stride_samples
|
| 831 |
+
|
| 832 |
+
# cannot drop the left context that is still needed for subsequent emission
|
| 833 |
+
# in trigger mode, we only need to protect the minimum necessary data
|
| 834 |
+
# i.e. ensure that we do not discard frames that may be needed in the future
|
| 835 |
+
last_emitted_local = self.last_emitted_T - self.base_T
|
| 836 |
+
|
| 837 |
+
# only protect necessary context (e.g. the most recent 1 second data)
|
| 838 |
+
min_keep_seconds = 1.0 # keep at least 1 second of data to ensure continuity
|
| 839 |
+
min_keep_samples = int(min_keep_seconds * sr)
|
| 840 |
+
|
| 841 |
+
# guard_samples are the minimum samples we must keep
|
| 842 |
+
guard_samples = min(min_keep_samples, L - drop)
|
| 843 |
+
|
| 844 |
+
# limit: do not exceed the safe boundary; and align hop
|
| 845 |
+
max_allowed_drop = max(0, L - guard_samples)
|
| 846 |
+
drop = min(drop, max_allowed_drop)
|
| 847 |
+
drop = (drop // hop) * hop
|
| 848 |
+
|
| 849 |
+
if drop <= 0:
|
| 850 |
+
return
|
| 851 |
+
|
| 852 |
+
# truly drop & update base
|
| 853 |
+
self.buffer = self.buffer[drop:]
|
| 854 |
+
self.left_samples_dropped += drop
|
| 855 |
+
self.base_T += drop // hop
|
| 856 |
+
|
| 857 |
+
def process(self, audio_chunk: np.ndarray, is_last_chunk: bool = False) -> Tuple[torch.Tensor, Dict]:
|
| 858 |
+
self.chunk_count += 1
|
| 859 |
+
# append to buffer
|
| 860 |
+
if len(self.buffer) == 0:
|
| 861 |
+
self.buffer = audio_chunk.astype(np.float32, copy=True)
|
| 862 |
+
else:
|
| 863 |
+
self.buffer = np.concatenate([self.buffer, audio_chunk.astype(np.float32, copy=True)])
|
| 864 |
+
|
| 865 |
+
# sliding window processing
|
| 866 |
+
self._maybe_slide_buffer()
|
| 867 |
+
|
| 868 |
+
# full extraction (for the current window)
|
| 869 |
+
mel_full = self._extract_full()
|
| 870 |
+
T_full = mel_full.shape[-1] # local frames in the current window
|
| 871 |
+
stable_T = min(T_full, self._stable_frames_count()) # local stable frames
|
| 872 |
+
stable_T_global = self.base_T + stable_T # map to global frame coordinates
|
| 873 |
+
|
| 874 |
+
# plan the core frames for the current emission (global coordinates)
|
| 875 |
+
core_start_g = self.last_emitted_T
|
| 876 |
+
core_end_g = core_start_g + self.chunk_frames
|
| 877 |
+
required_stable_g = core_end_g + self.cnn_redundancy_frames
|
| 878 |
+
|
| 879 |
+
if stable_T_global >= required_stable_g or is_last_chunk:
|
| 880 |
+
emit_start_g = max(0, core_start_g - self.cnn_redundancy_frames)
|
| 881 |
+
emit_end_g = core_end_g + self.cnn_redundancy_frames
|
| 882 |
+
|
| 883 |
+
# global -> local index
|
| 884 |
+
emit_start = max(0, emit_start_g - self.base_T)
|
| 885 |
+
emit_end = emit_end_g - self.base_T
|
| 886 |
+
emit_start = max(0, min(emit_start, T_full))
|
| 887 |
+
emit_end = max(emit_start, min(emit_end, T_full))
|
| 888 |
+
|
| 889 |
+
mel_output = mel_full[:, :, emit_start:emit_end]
|
| 890 |
+
self.last_emitted_T = core_end_g # only advance the core frame pointer (global)
|
| 891 |
+
else:
|
| 892 |
+
mel_output = mel_full[:, :, 0:0]
|
| 893 |
+
|
| 894 |
+
self.total_samples_processed += len(audio_chunk)
|
| 895 |
+
self.is_first = False
|
| 896 |
+
|
| 897 |
+
info = {
|
| 898 |
+
"type": "exact_chunk",
|
| 899 |
+
"chunk_number": self.chunk_count,
|
| 900 |
+
"emitted_frames": mel_output.shape[-1],
|
| 901 |
+
"stable_T": stable_T,
|
| 902 |
+
"T_full": T_full,
|
| 903 |
+
"base_T": self.base_T,
|
| 904 |
+
"stable_T_global": stable_T_global,
|
| 905 |
+
"buffer_len_samples": int(self.buffer.shape[0]),
|
| 906 |
+
"left_samples_dropped": self.left_samples_dropped,
|
| 907 |
+
"core_start": core_start_g, # if keep the original field name, use the global value here
|
| 908 |
+
"core_end": core_end_g, # same as above
|
| 909 |
+
}
|
| 910 |
+
return mel_output, info
|
| 911 |
+
|
| 912 |
+
def flush(self) -> torch.Tensor:
|
| 913 |
+
"""Called when the stream ends, output the remaining unemitted frames, ensuring consistency with offline (calculated by global coordinates)."""
|
| 914 |
+
if len(self.buffer) == 0:
|
| 915 |
+
return torch.zeros(1, 80, 0)
|
| 916 |
+
|
| 917 |
+
mel_full = self._extract_full()
|
| 918 |
+
T_local = mel_full.shape[-1]
|
| 919 |
+
T_global = self.base_T + T_local
|
| 920 |
+
|
| 921 |
+
if self.last_emitted_T < T_global:
|
| 922 |
+
start_l = max(0, self.last_emitted_T - self.base_T)
|
| 923 |
+
tail = mel_full[:, :, start_l:]
|
| 924 |
+
self.last_emitted_T = T_global
|
| 925 |
+
return tail
|
| 926 |
+
return mel_full[:, :, 0:0]
|
| 927 |
+
|
| 928 |
+
def get_config(self) -> Dict:
|
| 929 |
+
return {
|
| 930 |
+
"chunk_ms": self.chunk_ms,
|
| 931 |
+
"first_chunk_ms": self.first_chunk_ms,
|
| 932 |
+
"effective_first_chunk_ms": self.first_chunk_samples / self.sample_rate * 1000.0,
|
| 933 |
+
"sample_rate": self.sample_rate,
|
| 934 |
+
"n_fft": self.n_fft,
|
| 935 |
+
"hop_length": self.hop_length,
|
| 936 |
+
"cnn_redundancy_ms": self.cnn_redundancy_ms,
|
| 937 |
+
"cnn_redundancy_frames": self.cnn_redundancy_frames,
|
| 938 |
+
"enable_sliding_window": self.enable_sliding_window,
|
| 939 |
+
"trigger_seconds": self.trigger_seconds,
|
| 940 |
+
"slide_seconds": self.slide_seconds,
|
| 941 |
+
}
|
| 942 |
+
|
| 943 |
+
def get_state(self) -> Dict:
|
| 944 |
+
return {
|
| 945 |
+
"chunk_count": self.chunk_count,
|
| 946 |
+
"last_emitted_T": self.last_emitted_T,
|
| 947 |
+
"total_samples_processed": self.total_samples_processed,
|
| 948 |
+
"buffer_len": int(self.buffer.shape[0]),
|
| 949 |
+
"base_T": self.base_T,
|
| 950 |
+
"left_samples_dropped": self.left_samples_dropped,
|
| 951 |
+
}
|
| 952 |
+
|
| 953 |
+
def get_snapshot(self) -> Dict:
|
| 954 |
+
"""Get a complete state snapshot (including buffer), used for recovery from a fast start.
|
| 955 |
+
|
| 956 |
+
Returns:
|
| 957 |
+
A dictionary containing the complete state, which can be used to restore the snapshot
|
| 958 |
+
"""
|
| 959 |
+
buffer_copy = self.buffer.copy()
|
| 960 |
+
snapshot = {
|
| 961 |
+
"chunk_count": self.chunk_count,
|
| 962 |
+
"last_emitted_T": self.last_emitted_T,
|
| 963 |
+
"total_samples_processed": self.total_samples_processed,
|
| 964 |
+
"buffer": buffer_copy,
|
| 965 |
+
"base_T": self.base_T,
|
| 966 |
+
"left_samples_dropped": self.left_samples_dropped,
|
| 967 |
+
"is_first": self.is_first,
|
| 968 |
+
# save the state of the feature_extractor (key: ensure determinism of mel feature extraction)
|
| 969 |
+
"fe_dynamic_log_norm": getattr(self.feature_extractor, "dynamic_log_norm", None),
|
| 970 |
+
"fe_dynamic_range_db": getattr(self.feature_extractor, "dynamic_range_db", None),
|
| 971 |
+
"fe_log_floor_db": getattr(self.feature_extractor, "log_floor_db", None),
|
| 972 |
+
}
|
| 973 |
+
|
| 974 |
+
return snapshot
|
| 975 |
+
|
| 976 |
+
def restore_snapshot(self, snapshot: Dict) -> None:
|
| 977 |
+
"""Restore state from a snapshot
|
| 978 |
+
|
| 979 |
+
Args:
|
| 980 |
+
snapshot: the snapshot dictionary returned by get_snapshot
|
| 981 |
+
"""
|
| 982 |
+
# record the state before restoration
|
| 983 |
+
prev_state = {
|
| 984 |
+
"chunk_count": self.chunk_count,
|
| 985 |
+
"last_emitted_T": self.last_emitted_T,
|
| 986 |
+
"buffer_len": len(self.buffer),
|
| 987 |
+
}
|
| 988 |
+
|
| 989 |
+
# restore state
|
| 990 |
+
self.chunk_count = snapshot["chunk_count"]
|
| 991 |
+
self.last_emitted_T = snapshot["last_emitted_T"]
|
| 992 |
+
self.total_samples_processed = snapshot["total_samples_processed"]
|
| 993 |
+
self.buffer = snapshot["buffer"].copy() # copy buffer
|
| 994 |
+
self.base_T = snapshot["base_T"]
|
| 995 |
+
self.left_samples_dropped = snapshot["left_samples_dropped"]
|
| 996 |
+
self.is_first = snapshot["is_first"]
|
| 997 |
+
|
| 998 |
+
# restore the state of the feature_extractor (key: ensure determinism of mel feature extraction)
|
| 999 |
+
if snapshot.get("fe_dynamic_log_norm") is not None:
|
| 1000 |
+
self.feature_extractor.dynamic_log_norm = snapshot["fe_dynamic_log_norm"]
|
| 1001 |
+
if snapshot.get("fe_dynamic_range_db") is not None:
|
| 1002 |
+
self.feature_extractor.dynamic_range_db = snapshot["fe_dynamic_range_db"]
|
| 1003 |
+
if snapshot.get("fe_log_floor_db") is not None:
|
| 1004 |
+
self.feature_extractor.log_floor_db = snapshot["fe_log_floor_db"]
|
| 1005 |
+
|
| 1006 |
+
|
| 1007 |
+
class MiniCPMOProcessor(ProcessorMixin):
|
| 1008 |
+
attributes = ["image_processor", "audio_processor", "tokenizer"]
|
| 1009 |
+
audio_processor_class = "AutoFeatureExtractor"
|
| 1010 |
+
image_processor_class = "AutoImageProcessor"
|
| 1011 |
+
tokenizer_class = "AutoTokenizer"
|
| 1012 |
+
|
| 1013 |
+
def __init__(self, image_processor=None, audio_processor=None, tokenizer=None, **kwargs):
|
| 1014 |
+
super().__init__(image_processor, audio_processor, tokenizer)
|
| 1015 |
+
|
| 1016 |
+
self.version = image_processor.version if image_processor else None
|
| 1017 |
+
# audio feature pooling step, needs to be consistent with config.audio_pool_step
|
| 1018 |
+
self.pool_step = kwargs.get("audio_pool_step", 5)
|
| 1019 |
+
|
| 1020 |
+
# initialize the streaming audio processor
|
| 1021 |
+
self._streaming_mel_processor = None
|
| 1022 |
+
if audio_processor is not None:
|
| 1023 |
+
self._init_streaming_processor()
|
| 1024 |
+
|
| 1025 |
+
def get_audio_placeholder(
|
| 1026 |
+
self,
|
| 1027 |
+
audio_lens: int,
|
| 1028 |
+
chunk_input: bool = True,
|
| 1029 |
+
chunk_length: int = 1,
|
| 1030 |
+
) -> str:
|
| 1031 |
+
"""
|
| 1032 |
+
Public method to get audio placeholder string for vLLM integration.
|
| 1033 |
+
|
| 1034 |
+
Args:
|
| 1035 |
+
audio_lens: Length of audio in samples
|
| 1036 |
+
chunk_input: Whether to use chunked processing
|
| 1037 |
+
chunk_length: Chunk length in seconds
|
| 1038 |
+
|
| 1039 |
+
Returns:
|
| 1040 |
+
Audio placeholder string
|
| 1041 |
+
"""
|
| 1042 |
+
pool_step = self.pool_step
|
| 1043 |
+
feature_lens = math.ceil(audio_lens / self.audio_processor.hop_length)
|
| 1044 |
+
|
| 1045 |
+
feature_lens = (feature_lens - 1) // 2 + 1
|
| 1046 |
+
output_lens = (feature_lens - pool_step) // pool_step + 1
|
| 1047 |
+
|
| 1048 |
+
if chunk_input:
|
| 1049 |
+
fbank_feat_in_chunk = int(chunk_length * 100)
|
| 1050 |
+
cnn_feat_in_chunk = (fbank_feat_in_chunk - 1) // 2 + 1
|
| 1051 |
+
audio_embeds_in_chunk = (cnn_feat_in_chunk - pool_step) // pool_step + 1
|
| 1052 |
+
num_audio_chunks = (output_lens + audio_embeds_in_chunk - 1) // audio_embeds_in_chunk
|
| 1053 |
+
|
| 1054 |
+
place_holders = ""
|
| 1055 |
+
total_unk_len = 0
|
| 1056 |
+
for _ in range(num_audio_chunks):
|
| 1057 |
+
unk_len = min(audio_embeds_in_chunk, output_lens - total_unk_len)
|
| 1058 |
+
place_holders += self.tokenizer.audio_start + "<unk>" * unk_len + self.tokenizer.audio_end
|
| 1059 |
+
total_unk_len += unk_len
|
| 1060 |
+
audio_placeholder = place_holders
|
| 1061 |
+
else:
|
| 1062 |
+
audio_placeholder = self.tokenizer.audio_start + "<unk>" * output_lens + self.tokenizer.audio_end
|
| 1063 |
+
|
| 1064 |
+
return audio_placeholder
|
| 1065 |
+
|
| 1066 |
+
def _init_streaming_processor(
|
| 1067 |
+
self,
|
| 1068 |
+
chunk_ms: int = 100,
|
| 1069 |
+
cnn_redundancy_ms: int = 0,
|
| 1070 |
+
*,
|
| 1071 |
+
mode: str = "exact",
|
| 1072 |
+
first_chunk_ms: Optional[int] = None,
|
| 1073 |
+
enable_sliding_window: bool = False,
|
| 1074 |
+
slide_trigger_seconds: float = 30.0,
|
| 1075 |
+
slide_stride_seconds: float = 10.0,
|
| 1076 |
+
):
|
| 1077 |
+
"""Initialize the streaming processor
|
| 1078 |
+
|
| 1079 |
+
Args:
|
| 1080 |
+
chunk_ms: Chunk size in milliseconds, also the sliding step.
|
| 1081 |
+
cnn_redundancy_ms: CNN boundary redundancy in milliseconds (before and after), 0 means standard mode.
|
| 1082 |
+
mode: streaming processing mode, currently only supports "exact"
|
| 1083 |
+
first_chunk_ms: the size of the first chunk (milliseconds), if not specified, it is the same as chunk_ms
|
| 1084 |
+
enable_sliding_window: whether to enable sliding window (trigger mode)
|
| 1085 |
+
slide_trigger_seconds: trigger threshold for sliding window in seconds
|
| 1086 |
+
slide_stride_seconds: stride for sliding window in seconds
|
| 1087 |
+
"""
|
| 1088 |
+
if mode == "exact":
|
| 1089 |
+
self._streaming_mel_processor = StreamingMelProcessorExact(
|
| 1090 |
+
feature_extractor=self.audio_processor,
|
| 1091 |
+
chunk_ms=chunk_ms,
|
| 1092 |
+
first_chunk_ms=first_chunk_ms,
|
| 1093 |
+
sample_rate=16000,
|
| 1094 |
+
cnn_redundancy_ms=cnn_redundancy_ms,
|
| 1095 |
+
enable_sliding_window=enable_sliding_window,
|
| 1096 |
+
slide_trigger_seconds=slide_trigger_seconds,
|
| 1097 |
+
slide_stride_seconds=slide_stride_seconds,
|
| 1098 |
+
)
|
| 1099 |
+
else:
|
| 1100 |
+
raise ValueError(f"Unsupported mode: {mode}, only 'exact' is supported")
|
| 1101 |
+
self._streaming_mode = mode if mode in ["exact"] else ("exact")
|
| 1102 |
+
|
| 1103 |
+
def set_streaming_mode(
|
| 1104 |
+
self,
|
| 1105 |
+
mode: str = "exact",
|
| 1106 |
+
chunk_ms: int = 100,
|
| 1107 |
+
cnn_redundancy_ms: int = 0,
|
| 1108 |
+
*,
|
| 1109 |
+
first_chunk_ms: Optional[int] = None,
|
| 1110 |
+
enable_sliding_window: bool = False,
|
| 1111 |
+
slide_trigger_seconds: float = 30.0,
|
| 1112 |
+
slide_stride_seconds: float = 10.0,
|
| 1113 |
+
):
|
| 1114 |
+
"""Set streaming processing mode
|
| 1115 |
+
|
| 1116 |
+
Args:
|
| 1117 |
+
mode: streaming processing mode, currently only supports "exact"
|
| 1118 |
+
chunk_ms: chunk size in milliseconds, also the sliding step.
|
| 1119 |
+
cnn_redundancy_ms: CNN boundary redundancy in milliseconds (before and after), 0 means standard mode.
|
| 1120 |
+
first_chunk_ms: the size of the first chunk (milliseconds), if not specified, it is the same as chunk_ms
|
| 1121 |
+
enable_sliding_window: whether to enable sliding window (trigger mode)
|
| 1122 |
+
slide_trigger_seconds: trigger threshold for sliding window in seconds
|
| 1123 |
+
slide_stride_seconds: stride for sliding window in seconds
|
| 1124 |
+
"""
|
| 1125 |
+
if self.audio_processor is None:
|
| 1126 |
+
raise ValueError("audio_processor is not set, cannot initialize the streaming processor")
|
| 1127 |
+
self._init_streaming_processor(
|
| 1128 |
+
chunk_ms=chunk_ms,
|
| 1129 |
+
cnn_redundancy_ms=cnn_redundancy_ms,
|
| 1130 |
+
mode=mode,
|
| 1131 |
+
first_chunk_ms=first_chunk_ms,
|
| 1132 |
+
enable_sliding_window=enable_sliding_window,
|
| 1133 |
+
slide_trigger_seconds=slide_trigger_seconds,
|
| 1134 |
+
slide_stride_seconds=slide_stride_seconds,
|
| 1135 |
+
)
|
| 1136 |
+
|
| 1137 |
+
def process_image(
|
| 1138 |
+
self,
|
| 1139 |
+
images: Optional[ImageInput] = None,
|
| 1140 |
+
do_pad: bool = True,
|
| 1141 |
+
max_slice_nums: int = 1,
|
| 1142 |
+
return_tensors: str = "pt",
|
| 1143 |
+
) -> MiniCPMOBatchFeature:
|
| 1144 |
+
"""Process image data
|
| 1145 |
+
|
| 1146 |
+
Args:
|
| 1147 |
+
images: input images
|
| 1148 |
+
do_pad: whether to pad
|
| 1149 |
+
max_slice_nums: maximum number of slices
|
| 1150 |
+
return_tensors: return tensor type
|
| 1151 |
+
Returns:
|
| 1152 |
+
MiniCPMOBatchFeature object
|
| 1153 |
+
"""
|
| 1154 |
+
if images is None:
|
| 1155 |
+
return MiniCPMOBatchFeature(data={"pixel_values": [[]], "image_sizes": [[]], "tgt_sizes": [[]]})
|
| 1156 |
+
|
| 1157 |
+
result = self.image_processor(
|
| 1158 |
+
images, do_pad=do_pad, max_slice_nums=max_slice_nums, return_tensors=return_tensors
|
| 1159 |
+
)
|
| 1160 |
+
|
| 1161 |
+
model_inputs = {
|
| 1162 |
+
"pixel_values": result.get("pixel_values", [[]]),
|
| 1163 |
+
"image_sizes": result.get("image_sizes", [[]]),
|
| 1164 |
+
"tgt_sizes": result.get("tgt_sizes", [[]]),
|
| 1165 |
+
}
|
| 1166 |
+
|
| 1167 |
+
return MiniCPMOBatchFeature(data=model_inputs)
|
| 1168 |
+
|
| 1169 |
+
def process_audio(
|
| 1170 |
+
self,
|
| 1171 |
+
audios: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
|
| 1172 |
+
sampling_rate: int = 16000,
|
| 1173 |
+
regroup_to_seconds: Optional[int] = None,
|
| 1174 |
+
fps: int = 100,
|
| 1175 |
+
) -> MiniCPMOBatchFeature:
|
| 1176 |
+
"""Process audio data in batch
|
| 1177 |
+
|
| 1178 |
+
Args:
|
| 1179 |
+
audios: audio data
|
| 1180 |
+
sampling_rate: sampling rate
|
| 1181 |
+
regroup_to_seconds: regroup duration in seconds
|
| 1182 |
+
fps: frames per second
|
| 1183 |
+
Returns:
|
| 1184 |
+
MiniCPMOBatchFeature object
|
| 1185 |
+
"""
|
| 1186 |
+
if audios is None:
|
| 1187 |
+
return MiniCPMOBatchFeature(data={"audio_features": [], "audio_feature_lens": []})
|
| 1188 |
+
|
| 1189 |
+
audio_features, audio_feature_lens = process_audio_batch(
|
| 1190 |
+
audios=audios,
|
| 1191 |
+
feature_extractor=self.audio_processor,
|
| 1192 |
+
sampling_rate=sampling_rate,
|
| 1193 |
+
max_duration_seconds=30,
|
| 1194 |
+
return_attention_mask=True,
|
| 1195 |
+
)
|
| 1196 |
+
|
| 1197 |
+
if regroup_to_seconds is not None and len(audio_features) > 0:
|
| 1198 |
+
audio_features, audio_feature_lens = regroup_audio_features(
|
| 1199 |
+
audio_features=audio_features,
|
| 1200 |
+
audio_feature_lens=audio_feature_lens,
|
| 1201 |
+
regroup_seconds=regroup_to_seconds,
|
| 1202 |
+
fps=fps,
|
| 1203 |
+
)
|
| 1204 |
+
|
| 1205 |
+
model_inputs = {"audio_features": audio_features, "audio_feature_lens": audio_feature_lens}
|
| 1206 |
+
|
| 1207 |
+
return MiniCPMOBatchFeature(data=model_inputs)
|
| 1208 |
+
|
| 1209 |
+
def process_audio_streaming(
|
| 1210 |
+
self,
|
| 1211 |
+
audio_chunk: np.ndarray,
|
| 1212 |
+
reset: bool = False,
|
| 1213 |
+
return_batch_feature: bool = False,
|
| 1214 |
+
is_last_chunk: bool = False,
|
| 1215 |
+
) -> Union[Tuple[torch.Tensor, dict], MiniCPMOBatchFeature]:
|
| 1216 |
+
"""Process audio chunk in streaming
|
| 1217 |
+
|
| 1218 |
+
Args:
|
| 1219 |
+
audio_chunk: audio data chunk (any audio, e.g. first process 125ms, then process 100ms)
|
| 1220 |
+
reset: whether to reset the processor state
|
| 1221 |
+
return_batch_feature: whether to return MiniCPMOBatchFeature format (consistent with process_audio)
|
| 1222 |
+
Returns:
|
| 1223 |
+
If return_batch_feature=False:
|
| 1224 |
+
(audio_features, info)
|
| 1225 |
+
- audio_features: [1, 80, n_frames] mel features
|
| 1226 |
+
- info: processing information dictionary
|
| 1227 |
+
If return_batch_feature=True:
|
| 1228 |
+
MiniCPMOBatchFeature object, containing:
|
| 1229 |
+
- audio_features: [1, 80, n_frames] mel features
|
| 1230 |
+
- audio_feature_lens: [tensor([n_frames])]
|
| 1231 |
+
- info: processing information (as an extra attribute)
|
| 1232 |
+
"""
|
| 1233 |
+
if self._streaming_mel_processor is None:
|
| 1234 |
+
raise ValueError("Streaming processor not initialized, please ensure audio_processor is set")
|
| 1235 |
+
|
| 1236 |
+
if reset:
|
| 1237 |
+
self._streaming_mel_processor.reset()
|
| 1238 |
+
|
| 1239 |
+
# process chunk
|
| 1240 |
+
mel_features, info = self._streaming_mel_processor.process(audio_chunk, is_last_chunk=is_last_chunk)
|
| 1241 |
+
|
| 1242 |
+
# determine the return format based on the parameters
|
| 1243 |
+
if return_batch_feature:
|
| 1244 |
+
# return the format consistent with process_audio
|
| 1245 |
+
# note: info returns emitted_frames, which represents the actual output frames
|
| 1246 |
+
n_frames = info.get("emitted_frames", mel_features.shape[-1])
|
| 1247 |
+
model_inputs = {
|
| 1248 |
+
"audio_features": mel_features,
|
| 1249 |
+
"audio_feature_lens": [torch.tensor([n_frames])],
|
| 1250 |
+
"streaming_info": info, # add streaming processing information
|
| 1251 |
+
}
|
| 1252 |
+
return MiniCPMOBatchFeature(data=model_inputs)
|
| 1253 |
+
else:
|
| 1254 |
+
return mel_features, info
|
| 1255 |
+
|
| 1256 |
+
def reset_streaming(self):
|
| 1257 |
+
if self._streaming_mel_processor is not None:
|
| 1258 |
+
self._streaming_mel_processor.reset()
|
| 1259 |
+
|
| 1260 |
+
def get_streaming_chunk_size(self) -> int:
|
| 1261 |
+
if self._streaming_mel_processor is None:
|
| 1262 |
+
raise ValueError("Streaming processor not initialized")
|
| 1263 |
+
return self._streaming_mel_processor.get_chunk_size()
|
| 1264 |
+
|
| 1265 |
+
def configure_streaming(
|
| 1266 |
+
self,
|
| 1267 |
+
chunk_ms: int = 100,
|
| 1268 |
+
enable_sliding_window: bool = False,
|
| 1269 |
+
slide_trigger_seconds: float = 30.0,
|
| 1270 |
+
slide_stride_seconds: float = 10.0,
|
| 1271 |
+
):
|
| 1272 |
+
"""Configure streaming processor parameters
|
| 1273 |
+
|
| 1274 |
+
Args:
|
| 1275 |
+
chunk_ms: chunk size in milliseconds
|
| 1276 |
+
enable_sliding_window: whether to enable sliding window (trigger mode)
|
| 1277 |
+
slide_trigger_seconds: trigger threshold for sliding window in seconds
|
| 1278 |
+
slide_stride_seconds: stride for sliding window in seconds
|
| 1279 |
+
"""
|
| 1280 |
+
if self.audio_processor is None:
|
| 1281 |
+
raise ValueError("audio_processor is not set")
|
| 1282 |
+
|
| 1283 |
+
self._init_streaming_processor(
|
| 1284 |
+
chunk_ms=chunk_ms,
|
| 1285 |
+
enable_sliding_window=enable_sliding_window,
|
| 1286 |
+
slide_trigger_seconds=slide_trigger_seconds,
|
| 1287 |
+
slide_stride_seconds=slide_stride_seconds,
|
| 1288 |
+
)
|
| 1289 |
+
|
| 1290 |
+
def get_streaming_config(self) -> dict:
|
| 1291 |
+
if self._streaming_mel_processor is None:
|
| 1292 |
+
return {}
|
| 1293 |
+
return self._streaming_mel_processor.get_config()
|
| 1294 |
+
|
| 1295 |
+
def get_streaming_state(self) -> dict:
|
| 1296 |
+
if self._streaming_mel_processor is None:
|
| 1297 |
+
return {}
|
| 1298 |
+
return self._streaming_mel_processor.get_state()
|
| 1299 |
+
|
| 1300 |
+
def get_streaming_snapshot(self) -> dict:
|
| 1301 |
+
if self._streaming_mel_processor is None:
|
| 1302 |
+
return {}
|
| 1303 |
+
return self._streaming_mel_processor.get_snapshot()
|
| 1304 |
+
|
| 1305 |
+
def restore_streaming_snapshot(self, snapshot: dict) -> None:
|
| 1306 |
+
if self._streaming_mel_processor is None:
|
| 1307 |
+
return
|
| 1308 |
+
if not snapshot:
|
| 1309 |
+
return
|
| 1310 |
+
self._streaming_mel_processor.restore_snapshot(snapshot)
|
| 1311 |
+
|
| 1312 |
+
def __call__(
|
| 1313 |
+
self,
|
| 1314 |
+
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
|
| 1315 |
+
images: ImageInput = None,
|
| 1316 |
+
audios: Union[np.ndarray, List[np.ndarray], List[List[np.ndarray]]] = None,
|
| 1317 |
+
audio_parts: Optional[list] = None,
|
| 1318 |
+
max_length: Optional[int] = None,
|
| 1319 |
+
do_pad: Optional[bool] = True,
|
| 1320 |
+
max_slice_nums: int = None,
|
| 1321 |
+
use_image_id: bool = True,
|
| 1322 |
+
stream_input: bool = False,
|
| 1323 |
+
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
|
| 1324 |
+
sampling_rate: Optional[int] = 16000,
|
| 1325 |
+
online_streaming: bool = False,
|
| 1326 |
+
audio_chunk_idx: int = 0,
|
| 1327 |
+
is_last_chunk: bool = False,
|
| 1328 |
+
**kwargs,
|
| 1329 |
+
) -> MiniCPMOBatchFeature:
|
| 1330 |
+
if images is not None:
|
| 1331 |
+
image_inputs = self.process_image(
|
| 1332 |
+
images=images, do_pad=do_pad, max_slice_nums=max_slice_nums, return_tensors=return_tensors
|
| 1333 |
+
)
|
| 1334 |
+
else:
|
| 1335 |
+
image_inputs = None
|
| 1336 |
+
|
| 1337 |
+
audio_features, audio_feature_lens, audio_phs = self.audio_feature_extract(
|
| 1338 |
+
audios,
|
| 1339 |
+
audio_parts,
|
| 1340 |
+
stream_input,
|
| 1341 |
+
sampling_rate,
|
| 1342 |
+
online_streaming=online_streaming,
|
| 1343 |
+
is_last_chunk=is_last_chunk,
|
| 1344 |
+
)
|
| 1345 |
+
|
| 1346 |
+
model_inputs = self._convert_omni_to_inputs(
|
| 1347 |
+
image_inputs,
|
| 1348 |
+
audio_phs,
|
| 1349 |
+
text,
|
| 1350 |
+
max_slice_nums=max_slice_nums,
|
| 1351 |
+
use_image_id=use_image_id,
|
| 1352 |
+
max_length=max_length,
|
| 1353 |
+
**kwargs,
|
| 1354 |
+
)
|
| 1355 |
+
|
| 1356 |
+
model_inputs["audio_features"] = audio_features
|
| 1357 |
+
model_inputs["audio_feature_lens"] = audio_feature_lens
|
| 1358 |
+
|
| 1359 |
+
result = MiniCPMOBatchFeature(data={**model_inputs})
|
| 1360 |
+
|
| 1361 |
+
if online_streaming:
|
| 1362 |
+
result.use_extra_context = True
|
| 1363 |
+
result.prefix_extra_frames = 0 if audio_chunk_idx == 0 else 2
|
| 1364 |
+
result.suffix_extra_frames = 2
|
| 1365 |
+
result.chunk_idx = audio_chunk_idx
|
| 1366 |
+
|
| 1367 |
+
return result
|
| 1368 |
+
|
| 1369 |
+
def audio_feature_extract(
|
| 1370 |
+
self,
|
| 1371 |
+
audios: Union[np.ndarray, List[np.ndarray], List[List[np.ndarray]], None] = None,
|
| 1372 |
+
audio_parts: Optional[list] = None,
|
| 1373 |
+
stream_input: Optional[bool] = False,
|
| 1374 |
+
sampling_rate: Optional[int] = None,
|
| 1375 |
+
chunk_length: Optional[int] = 1,
|
| 1376 |
+
online_streaming: bool = False,
|
| 1377 |
+
is_last_chunk: bool = False,
|
| 1378 |
+
**kwargs,
|
| 1379 |
+
):
|
| 1380 |
+
if audios is None:
|
| 1381 |
+
return [], [], []
|
| 1382 |
+
|
| 1383 |
+
if isinstance(audios, np.ndarray):
|
| 1384 |
+
audios_list = [[audios]]
|
| 1385 |
+
elif isinstance(audios[0], np.ndarray):
|
| 1386 |
+
audios_list = [audios]
|
| 1387 |
+
else:
|
| 1388 |
+
audios_list = audios
|
| 1389 |
+
|
| 1390 |
+
if audio_parts is not None:
|
| 1391 |
+
assert len(audio_parts) == len(audios_list)
|
| 1392 |
+
for parts, audios in zip(audio_parts, audios_list):
|
| 1393 |
+
assert len(parts) == len(audios)
|
| 1394 |
+
|
| 1395 |
+
audio_feature_lens_list = []
|
| 1396 |
+
audio_ph_list = []
|
| 1397 |
+
audio_features_all = []
|
| 1398 |
+
|
| 1399 |
+
# audio placeholder not dependent on audio_parts
|
| 1400 |
+
for audios in audios_list:
|
| 1401 |
+
if audios:
|
| 1402 |
+
audio_ph_list.append(
|
| 1403 |
+
[
|
| 1404 |
+
self.get_audio_placeholder(len(a), chunk_input=stream_input, chunk_length=chunk_length)
|
| 1405 |
+
for a in audios
|
| 1406 |
+
]
|
| 1407 |
+
)
|
| 1408 |
+
else:
|
| 1409 |
+
audio_ph_list.append([])
|
| 1410 |
+
|
| 1411 |
+
for idx, audios in enumerate(audios_list):
|
| 1412 |
+
if audio_parts is not None:
|
| 1413 |
+
# same audio part merge
|
| 1414 |
+
audio_part = audio_parts[idx]
|
| 1415 |
+
merge_audio = []
|
| 1416 |
+
cur_audio = []
|
| 1417 |
+
for aid, (part, audio) in enumerate(zip(audio_part, audios)):
|
| 1418 |
+
if aid == 0 or audio_part[aid] == audio_part[aid - 1]:
|
| 1419 |
+
cur_audio.append(audio)
|
| 1420 |
+
else:
|
| 1421 |
+
merge_audio.append(np.hstack(cur_audio))
|
| 1422 |
+
cur_audio = [audio]
|
| 1423 |
+
if cur_audio:
|
| 1424 |
+
merge_audio.append(np.hstack(cur_audio))
|
| 1425 |
+
else:
|
| 1426 |
+
merge_audio = audios
|
| 1427 |
+
|
| 1428 |
+
# If the audio exceeds 30 seconds, split it into chunks every 30 seconds.
|
| 1429 |
+
final_merge_audio = []
|
| 1430 |
+
max_audio_inp_len = 30 * sampling_rate
|
| 1431 |
+
for audio in merge_audio:
|
| 1432 |
+
if len(audio) <= max_audio_inp_len:
|
| 1433 |
+
final_merge_audio.append(audio)
|
| 1434 |
+
else:
|
| 1435 |
+
for i in range(math.ceil(len(audio) / max_audio_inp_len)):
|
| 1436 |
+
final_merge_audio.append(audio[i * max_audio_inp_len : (i + 1) * max_audio_inp_len])
|
| 1437 |
+
|
| 1438 |
+
audio_feature_lens = []
|
| 1439 |
+
|
| 1440 |
+
if audios:
|
| 1441 |
+
if online_streaming:
|
| 1442 |
+
# online streaming: only support single audio, directly use process_audio_streaming return format
|
| 1443 |
+
assert (
|
| 1444 |
+
len(final_merge_audio) == 1
|
| 1445 |
+
), f"online streaming mode only supports single audio, currently there are {len(final_merge_audio)}"
|
| 1446 |
+
audio = final_merge_audio[0]
|
| 1447 |
+
result = self.process_audio_streaming(
|
| 1448 |
+
audio, reset=False, return_batch_feature=True, is_last_chunk=is_last_chunk
|
| 1449 |
+
)
|
| 1450 |
+
audio_features_all.append(
|
| 1451 |
+
result["audio_features"].squeeze(0)
|
| 1452 |
+
) # [1, 80, T] -> [80, T], keep consistent with batch processing
|
| 1453 |
+
audio_feature_lens_list.append(result["audio_feature_lens"][0])
|
| 1454 |
+
else:
|
| 1455 |
+
# batch processing
|
| 1456 |
+
audio_inputs = self.audio_processor(
|
| 1457 |
+
final_merge_audio,
|
| 1458 |
+
sampling_rate=sampling_rate,
|
| 1459 |
+
return_attention_mask=True,
|
| 1460 |
+
padding="max_length",
|
| 1461 |
+
return_tensors="pt",
|
| 1462 |
+
**kwargs,
|
| 1463 |
+
)
|
| 1464 |
+
audio_feature = audio_inputs["input_features"]
|
| 1465 |
+
actual_lens = audio_inputs["attention_mask"].sum(dim=1)
|
| 1466 |
+
|
| 1467 |
+
for feat, lens in zip(audio_feature, actual_lens):
|
| 1468 |
+
audio_features_all.append(feat[:, :lens])
|
| 1469 |
+
audio_feature_lens.append(lens)
|
| 1470 |
+
|
| 1471 |
+
audio_feature_lens = torch.hstack(audio_feature_lens)
|
| 1472 |
+
audio_feature_lens_list.append(audio_feature_lens)
|
| 1473 |
+
else:
|
| 1474 |
+
audio_feature_lens_list.append([])
|
| 1475 |
+
|
| 1476 |
+
if audio_features_all:
|
| 1477 |
+
audio_features = [i.permute(1, 0) for i in audio_features_all]
|
| 1478 |
+
audio_features = torch.nn.utils.rnn.pad_sequence(
|
| 1479 |
+
audio_features, batch_first=True, padding_value=0.0
|
| 1480 |
+
).permute(0, 2, 1)
|
| 1481 |
+
else:
|
| 1482 |
+
audio_features = []
|
| 1483 |
+
|
| 1484 |
+
return audio_features, audio_feature_lens_list, audio_ph_list
|
| 1485 |
+
|
| 1486 |
+
def _convert(self, input_str, max_inp_length: Optional[int] = None):
|
| 1487 |
+
old_input_ids = self.tokenizer.encode(input_str)
|
| 1488 |
+
|
| 1489 |
+
listen_token_id = self.tokenizer.convert_tokens_to_ids("<|listen|>")
|
| 1490 |
+
input_ids = []
|
| 1491 |
+
for token in old_input_ids:
|
| 1492 |
+
if token != listen_token_id:
|
| 1493 |
+
input_ids.append(token)
|
| 1494 |
+
|
| 1495 |
+
if max_inp_length is not None:
|
| 1496 |
+
input_ids = input_ids[:max_inp_length]
|
| 1497 |
+
input_ids = torch.tensor(input_ids, dtype=torch.int32)
|
| 1498 |
+
|
| 1499 |
+
## image bound
|
| 1500 |
+
start_cond = (input_ids == self.tokenizer.im_start_id) | (input_ids == self.tokenizer.slice_start_id)
|
| 1501 |
+
end_cond = (input_ids == self.tokenizer.im_end_id) | (input_ids == self.tokenizer.slice_end_id)
|
| 1502 |
+
|
| 1503 |
+
image_start_idx = torch.where(start_cond)[0]
|
| 1504 |
+
image_start_idx += 1
|
| 1505 |
+
image_end_idx = torch.where(end_cond)[0]
|
| 1506 |
+
|
| 1507 |
+
valid_image_nums = max(len(image_start_idx), len(image_end_idx))
|
| 1508 |
+
|
| 1509 |
+
image_bounds = torch.hstack(
|
| 1510 |
+
[
|
| 1511 |
+
image_start_idx[:valid_image_nums].unsqueeze(-1),
|
| 1512 |
+
image_end_idx[:valid_image_nums].unsqueeze(-1),
|
| 1513 |
+
]
|
| 1514 |
+
)
|
| 1515 |
+
|
| 1516 |
+
## audio bound
|
| 1517 |
+
audio_start_idx = torch.where(input_ids == self.tokenizer.audio_start_id)[0]
|
| 1518 |
+
audio_end_idx = torch.where(input_ids == self.tokenizer.audio_end_id)[0]
|
| 1519 |
+
assert len(audio_start_idx) == len(audio_end_idx)
|
| 1520 |
+
audio_bounds = torch.hstack([(audio_start_idx + 1).unsqueeze(-1), audio_end_idx.unsqueeze(-1)])
|
| 1521 |
+
|
| 1522 |
+
spk_start_idx = torch.where(input_ids == self.tokenizer.spk_start_id)[0]
|
| 1523 |
+
spk_end_idx = torch.where(input_ids == self.tokenizer.spk_end_id)[0]
|
| 1524 |
+
assert len(spk_start_idx) == len(spk_end_idx)
|
| 1525 |
+
spk_bounds = torch.hstack([(spk_start_idx + 1).unsqueeze(-1), spk_end_idx.unsqueeze(-1)])
|
| 1526 |
+
|
| 1527 |
+
return input_ids, image_bounds, audio_bounds, spk_bounds
|
| 1528 |
+
|
| 1529 |
+
def _convert_omni_to_inputs(
|
| 1530 |
+
self,
|
| 1531 |
+
images,
|
| 1532 |
+
audio_phs,
|
| 1533 |
+
texts: Union[str, List[str]],
|
| 1534 |
+
truncation=None,
|
| 1535 |
+
max_length=None,
|
| 1536 |
+
max_slice_nums=None,
|
| 1537 |
+
use_image_id=None,
|
| 1538 |
+
return_tensors=None,
|
| 1539 |
+
**kwargs,
|
| 1540 |
+
):
|
| 1541 |
+
if images is None and audio_phs is None:
|
| 1542 |
+
model_inputs = self.tokenizer(
|
| 1543 |
+
texts, return_tensors=return_tensors, truncation=truncation, max_length=max_length, **kwargs
|
| 1544 |
+
)
|
| 1545 |
+
return MiniCPMOBatchFeature(data={**model_inputs})
|
| 1546 |
+
|
| 1547 |
+
image_pattern = "<image>./</image>"
|
| 1548 |
+
audio_pattern = "<audio>./</audio>"
|
| 1549 |
+
split_pattern = f"({image_pattern}|{audio_pattern})"
|
| 1550 |
+
|
| 1551 |
+
if isinstance(texts, str):
|
| 1552 |
+
texts = [texts]
|
| 1553 |
+
|
| 1554 |
+
bs = len(texts)
|
| 1555 |
+
if images is not None:
|
| 1556 |
+
images, image_sizes, tgt_sizes = images["pixel_values"], images["image_sizes"], images["tgt_sizes"]
|
| 1557 |
+
else:
|
| 1558 |
+
images, image_sizes, tgt_sizes = [[]] * bs, [[]] * bs, [[]] * bs
|
| 1559 |
+
|
| 1560 |
+
input_ids_list = []
|
| 1561 |
+
image_bounds_list = []
|
| 1562 |
+
audio_bounds_list = []
|
| 1563 |
+
spk_bounds_list = []
|
| 1564 |
+
|
| 1565 |
+
for index, text in enumerate(texts):
|
| 1566 |
+
text_chunks = re.split(split_pattern, text)
|
| 1567 |
+
|
| 1568 |
+
image_tags = re.findall(image_pattern, text)
|
| 1569 |
+
audio_tags = re.findall(audio_pattern, text)
|
| 1570 |
+
|
| 1571 |
+
if image_tags:
|
| 1572 |
+
assert images is not None
|
| 1573 |
+
assert len(image_tags) == len(image_sizes[index])
|
| 1574 |
+
if audio_tags:
|
| 1575 |
+
assert audio_phs is not None
|
| 1576 |
+
assert len(audio_tags) == len(audio_phs[index])
|
| 1577 |
+
|
| 1578 |
+
image_id = 0
|
| 1579 |
+
audio_id = 0
|
| 1580 |
+
for i, chunk in enumerate(text_chunks):
|
| 1581 |
+
if chunk == image_pattern:
|
| 1582 |
+
image_placeholder = self.image_processor.get_slice_image_placeholder(
|
| 1583 |
+
image_sizes[index][image_id], image_id, max_slice_nums, use_image_id
|
| 1584 |
+
)
|
| 1585 |
+
image_id += 1
|
| 1586 |
+
text_chunks[i] = image_placeholder
|
| 1587 |
+
elif chunk == audio_pattern:
|
| 1588 |
+
audio_placeholder = audio_phs[index][audio_id]
|
| 1589 |
+
audio_id += 1
|
| 1590 |
+
text_chunks[i] = audio_placeholder
|
| 1591 |
+
|
| 1592 |
+
final_text = "".join(text_chunks)
|
| 1593 |
+
input_ids, image_bounds, audio_bounds, spk_bounds = self._convert(final_text, max_length)
|
| 1594 |
+
|
| 1595 |
+
input_ids_list.append(input_ids)
|
| 1596 |
+
image_bounds_list.append(image_bounds)
|
| 1597 |
+
audio_bounds_list.append(audio_bounds)
|
| 1598 |
+
spk_bounds_list.append(spk_bounds)
|
| 1599 |
+
|
| 1600 |
+
padded_input_ids, padding_lengths = self.pad(input_ids_list, padding_side="left")
|
| 1601 |
+
attention_mask = torch.ones_like(padded_input_ids, dtype=torch.bool)
|
| 1602 |
+
for i, length in enumerate(padding_lengths):
|
| 1603 |
+
image_bounds_list[i] = image_bounds_list[i] + length
|
| 1604 |
+
audio_bounds_list[i] = audio_bounds_list[i] + length
|
| 1605 |
+
spk_bounds_list[i] = spk_bounds_list[i] + length
|
| 1606 |
+
attention_mask[i, :length] = False
|
| 1607 |
+
|
| 1608 |
+
data = {
|
| 1609 |
+
"input_ids": padded_input_ids,
|
| 1610 |
+
"attention_mask": attention_mask,
|
| 1611 |
+
"pixel_values": images,
|
| 1612 |
+
"image_sizes": image_sizes,
|
| 1613 |
+
"image_bound": image_bounds_list,
|
| 1614 |
+
"tgt_sizes": tgt_sizes,
|
| 1615 |
+
"audio_bounds": audio_bounds_list,
|
| 1616 |
+
"spk_bounds": spk_bounds_list,
|
| 1617 |
+
}
|
| 1618 |
+
|
| 1619 |
+
return data
|
| 1620 |
+
|
| 1621 |
+
def pad(self, inputs, max_length=None, padding_value=0, padding_side="left"):
|
| 1622 |
+
items = []
|
| 1623 |
+
if isinstance(inputs[0], list):
|
| 1624 |
+
assert isinstance(inputs[0][0], torch.Tensor)
|
| 1625 |
+
for it in inputs:
|
| 1626 |
+
for tr in it:
|
| 1627 |
+
items.append(tr)
|
| 1628 |
+
else:
|
| 1629 |
+
assert isinstance(inputs[0], torch.Tensor)
|
| 1630 |
+
items = inputs
|
| 1631 |
+
|
| 1632 |
+
batch_size = len(items)
|
| 1633 |
+
shape = items[0].shape
|
| 1634 |
+
dim = len(shape)
|
| 1635 |
+
assert dim <= 2
|
| 1636 |
+
if max_length is None:
|
| 1637 |
+
max_length = 0
|
| 1638 |
+
max_length = max(max_length, max(item.shape[-1] for item in items))
|
| 1639 |
+
min_length = min(item.shape[-1] for item in items)
|
| 1640 |
+
dtype = items[0].dtype
|
| 1641 |
+
|
| 1642 |
+
if dim == 0:
|
| 1643 |
+
return torch.stack([item for item in items], dim=0), [0]
|
| 1644 |
+
elif dim == 1:
|
| 1645 |
+
if max_length == min_length:
|
| 1646 |
+
return torch.stack([item for item in items], dim=0), [0] * batch_size
|
| 1647 |
+
tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
|
| 1648 |
+
else:
|
| 1649 |
+
tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value
|
| 1650 |
+
|
| 1651 |
+
padding_length = []
|
| 1652 |
+
for i, item in enumerate(items):
|
| 1653 |
+
if dim == 1:
|
| 1654 |
+
if padding_side == "left":
|
| 1655 |
+
tensor[i, -len(item) :] = item.clone()
|
| 1656 |
+
else:
|
| 1657 |
+
tensor[i, : len(item)] = item.clone()
|
| 1658 |
+
elif dim == 2:
|
| 1659 |
+
if padding_side == "left":
|
| 1660 |
+
tensor[i, -len(item) :, :] = item.clone()
|
| 1661 |
+
else:
|
| 1662 |
+
tensor[i, : len(item), :] = item.clone()
|
| 1663 |
+
padding_length.append(tensor.shape[-1] - len(item))
|
| 1664 |
+
|
| 1665 |
+
return tensor, padding_length
|
processor_config.json
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"audio_processor": {
|
| 3 |
+
"audio_pool_step": 5,
|
| 4 |
+
"auto_map": {
|
| 5 |
+
"AutoFeatureExtractor": "processing_minicpmo.MiniCPMAAudioProcessor",
|
| 6 |
+
"AutoImageProcessor": "processing_minicpmo.MiniCPMVImageProcessor",
|
| 7 |
+
"AutoProcessor": "processing_minicpmo.MiniCPMOProcessor"
|
| 8 |
+
},
|
| 9 |
+
"chunk_length": 30,
|
| 10 |
+
"dither": 0.0,
|
| 11 |
+
"dynamic_log_norm": true,
|
| 12 |
+
"dynamic_range_db": 8.0,
|
| 13 |
+
"feature_extractor_type": "MiniCPMAAudioProcessor",
|
| 14 |
+
"feature_size": 80,
|
| 15 |
+
"hop_length": 160,
|
| 16 |
+
"im_end": "</image>",
|
| 17 |
+
"im_id_end": "</image_id>",
|
| 18 |
+
"im_id_start": "<image_id>",
|
| 19 |
+
"im_start": "<image>",
|
| 20 |
+
"image_feature_size": 64,
|
| 21 |
+
"image_processor_type": "MiniCPMVImageProcessor",
|
| 22 |
+
"log_floor_db": -10.0,
|
| 23 |
+
"max_slice_nums": 9,
|
| 24 |
+
"n_fft": 400,
|
| 25 |
+
"n_samples": 480000,
|
| 26 |
+
"nb_max_frames": 3000,
|
| 27 |
+
"norm_mean": [
|
| 28 |
+
0.5,
|
| 29 |
+
0.5,
|
| 30 |
+
0.5
|
| 31 |
+
],
|
| 32 |
+
"norm_std": [
|
| 33 |
+
0.5,
|
| 34 |
+
0.5,
|
| 35 |
+
0.5
|
| 36 |
+
],
|
| 37 |
+
"padding_side": "right",
|
| 38 |
+
"padding_value": 0.0,
|
| 39 |
+
"patch_size": 14,
|
| 40 |
+
"return_attention_mask": false,
|
| 41 |
+
"sampling_rate": 16000,
|
| 42 |
+
"scale_resolution": 448,
|
| 43 |
+
"slice_end": "</slice>",
|
| 44 |
+
"slice_mode": true,
|
| 45 |
+
"slice_start": "<slice>",
|
| 46 |
+
"unk": "<unk>",
|
| 47 |
+
"use_image_id": true,
|
| 48 |
+
"version": 4.5
|
| 49 |
+
},
|
| 50 |
+
"auto_map": {
|
| 51 |
+
"AutoProcessor": "processing_minicpmo.MiniCPMOProcessor"
|
| 52 |
+
},
|
| 53 |
+
"image_processor": {
|
| 54 |
+
"audio_pool_step": 5,
|
| 55 |
+
"auto_map": {
|
| 56 |
+
"AutoFeatureExtractor": "processing_minicpmo.MiniCPMAAudioProcessor",
|
| 57 |
+
"AutoImageProcessor": "processing_minicpmo.MiniCPMVImageProcessor",
|
| 58 |
+
"AutoProcessor": "processing_minicpmo.MiniCPMOProcessor"
|
| 59 |
+
},
|
| 60 |
+
"im_end": "</image>",
|
| 61 |
+
"im_end_token": "</image>",
|
| 62 |
+
"im_id_end": "</image_id>",
|
| 63 |
+
"im_id_start": "<image_id>",
|
| 64 |
+
"im_start": "<image>",
|
| 65 |
+
"im_start_token": "<image>",
|
| 66 |
+
"image_feature_size": 64,
|
| 67 |
+
"image_processor_type": "MiniCPMVImageProcessor",
|
| 68 |
+
"max_slice_nums": 9,
|
| 69 |
+
"mean": [
|
| 70 |
+
0.5,
|
| 71 |
+
0.5,
|
| 72 |
+
0.5
|
| 73 |
+
],
|
| 74 |
+
"norm_mean": [
|
| 75 |
+
0.5,
|
| 76 |
+
0.5,
|
| 77 |
+
0.5
|
| 78 |
+
],
|
| 79 |
+
"norm_std": [
|
| 80 |
+
0.5,
|
| 81 |
+
0.5,
|
| 82 |
+
0.5
|
| 83 |
+
],
|
| 84 |
+
"patch_size": 14,
|
| 85 |
+
"scale_resolution": 448,
|
| 86 |
+
"slice_end": "</slice>",
|
| 87 |
+
"slice_end_token": "</slice>",
|
| 88 |
+
"slice_mode": true,
|
| 89 |
+
"slice_start": "<slice>",
|
| 90 |
+
"slice_start_token": "<slice>",
|
| 91 |
+
"std": [
|
| 92 |
+
0.5,
|
| 93 |
+
0.5,
|
| 94 |
+
0.5
|
| 95 |
+
],
|
| 96 |
+
"unk": "<unk>",
|
| 97 |
+
"unk_token": "<unk>",
|
| 98 |
+
"use_image_id": true,
|
| 99 |
+
"version": 4.5
|
| 100 |
+
},
|
| 101 |
+
"processor_class": "MiniCPMOProcessor"
|
| 102 |
+
}
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"additional_special_tokens": [
|
| 3 |
+
"<unk>",
|
| 4 |
+
"<image>",
|
| 5 |
+
"</image>",
|
| 6 |
+
"<ref>",
|
| 7 |
+
"</ref>",
|
| 8 |
+
"<box>",
|
| 9 |
+
"</box>",
|
| 10 |
+
"<quad>",
|
| 11 |
+
"</quad>",
|
| 12 |
+
"<point>",
|
| 13 |
+
"</point>",
|
| 14 |
+
"<slice>",
|
| 15 |
+
"</slice>",
|
| 16 |
+
"<image_id>",
|
| 17 |
+
"</image_id>",
|
| 18 |
+
"<unit>",
|
| 19 |
+
"</unit>",
|
| 20 |
+
"<answer>",
|
| 21 |
+
"</answer>",
|
| 22 |
+
"<focus>",
|
| 23 |
+
"</focus>",
|
| 24 |
+
"<line>",
|
| 25 |
+
"</line>",
|
| 26 |
+
"<perception>",
|
| 27 |
+
"</perception>",
|
| 28 |
+
"<source_image>",
|
| 29 |
+
"</source_image>",
|
| 30 |
+
"<image_save_to>",
|
| 31 |
+
"</image_save_to>",
|
| 32 |
+
"<|audio_start|>",
|
| 33 |
+
"<|audio|>",
|
| 34 |
+
"<|audio_end|>",
|
| 35 |
+
"<|spk_bos|>",
|
| 36 |
+
"<|spk|>",
|
| 37 |
+
"<|spk_eos|>",
|
| 38 |
+
"<|tts_bos|>",
|
| 39 |
+
"<|tts_eos|>",
|
| 40 |
+
"<|listen|>",
|
| 41 |
+
"<|speak|>",
|
| 42 |
+
"<|interrupt|>",
|
| 43 |
+
"<|vad_start|>",
|
| 44 |
+
"<|vad_end|>",
|
| 45 |
+
"<|emotion_start|>",
|
| 46 |
+
"<|emotion_end|>",
|
| 47 |
+
"<|speed_start|>",
|
| 48 |
+
"<|speed_end|>",
|
| 49 |
+
"<|pitch_start|>",
|
| 50 |
+
"<|pitch_end|>",
|
| 51 |
+
"<|turn_bos|>",
|
| 52 |
+
"<|turn_eos|>",
|
| 53 |
+
"<|chunk_eos|>",
|
| 54 |
+
"<|chunk_bos|>",
|
| 55 |
+
"<|chunk_tts_bos|>",
|
| 56 |
+
"<|chunk_tts_eos|>",
|
| 57 |
+
"<|tts_pad|>",
|
| 58 |
+
"<|timbre_7|>",
|
| 59 |
+
"<|timbre_8|>",
|
| 60 |
+
"<|timbre_9|>",
|
| 61 |
+
"<|timbre_10|>",
|
| 62 |
+
"<|timbre_11|>",
|
| 63 |
+
"<|timbre_12|>",
|
| 64 |
+
"<|timbre_13|>",
|
| 65 |
+
"<|timbre_14|>",
|
| 66 |
+
"<|timbre_15|>",
|
| 67 |
+
"<|timbre_16|>",
|
| 68 |
+
"<|timbre_17|>",
|
| 69 |
+
"<|timbre_18|>",
|
| 70 |
+
"<|timbre_19|>",
|
| 71 |
+
"<|timbre_20|>",
|
| 72 |
+
"<|timbre_21|>",
|
| 73 |
+
"<|timbre_22|>",
|
| 74 |
+
"<|timbre_23|>",
|
| 75 |
+
"<|timbre_24|>",
|
| 76 |
+
"<|timbre_25|>",
|
| 77 |
+
"<|timbre_26|>",
|
| 78 |
+
"<|timbre_27|>",
|
| 79 |
+
"<|timbre_28|>",
|
| 80 |
+
"<|timbre_29|>",
|
| 81 |
+
"<|timbre_30|>",
|
| 82 |
+
"<|timbre_31|>"
|
| 83 |
+
],
|
| 84 |
+
"bos_token": "<|im_start|>",
|
| 85 |
+
"eos_token": "<|im_end|>",
|
| 86 |
+
"pad_token": "<|endoftext|>",
|
| 87 |
+
"unk_token": "<unk>"
|
| 88 |
+
}
|
tokenization_minicpmo_fast.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# Copyright 2026 The OpenBMB Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
from typing import List
|
| 18 |
+
|
| 19 |
+
from transformers import Qwen2TokenizerFast
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class MiniCPMOTokenizerFast(Qwen2TokenizerFast):
|
| 23 |
+
def __init__(self, **kwargs):
|
| 24 |
+
self._bad_token_ids = kwargs.pop("bad_token_ids", [])
|
| 25 |
+
|
| 26 |
+
super().__init__(**kwargs)
|
| 27 |
+
|
| 28 |
+
# image
|
| 29 |
+
self.im_start = "<image>"
|
| 30 |
+
self.im_end = "</image>"
|
| 31 |
+
self.ref_start = "<ref>"
|
| 32 |
+
self.ref_end = "</ref>"
|
| 33 |
+
self.box_start = "<box>"
|
| 34 |
+
self.box_end = "</box>"
|
| 35 |
+
self.quad_start = "<quad>"
|
| 36 |
+
self.quad_end = "</quad>"
|
| 37 |
+
self.slice_start = "<slice>"
|
| 38 |
+
self.slice_end = "</slice>"
|
| 39 |
+
self.im_id_start = "<image_id>"
|
| 40 |
+
self.im_id_end = "</image_id>"
|
| 41 |
+
|
| 42 |
+
# audio
|
| 43 |
+
self.audio_start = "<|audio_start|>"
|
| 44 |
+
self.audio_end = "<|audio_end|>"
|
| 45 |
+
self.spk_start = "<|spk_bos|>"
|
| 46 |
+
self.spk_end = "<|spk_eos|>"
|
| 47 |
+
self.tts_start = "<|tts_bos|>"
|
| 48 |
+
self.tts_end = "<|tts_eos|>"
|
| 49 |
+
|
| 50 |
+
@property
|
| 51 |
+
def eos_id(self):
|
| 52 |
+
return self.eos_token_id
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
def bos_id(self):
|
| 56 |
+
return self.bos_token_id
|
| 57 |
+
|
| 58 |
+
@property
|
| 59 |
+
def unk_id(self):
|
| 60 |
+
return self.unk_token_id
|
| 61 |
+
|
| 62 |
+
@property
|
| 63 |
+
def im_start_id(self):
|
| 64 |
+
return self.convert_tokens_to_ids(self.im_start)
|
| 65 |
+
|
| 66 |
+
@property
|
| 67 |
+
def im_end_id(self):
|
| 68 |
+
return self.convert_tokens_to_ids(self.im_end)
|
| 69 |
+
|
| 70 |
+
@property
|
| 71 |
+
def slice_start_id(self):
|
| 72 |
+
return self.convert_tokens_to_ids(self.slice_start)
|
| 73 |
+
|
| 74 |
+
@property
|
| 75 |
+
def slice_end_id(self):
|
| 76 |
+
return self.convert_tokens_to_ids(self.slice_end)
|
| 77 |
+
|
| 78 |
+
@property
|
| 79 |
+
def im_id_start_id(self):
|
| 80 |
+
return self.convert_tokens_to_ids(self.im_id_start)
|
| 81 |
+
|
| 82 |
+
@property
|
| 83 |
+
def im_id_end_id(self):
|
| 84 |
+
return self.convert_tokens_to_ids(self.im_id_end)
|
| 85 |
+
|
| 86 |
+
@property
|
| 87 |
+
def audio_start_id(self):
|
| 88 |
+
return self.convert_tokens_to_ids(self.audio_start)
|
| 89 |
+
|
| 90 |
+
@property
|
| 91 |
+
def audio_end_id(self):
|
| 92 |
+
return self.convert_tokens_to_ids(self.audio_end)
|
| 93 |
+
|
| 94 |
+
@property
|
| 95 |
+
def spk_start_id(self):
|
| 96 |
+
return self.convert_tokens_to_ids(self.spk_start)
|
| 97 |
+
|
| 98 |
+
@property
|
| 99 |
+
def spk_end_id(self):
|
| 100 |
+
return self.convert_tokens_to_ids(self.spk_end)
|
| 101 |
+
|
| 102 |
+
@property
|
| 103 |
+
def tts_start_id(self):
|
| 104 |
+
return self.convert_tokens_to_ids(self.tts_start)
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def tts_end_id(self):
|
| 108 |
+
return self.convert_tokens_to_ids(self.tts_end)
|
| 109 |
+
|
| 110 |
+
@staticmethod
|
| 111 |
+
def escape(text: str) -> str:
|
| 112 |
+
return text
|
| 113 |
+
|
| 114 |
+
@staticmethod
|
| 115 |
+
def unescape(text: str) -> str:
|
| 116 |
+
return text
|
| 117 |
+
|
| 118 |
+
@property
|
| 119 |
+
def bad_token_ids(self) -> List[int]:
|
| 120 |
+
return self._bad_token_ids
|
tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:66664f87759d9e829e7ef0ded96976727374dcd7ca6f3ae9bfe89bbda541e5af
|
| 3 |
+
size 11437708
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_prefix_space": false,
|
| 3 |
+
"auto_map": {
|
| 4 |
+
"AutoProcessor": "processing_minicpmo.MiniCPMOProcessor",
|
| 5 |
+
"AutoTokenizer": [
|
| 6 |
+
null,
|
| 7 |
+
"tokenization_minicpmo_fast.MiniCPMOTokenizerFast"
|
| 8 |
+
]
|
| 9 |
+
},
|
| 10 |
+
"backend": "tokenizers",
|
| 11 |
+
"bos_token": "<|im_start|>",
|
| 12 |
+
"clean_up_tokenization_spaces": false,
|
| 13 |
+
"eos_token": "<|im_end|>",
|
| 14 |
+
"errors": "replace",
|
| 15 |
+
"is_local": true,
|
| 16 |
+
"model_max_length": 131072,
|
| 17 |
+
"pad_token": "<|endoftext|>",
|
| 18 |
+
"processor_class": "MiniCPMOProcessor",
|
| 19 |
+
"split_special_tokens": false,
|
| 20 |
+
"tokenizer_class": "MiniCPMOTokenizer",
|
| 21 |
+
"unk_token": "<unk>"
|
| 22 |
+
}
|
utils.py
ADDED
|
@@ -0,0 +1,2417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# Copyright 2026 The OpenBMB Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import logging
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Any
|
| 20 |
+
from typing import Dict
|
| 21 |
+
from typing import List
|
| 22 |
+
from typing import Literal
|
| 23 |
+
from typing import Optional
|
| 24 |
+
from typing import Tuple
|
| 25 |
+
from typing import Union
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
import torch.nn.functional as F
|
| 29 |
+
import torch.nn.utils.parametrize as P
|
| 30 |
+
from transformers.cache_utils import DynamicCache
|
| 31 |
+
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# text
|
| 36 |
+
@dataclass
|
| 37 |
+
class GenerateChunkOutput:
|
| 38 |
+
chunk_token_ids: torch.Tensor
|
| 39 |
+
current_inputs_embeds: torch.Tensor
|
| 40 |
+
input_last_hidden_states: Optional[torch.Tensor] # for tts use_speaker_embedding
|
| 41 |
+
last_hidden_states: Optional[torch.Tensor] # for tts input feature (projector_semantic)
|
| 42 |
+
past_key_values: Optional[torch.Tensor]
|
| 43 |
+
finished: bool
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class ChunkPrefillChunkGenerate:
|
| 47 |
+
def __init__(self, model, tokenizer, terminators):
|
| 48 |
+
self.tokenizer = tokenizer
|
| 49 |
+
self.model = model
|
| 50 |
+
self.terminators = terminators
|
| 51 |
+
self.terminators_ids = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
|
| 52 |
+
self.embedding_layer = self.model.get_input_embeddings()
|
| 53 |
+
|
| 54 |
+
self.forbidden_tokens = [
|
| 55 |
+
":",
|
| 56 |
+
":",
|
| 57 |
+
";",
|
| 58 |
+
"#",
|
| 59 |
+
"“",
|
| 60 |
+
"”",
|
| 61 |
+
"‘",
|
| 62 |
+
"’",
|
| 63 |
+
"@",
|
| 64 |
+
"*",
|
| 65 |
+
"【",
|
| 66 |
+
"】",
|
| 67 |
+
"「",
|
| 68 |
+
"」",
|
| 69 |
+
"(",
|
| 70 |
+
")",
|
| 71 |
+
"(",
|
| 72 |
+
")",
|
| 73 |
+
"[",
|
| 74 |
+
"]",
|
| 75 |
+
"&",
|
| 76 |
+
"/",
|
| 77 |
+
"$",
|
| 78 |
+
]
|
| 79 |
+
|
| 80 |
+
self.forbidden_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in self.forbidden_tokens]
|
| 81 |
+
bad_token_ids = getattr(tokenizer, "bad_token_ids", [])
|
| 82 |
+
if bad_token_ids:
|
| 83 |
+
self.forbidden_token_ids.extend(bad_token_ids)
|
| 84 |
+
|
| 85 |
+
@staticmethod
|
| 86 |
+
def prepare_generation_config(do_sample, max_new_tokens=50, min_new_tokens=0, **kwargs):
|
| 87 |
+
num_beams = kwargs.get("num_beams", 3)
|
| 88 |
+
generation_config = {
|
| 89 |
+
"num_beams": num_beams,
|
| 90 |
+
"top_p": 0.8,
|
| 91 |
+
"top_k": 100,
|
| 92 |
+
"temperature": 0.7,
|
| 93 |
+
"do_sample": True,
|
| 94 |
+
"repetition_penalty": 1.05,
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
if do_sample:
|
| 98 |
+
generation_config.update(
|
| 99 |
+
{
|
| 100 |
+
"top_p": 0.8,
|
| 101 |
+
"top_k": 100,
|
| 102 |
+
"temperature": 0.7,
|
| 103 |
+
"do_sample": True,
|
| 104 |
+
"repetition_penalty": 1.05,
|
| 105 |
+
}
|
| 106 |
+
)
|
| 107 |
+
elif num_beams > 1:
|
| 108 |
+
generation_config.update({"num_beams": num_beams, "repetition_penalty": 1.2, "do_sample": False})
|
| 109 |
+
else:
|
| 110 |
+
generation_config.update({"do_sample": False, "repetition_penalty": 1.05})
|
| 111 |
+
|
| 112 |
+
generation_config.update((k, kwargs[k]) for k in generation_config.keys() & kwargs.keys())
|
| 113 |
+
generation_config["min_new_tokens"] = min_new_tokens
|
| 114 |
+
generation_config["max_new_tokens"] = max_new_tokens
|
| 115 |
+
|
| 116 |
+
return generation_config
|
| 117 |
+
|
| 118 |
+
def chunk_generate(
|
| 119 |
+
self,
|
| 120 |
+
inputs_embeds: torch.Tensor,
|
| 121 |
+
past_key_values,
|
| 122 |
+
is_first_generate_chunk: bool,
|
| 123 |
+
chunk_size: int,
|
| 124 |
+
return_hidden_states: bool,
|
| 125 |
+
do_sample: bool,
|
| 126 |
+
temperature: float,
|
| 127 |
+
top_p: float,
|
| 128 |
+
top_k: int,
|
| 129 |
+
repetition_penalty: float = 1.05,
|
| 130 |
+
length_penalty: float = 1.0,
|
| 131 |
+
all_input_ids: Optional[torch.Tensor] = None,
|
| 132 |
+
) -> GenerateChunkOutput:
|
| 133 |
+
"""
|
| 134 |
+
Args:
|
| 135 |
+
inputs_embeds: [1, seq_len, hidden_dim], Input embeddings of current chunk.
|
| 136 |
+
past_key_values: [num_layers, 2, batch_size, num_heads, seq_len, head_dim], Past key values for llm.
|
| 137 |
+
is_first_generate_chunk: bool, Whether this is the first generate chunk.
|
| 138 |
+
chunk_size: int, The size of the current chunk, default is 10, and it is fixed during training.
|
| 139 |
+
return_hidden_states: bool Whether to return the hidden states, default is True.
|
| 140 |
+
do_sample: bool Whether to sample from the model, default is True.
|
| 141 |
+
temperature: float The temperature for the model, default is 0.7.
|
| 142 |
+
top_p: float The top-p for the model, default is 0.8.
|
| 143 |
+
top_k: int The top-k for the model, default is 100.
|
| 144 |
+
repetition_penalty: float, The repetition penalty for the model, default is 1.05.
|
| 145 |
+
length_penalty: float, The length penalty for the model, default is 1.0. Higher value means more detailed generation.
|
| 146 |
+
all_input_ids: Optional[torch.Tensor], The input ids for the current chunk.
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
finished = False
|
| 150 |
+
current_inputs_embeds = inputs_embeds.clone()
|
| 151 |
+
input_last_hidden_states = []
|
| 152 |
+
last_hidden_states = []
|
| 153 |
+
generated_tokens = []
|
| 154 |
+
|
| 155 |
+
for token_idx in range(chunk_size):
|
| 156 |
+
if is_first_generate_chunk and token_idx == 0:
|
| 157 |
+
# first generate chunk, prefill inputs_embeds
|
| 158 |
+
model_inputs = {
|
| 159 |
+
"inputs_embeds": current_inputs_embeds,
|
| 160 |
+
"past_key_values": past_key_values,
|
| 161 |
+
"use_cache": True,
|
| 162 |
+
"output_hidden_states": return_hidden_states,
|
| 163 |
+
}
|
| 164 |
+
else: # for all other cases: prefill the latest generated token
|
| 165 |
+
model_inputs = {
|
| 166 |
+
"inputs_embeds": current_inputs_embeds[:, -1:, :],
|
| 167 |
+
"past_key_values": past_key_values,
|
| 168 |
+
"use_cache": True,
|
| 169 |
+
"output_hidden_states": return_hidden_states,
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
with torch.no_grad():
|
| 173 |
+
outputs = self.model(**model_inputs)
|
| 174 |
+
|
| 175 |
+
# last token's logits
|
| 176 |
+
logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=inputs_embeds.device)
|
| 177 |
+
|
| 178 |
+
# forbid specific tokens decoding = model.generate@suppress_tokens
|
| 179 |
+
if self.forbidden_token_ids:
|
| 180 |
+
logits[:, self.forbidden_token_ids] = float("-inf")
|
| 181 |
+
|
| 182 |
+
past_key_values = outputs.past_key_values
|
| 183 |
+
|
| 184 |
+
PENALTY_WINDOW_SIZE = 128
|
| 185 |
+
|
| 186 |
+
# apply repetition penalty
|
| 187 |
+
if repetition_penalty != 1.0:
|
| 188 |
+
# get token ids for repetition penalty
|
| 189 |
+
if all_input_ids is not None:
|
| 190 |
+
# use global input ids (including original input and generated part)
|
| 191 |
+
if len(generated_tokens) > 0:
|
| 192 |
+
generated_token_ids = torch.cat(generated_tokens, dim=1)
|
| 193 |
+
current_sequence = torch.cat(
|
| 194 |
+
[
|
| 195 |
+
all_input_ids[:, -PENALTY_WINDOW_SIZE:],
|
| 196 |
+
generated_token_ids,
|
| 197 |
+
],
|
| 198 |
+
dim=1,
|
| 199 |
+
)
|
| 200 |
+
else:
|
| 201 |
+
current_sequence = all_input_ids[:, -PENALTY_WINDOW_SIZE:]
|
| 202 |
+
unique_token_ids = torch.unique(current_sequence.squeeze(0))
|
| 203 |
+
elif len(generated_tokens) > 0:
|
| 204 |
+
# revert to original logic: only use generated tokens
|
| 205 |
+
generated_token_ids = torch.cat(generated_tokens, dim=1).squeeze(0)
|
| 206 |
+
unique_token_ids = torch.unique(generated_token_ids)
|
| 207 |
+
else:
|
| 208 |
+
unique_token_ids = torch.tensor([], dtype=torch.long, device=logits.device)
|
| 209 |
+
|
| 210 |
+
# apply repetition penalty
|
| 211 |
+
for token_id in unique_token_ids:
|
| 212 |
+
if logits[0, token_id] > 0:
|
| 213 |
+
logits[0, token_id] = logits[0, token_id] / repetition_penalty
|
| 214 |
+
else:
|
| 215 |
+
logits[0, token_id] = logits[0, token_id] * repetition_penalty
|
| 216 |
+
|
| 217 |
+
# apply length penalty, higher value means more detailed generation
|
| 218 |
+
if length_penalty != 1.0:
|
| 219 |
+
for eos_token_id in self.terminators_ids:
|
| 220 |
+
if logits[0, eos_token_id] > 0:
|
| 221 |
+
logits[0, eos_token_id] = logits[0, eos_token_id] / length_penalty
|
| 222 |
+
else:
|
| 223 |
+
logits[0, eos_token_id] = logits[0, eos_token_id] * length_penalty
|
| 224 |
+
|
| 225 |
+
# apply temperature
|
| 226 |
+
if temperature != 1.0:
|
| 227 |
+
logits = logits / temperature
|
| 228 |
+
|
| 229 |
+
if do_sample:
|
| 230 |
+
# Top-k filtering
|
| 231 |
+
if top_k > 0:
|
| 232 |
+
top_k_logits, top_k_indices = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 233 |
+
logits_filtered = torch.full_like(logits, float("-inf"))
|
| 234 |
+
logits_filtered.scatter_(1, top_k_indices, top_k_logits)
|
| 235 |
+
logits = logits_filtered
|
| 236 |
+
|
| 237 |
+
# Top-p filtering
|
| 238 |
+
if top_p < 1.0:
|
| 239 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 240 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 241 |
+
|
| 242 |
+
# remove tokens with cumulative probability greater than top_p
|
| 243 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 244 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 245 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 246 |
+
|
| 247 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
| 248 |
+
logits[indices_to_remove] = float("-inf")
|
| 249 |
+
|
| 250 |
+
# sampling
|
| 251 |
+
probs = F.softmax(logits, dim=-1)
|
| 252 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 253 |
+
else:
|
| 254 |
+
next_token = torch.argmax(logits, dim=-1, keepdim=True)
|
| 255 |
+
|
| 256 |
+
if return_hidden_states:
|
| 257 |
+
if is_first_generate_chunk and token_idx == 0:
|
| 258 |
+
input_last_hidden_states.append(outputs.hidden_states[-1])
|
| 259 |
+
else:
|
| 260 |
+
last_hidden_states.append(outputs.hidden_states[-1])
|
| 261 |
+
|
| 262 |
+
# if terminator token, stop generating
|
| 263 |
+
if next_token.item() in self.terminators_ids:
|
| 264 |
+
finished = True
|
| 265 |
+
break
|
| 266 |
+
|
| 267 |
+
generated_tokens.append(next_token)
|
| 268 |
+
|
| 269 |
+
# convert new token to embeddings and concatenate
|
| 270 |
+
next_token_embed = self.embedding_layer(next_token)
|
| 271 |
+
|
| 272 |
+
# update inputs_embeds, add one
|
| 273 |
+
current_inputs_embeds = torch.cat([current_inputs_embeds, next_token_embed], dim=1)
|
| 274 |
+
|
| 275 |
+
if len(generated_tokens) > 0:
|
| 276 |
+
chunk_token_ids = torch.cat(generated_tokens, dim=1)
|
| 277 |
+
else:
|
| 278 |
+
# special case: if last chunk and first predict is eos token, return last token of previous chunk. return a tensor with shape (1, 0)
|
| 279 |
+
if finished:
|
| 280 |
+
chunk_token_ids = torch.zeros((1, 0), dtype=torch.long, device=current_inputs_embeds.device)
|
| 281 |
+
else:
|
| 282 |
+
raise Exception("this should not happen")
|
| 283 |
+
|
| 284 |
+
if len(last_hidden_states) > 0:
|
| 285 |
+
last_hidden_states = torch.cat(last_hidden_states, dim=1)
|
| 286 |
+
else:
|
| 287 |
+
# special case: if last chunk, return last token of previous chunk.
|
| 288 |
+
if finished:
|
| 289 |
+
last_hidden_states = torch.cat(last_hidden_states, dim=1)
|
| 290 |
+
else:
|
| 291 |
+
raise Exception("this should not happen")
|
| 292 |
+
|
| 293 |
+
if len(input_last_hidden_states) > 0:
|
| 294 |
+
input_last_hidden_states = torch.cat(input_last_hidden_states, dim=1)
|
| 295 |
+
else:
|
| 296 |
+
input_last_hidden_states = None
|
| 297 |
+
|
| 298 |
+
return GenerateChunkOutput(
|
| 299 |
+
chunk_token_ids=chunk_token_ids,
|
| 300 |
+
current_inputs_embeds=current_inputs_embeds,
|
| 301 |
+
input_last_hidden_states=input_last_hidden_states,
|
| 302 |
+
last_hidden_states=last_hidden_states,
|
| 303 |
+
past_key_values=past_key_values,
|
| 304 |
+
finished=finished,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def streaming_token_decoder(token_iterator, tokenizer, skip_special_tokens=False):
|
| 309 |
+
"""
|
| 310 |
+
Incrementally decode tokens from an iterator, handling partial multi-byte characters.
|
| 311 |
+
|
| 312 |
+
When streaming tokens, multi-byte characters (like Chinese) may be split across multiple
|
| 313 |
+
tokens. Decoding partial tokens results in replacement characters (U+FFFD). This function
|
| 314 |
+
buffers tokens and only yields complete characters.
|
| 315 |
+
|
| 316 |
+
Args:
|
| 317 |
+
token_iterator: An iterator yielding (token_ids, is_finished) tuples.
|
| 318 |
+
token_ids can be torch.Tensor or any iterable of integers.
|
| 319 |
+
tokenizer: The tokenizer to use for decoding.
|
| 320 |
+
skip_special_tokens: Whether to skip special tokens during decoding.
|
| 321 |
+
|
| 322 |
+
Yields:
|
| 323 |
+
(decoded_text, is_finished) tuples where decoded_text is the new text since last yield.
|
| 324 |
+
"""
|
| 325 |
+
accumulated_token_ids = []
|
| 326 |
+
yielded_text_len = 0
|
| 327 |
+
|
| 328 |
+
for token_ids, is_finished in token_iterator:
|
| 329 |
+
# Accumulate token IDs
|
| 330 |
+
if torch.is_tensor(token_ids):
|
| 331 |
+
accumulated_token_ids.extend(token_ids.reshape(-1).tolist())
|
| 332 |
+
else:
|
| 333 |
+
accumulated_token_ids.extend(list(token_ids) if hasattr(token_ids, "__iter__") else [token_ids])
|
| 334 |
+
|
| 335 |
+
# Decode all accumulated tokens
|
| 336 |
+
full_decoded = tokenizer.decode(accumulated_token_ids, skip_special_tokens=skip_special_tokens)
|
| 337 |
+
|
| 338 |
+
if is_finished:
|
| 339 |
+
# Final chunk - yield all remaining text
|
| 340 |
+
new_text = full_decoded[yielded_text_len:]
|
| 341 |
+
yield new_text, is_finished
|
| 342 |
+
else:
|
| 343 |
+
# Find safe prefix without incomplete multi-byte characters
|
| 344 |
+
# The replacement character '�' (U+FFFD) indicates incomplete decoding
|
| 345 |
+
new_text = full_decoded[yielded_text_len:]
|
| 346 |
+
|
| 347 |
+
# Hold back text ending with replacement character (incomplete UTF-8 sequence)
|
| 348 |
+
safe_end = len(new_text)
|
| 349 |
+
while safe_end > 0 and new_text[safe_end - 1] == "\ufffd":
|
| 350 |
+
safe_end -= 1
|
| 351 |
+
|
| 352 |
+
safe_text = new_text[:safe_end] if safe_end > 0 else ""
|
| 353 |
+
yielded_text_len += len(safe_text)
|
| 354 |
+
yield safe_text, is_finished
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def torch_clone_recursive(obj):
|
| 358 |
+
"""Recursively clone nested containers of torch.Tensors.
|
| 359 |
+
|
| 360 |
+
Supported container types: dict, list, tuple. Non-container non-Tensor
|
| 361 |
+
objects are returned as-is.
|
| 362 |
+
"""
|
| 363 |
+
if torch.is_tensor(obj):
|
| 364 |
+
return obj.clone()
|
| 365 |
+
elif isinstance(obj, dict):
|
| 366 |
+
return {k: torch_clone_recursive(v) for k, v in obj.items()}
|
| 367 |
+
elif isinstance(obj, list):
|
| 368 |
+
return [torch_clone_recursive(v) for v in obj]
|
| 369 |
+
elif isinstance(obj, tuple):
|
| 370 |
+
return tuple(torch_clone_recursive(v) for v in obj)
|
| 371 |
+
else:
|
| 372 |
+
raise ValueError(f"Unsupported type: {type(obj)}")
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| 376 |
+
"""Rotate half the hidden dims of the input for RoPE."""
|
| 377 |
+
dim = x.shape[-1]
|
| 378 |
+
x1 = x[..., : dim // 2]
|
| 379 |
+
x2 = x[..., dim // 2 :]
|
| 380 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
@dataclass
|
| 384 |
+
class SpeculativeSnapshot:
|
| 385 |
+
"""Speculative snapshot for VAD speculative rollback.
|
| 386 |
+
|
| 387 |
+
Used in VAD speculative execution: creates a snapshot after streaming_prefill
|
| 388 |
+
and before streaming_generate. If speculation fails (user continues speaking),
|
| 389 |
+
the state can be restored to continue streaming_prefill.
|
| 390 |
+
|
| 391 |
+
Implementation:
|
| 392 |
+
- LLM KV Cache: only record length, restore by truncation (zero extra VRAM)
|
| 393 |
+
- Audio KV Cache: requires cloning, as generate sets it to None
|
| 394 |
+
- Mel processor: save full state snapshot (including buffer)
|
| 395 |
+
"""
|
| 396 |
+
|
| 397 |
+
# KV Cache length (for truncation recovery)
|
| 398 |
+
llm_cache_length: int
|
| 399 |
+
audio_cache_length: int
|
| 400 |
+
|
| 401 |
+
# session state
|
| 402 |
+
new_user_msg: bool
|
| 403 |
+
llm_generated: bool
|
| 404 |
+
llm_generate_completed: bool
|
| 405 |
+
|
| 406 |
+
# Round management
|
| 407 |
+
next_round_id: int
|
| 408 |
+
pending_round_id: Optional[int]
|
| 409 |
+
omni_chunk_history_length: int
|
| 410 |
+
|
| 411 |
+
# TTS state (requires cloning, but usually small)
|
| 412 |
+
tts_last_turn_tokens: Optional[torch.Tensor]
|
| 413 |
+
|
| 414 |
+
# Streaming processor state
|
| 415 |
+
audio_chunk_idx: int
|
| 416 |
+
|
| 417 |
+
# Mel processor state snapshot (including buffer)
|
| 418 |
+
mel_processor_snapshot: Optional[dict] = None
|
| 419 |
+
|
| 420 |
+
# Audio encoder KV cache (requires cloning to ensure determinism after recovery)
|
| 421 |
+
audio_past_key_values: Optional[tuple] = None
|
| 422 |
+
|
| 423 |
+
# timestamp (for debugging)
|
| 424 |
+
timestamp: float = 0.0
|
| 425 |
+
|
| 426 |
+
# debug field: for verifying correctness of recovery
|
| 427 |
+
llm_cache_checksum: Optional[float] = None # LLM KV Cache first layer K sum
|
| 428 |
+
audio_cache_checksum: Optional[float] = None # Audio KV Cache first layer K sum
|
| 429 |
+
mel_buffer_checksum: Optional[float] = None # Mel buffer sum
|
| 430 |
+
|
| 431 |
+
# RNG state (key: for ensuring determinism of dithering etc. after recovery)
|
| 432 |
+
rng_state_cpu: Optional[torch.Tensor] = None # torch CPU RNG state
|
| 433 |
+
rng_state_cuda: Optional[torch.Tensor] = None # torch CUDA RNG state (if on GPU)
|
| 434 |
+
|
| 435 |
+
def summary(self) -> str:
|
| 436 |
+
mel_buf_len = 0
|
| 437 |
+
if self.mel_processor_snapshot:
|
| 438 |
+
buf = self.mel_processor_snapshot.get("buffer")
|
| 439 |
+
if buf is not None:
|
| 440 |
+
mel_buf_len = len(buf)
|
| 441 |
+
return (
|
| 442 |
+
f"llm_cache={self.llm_cache_length}, "
|
| 443 |
+
f"audio_cache={self.audio_cache_length}, "
|
| 444 |
+
f"audio_chunk_idx={self.audio_chunk_idx}, "
|
| 445 |
+
f"mel_buffer={mel_buf_len}, "
|
| 446 |
+
f"history_len={self.omni_chunk_history_length}, "
|
| 447 |
+
f"new_user_msg={self.new_user_msg}, "
|
| 448 |
+
f"llm_generated={self.llm_generated}"
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
# tts
|
| 453 |
+
@dataclass
|
| 454 |
+
class TTSSamplingParams:
|
| 455 |
+
top_p: float = 0.85
|
| 456 |
+
min_p: float = 0.01
|
| 457 |
+
top_k: int = 25
|
| 458 |
+
repetition_penalty: float = 1.05
|
| 459 |
+
temperature: float = 0.8
|
| 460 |
+
win_size: int = 16
|
| 461 |
+
tau_r: float = 0.1
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
class TTSStreamingGenerator:
|
| 465 |
+
"""
|
| 466 |
+
Streaming generator for TTS that processes chunks and yields audio tokens in real-time.
|
| 467 |
+
|
| 468 |
+
Supported attention types:
|
| 469 |
+
- full_attention: Full attention, all tokens can attend to each other
|
| 470 |
+
- sliding_window: Sliding window attention, KV cache is truncated to fixed size (token_window_size)
|
| 471 |
+
- sliding_recompute: Sliding recompute, only keep previous chunk and recompute with current chunk
|
| 472 |
+
- reindex: Keep first chunk as sink, reindex sliding window positions via RoPE rotation
|
| 473 |
+
"""
|
| 474 |
+
|
| 475 |
+
def __init__(
|
| 476 |
+
self,
|
| 477 |
+
model,
|
| 478 |
+
temperature: float,
|
| 479 |
+
eos_token: Union[int, torch.Tensor],
|
| 480 |
+
chunk_size: int = 25, # s3tokenizer 1s = 25token
|
| 481 |
+
tts_last_turn_tokens: torch.Tensor = None,
|
| 482 |
+
logits_processors=None,
|
| 483 |
+
logits_warpers=None,
|
| 484 |
+
):
|
| 485 |
+
self.tts = model
|
| 486 |
+
self.device = model.device
|
| 487 |
+
self.temperature = torch.tensor([temperature], dtype=torch.float, device=self.device)
|
| 488 |
+
self.eos_token = (
|
| 489 |
+
torch.tensor(eos_token, device=self.device) if isinstance(eos_token, int) else eos_token.to(self.device)
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
self.num_vq = model.num_vq
|
| 493 |
+
self.num_audio_tokens = model.num_audio_tokens
|
| 494 |
+
self.recomputed_chunks = model.recomputed_chunks
|
| 495 |
+
self.emb_code = model.emb_code
|
| 496 |
+
self.head_code = model.head_code
|
| 497 |
+
|
| 498 |
+
# Attention type and window sizes
|
| 499 |
+
self.attention_type = model.attention_type # "full_attention", "sliding_window", "sliding_recompute", "reindex"
|
| 500 |
+
self.chunk_window_size = model.chunk_window_size # chunk-level window for sliding_recompute (default 2)
|
| 501 |
+
self.token_window_size = model.token_window_size # token-level window for sliding_window/reindex (default 300)
|
| 502 |
+
|
| 503 |
+
# RoPE config (for reindex mode)
|
| 504 |
+
self.rope_theta = model.model.config.rope_theta
|
| 505 |
+
self.head_dim = model.model.config.hidden_size // model.model.config.num_attention_heads
|
| 506 |
+
|
| 507 |
+
# Logits processors
|
| 508 |
+
self.logits_processors = logits_processors if logits_processors is not None else []
|
| 509 |
+
# Logits warpers (like TopP/TopK), separate from processors
|
| 510 |
+
self.logits_warpers = logits_warpers if logits_warpers is not None else []
|
| 511 |
+
|
| 512 |
+
# initialize state
|
| 513 |
+
self.past_key_values = None
|
| 514 |
+
self.text_start_pos = 0
|
| 515 |
+
self.idx = -1 # start from -1, become 0 when first called
|
| 516 |
+
self.all_conditions = []
|
| 517 |
+
self.all_generated_tokens = []
|
| 518 |
+
self.tts_last_turn_tokens = tts_last_turn_tokens
|
| 519 |
+
self.spk_emb = None
|
| 520 |
+
|
| 521 |
+
audio_bos = [self.tts.audio_bos_token_id]
|
| 522 |
+
audio_bos = torch.Tensor(audio_bos).to(self.tts.emb_text.weight.device, dtype=torch.long)
|
| 523 |
+
|
| 524 |
+
self.audio_bos_embeds = self.tts.emb_text(audio_bos).unsqueeze(0)
|
| 525 |
+
self.text_eos_embed = self.tts.emb_text(
|
| 526 |
+
torch.tensor(
|
| 527 |
+
[self.tts.config.text_eos_token_id],
|
| 528 |
+
device=self.tts.emb_text.weight.device,
|
| 529 |
+
dtype=torch.long,
|
| 530 |
+
)
|
| 531 |
+
).unsqueeze(0)
|
| 532 |
+
|
| 533 |
+
# buffer related, used to fill up chunk_size and yield to outside
|
| 534 |
+
self.chunk_size = chunk_size
|
| 535 |
+
self._token_buffer: List[torch.Tensor] = []
|
| 536 |
+
|
| 537 |
+
# Chunk info tracking for sliding_recompute and reindex
|
| 538 |
+
self._chunk_info: List[dict] = []
|
| 539 |
+
self._total_seq_len = 0
|
| 540 |
+
|
| 541 |
+
# Reindex mode: track sink (first chunk) length
|
| 542 |
+
self._sink_kv_len = 0
|
| 543 |
+
|
| 544 |
+
def _build_recompute_inputs(self, current_condition: torch.Tensor) -> torch.Tensor:
|
| 545 |
+
"""Build recompute inputs for sliding_recompute mode."""
|
| 546 |
+
if len(self._chunk_info) == 0:
|
| 547 |
+
return current_condition
|
| 548 |
+
|
| 549 |
+
prev_chunk = self._chunk_info[-1]
|
| 550 |
+
prev_condition = prev_chunk["condition"]
|
| 551 |
+
prev_audio_tokens = prev_chunk["audio_tokens"]
|
| 552 |
+
|
| 553 |
+
recompute_list = [prev_condition]
|
| 554 |
+
if len(prev_audio_tokens) > 0:
|
| 555 |
+
prev_audio_embeds = torch.cat([self.emb_code[0](tok) for tok in prev_audio_tokens], dim=1)
|
| 556 |
+
recompute_list.append(prev_audio_embeds)
|
| 557 |
+
|
| 558 |
+
recompute_list.append(current_condition)
|
| 559 |
+
return torch.cat(recompute_list, dim=1)
|
| 560 |
+
|
| 561 |
+
def _truncate_kv_cache_sliding_window(self):
|
| 562 |
+
"""Truncate KV cache for sliding_window mode."""
|
| 563 |
+
if self.past_key_values is None:
|
| 564 |
+
return
|
| 565 |
+
|
| 566 |
+
if hasattr(self.past_key_values, "get_seq_length"):
|
| 567 |
+
current_kv_len = self.past_key_values.get_seq_length()
|
| 568 |
+
else:
|
| 569 |
+
current_kv_len = self.past_key_values[0][0].shape[2]
|
| 570 |
+
|
| 571 |
+
if current_kv_len <= self.token_window_size:
|
| 572 |
+
return
|
| 573 |
+
|
| 574 |
+
new_cache = DynamicCache()
|
| 575 |
+
num_layers = (
|
| 576 |
+
len(self.past_key_values.key_cache)
|
| 577 |
+
if hasattr(self.past_key_values, "key_cache")
|
| 578 |
+
else len(self.past_key_values)
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
for layer_idx in range(num_layers):
|
| 582 |
+
if hasattr(self.past_key_values, "key_cache"):
|
| 583 |
+
key = self.past_key_values.key_cache[layer_idx][:, :, -self.token_window_size :, :]
|
| 584 |
+
value = self.past_key_values.value_cache[layer_idx][:, :, -self.token_window_size :, :]
|
| 585 |
+
else:
|
| 586 |
+
key = self.past_key_values[layer_idx][0][:, :, -self.token_window_size :, :]
|
| 587 |
+
value = self.past_key_values[layer_idx][1][:, :, -self.token_window_size :, :]
|
| 588 |
+
new_cache.update(key, value, layer_idx)
|
| 589 |
+
|
| 590 |
+
self.past_key_values = new_cache
|
| 591 |
+
|
| 592 |
+
@staticmethod
|
| 593 |
+
def _apply_rope_rotation(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
|
| 594 |
+
"""Apply RoPE rotation to tensor."""
|
| 595 |
+
return x * cos + rotate_half(x) * sin
|
| 596 |
+
|
| 597 |
+
def _compute_rope_cos_sin(self, positions: torch.Tensor, device: torch.device, dtype: torch.dtype):
|
| 598 |
+
"""Compute RoPE cos and sin for given positions."""
|
| 599 |
+
dim_half = self.head_dim // 2
|
| 600 |
+
freq_seq = torch.arange(0, dim_half, dtype=torch.float32, device=device)
|
| 601 |
+
inv_freq = 1.0 / (self.rope_theta ** (freq_seq / dim_half))
|
| 602 |
+
|
| 603 |
+
# positions: [seq_len]
|
| 604 |
+
angles = positions.float().unsqueeze(-1) * inv_freq.unsqueeze(0) # [seq_len, dim_half]
|
| 605 |
+
angles = torch.cat([angles, angles], dim=-1) # [seq_len, head_dim]
|
| 606 |
+
|
| 607 |
+
cos = angles.cos().to(dtype)
|
| 608 |
+
sin = angles.sin().to(dtype)
|
| 609 |
+
return cos, sin
|
| 610 |
+
|
| 611 |
+
def _reindex_kv_cache(self):
|
| 612 |
+
"""
|
| 613 |
+
Reindex KV cache for reindex mode:
|
| 614 |
+
1. Keep first chunk as attention sink
|
| 615 |
+
2. Keep last chunk
|
| 616 |
+
3. Discard middle chunks
|
| 617 |
+
4. Reindex the last chunk's key positions to be right after sink via RoPE rotation
|
| 618 |
+
"""
|
| 619 |
+
if self.past_key_values is None or len(self._chunk_info) < 2:
|
| 620 |
+
return
|
| 621 |
+
|
| 622 |
+
# Get current KV cache length
|
| 623 |
+
if hasattr(self.past_key_values, "get_seq_length"):
|
| 624 |
+
current_kv_len = self.past_key_values.get_seq_length()
|
| 625 |
+
else:
|
| 626 |
+
current_kv_len = self.past_key_values[0][0].shape[2]
|
| 627 |
+
|
| 628 |
+
# Calculate sink length (first chunk)
|
| 629 |
+
sink_len = self._chunk_info[0]["condition_len"] + self._chunk_info[0]["audio_token_count"]
|
| 630 |
+
|
| 631 |
+
# Last chunk length
|
| 632 |
+
last_chunk = self._chunk_info[-1]
|
| 633 |
+
last_chunk_len = last_chunk["condition_len"] + last_chunk["audio_token_count"]
|
| 634 |
+
|
| 635 |
+
keep_len = sink_len + last_chunk_len
|
| 636 |
+
|
| 637 |
+
# Get device and dtype
|
| 638 |
+
device = self.past_key_values.key_cache[0].device
|
| 639 |
+
dtype = self.past_key_values.key_cache[0].dtype
|
| 640 |
+
|
| 641 |
+
if current_kv_len <= keep_len:
|
| 642 |
+
last_chunk_kv_len = current_kv_len - sink_len
|
| 643 |
+
if last_chunk_kv_len <= 0:
|
| 644 |
+
return
|
| 645 |
+
self.text_start_pos = current_kv_len
|
| 646 |
+
return
|
| 647 |
+
|
| 648 |
+
# Step 1: Truncate KV cache - keep sink and last chunk
|
| 649 |
+
new_cache = DynamicCache()
|
| 650 |
+
num_layers = len(self.past_key_values.key_cache)
|
| 651 |
+
|
| 652 |
+
original_start_pos = current_kv_len - last_chunk_len
|
| 653 |
+
new_start_pos = sink_len
|
| 654 |
+
delta = new_start_pos - original_start_pos # This is a scalar constant
|
| 655 |
+
delta_positions = torch.full((last_chunk_len,), delta, dtype=torch.float32, device=device)
|
| 656 |
+
|
| 657 |
+
# Compute rotation cos/sin
|
| 658 |
+
cos, sin = self._compute_rope_cos_sin(delta_positions, device, dtype)
|
| 659 |
+
cos = cos.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, head_dim]
|
| 660 |
+
sin = sin.unsqueeze(0).unsqueeze(0)
|
| 661 |
+
|
| 662 |
+
for layer_idx in range(num_layers):
|
| 663 |
+
key_full = self.past_key_values.key_cache[layer_idx]
|
| 664 |
+
value_full = self.past_key_values.value_cache[layer_idx]
|
| 665 |
+
|
| 666 |
+
# Extract sink and last chunk
|
| 667 |
+
key_sink = key_full[:, :, :sink_len, :]
|
| 668 |
+
value_sink = value_full[:, :, :sink_len, :]
|
| 669 |
+
key_last = key_full[:, :, -last_chunk_len:, :]
|
| 670 |
+
value_last = value_full[:, :, -last_chunk_len:, :]
|
| 671 |
+
|
| 672 |
+
# Apply RoPE rotation to reindex key positions
|
| 673 |
+
key_last_reindexed = self._apply_rope_rotation(key_last, cos, sin)
|
| 674 |
+
|
| 675 |
+
# Concatenate sink and reindexed last chunk
|
| 676 |
+
key = torch.cat([key_sink, key_last_reindexed], dim=2)
|
| 677 |
+
value = torch.cat([value_sink, value_last], dim=2)
|
| 678 |
+
|
| 679 |
+
new_cache.update(key, value, layer_idx)
|
| 680 |
+
|
| 681 |
+
self.past_key_values = new_cache
|
| 682 |
+
|
| 683 |
+
# Update text_start_pos to reflect new positions
|
| 684 |
+
self.text_start_pos = sink_len + last_chunk_len
|
| 685 |
+
|
| 686 |
+
@torch.inference_mode()
|
| 687 |
+
def generate_with_buffer(
|
| 688 |
+
self,
|
| 689 |
+
condition: torch.Tensor,
|
| 690 |
+
text_finished: bool = False,
|
| 691 |
+
max_new_token: int = 500,
|
| 692 |
+
):
|
| 693 |
+
"""input a condition embedding chunk, generate audio token each time,
|
| 694 |
+
and accumulate to buffer, only yield when buffer satisfies chunk_size.
|
| 695 |
+
|
| 696 |
+
Yields:
|
| 697 |
+
torch.Tensor of shape [chunk_size] (2D: [1, chunk_size])
|
| 698 |
+
"""
|
| 699 |
+
self.idx += 1
|
| 700 |
+
self.device = self.tts.device
|
| 701 |
+
|
| 702 |
+
# if text finished, first concatenate Text EOS
|
| 703 |
+
if text_finished:
|
| 704 |
+
condition = torch.cat([condition, self.text_eos_embed], dim=1)
|
| 705 |
+
|
| 706 |
+
# always concatenate Audio BOS
|
| 707 |
+
condition = torch.cat([condition, self.audio_bos_embeds], dim=1).to(self.device)
|
| 708 |
+
|
| 709 |
+
self.all_conditions.append(condition)
|
| 710 |
+
|
| 711 |
+
# Initialize current chunk info
|
| 712 |
+
current_chunk_info = {
|
| 713 |
+
"condition_len": condition.shape[1],
|
| 714 |
+
"audio_token_count": 0,
|
| 715 |
+
"condition": condition.clone(),
|
| 716 |
+
"audio_tokens": [],
|
| 717 |
+
}
|
| 718 |
+
|
| 719 |
+
# Handle different attention types
|
| 720 |
+
if self.attention_type == "sliding_recompute" and self.idx >= 1:
|
| 721 |
+
# sliding_recompute: discard KV cache, recompute with previous + current chunk
|
| 722 |
+
self.past_key_values = None
|
| 723 |
+
current_condition = self._build_recompute_inputs(condition)
|
| 724 |
+
self.text_start_pos = 0
|
| 725 |
+
elif self.attention_type == "reindex" and self.idx >= 1:
|
| 726 |
+
# reindex: truncate KV cache keeping sink + last chunk, reindex positions via RoPE
|
| 727 |
+
self._reindex_kv_cache()
|
| 728 |
+
current_condition = condition
|
| 729 |
+
# Always update text_start_pos based on actual KV cache length (like reference code)
|
| 730 |
+
if self.past_key_values is not None:
|
| 731 |
+
if hasattr(self.past_key_values, "get_seq_length"):
|
| 732 |
+
kv_len = self.past_key_values.get_seq_length()
|
| 733 |
+
else:
|
| 734 |
+
kv_len = self.past_key_values[0][0].shape[2]
|
| 735 |
+
self.text_start_pos = kv_len
|
| 736 |
+
else:
|
| 737 |
+
current_condition = condition
|
| 738 |
+
|
| 739 |
+
condition_length = current_condition.shape[1]
|
| 740 |
+
prefill_len = condition_length
|
| 741 |
+
finished = torch.zeros(1, dtype=torch.bool, device=self.device)
|
| 742 |
+
chunk_generated_tokens = []
|
| 743 |
+
|
| 744 |
+
for t in range(max_new_token):
|
| 745 |
+
if t == 0:
|
| 746 |
+
inputs_embeds = current_condition
|
| 747 |
+
pos_ids = torch.arange(
|
| 748 |
+
self.text_start_pos,
|
| 749 |
+
self.text_start_pos + condition_length,
|
| 750 |
+
dtype=torch.long,
|
| 751 |
+
device=self.device,
|
| 752 |
+
).unsqueeze(0)
|
| 753 |
+
else:
|
| 754 |
+
last = self.all_generated_tokens[-1]
|
| 755 |
+
# last: [1,1], directly as code id
|
| 756 |
+
inputs_embeds = self.emb_code[0](last)
|
| 757 |
+
pos_ids = torch.tensor(
|
| 758 |
+
[self.text_start_pos + prefill_len + t - 1],
|
| 759 |
+
dtype=torch.long,
|
| 760 |
+
device=self.device,
|
| 761 |
+
).unsqueeze(0)
|
| 762 |
+
|
| 763 |
+
outputs = self.tts.model(
|
| 764 |
+
position_ids=pos_ids,
|
| 765 |
+
past_key_values=self.past_key_values,
|
| 766 |
+
inputs_embeds=inputs_embeds,
|
| 767 |
+
use_cache=True,
|
| 768 |
+
)
|
| 769 |
+
hidden_states = outputs.last_hidden_state
|
| 770 |
+
|
| 771 |
+
# Handle KV cache based on attention type
|
| 772 |
+
if self.attention_type == "sliding_window":
|
| 773 |
+
self.past_key_values = outputs.past_key_values
|
| 774 |
+
self._truncate_kv_cache_sliding_window()
|
| 775 |
+
else:
|
| 776 |
+
self.past_key_values = outputs.past_key_values
|
| 777 |
+
|
| 778 |
+
with P.cached():
|
| 779 |
+
logits = torch.empty(
|
| 780 |
+
hidden_states.size(0),
|
| 781 |
+
hidden_states.size(1),
|
| 782 |
+
self.num_audio_tokens,
|
| 783 |
+
self.num_vq,
|
| 784 |
+
dtype=torch.float,
|
| 785 |
+
device=self.device,
|
| 786 |
+
)
|
| 787 |
+
for num_vq_iter in range(self.num_vq):
|
| 788 |
+
x: torch.Tensor = self.head_code[num_vq_iter](hidden_states)
|
| 789 |
+
logits[..., num_vq_iter] = x
|
| 790 |
+
del x
|
| 791 |
+
|
| 792 |
+
del hidden_states
|
| 793 |
+
|
| 794 |
+
logits = logits[:, -1].float()
|
| 795 |
+
|
| 796 |
+
logits = logits.permute(0, 2, 1)
|
| 797 |
+
logits = logits.reshape(-1, logits.size(2))
|
| 798 |
+
|
| 799 |
+
logits /= self.temperature
|
| 800 |
+
|
| 801 |
+
audio_bos = len(self.all_generated_tokens) == 0 and t == 0
|
| 802 |
+
|
| 803 |
+
if not audio_bos:
|
| 804 |
+
# use generated tokens (current chunk) as input for processor/warper (align with modeling_minicpmo)
|
| 805 |
+
all_generated_tokens = torch.cat(self.all_generated_tokens, dim=1).to(self.device) # [1, T]
|
| 806 |
+
for processor in self.logits_processors:
|
| 807 |
+
logits = processor(all_generated_tokens, logits)
|
| 808 |
+
|
| 809 |
+
for warper in self.logits_warpers:
|
| 810 |
+
logits = warper(all_generated_tokens, logits)
|
| 811 |
+
del all_generated_tokens
|
| 812 |
+
|
| 813 |
+
# sample next token (only use first codebook, same as generate)
|
| 814 |
+
scores = F.softmax(logits, dim=-1)
|
| 815 |
+
idx_next = torch.multinomial(scores, num_samples=1) # [(B*num_vq), 1]
|
| 816 |
+
next_id = idx_next.view(-1, self.num_vq)[:, 0:1] # only take first codebook → [B, 1]
|
| 817 |
+
del scores
|
| 818 |
+
|
| 819 |
+
if next_id.eq(
|
| 820 |
+
self.eos_token
|
| 821 |
+
).any(): # generated audio eos token, means this chunk is finished, no longer generate new tokens
|
| 822 |
+
finished[:] = True
|
| 823 |
+
else: # eos token cannot be added to buffer, he does not speak.
|
| 824 |
+
# convert next_id to correct shape [1, 1], no num_vq dimension
|
| 825 |
+
if next_id.dim() == 0: # if scalar
|
| 826 |
+
next_tok = next_id.unsqueeze(0).unsqueeze(0) # [1, 1]
|
| 827 |
+
elif next_id.dim() == 1: # if 1D [1]
|
| 828 |
+
next_tok = next_id.unsqueeze(0) # [1, 1]
|
| 829 |
+
else:
|
| 830 |
+
next_tok = next_id
|
| 831 |
+
|
| 832 |
+
self.all_generated_tokens.append(next_tok)
|
| 833 |
+
chunk_generated_tokens.append(next_tok)
|
| 834 |
+
|
| 835 |
+
# Update chunk info for sliding_recompute
|
| 836 |
+
current_chunk_info["audio_tokens"].append(next_tok.clone())
|
| 837 |
+
current_chunk_info["audio_token_count"] += 1
|
| 838 |
+
|
| 839 |
+
self._token_buffer.append(next_tok)
|
| 840 |
+
|
| 841 |
+
if len(self._token_buffer) == 0:
|
| 842 |
+
# case 1: if last text chunk, yield None
|
| 843 |
+
if text_finished:
|
| 844 |
+
yield torch.empty(1, 0, dtype=torch.long, device=self.device), True
|
| 845 |
+
break
|
| 846 |
+
# case 2: if not last text chunk, break directly
|
| 847 |
+
else:
|
| 848 |
+
break
|
| 849 |
+
else: # buffer has something
|
| 850 |
+
# case 1: if buffer is larger/equal to chunk_size, yield out
|
| 851 |
+
if len(self._token_buffer) >= self.chunk_size:
|
| 852 |
+
batch = torch.cat(self._token_buffer[: self.chunk_size], dim=1) # [1, chunk_size]
|
| 853 |
+
yield batch, False # → [1, chunk_size]
|
| 854 |
+
# discard yielded part
|
| 855 |
+
self._token_buffer = self._token_buffer[self.chunk_size :]
|
| 856 |
+
|
| 857 |
+
# case 2: if buffer is smaller than chunk_size
|
| 858 |
+
else:
|
| 859 |
+
# if generation finished, and is the last text chunk, yield all remaining tokens, then break
|
| 860 |
+
if finished.all():
|
| 861 |
+
if text_finished:
|
| 862 |
+
batch = torch.cat(self._token_buffer, dim=1) # [1, chunk_size]
|
| 863 |
+
yield batch, True # → [1, chunk_size]
|
| 864 |
+
self._token_buffer = []
|
| 865 |
+
break
|
| 866 |
+
else:
|
| 867 |
+
# not the last text chunk, need to wait for next text chunk to fill up buffer, then this call ends
|
| 868 |
+
break
|
| 869 |
+
else: # generation of this audio chunk is not finished, continue generating
|
| 870 |
+
continue
|
| 871 |
+
|
| 872 |
+
# Save current chunk info for sliding_recompute and reindex
|
| 873 |
+
self._chunk_info.append(current_chunk_info)
|
| 874 |
+
self._total_seq_len += condition.shape[1] + len(chunk_generated_tokens)
|
| 875 |
+
|
| 876 |
+
# Update text_start_pos based on attention type
|
| 877 |
+
if self.attention_type == "sliding_recompute":
|
| 878 |
+
# sliding_recompute: will be reset at next chunk start, update normally here
|
| 879 |
+
self.text_start_pos += prefill_len + len(chunk_generated_tokens)
|
| 880 |
+
elif self.attention_type == "reindex":
|
| 881 |
+
# reindex: position based on actual KV cache length (positions have been reindexed to be continuous)
|
| 882 |
+
if self.past_key_values is not None:
|
| 883 |
+
if hasattr(self.past_key_values, "get_seq_length"):
|
| 884 |
+
self.text_start_pos = self.past_key_values.get_seq_length()
|
| 885 |
+
else:
|
| 886 |
+
self.text_start_pos = self.past_key_values[0][0].shape[2]
|
| 887 |
+
else:
|
| 888 |
+
self.text_start_pos += condition.shape[1] + len(chunk_generated_tokens)
|
| 889 |
+
else:
|
| 890 |
+
self.text_start_pos += condition.shape[1] + len(chunk_generated_tokens)
|
| 891 |
+
# note: remaining tokens in buffer will be kept, and accumulated next time
|
| 892 |
+
|
| 893 |
+
|
| 894 |
+
# sliding window
|
| 895 |
+
@dataclass
|
| 896 |
+
class StreamingWindowConfig:
|
| 897 |
+
text_window_high_tokens: int = 8000
|
| 898 |
+
text_window_low_tokens: int = 6000
|
| 899 |
+
|
| 900 |
+
|
| 901 |
+
@dataclass
|
| 902 |
+
class DuplexWindowConfig:
|
| 903 |
+
"""duplex sliding window configuration
|
| 904 |
+
|
| 905 |
+
sliding window mode:
|
| 906 |
+
- "off": disable sliding window
|
| 907 |
+
- "basic": basic sliding window (trigger by cache length)
|
| 908 |
+
- "context": sliding window with context (trigger by unit number, preserve generated text to previous)
|
| 909 |
+
"""
|
| 910 |
+
|
| 911 |
+
# sliding window mode
|
| 912 |
+
sliding_window_mode: str = "off" # "off" / "basic" / "context"
|
| 913 |
+
|
| 914 |
+
# basic sliding window parameters
|
| 915 |
+
basic_window_high_tokens: int = 8000 # high watermark: trigger sliding window when exceeded
|
| 916 |
+
basic_window_low_tokens: int = 6000 # low watermark: keep to this value after sliding window
|
| 917 |
+
|
| 918 |
+
# context sliding window parameters
|
| 919 |
+
context_previous_max_tokens: int = 500 # previous maximum token number
|
| 920 |
+
context_max_units: int = 24 # maximum unit number (trigger sliding window when exceeded)
|
| 921 |
+
|
| 922 |
+
# verification mode (for comparison test)
|
| 923 |
+
verify_mode: bool = False # whether to enable verification log
|
| 924 |
+
|
| 925 |
+
|
| 926 |
+
def as_dynamic_cache(past_key_values):
|
| 927 |
+
"""Convert legacy tuple cache to DynamicCache if needed."""
|
| 928 |
+
if isinstance(past_key_values, DynamicCache):
|
| 929 |
+
return past_key_values
|
| 930 |
+
|
| 931 |
+
if isinstance(past_key_values, tuple):
|
| 932 |
+
return DynamicCache.from_legacy_cache(past_key_values)
|
| 933 |
+
|
| 934 |
+
return past_key_values
|
| 935 |
+
|
| 936 |
+
|
| 937 |
+
def get_kv_cache_length(cache) -> int:
|
| 938 |
+
"""Get the sequence length of a KV cache.
|
| 939 |
+
|
| 940 |
+
Args:
|
| 941 |
+
cache: DynamicCache or tuple-based cache
|
| 942 |
+
|
| 943 |
+
Returns:
|
| 944 |
+
The number of tokens in the cache
|
| 945 |
+
"""
|
| 946 |
+
if cache is None:
|
| 947 |
+
return 0
|
| 948 |
+
|
| 949 |
+
if isinstance(cache, DynamicCache):
|
| 950 |
+
if not cache.key_cache or not cache.key_cache[0].numel():
|
| 951 |
+
return 0
|
| 952 |
+
return cache.key_cache[0].shape[-2]
|
| 953 |
+
|
| 954 |
+
if isinstance(cache, tuple):
|
| 955 |
+
return cache[0][0].shape[2]
|
| 956 |
+
|
| 957 |
+
return 0
|
| 958 |
+
|
| 959 |
+
|
| 960 |
+
def get_rotary_cos_sin(
|
| 961 |
+
head_dim: int,
|
| 962 |
+
positions: torch.Tensor,
|
| 963 |
+
device: torch.device,
|
| 964 |
+
dtype: torch.dtype,
|
| 965 |
+
rope_theta: float = 10000.0,
|
| 966 |
+
inv_freq_cache: Optional[Dict[Tuple, torch.Tensor]] = None,
|
| 967 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 968 |
+
"""Compute RoPE cos and sin components for given positions.
|
| 969 |
+
|
| 970 |
+
Args:
|
| 971 |
+
head_dim: Dimension of each attention head
|
| 972 |
+
positions: Position indices tensor
|
| 973 |
+
device: Target device
|
| 974 |
+
dtype: Target dtype
|
| 975 |
+
rope_theta: RoPE base frequency (default 10000.0)
|
| 976 |
+
inv_freq_cache: Optional cache dict for inverse frequencies
|
| 977 |
+
|
| 978 |
+
Returns:
|
| 979 |
+
Tuple of (cos, sin) tensors with shape [1, 1, seq_len, head_dim]
|
| 980 |
+
"""
|
| 981 |
+
cache_key = (head_dim, device)
|
| 982 |
+
|
| 983 |
+
inv_freq = inv_freq_cache.get(cache_key) if inv_freq_cache is not None else None
|
| 984 |
+
if inv_freq is None or inv_freq.device != device or inv_freq.shape[0] != head_dim // 2:
|
| 985 |
+
exponent = torch.arange(0, head_dim, 2, device=device, dtype=torch.float32) / head_dim
|
| 986 |
+
inv_freq = 1.0 / (rope_theta**exponent)
|
| 987 |
+
if inv_freq_cache is not None:
|
| 988 |
+
inv_freq_cache[cache_key] = inv_freq
|
| 989 |
+
|
| 990 |
+
positions = positions.to(device=device, dtype=torch.float32)
|
| 991 |
+
angles = torch.einsum("i,j->ij", positions, inv_freq)
|
| 992 |
+
cos = torch.cos(angles)
|
| 993 |
+
sin = torch.sin(angles)
|
| 994 |
+
|
| 995 |
+
# Use cat instead of repeat_interleave, consistent with model's original RotaryEmbedding
|
| 996 |
+
# Original: emb = torch.cat((freqs, freqs), dim=-1) -> [f0, f1, ..., f_{d/2}, f0, f1, ..., f_{d/2}]
|
| 997 |
+
cos_full = torch.cat([cos, cos], dim=-1).to(dtype=dtype)
|
| 998 |
+
sin_full = torch.cat([sin, sin], dim=-1).to(dtype=dtype)
|
| 999 |
+
cos_full = cos_full.unsqueeze(0).unsqueeze(0)
|
| 1000 |
+
sin_full = sin_full.unsqueeze(0).unsqueeze(0)
|
| 1001 |
+
return cos_full, sin_full
|
| 1002 |
+
|
| 1003 |
+
|
| 1004 |
+
def realign_rotary_suffix(
|
| 1005 |
+
suffix_keys: torch.Tensor,
|
| 1006 |
+
old_positions: torch.Tensor,
|
| 1007 |
+
new_positions: torch.Tensor,
|
| 1008 |
+
rope_theta: float = 10000.0,
|
| 1009 |
+
inv_freq_cache: Optional[Dict[Tuple, torch.Tensor]] = None,
|
| 1010 |
+
) -> torch.Tensor:
|
| 1011 |
+
"""Realign RoPE position encoding after cache eviction.
|
| 1012 |
+
|
| 1013 |
+
When tokens are dropped from the middle of a cache, the suffix tokens
|
| 1014 |
+
need their RoPE embeddings recalculated with new position indices.
|
| 1015 |
+
|
| 1016 |
+
Args:
|
| 1017 |
+
suffix_keys: Key tensor to realign, shape [batch, heads, seq_len, head_dim]
|
| 1018 |
+
old_positions: Original position indices
|
| 1019 |
+
new_positions: New position indices after eviction
|
| 1020 |
+
rope_theta: RoPE base frequency
|
| 1021 |
+
inv_freq_cache: Optional cache dict for inverse frequencies
|
| 1022 |
+
|
| 1023 |
+
Returns:
|
| 1024 |
+
Realigned key tensor with same shape as input
|
| 1025 |
+
"""
|
| 1026 |
+
if suffix_keys.numel() == 0:
|
| 1027 |
+
return suffix_keys
|
| 1028 |
+
|
| 1029 |
+
head_dim = suffix_keys.shape[-1]
|
| 1030 |
+
device = suffix_keys.device
|
| 1031 |
+
dtype = suffix_keys.dtype
|
| 1032 |
+
|
| 1033 |
+
# Compute old position cos/sin
|
| 1034 |
+
cos_old, sin_old = get_rotary_cos_sin(head_dim, old_positions, device, dtype, rope_theta, inv_freq_cache)
|
| 1035 |
+
|
| 1036 |
+
# Inverse transform: recover original key
|
| 1037 |
+
base = cos_old * suffix_keys - sin_old * rotate_half(suffix_keys)
|
| 1038 |
+
|
| 1039 |
+
# Compute new position cos/sin
|
| 1040 |
+
cos_new, sin_new = get_rotary_cos_sin(head_dim, new_positions, device, dtype, rope_theta, inv_freq_cache)
|
| 1041 |
+
|
| 1042 |
+
# Forward transform: re-encode with new positions
|
| 1043 |
+
return cos_new * base + sin_new * rotate_half(base)
|
| 1044 |
+
|
| 1045 |
+
|
| 1046 |
+
def drop_tokens_from_cache(
|
| 1047 |
+
cache: Optional[DynamicCache | Tuple],
|
| 1048 |
+
length: int,
|
| 1049 |
+
preserve: int,
|
| 1050 |
+
position_offset: int,
|
| 1051 |
+
rope_theta: float = 10000.0,
|
| 1052 |
+
inv_freq_cache: Optional[Dict[Tuple, torch.Tensor]] = None,
|
| 1053 |
+
) -> Tuple[Optional[DynamicCache], int, bool]:
|
| 1054 |
+
"""Drop tokens from a KV cache while preserving system prompt.
|
| 1055 |
+
|
| 1056 |
+
Removes tokens in the range [preserve, preserve + length) from the cache,
|
| 1057 |
+
realigning RoPE embeddings for the suffix.
|
| 1058 |
+
|
| 1059 |
+
Args:
|
| 1060 |
+
cache: DynamicCache or tuple-based cache (will be converted to DynamicCache)
|
| 1061 |
+
length: Number of tokens to drop
|
| 1062 |
+
preserve: Number of tokens to preserve at the start (system prompt)
|
| 1063 |
+
position_offset: Current position offset for RoPE calculation
|
| 1064 |
+
rope_theta: RoPE base frequency
|
| 1065 |
+
inv_freq_cache: Optional cache dict for inverse frequencies
|
| 1066 |
+
|
| 1067 |
+
Returns:
|
| 1068 |
+
Tuple of (cache, new_position_offset, success)
|
| 1069 |
+
Note: Tuple cache will be converted to DynamicCache. Modification is in-place.
|
| 1070 |
+
"""
|
| 1071 |
+
if cache is None or length <= 0:
|
| 1072 |
+
return cache, position_offset, False
|
| 1073 |
+
|
| 1074 |
+
cache = as_dynamic_cache(cache)
|
| 1075 |
+
|
| 1076 |
+
total_len = get_kv_cache_length(cache)
|
| 1077 |
+
if total_len <= 0:
|
| 1078 |
+
return cache, position_offset, False
|
| 1079 |
+
|
| 1080 |
+
preserve = min(preserve, total_len)
|
| 1081 |
+
available = total_len - preserve
|
| 1082 |
+
|
| 1083 |
+
if available < length:
|
| 1084 |
+
logger.warning(
|
| 1085 |
+
"Cannot drop %d tokens: only %d available (total=%d, preserve=%d)",
|
| 1086 |
+
length,
|
| 1087 |
+
available,
|
| 1088 |
+
total_len,
|
| 1089 |
+
preserve,
|
| 1090 |
+
)
|
| 1091 |
+
return cache, position_offset, False
|
| 1092 |
+
|
| 1093 |
+
suffix_len = total_len - preserve - length
|
| 1094 |
+
# note: after RoPE reindex, the position of cache has been compressed (from preserve start)
|
| 1095 |
+
# so here should not add position_offset, but use the actual layout of current cache
|
| 1096 |
+
suffix_offset = preserve + length # suffix current position in cache
|
| 1097 |
+
prefix_offset = preserve # suffix new position (follow preserve)
|
| 1098 |
+
|
| 1099 |
+
# Prepare position tensors for RoPE realignment
|
| 1100 |
+
old_positions = None
|
| 1101 |
+
new_positions = None
|
| 1102 |
+
if suffix_len > 0:
|
| 1103 |
+
device = cache.key_cache[0].device
|
| 1104 |
+
old_positions = torch.arange(
|
| 1105 |
+
suffix_offset,
|
| 1106 |
+
suffix_offset + suffix_len,
|
| 1107 |
+
device=device,
|
| 1108 |
+
dtype=torch.long,
|
| 1109 |
+
)
|
| 1110 |
+
new_positions = torch.arange(
|
| 1111 |
+
prefix_offset,
|
| 1112 |
+
prefix_offset + suffix_len,
|
| 1113 |
+
device=device,
|
| 1114 |
+
dtype=torch.long,
|
| 1115 |
+
)
|
| 1116 |
+
|
| 1117 |
+
keep_len = total_len - length
|
| 1118 |
+
|
| 1119 |
+
# Process each layer (in-place modification)
|
| 1120 |
+
for layer_idx in range(len(cache.key_cache)):
|
| 1121 |
+
key_tensor = cache.key_cache[layer_idx]
|
| 1122 |
+
value_tensor = cache.value_cache[layer_idx]
|
| 1123 |
+
|
| 1124 |
+
if not key_tensor.numel():
|
| 1125 |
+
continue
|
| 1126 |
+
|
| 1127 |
+
# Preserve prefix (system prompt)
|
| 1128 |
+
prefix_keys = key_tensor[:, :, :preserve, :]
|
| 1129 |
+
prefix_values = value_tensor[:, :, :preserve, :]
|
| 1130 |
+
|
| 1131 |
+
if suffix_len > 0:
|
| 1132 |
+
# Keep and realign suffix
|
| 1133 |
+
suffix_keys = key_tensor[:, :, preserve + length :, :]
|
| 1134 |
+
suffix_values = value_tensor[:, :, preserve + length :, :]
|
| 1135 |
+
|
| 1136 |
+
if old_positions is not None and new_positions is not None and suffix_keys.numel():
|
| 1137 |
+
suffix_keys = realign_rotary_suffix(
|
| 1138 |
+
suffix_keys,
|
| 1139 |
+
old_positions,
|
| 1140 |
+
new_positions,
|
| 1141 |
+
rope_theta,
|
| 1142 |
+
inv_freq_cache,
|
| 1143 |
+
)
|
| 1144 |
+
|
| 1145 |
+
cache.key_cache[layer_idx] = torch.cat([prefix_keys, suffix_keys], dim=-2).contiguous()
|
| 1146 |
+
cache.value_cache[layer_idx] = torch.cat([prefix_values, suffix_values], dim=-2).contiguous()
|
| 1147 |
+
else:
|
| 1148 |
+
cache.key_cache[layer_idx] = prefix_keys.contiguous()
|
| 1149 |
+
cache.value_cache[layer_idx] = prefix_values.contiguous()
|
| 1150 |
+
|
| 1151 |
+
cache.crop(keep_len)
|
| 1152 |
+
cache._seen_tokens = max(keep_len, 0)
|
| 1153 |
+
|
| 1154 |
+
new_offset = position_offset + length
|
| 1155 |
+
logger.debug("Dropped %d tokens from cache, new length=%d", length, keep_len)
|
| 1156 |
+
|
| 1157 |
+
return cache, new_offset, True
|
| 1158 |
+
|
| 1159 |
+
|
| 1160 |
+
# stream decoder
|
| 1161 |
+
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float("inf")):
|
| 1162 |
+
logits = logits.clone()
|
| 1163 |
+
|
| 1164 |
+
# Top-k filtering
|
| 1165 |
+
if top_k > 0:
|
| 1166 |
+
top_k = min(top_k, logits.size(-1))
|
| 1167 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
| 1168 |
+
logits[indices_to_remove] = filter_value
|
| 1169 |
+
|
| 1170 |
+
# Top-p (nucleus) filtering
|
| 1171 |
+
if top_p > 0.0:
|
| 1172 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 1173 |
+
probs = F.softmax(sorted_logits, dim=-1)
|
| 1174 |
+
cumulative_probs = torch.cumsum(probs, dim=-1)
|
| 1175 |
+
|
| 1176 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 1177 |
+
# keep the first token that exceeds top_p
|
| 1178 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 1179 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 1180 |
+
|
| 1181 |
+
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
| 1182 |
+
logits[0, indices_to_remove] = filter_value
|
| 1183 |
+
|
| 1184 |
+
return logits
|
| 1185 |
+
|
| 1186 |
+
|
| 1187 |
+
class StreamDecoder:
|
| 1188 |
+
def __init__(self, llm, tokenizer, special_token_ids=None, forbidden_token_ids=None):
|
| 1189 |
+
self.m = llm
|
| 1190 |
+
self.tokenizer = tokenizer
|
| 1191 |
+
self.listen_id = self.tokenizer.eos_token_id
|
| 1192 |
+
|
| 1193 |
+
self.chunk_eos_id = self.tokenizer.convert_tokens_to_ids("<|chunk_eos|>")
|
| 1194 |
+
self.chunk_tts_eos_id = self.tokenizer.convert_tokens_to_ids("<|chunk_tts_eos|>")
|
| 1195 |
+
self.turn_eos_id = self.tokenizer.convert_tokens_to_ids("<|turn_eos|>")
|
| 1196 |
+
self.speak_id = self.tokenizer.convert_tokens_to_ids("<|speak|>")
|
| 1197 |
+
|
| 1198 |
+
self.special_token_ids = special_token_ids if special_token_ids is not None else []
|
| 1199 |
+
|
| 1200 |
+
# cache special tokens (used for context sliding window filtering)
|
| 1201 |
+
self._all_special_ids = set()
|
| 1202 |
+
self._all_special_tokens_text = set()
|
| 1203 |
+
if self.tokenizer:
|
| 1204 |
+
if hasattr(self.tokenizer, "all_special_ids"):
|
| 1205 |
+
self._all_special_ids = set(self.tokenizer.all_special_ids)
|
| 1206 |
+
if hasattr(self.tokenizer, "all_special_tokens"):
|
| 1207 |
+
self._all_special_tokens_text = set(self.tokenizer.all_special_tokens)
|
| 1208 |
+
|
| 1209 |
+
custom_special_tokens = [
|
| 1210 |
+
"<unit>",
|
| 1211 |
+
"</unit>",
|
| 1212 |
+
"<image>",
|
| 1213 |
+
"</image>",
|
| 1214 |
+
"<slice>",
|
| 1215 |
+
"</slice>",
|
| 1216 |
+
"<|listen|>",
|
| 1217 |
+
"<|speak|>",
|
| 1218 |
+
"<|tts_bos|>",
|
| 1219 |
+
"<|tts_eos|>",
|
| 1220 |
+
"<|audio_start|>",
|
| 1221 |
+
"<|audio_end|>",
|
| 1222 |
+
"<|chunk_eos|>",
|
| 1223 |
+
"<|chunk_tts_eos|>",
|
| 1224 |
+
"<|turn_eos|>",
|
| 1225 |
+
"<|audio_start|>",
|
| 1226 |
+
"<|audio_end|>",
|
| 1227 |
+
]
|
| 1228 |
+
self._all_special_tokens_text.update(custom_special_tokens)
|
| 1229 |
+
for token in custom_special_tokens:
|
| 1230 |
+
token_id = self.tokenizer.convert_tokens_to_ids(token)
|
| 1231 |
+
if token_id is not None and token_id != self.tokenizer.unk_token_id:
|
| 1232 |
+
self._all_special_ids.add(token_id)
|
| 1233 |
+
|
| 1234 |
+
if forbidden_token_ids is None:
|
| 1235 |
+
self.forbidden_token_ids = []
|
| 1236 |
+
elif isinstance(forbidden_token_ids, int):
|
| 1237 |
+
self.forbidden_token_ids = [self.forbidden_token_ids]
|
| 1238 |
+
else:
|
| 1239 |
+
self.forbidden_token_ids = forbidden_token_ids
|
| 1240 |
+
self.forbidden_token_ids.append(self.chunk_eos_id)
|
| 1241 |
+
|
| 1242 |
+
assert isinstance(self.forbidden_token_ids, list)
|
| 1243 |
+
|
| 1244 |
+
self.cache = None
|
| 1245 |
+
self.context = ""
|
| 1246 |
+
self.generated_tokens = [] # track generated tokens
|
| 1247 |
+
self.generated_special_tokens = [] # track generated special tokens
|
| 1248 |
+
self.reset()
|
| 1249 |
+
self.embeds = None
|
| 1250 |
+
self.system_embeds = None
|
| 1251 |
+
|
| 1252 |
+
# sliding window related states
|
| 1253 |
+
self._unit_history: List[Dict[str, Any]] = []
|
| 1254 |
+
self._next_unit_id: int = 0
|
| 1255 |
+
self._pending_unit_id: Optional[int] = None
|
| 1256 |
+
self._pending_unit_start_cache_len: int = 0
|
| 1257 |
+
self._system_preserve_length: int = 0
|
| 1258 |
+
self._position_offset: int = 0
|
| 1259 |
+
self._window_config = DuplexWindowConfig()
|
| 1260 |
+
self._window_enabled: bool = True
|
| 1261 |
+
self._rope_inv_freq_cache: Dict[Tuple, torch.Tensor] = {}
|
| 1262 |
+
|
| 1263 |
+
# context preserving sliding window states
|
| 1264 |
+
# initial cache layout: [prefix] [suffix] [units...]
|
| 1265 |
+
# after first sliding window: [prefix] [previous_marker + content] [suffix] [units...]
|
| 1266 |
+
# fixed dynamic sliding region fixed
|
| 1267 |
+
self._preserve_prefix_length: int = 0 # original prefix length (fixed)
|
| 1268 |
+
self._previous_content_length: int = 0 # previous content length (dynamic, including marker)
|
| 1269 |
+
self._suffix_token_ids: List[int] = [] # suffix token ids (e.g. <|im_end|>)
|
| 1270 |
+
|
| 1271 |
+
# previous marker (added dynamically after first sliding window)
|
| 1272 |
+
self._previous_marker: str = "\n\nprevious: " # fixed prefix marker
|
| 1273 |
+
self._previous_marker_token_ids: List[int] = [] # marker token ids (initialized)
|
| 1274 |
+
self._has_previous: bool = False # whether previous marker has been added
|
| 1275 |
+
|
| 1276 |
+
# previous content
|
| 1277 |
+
self._previous_text: str = "" # accumulated generated text (without marker)
|
| 1278 |
+
self._previous_token_ids: List[int] = [] # previous full token ids (including marker)
|
| 1279 |
+
|
| 1280 |
+
# validation statistics
|
| 1281 |
+
self._sliding_event_count: int = 0 # sliding window trigger count
|
| 1282 |
+
self._total_dropped_tokens: int = 0 # total dropped token count
|
| 1283 |
+
self._total_dropped_units: int = 0 # total dropped unit count
|
| 1284 |
+
|
| 1285 |
+
def sliding_embeds(self):
|
| 1286 |
+
# tmp = system_embeds
|
| 1287 |
+
# tmp +-》 embeds after 5s
|
| 1288 |
+
# reset
|
| 1289 |
+
# feed
|
| 1290 |
+
pass
|
| 1291 |
+
|
| 1292 |
+
def reset(self):
|
| 1293 |
+
self.context = ""
|
| 1294 |
+
self.cache = None
|
| 1295 |
+
self.generated_tokens = []
|
| 1296 |
+
self.generated_special_tokens = []
|
| 1297 |
+
self.embeds = None
|
| 1298 |
+
self.system_embeds = None
|
| 1299 |
+
|
| 1300 |
+
# sliding window state reset
|
| 1301 |
+
old_unit_count = len(self._unit_history) if hasattr(self, "_unit_history") else 0
|
| 1302 |
+
self._unit_history = []
|
| 1303 |
+
self._next_unit_id = 0
|
| 1304 |
+
self._pending_unit_id = None
|
| 1305 |
+
self._pending_unit_start_cache_len = 0
|
| 1306 |
+
self._system_preserve_length = 0
|
| 1307 |
+
self._position_offset = 0
|
| 1308 |
+
self._rope_inv_freq_cache = {}
|
| 1309 |
+
|
| 1310 |
+
# context preserving sliding window state reset
|
| 1311 |
+
self._preserve_prefix_length = 0
|
| 1312 |
+
self._previous_content_length = 0
|
| 1313 |
+
self._suffix_token_ids = []
|
| 1314 |
+
self._previous_marker = "\n\nprevious: "
|
| 1315 |
+
self._previous_marker_token_ids = []
|
| 1316 |
+
self._has_previous = False
|
| 1317 |
+
self._previous_text = ""
|
| 1318 |
+
self._previous_token_ids = []
|
| 1319 |
+
|
| 1320 |
+
# validation statistics
|
| 1321 |
+
self._sliding_event_count = 0 # sliding window trigger count
|
| 1322 |
+
self._total_dropped_tokens = 0 # total dropped token count
|
| 1323 |
+
self._total_dropped_units = 0 # total dropped unit count
|
| 1324 |
+
|
| 1325 |
+
def get_cache_length(self) -> int:
|
| 1326 |
+
if self.cache is None:
|
| 1327 |
+
return 0
|
| 1328 |
+
if isinstance(self.cache, DynamicCache):
|
| 1329 |
+
if len(self.cache.key_cache) > 0 and self.cache.key_cache[0].numel() > 0:
|
| 1330 |
+
return self.cache.key_cache[0].shape[2]
|
| 1331 |
+
return 0
|
| 1332 |
+
# Tuple cache format
|
| 1333 |
+
return self.cache[0][0].shape[2]
|
| 1334 |
+
|
| 1335 |
+
def get_total_generated_tokens(self) -> int:
|
| 1336 |
+
return sum(len(u.get("generated_tokens", [])) for u in self._unit_history)
|
| 1337 |
+
|
| 1338 |
+
def register_unit_start(self) -> int:
|
| 1339 |
+
self._pending_unit_id = self._next_unit_id
|
| 1340 |
+
self._pending_unit_start_cache_len = self.get_cache_length()
|
| 1341 |
+
return self._pending_unit_id
|
| 1342 |
+
|
| 1343 |
+
def register_unit_end(
|
| 1344 |
+
self,
|
| 1345 |
+
input_type: str,
|
| 1346 |
+
generated_tokens: Optional[List[int]] = None,
|
| 1347 |
+
is_listen: bool = False,
|
| 1348 |
+
generated_text: Optional[str] = None,
|
| 1349 |
+
):
|
| 1350 |
+
"""Call when unit ends, record unit information
|
| 1351 |
+
|
| 1352 |
+
Should be called after feeding </unit> token
|
| 1353 |
+
|
| 1354 |
+
Args:
|
| 1355 |
+
input_type: "audio" / "video" / "omni" / "system"
|
| 1356 |
+
generated_tokens: tokens generated by the unit (token ids)
|
| 1357 |
+
is_listen: whether the unit is in listen state
|
| 1358 |
+
generated_text: text generated by the unit (used for context preserving mode)
|
| 1359 |
+
"""
|
| 1360 |
+
if self._pending_unit_id is None:
|
| 1361 |
+
logger.warning("register_unit_end called without register_unit_start")
|
| 1362 |
+
return
|
| 1363 |
+
|
| 1364 |
+
# calculate the length of the unit
|
| 1365 |
+
current_cache_len = self.get_cache_length()
|
| 1366 |
+
unit_len = current_cache_len - self._pending_unit_start_cache_len
|
| 1367 |
+
|
| 1368 |
+
if unit_len > 0:
|
| 1369 |
+
entry = {
|
| 1370 |
+
"unit_id": self._pending_unit_id,
|
| 1371 |
+
"length": unit_len,
|
| 1372 |
+
"type": input_type,
|
| 1373 |
+
"generated_tokens": generated_tokens or [],
|
| 1374 |
+
"generated_text": generated_text or "", # used for context preserving mode
|
| 1375 |
+
"is_listen": is_listen,
|
| 1376 |
+
}
|
| 1377 |
+
self._unit_history.append(entry)
|
| 1378 |
+
|
| 1379 |
+
self._pending_unit_id = None
|
| 1380 |
+
self._pending_unit_start_cache_len = 0
|
| 1381 |
+
self._next_unit_id += 1
|
| 1382 |
+
|
| 1383 |
+
def register_system_prompt(self):
|
| 1384 |
+
"""Call after system prompt prefill, record preserve length"""
|
| 1385 |
+
self._system_preserve_length = self.get_cache_length()
|
| 1386 |
+
|
| 1387 |
+
# sliding window core methods
|
| 1388 |
+
|
| 1389 |
+
def _get_rope_theta(self) -> float:
|
| 1390 |
+
"""get model rope_theta configuration"""
|
| 1391 |
+
return float(getattr(self.m.config, "rope_theta", 10000.0))
|
| 1392 |
+
|
| 1393 |
+
def _drop_tokens_from_cache(self, length: int) -> bool:
|
| 1394 |
+
"""remove specified number of tokens from cache (protect system prompt)
|
| 1395 |
+
|
| 1396 |
+
remove tokens in the range [preserve, preserve + length)
|
| 1397 |
+
supports DynamicCache and tuple cache formats
|
| 1398 |
+
"""
|
| 1399 |
+
if self.cache is None or length <= 0:
|
| 1400 |
+
return False
|
| 1401 |
+
|
| 1402 |
+
cache_type = "DynamicCache" if isinstance(self.cache, DynamicCache) else "TupleCache"
|
| 1403 |
+
cache_len_before = self.get_cache_length()
|
| 1404 |
+
offset_before = self._position_offset
|
| 1405 |
+
|
| 1406 |
+
new_cache, new_offset, success = drop_tokens_from_cache(
|
| 1407 |
+
cache=self.cache,
|
| 1408 |
+
length=length,
|
| 1409 |
+
preserve=self._system_preserve_length,
|
| 1410 |
+
position_offset=self._position_offset,
|
| 1411 |
+
rope_theta=self._get_rope_theta(),
|
| 1412 |
+
inv_freq_cache=self._rope_inv_freq_cache,
|
| 1413 |
+
)
|
| 1414 |
+
if success:
|
| 1415 |
+
self.cache = new_cache # For DynamicCache this is the same object (in-place)
|
| 1416 |
+
self._position_offset = new_offset
|
| 1417 |
+
|
| 1418 |
+
return success
|
| 1419 |
+
|
| 1420 |
+
def _drop_unit(self, unit_id: int) -> bool:
|
| 1421 |
+
"""remove specified unit"""
|
| 1422 |
+
entries = [u for u in self._unit_history if u["unit_id"] == unit_id]
|
| 1423 |
+
if not entries:
|
| 1424 |
+
return False
|
| 1425 |
+
|
| 1426 |
+
total_len = sum(e["length"] for e in entries)
|
| 1427 |
+
if total_len <= 0:
|
| 1428 |
+
for e in entries:
|
| 1429 |
+
self._unit_history.remove(e)
|
| 1430 |
+
return False
|
| 1431 |
+
|
| 1432 |
+
if not self._drop_tokens_from_cache(total_len):
|
| 1433 |
+
return False
|
| 1434 |
+
|
| 1435 |
+
for e in entries:
|
| 1436 |
+
self._unit_history.remove(e)
|
| 1437 |
+
|
| 1438 |
+
return True
|
| 1439 |
+
|
| 1440 |
+
def _drop_next_unit(self) -> bool:
|
| 1441 |
+
"""remove the earliest non-system unit"""
|
| 1442 |
+
for entry in self._unit_history:
|
| 1443 |
+
unit_id = entry.get("unit_id")
|
| 1444 |
+
if unit_id is None:
|
| 1445 |
+
continue
|
| 1446 |
+
# skip system type
|
| 1447 |
+
if entry.get("type") == "system":
|
| 1448 |
+
continue
|
| 1449 |
+
if self._drop_unit(unit_id):
|
| 1450 |
+
return True
|
| 1451 |
+
return False
|
| 1452 |
+
|
| 1453 |
+
def enforce_window(self) -> bool:
|
| 1454 |
+
"""enforce sliding window strategy (same as single-mode, only look at cache length)
|
| 1455 |
+
|
| 1456 |
+
when cache length exceeds high water line, loop to remove the earliest unit,
|
| 1457 |
+
until cache length drops below the low water line.
|
| 1458 |
+
"""
|
| 1459 |
+
if not self._window_enabled:
|
| 1460 |
+
return False
|
| 1461 |
+
|
| 1462 |
+
cfg = self._window_config
|
| 1463 |
+
cache_len_before = self.get_cache_length()
|
| 1464 |
+
|
| 1465 |
+
if cache_len_before <= cfg.basic_window_high_tokens:
|
| 1466 |
+
return False # not above high water line, no trigger
|
| 1467 |
+
|
| 1468 |
+
dropped_count = 0
|
| 1469 |
+
cache_len = cache_len_before
|
| 1470 |
+
while cache_len > cfg.basic_window_low_tokens:
|
| 1471 |
+
if not self._drop_next_unit():
|
| 1472 |
+
break
|
| 1473 |
+
dropped_count += 1
|
| 1474 |
+
cache_len = self.get_cache_length()
|
| 1475 |
+
|
| 1476 |
+
if dropped_count > 0:
|
| 1477 |
+
# update statistics counters
|
| 1478 |
+
self._sliding_event_count += 1
|
| 1479 |
+
self._total_dropped_tokens += cache_len_before - cache_len
|
| 1480 |
+
self._total_dropped_units += dropped_count
|
| 1481 |
+
|
| 1482 |
+
# consistency check
|
| 1483 |
+
expected = self._system_preserve_length + sum(u["length"] for u in self._unit_history)
|
| 1484 |
+
is_consistent = expected == cache_len
|
| 1485 |
+
if not is_consistent:
|
| 1486 |
+
logger.error(
|
| 1487 |
+
"CONSISTENCY ERROR! preserve=%d + sum(units)=%d != cache=%d, offset=%d",
|
| 1488 |
+
self._system_preserve_length,
|
| 1489 |
+
sum(u["length"] for u in self._unit_history),
|
| 1490 |
+
cache_len,
|
| 1491 |
+
self._position_offset,
|
| 1492 |
+
)
|
| 1493 |
+
|
| 1494 |
+
return dropped_count > 0
|
| 1495 |
+
|
| 1496 |
+
# context preserving sliding window methods
|
| 1497 |
+
|
| 1498 |
+
def register_system_prompt_with_context(
|
| 1499 |
+
self,
|
| 1500 |
+
suffix_token_ids: Optional[List[int]] = None,
|
| 1501 |
+
context_previous_marker: str = "\n\nprevious: ",
|
| 1502 |
+
):
|
| 1503 |
+
"""register system prompt (with context preserving mode)
|
| 1504 |
+
|
| 1505 |
+
initial cache layout: [prefix] [suffix] [units...]
|
| 1506 |
+
after first sliding window: [prefix] [context_previous_marker + content] [suffix] [units...]
|
| 1507 |
+
|
| 1508 |
+
when calling this method, cache should only have prefix (without previous marker)
|
| 1509 |
+
suffix will be fed in later
|
| 1510 |
+
|
| 1511 |
+
Args:
|
| 1512 |
+
suffix_token_ids: suffix token ids (e.g. id of <|im_end|>)
|
| 1513 |
+
context_previous_marker: previous marker prefix, e.g. "\\n\\nprevious: "
|
| 1514 |
+
"""
|
| 1515 |
+
# prefix = current cache content (fixed, without previous marker)
|
| 1516 |
+
self._preserve_prefix_length = self.get_cache_length()
|
| 1517 |
+
self._previous_content_length = 0 # initially no previous content
|
| 1518 |
+
self._suffix_token_ids = suffix_token_ids or []
|
| 1519 |
+
# total preserve length = prefix + suffix (initially no previous)
|
| 1520 |
+
self._system_preserve_length = self._preserve_prefix_length + len(self._suffix_token_ids)
|
| 1521 |
+
|
| 1522 |
+
# initialize previous related states
|
| 1523 |
+
self._previous_marker = context_previous_marker
|
| 1524 |
+
self._previous_marker_token_ids = (
|
| 1525 |
+
self.tokenizer.encode(context_previous_marker, add_special_tokens=False) if self.tokenizer else []
|
| 1526 |
+
)
|
| 1527 |
+
self._has_previous = False
|
| 1528 |
+
self._previous_text = ""
|
| 1529 |
+
self._previous_token_ids = []
|
| 1530 |
+
|
| 1531 |
+
def _extract_generated_text(self, units: List[Dict[str, Any]]) -> Tuple[str, List[int]]:
|
| 1532 |
+
"""extract generated text and token ids from units
|
| 1533 |
+
|
| 1534 |
+
Args:
|
| 1535 |
+
units: list of units to extract
|
| 1536 |
+
|
| 1537 |
+
Returns:
|
| 1538 |
+
(text, token_ids): concatenated text and token ids (filtered out special tokens)
|
| 1539 |
+
"""
|
| 1540 |
+
text_parts = []
|
| 1541 |
+
token_ids = []
|
| 1542 |
+
|
| 1543 |
+
for u in units:
|
| 1544 |
+
# only keep generated content of non-listen units
|
| 1545 |
+
if u.get("is_listen", False):
|
| 1546 |
+
continue
|
| 1547 |
+
gen_text = u.get("generated_text", "")
|
| 1548 |
+
gen_tokens = u.get("generated_tokens", [])
|
| 1549 |
+
|
| 1550 |
+
# filter out special tokens from text
|
| 1551 |
+
if gen_text:
|
| 1552 |
+
clean_text = gen_text
|
| 1553 |
+
for st in self._all_special_tokens_text:
|
| 1554 |
+
clean_text = clean_text.replace(st, "")
|
| 1555 |
+
if clean_text.strip():
|
| 1556 |
+
text_parts.append(clean_text)
|
| 1557 |
+
|
| 1558 |
+
# filter out special tokens
|
| 1559 |
+
if gen_tokens:
|
| 1560 |
+
filtered_tokens = [t for t in gen_tokens if t not in self._all_special_ids]
|
| 1561 |
+
token_ids.extend(filtered_tokens)
|
| 1562 |
+
|
| 1563 |
+
return "".join(text_parts), token_ids
|
| 1564 |
+
|
| 1565 |
+
def _rebuild_cache_with_previous(
|
| 1566 |
+
self,
|
| 1567 |
+
new_previous_tokens: List[int],
|
| 1568 |
+
units_to_keep_len: Optional[int] = None,
|
| 1569 |
+
) -> bool:
|
| 1570 |
+
"""rebuild cache, insert new previous content between prefix and suffix
|
| 1571 |
+
|
| 1572 |
+
cache layout change:
|
| 1573 |
+
[prefix] [old_prev] [suffix] [old_units] → [prefix] [new_prev] [suffix] [remaining_units]
|
| 1574 |
+
|
| 1575 |
+
Args:
|
| 1576 |
+
new_previous_tokens: new previous token ids
|
| 1577 |
+
units_to_keep_len: length of units to keep (from cache end backwards)
|
| 1578 |
+
if None, calculate based on unit_history
|
| 1579 |
+
|
| 1580 |
+
Returns:
|
| 1581 |
+
whether successful rebuild
|
| 1582 |
+
"""
|
| 1583 |
+
if self.cache is None:
|
| 1584 |
+
return False
|
| 1585 |
+
|
| 1586 |
+
old_previous_len = self._previous_content_length
|
| 1587 |
+
new_previous_len = len(new_previous_tokens)
|
| 1588 |
+
suffix_len = len(self._suffix_token_ids)
|
| 1589 |
+
total_cache_len = self.get_cache_length()
|
| 1590 |
+
|
| 1591 |
+
# calculate length of units to keep
|
| 1592 |
+
if units_to_keep_len is None:
|
| 1593 |
+
units_to_keep_len = sum(u["length"] for u in self._unit_history)
|
| 1594 |
+
|
| 1595 |
+
# special case: if previous is unchanged (new and old are empty), no need to rebuild prefix+suffix part of cache
|
| 1596 |
+
# but still need to reindex units RoPE (because a unit was deleted, position changed)
|
| 1597 |
+
if new_previous_len == 0 and old_previous_len == 0:
|
| 1598 |
+
# cache layout: [prefix(7)] [suffix(1)] [units...]
|
| 1599 |
+
# only keep prefix + suffix + remaining_units
|
| 1600 |
+
preserve_len = self._preserve_prefix_length + suffix_len
|
| 1601 |
+
|
| 1602 |
+
# simply slice cache: [prefix+suffix] + [remaining_units]
|
| 1603 |
+
# remaining_units in cache end
|
| 1604 |
+
if units_to_keep_len > 0:
|
| 1605 |
+
# [0:preserve_len] + [total-units_to_keep_len:total]
|
| 1606 |
+
prefix_suffix_cache = self._slice_cache(0, preserve_len)
|
| 1607 |
+
units_cache = self._slice_cache(total_cache_len - units_to_keep_len, None)
|
| 1608 |
+
|
| 1609 |
+
# calculate number of dropped tokens
|
| 1610 |
+
dropped_tokens = total_cache_len - preserve_len - units_to_keep_len
|
| 1611 |
+
|
| 1612 |
+
# reindex units RoPE: position from (preserve_len + dropped_tokens) to preserve_len
|
| 1613 |
+
# note: no position_offset, because cache position has been compressed (from 0 start)
|
| 1614 |
+
if dropped_tokens > 0:
|
| 1615 |
+
old_start = preserve_len + dropped_tokens
|
| 1616 |
+
new_start = preserve_len
|
| 1617 |
+
units_cache = self._reindex_rope_for_cache(units_cache, old_start, new_start, units_to_keep_len)
|
| 1618 |
+
|
| 1619 |
+
self.cache = self._concat_caches(prefix_suffix_cache, units_cache)
|
| 1620 |
+
else:
|
| 1621 |
+
self.cache = self._slice_cache(0, preserve_len)
|
| 1622 |
+
|
| 1623 |
+
return True
|
| 1624 |
+
|
| 1625 |
+
# 1. get prefix cache (fixed)
|
| 1626 |
+
prefix_end = self._preserve_prefix_length
|
| 1627 |
+
prefix_cache = self._slice_cache(0, prefix_end)
|
| 1628 |
+
|
| 1629 |
+
# 2. get units cache to keep (from end)
|
| 1630 |
+
units_start_in_old_cache = total_cache_len - units_to_keep_len
|
| 1631 |
+
units_cache = None
|
| 1632 |
+
if units_to_keep_len > 0:
|
| 1633 |
+
units_cache = self._slice_cache(units_start_in_old_cache, None)
|
| 1634 |
+
|
| 1635 |
+
# 3. calculate new previous + suffix cache (needs forward)
|
| 1636 |
+
# merge previous tokens and suffix tokens
|
| 1637 |
+
prev_suffix_tokens = new_previous_tokens + self._suffix_token_ids
|
| 1638 |
+
prev_suffix_len = len(prev_suffix_tokens)
|
| 1639 |
+
|
| 1640 |
+
new_prefix_prev_suffix_cache = prefix_cache
|
| 1641 |
+
if prev_suffix_len > 0:
|
| 1642 |
+
# Embed tokens
|
| 1643 |
+
prev_suffix_embeds = self.embed_tokens(prev_suffix_tokens)
|
| 1644 |
+
# calculate start position (after prefix)
|
| 1645 |
+
start_pos = self._preserve_prefix_length + self._position_offset
|
| 1646 |
+
|
| 1647 |
+
# forward calculate KV cache
|
| 1648 |
+
with torch.no_grad():
|
| 1649 |
+
device = prev_suffix_embeds.device
|
| 1650 |
+
position_ids = torch.arange(
|
| 1651 |
+
start_pos,
|
| 1652 |
+
start_pos + prev_suffix_len,
|
| 1653 |
+
device=device,
|
| 1654 |
+
).unsqueeze(0)
|
| 1655 |
+
|
| 1656 |
+
# use prefix cache as past_key_values
|
| 1657 |
+
outputs = self.m(
|
| 1658 |
+
inputs_embeds=(
|
| 1659 |
+
prev_suffix_embeds.unsqueeze(0) if prev_suffix_embeds.dim() == 2 else prev_suffix_embeds
|
| 1660 |
+
),
|
| 1661 |
+
position_ids=position_ids,
|
| 1662 |
+
past_key_values=prefix_cache,
|
| 1663 |
+
use_cache=True,
|
| 1664 |
+
return_dict=True,
|
| 1665 |
+
)
|
| 1666 |
+
# new cache contains prefix + new_previous + suffix
|
| 1667 |
+
new_prefix_prev_suffix_cache = outputs.past_key_values
|
| 1668 |
+
|
| 1669 |
+
# 4. adjust units cache RoPE
|
| 1670 |
+
# new layout: [prefix] [new_prev] [suffix] [units]
|
| 1671 |
+
# note: no position_offset, because cache position has been compressed (from 0 start)
|
| 1672 |
+
new_system_total = prefix_end + new_previous_len + suffix_len
|
| 1673 |
+
if units_cache is not None and self._get_cache_len(units_cache) > 0:
|
| 1674 |
+
old_start = units_start_in_old_cache
|
| 1675 |
+
new_start = new_system_total
|
| 1676 |
+
|
| 1677 |
+
if old_start != new_start:
|
| 1678 |
+
units_cache = self._reindex_rope_for_cache(units_cache, old_start, new_start, units_to_keep_len)
|
| 1679 |
+
|
| 1680 |
+
# 5. concatenate new cache
|
| 1681 |
+
if units_cache is not None and self._get_cache_len(units_cache) > 0:
|
| 1682 |
+
self.cache = self._concat_caches(new_prefix_prev_suffix_cache, units_cache)
|
| 1683 |
+
else:
|
| 1684 |
+
self.cache = new_prefix_prev_suffix_cache
|
| 1685 |
+
|
| 1686 |
+
# 6. update length
|
| 1687 |
+
self._previous_content_length = new_previous_len
|
| 1688 |
+
# total preserve length = prefix + previous + suffix
|
| 1689 |
+
self._system_preserve_length = prefix_end + new_previous_len + suffix_len
|
| 1690 |
+
|
| 1691 |
+
# print detailed cache layout information
|
| 1692 |
+
prev_text_preview = self._previous_text[:50] + "..." if len(self._previous_text) > 50 else self._previous_text
|
| 1693 |
+
suffix_preview = self.tokenizer.decode(self._suffix_token_ids) if self._suffix_token_ids else ""
|
| 1694 |
+
return True
|
| 1695 |
+
|
| 1696 |
+
def _slice_cache(self, start: int, end: Optional[int], clone: bool = True):
|
| 1697 |
+
"""slice cache
|
| 1698 |
+
|
| 1699 |
+
Args:
|
| 1700 |
+
start: start position
|
| 1701 |
+
end: end position (None means to end)
|
| 1702 |
+
clone: whether to clone (default True, to prevent shared memory issues)
|
| 1703 |
+
"""
|
| 1704 |
+
if self.cache is None:
|
| 1705 |
+
return None
|
| 1706 |
+
if isinstance(self.cache, DynamicCache):
|
| 1707 |
+
# DynamicCache
|
| 1708 |
+
new_key_cache = [
|
| 1709 |
+
k[:, :, start:end, :].clone() if clone else k[:, :, start:end, :] for k in self.cache.key_cache
|
| 1710 |
+
]
|
| 1711 |
+
new_value_cache = [
|
| 1712 |
+
v[:, :, start:end, :].clone() if clone else v[:, :, start:end, :] for v in self.cache.value_cache
|
| 1713 |
+
]
|
| 1714 |
+
new_cache = DynamicCache()
|
| 1715 |
+
new_cache.key_cache = new_key_cache
|
| 1716 |
+
new_cache.value_cache = new_value_cache
|
| 1717 |
+
return new_cache
|
| 1718 |
+
else:
|
| 1719 |
+
# Tuple cache
|
| 1720 |
+
if clone:
|
| 1721 |
+
return tuple(
|
| 1722 |
+
(layer[0][:, :, start:end, :].clone(), layer[1][:, :, start:end, :].clone()) for layer in self.cache
|
| 1723 |
+
)
|
| 1724 |
+
else:
|
| 1725 |
+
return tuple((layer[0][:, :, start:end, :], layer[1][:, :, start:end, :]) for layer in self.cache)
|
| 1726 |
+
|
| 1727 |
+
@staticmethod
|
| 1728 |
+
def _get_cache_len(cache) -> int:
|
| 1729 |
+
if cache is None:
|
| 1730 |
+
return 0
|
| 1731 |
+
if isinstance(cache, DynamicCache):
|
| 1732 |
+
if len(cache.key_cache) > 0 and cache.key_cache[0].numel() > 0:
|
| 1733 |
+
return cache.key_cache[0].shape[2]
|
| 1734 |
+
return 0
|
| 1735 |
+
|
| 1736 |
+
if cache and cache[0] and cache[0][0] is not None:
|
| 1737 |
+
return cache[0][0].shape[2]
|
| 1738 |
+
return 0
|
| 1739 |
+
|
| 1740 |
+
@staticmethod
|
| 1741 |
+
def _concat_caches(cache1, cache2):
|
| 1742 |
+
if cache1 is None:
|
| 1743 |
+
return cache2
|
| 1744 |
+
if cache2 is None:
|
| 1745 |
+
return cache1
|
| 1746 |
+
|
| 1747 |
+
if isinstance(cache1, DynamicCache):
|
| 1748 |
+
new_cache = DynamicCache()
|
| 1749 |
+
new_cache.key_cache = [torch.cat([k1, k2], dim=2) for k1, k2 in zip(cache1.key_cache, cache2.key_cache)]
|
| 1750 |
+
new_cache.value_cache = [
|
| 1751 |
+
torch.cat([v1, v2], dim=2) for v1, v2 in zip(cache1.value_cache, cache2.value_cache)
|
| 1752 |
+
]
|
| 1753 |
+
return new_cache
|
| 1754 |
+
else:
|
| 1755 |
+
return tuple(
|
| 1756 |
+
(
|
| 1757 |
+
torch.cat([layer1[0], layer2[0]], dim=2),
|
| 1758 |
+
torch.cat([layer1[1], layer2[1]], dim=2),
|
| 1759 |
+
)
|
| 1760 |
+
for layer1, layer2 in zip(cache1, cache2)
|
| 1761 |
+
)
|
| 1762 |
+
|
| 1763 |
+
def _reindex_rope_for_cache(self, cache, old_start: int, new_start: int, length: int):
|
| 1764 |
+
"""reindex RoPE position for cache"""
|
| 1765 |
+
if cache is None or length <= 0:
|
| 1766 |
+
return cache
|
| 1767 |
+
|
| 1768 |
+
if isinstance(cache, DynamicCache):
|
| 1769 |
+
device = cache.key_cache[0].device if cache.key_cache else None
|
| 1770 |
+
else:
|
| 1771 |
+
device = cache[0][0].device if cache and cache[0] else None
|
| 1772 |
+
|
| 1773 |
+
if device is None:
|
| 1774 |
+
return cache
|
| 1775 |
+
|
| 1776 |
+
old_positions = torch.arange(old_start, old_start + length, device=device, dtype=torch.long)
|
| 1777 |
+
new_positions = torch.arange(new_start, new_start + length, device=device, dtype=torch.long)
|
| 1778 |
+
|
| 1779 |
+
rope_theta = self._get_rope_theta()
|
| 1780 |
+
|
| 1781 |
+
if isinstance(cache, DynamicCache):
|
| 1782 |
+
new_key_cache = []
|
| 1783 |
+
for k in cache.key_cache:
|
| 1784 |
+
new_k = realign_rotary_suffix(k, old_positions, new_positions, rope_theta, self._rope_inv_freq_cache)
|
| 1785 |
+
new_key_cache.append(new_k)
|
| 1786 |
+
cache.key_cache = new_key_cache
|
| 1787 |
+
return cache
|
| 1788 |
+
else:
|
| 1789 |
+
new_cache = []
|
| 1790 |
+
for layer in cache:
|
| 1791 |
+
new_k = realign_rotary_suffix(
|
| 1792 |
+
layer[0], old_positions, new_positions, rope_theta, self._rope_inv_freq_cache
|
| 1793 |
+
)
|
| 1794 |
+
new_cache.append((new_k, layer[1]))
|
| 1795 |
+
return tuple(new_cache)
|
| 1796 |
+
|
| 1797 |
+
def _update_previous(
|
| 1798 |
+
self,
|
| 1799 |
+
new_text: str,
|
| 1800 |
+
new_tokens: List[int],
|
| 1801 |
+
max_tokens: int,
|
| 1802 |
+
) -> None:
|
| 1803 |
+
"""update previous context (also update cache)
|
| 1804 |
+
|
| 1805 |
+
when first sliding window, dynamically add marker + text, subsequent sliding window append text
|
| 1806 |
+
when content exceeds max_tokens, truncate content (keep marker)
|
| 1807 |
+
rebuild cache to maintain consistency
|
| 1808 |
+
|
| 1809 |
+
Args:
|
| 1810 |
+
new_text: new text
|
| 1811 |
+
new_tokens: new token ids
|
| 1812 |
+
max_tokens: previous content maximum token count (without marker)
|
| 1813 |
+
"""
|
| 1814 |
+
marker_len = len(self._previous_marker_token_ids)
|
| 1815 |
+
tokens_to_drop = 0
|
| 1816 |
+
|
| 1817 |
+
# if no new content, do not add marker, but still need to rebuild cache
|
| 1818 |
+
if not new_tokens and not new_text:
|
| 1819 |
+
# still need to rebuild cache (because a unit was deleted)
|
| 1820 |
+
self._rebuild_cache_with_previous(self._previous_token_ids)
|
| 1821 |
+
return
|
| 1822 |
+
|
| 1823 |
+
if not self._has_previous:
|
| 1824 |
+
# when first has actual content: add marker + text
|
| 1825 |
+
self._previous_text = new_text
|
| 1826 |
+
self._previous_token_ids = self._previous_marker_token_ids.copy() + new_tokens
|
| 1827 |
+
self._has_previous = True
|
| 1828 |
+
else:
|
| 1829 |
+
# subsequent sliding window: append text to previous
|
| 1830 |
+
self._previous_text += new_text
|
| 1831 |
+
self._previous_token_ids.extend(new_tokens)
|
| 1832 |
+
|
| 1833 |
+
# calculate token count of content (without marker)
|
| 1834 |
+
content_token_count = len(self._previous_token_ids) - marker_len
|
| 1835 |
+
|
| 1836 |
+
# check if need to truncate content (keep marker)
|
| 1837 |
+
if content_token_count > max_tokens:
|
| 1838 |
+
# truncate left content, keep marker + latest max_tokens content
|
| 1839 |
+
tokens_to_drop = content_token_count - max_tokens
|
| 1840 |
+
old_text = self._previous_text
|
| 1841 |
+
# keep marker + truncated content
|
| 1842 |
+
content_tokens = self._previous_token_ids[marker_len + tokens_to_drop :]
|
| 1843 |
+
self._previous_token_ids = self._previous_marker_token_ids.copy() + content_tokens
|
| 1844 |
+
# redecode text (only decode content part)
|
| 1845 |
+
try:
|
| 1846 |
+
self._previous_text = self.tokenizer.decode(
|
| 1847 |
+
content_tokens,
|
| 1848 |
+
skip_special_tokens=True,
|
| 1849 |
+
)
|
| 1850 |
+
except Exception as e:
|
| 1851 |
+
logger.warning("_update_previous: decode failed: %s", e)
|
| 1852 |
+
|
| 1853 |
+
# rebuild cache
|
| 1854 |
+
self._rebuild_cache_with_previous(self._previous_token_ids)
|
| 1855 |
+
|
| 1856 |
+
def _drop_unit_with_context(
|
| 1857 |
+
self,
|
| 1858 |
+
unit_id: int,
|
| 1859 |
+
max_previous_tokens: int,
|
| 1860 |
+
) -> Tuple[bool, str, List[int]]:
|
| 1861 |
+
"""remove specified unit and return its generated content (for context preserving)
|
| 1862 |
+
|
| 1863 |
+
process:
|
| 1864 |
+
1. extract generated content of unit
|
| 1865 |
+
2. remove unit from cache (without prefix+previous)
|
| 1866 |
+
3. append generated content to previous
|
| 1867 |
+
4. rebuild cache (in _update_previous)
|
| 1868 |
+
|
| 1869 |
+
Args:
|
| 1870 |
+
unit_id: unit ID to remove
|
| 1871 |
+
max_previous_tokens: previous maximum token count
|
| 1872 |
+
|
| 1873 |
+
Returns:
|
| 1874 |
+
(success, extracted_text, extracted_tokens): whether successful, extracted text and tokens
|
| 1875 |
+
"""
|
| 1876 |
+
entries = [u for u in self._unit_history if u["unit_id"] == unit_id]
|
| 1877 |
+
if not entries:
|
| 1878 |
+
return False, "", []
|
| 1879 |
+
|
| 1880 |
+
# extract generated content
|
| 1881 |
+
extracted_text, extracted_tokens = self._extract_generated_text(entries)
|
| 1882 |
+
|
| 1883 |
+
# calculate total length
|
| 1884 |
+
total_len = sum(e["length"] for e in entries)
|
| 1885 |
+
if total_len <= 0:
|
| 1886 |
+
for e in entries:
|
| 1887 |
+
self._unit_history.remove(e)
|
| 1888 |
+
return False, extracted_text, extracted_tokens
|
| 1889 |
+
|
| 1890 |
+
cache_before = self.get_cache_length()
|
| 1891 |
+
|
| 1892 |
+
# remove from unit_history (record for later processing)
|
| 1893 |
+
for e in entries:
|
| 1894 |
+
self._unit_history.remove(e)
|
| 1895 |
+
|
| 1896 |
+
# note: here no longer call _drop_tokens_from_cache
|
| 1897 |
+
# because _update_previous will rebuild the entire cache
|
| 1898 |
+
|
| 1899 |
+
# update previous (also rebuild cache)
|
| 1900 |
+
self._update_previous(extracted_text, extracted_tokens, max_previous_tokens)
|
| 1901 |
+
|
| 1902 |
+
return True, extracted_text, extracted_tokens
|
| 1903 |
+
|
| 1904 |
+
def _drop_next_unit_with_context(self, max_previous_tokens: int) -> bool:
|
| 1905 |
+
"""remove the earliest non-system unit (with context preserving)"""
|
| 1906 |
+
for entry in self._unit_history:
|
| 1907 |
+
unit_id = entry.get("unit_id")
|
| 1908 |
+
if unit_id is None:
|
| 1909 |
+
continue
|
| 1910 |
+
if entry.get("type") == "system":
|
| 1911 |
+
continue
|
| 1912 |
+
success, _, _ = self._drop_unit_with_context(unit_id, max_previous_tokens)
|
| 1913 |
+
if success:
|
| 1914 |
+
return True
|
| 1915 |
+
return False
|
| 1916 |
+
|
| 1917 |
+
def enforce_window_with_context(self) -> bool:
|
| 1918 |
+
"""context preserving sliding window execution
|
| 1919 |
+
|
| 1920 |
+
when unit count exceeds max_units, remove the earliest unit,
|
| 1921 |
+
and accumulate its generated content to previous.
|
| 1922 |
+
Cache will be automatically rebuilt in _update_previous.
|
| 1923 |
+
|
| 1924 |
+
Returns:
|
| 1925 |
+
whether sliding window is executed
|
| 1926 |
+
"""
|
| 1927 |
+
if not self._window_enabled:
|
| 1928 |
+
return False
|
| 1929 |
+
|
| 1930 |
+
cfg = self._window_config
|
| 1931 |
+
|
| 1932 |
+
if cfg.sliding_window_mode != "context":
|
| 1933 |
+
# if not context mode, fallback to basic sliding window
|
| 1934 |
+
return self.enforce_window()
|
| 1935 |
+
|
| 1936 |
+
cache_len_before = self.get_cache_length()
|
| 1937 |
+
units_before = len(self._unit_history)
|
| 1938 |
+
|
| 1939 |
+
# context preserving mode: only check if unit count exceeds limit
|
| 1940 |
+
# (previous exceeds limit in _update_previous will automatically truncate left)
|
| 1941 |
+
if units_before <= cfg.context_max_units:
|
| 1942 |
+
return False
|
| 1943 |
+
|
| 1944 |
+
# sliding window loop: remove unit until count ≤ max_units
|
| 1945 |
+
dropped_count = 0
|
| 1946 |
+
while len(self._unit_history) > cfg.context_max_units:
|
| 1947 |
+
if not self._drop_next_unit_with_context(cfg.context_previous_max_tokens):
|
| 1948 |
+
break
|
| 1949 |
+
|
| 1950 |
+
dropped_count += 1
|
| 1951 |
+
|
| 1952 |
+
cache_len_after = self.get_cache_length()
|
| 1953 |
+
|
| 1954 |
+
if dropped_count > 0:
|
| 1955 |
+
# update statistics counter
|
| 1956 |
+
self._sliding_event_count += 1
|
| 1957 |
+
self._total_dropped_tokens += cache_len_before - cache_len_after
|
| 1958 |
+
self._total_dropped_units += dropped_count
|
| 1959 |
+
|
| 1960 |
+
# consistency check
|
| 1961 |
+
expected = self._system_preserve_length + sum(u["length"] for u in self._unit_history)
|
| 1962 |
+
|
| 1963 |
+
return dropped_count > 0
|
| 1964 |
+
|
| 1965 |
+
def get_previous_context(self) -> Tuple[str, List[int]]:
|
| 1966 |
+
"""get current accumulated previous context
|
| 1967 |
+
|
| 1968 |
+
Returns:
|
| 1969 |
+
(previous_text, previous_token_ids): current accumulated text and token ids
|
| 1970 |
+
"""
|
| 1971 |
+
return self._previous_text, self._previous_token_ids.copy()
|
| 1972 |
+
|
| 1973 |
+
def get_window_stats(self) -> Dict[str, Any]:
|
| 1974 |
+
"""get sliding window statistics"""
|
| 1975 |
+
unit_lengths = [u["length"] for u in self._unit_history]
|
| 1976 |
+
return {
|
| 1977 |
+
"cache_length": self.get_cache_length(),
|
| 1978 |
+
"unit_count": len(self._unit_history),
|
| 1979 |
+
"unit_lengths": unit_lengths,
|
| 1980 |
+
"unit_total_length": sum(unit_lengths),
|
| 1981 |
+
"system_preserve_length": self._system_preserve_length,
|
| 1982 |
+
"position_offset": self._position_offset,
|
| 1983 |
+
"window_enabled": self._window_enabled,
|
| 1984 |
+
"total_generated_tokens": self.get_total_generated_tokens(),
|
| 1985 |
+
"pending_unit_id": self._pending_unit_id,
|
| 1986 |
+
"next_unit_id": self._next_unit_id,
|
| 1987 |
+
"config": {
|
| 1988 |
+
"sliding_window_mode": self._window_config.sliding_window_mode,
|
| 1989 |
+
"basic_window_high_tokens": self._window_config.basic_window_high_tokens,
|
| 1990 |
+
"basic_window_low_tokens": self._window_config.basic_window_low_tokens,
|
| 1991 |
+
"context_previous_max_tokens": self._window_config.context_previous_max_tokens,
|
| 1992 |
+
"context_max_units": self._window_config.context_max_units,
|
| 1993 |
+
},
|
| 1994 |
+
# context preserving related
|
| 1995 |
+
"preserve_prefix_length": self._preserve_prefix_length,
|
| 1996 |
+
"previous_content_length": self._previous_content_length,
|
| 1997 |
+
"suffix_token_count": len(self._suffix_token_ids),
|
| 1998 |
+
"previous_text_length": len(self._previous_text),
|
| 1999 |
+
"previous_token_count": len(self._previous_token_ids),
|
| 2000 |
+
"has_system_template": self._system_prompt_template is not None,
|
| 2001 |
+
}
|
| 2002 |
+
|
| 2003 |
+
def _verify_consistency(self) -> bool:
|
| 2004 |
+
"""verify unit history and cache length consistency"""
|
| 2005 |
+
expected = self._system_preserve_length + sum(u["length"] for u in self._unit_history)
|
| 2006 |
+
actual = self.get_cache_length()
|
| 2007 |
+
return expected == actual
|
| 2008 |
+
|
| 2009 |
+
def print_verification_summary(self) -> Dict[str, Any]:
|
| 2010 |
+
"""print verification summary (for comparing off/basic/context mode)
|
| 2011 |
+
|
| 2012 |
+
Returns:
|
| 2013 |
+
dictionary containing key verification data
|
| 2014 |
+
"""
|
| 2015 |
+
cfg = self._window_config
|
| 2016 |
+
|
| 2017 |
+
# collect all generated text
|
| 2018 |
+
all_generated_text = []
|
| 2019 |
+
all_generated_tokens = []
|
| 2020 |
+
for u in self._unit_history:
|
| 2021 |
+
if not u.get("is_listen", False):
|
| 2022 |
+
gen_text = u.get("generated_text", "")
|
| 2023 |
+
gen_tokens = u.get("generated_tokens", [])
|
| 2024 |
+
if gen_text:
|
| 2025 |
+
all_generated_text.append(gen_text)
|
| 2026 |
+
if gen_tokens:
|
| 2027 |
+
all_generated_tokens.extend(gen_tokens)
|
| 2028 |
+
|
| 2029 |
+
combined_text = "".join(all_generated_text)
|
| 2030 |
+
|
| 2031 |
+
summary = {
|
| 2032 |
+
"mode": cfg.sliding_window_mode,
|
| 2033 |
+
"final_cache_length": self.get_cache_length(),
|
| 2034 |
+
"final_unit_count": len(self._unit_history),
|
| 2035 |
+
"sliding_event_count": self._sliding_event_count,
|
| 2036 |
+
"total_dropped_tokens": self._total_dropped_tokens,
|
| 2037 |
+
"total_dropped_units": self._total_dropped_units,
|
| 2038 |
+
"total_generated_tokens": len(all_generated_tokens),
|
| 2039 |
+
"generated_text": combined_text,
|
| 2040 |
+
"previous_text": self._previous_text,
|
| 2041 |
+
"previous_token_count": len(self._previous_token_ids),
|
| 2042 |
+
"position_offset": self._position_offset,
|
| 2043 |
+
"system_preserve_length": self._system_preserve_length,
|
| 2044 |
+
}
|
| 2045 |
+
|
| 2046 |
+
return summary
|
| 2047 |
+
|
| 2048 |
+
def set_window_config(self, config: DuplexWindowConfig) -> None:
|
| 2049 |
+
"""set sliding window configuration"""
|
| 2050 |
+
self._window_config = config
|
| 2051 |
+
|
| 2052 |
+
def set_window_enabled(self, enabled: bool) -> None:
|
| 2053 |
+
"""enable/disable sliding window"""
|
| 2054 |
+
old_enabled = self._window_enabled
|
| 2055 |
+
self._window_enabled = enabled
|
| 2056 |
+
|
| 2057 |
+
def get_context(self):
|
| 2058 |
+
return self.context
|
| 2059 |
+
|
| 2060 |
+
def embed_token(self, tid):
|
| 2061 |
+
if isinstance(tid, int):
|
| 2062 |
+
tid = torch.tensor([tid], device=self.m.device)
|
| 2063 |
+
return self.m.model.embed_tokens(tid)
|
| 2064 |
+
|
| 2065 |
+
def embed_tokens(self, token_ids: List[int]) -> torch.Tensor:
|
| 2066 |
+
"""batch embed multiple tokens
|
| 2067 |
+
|
| 2068 |
+
Args:
|
| 2069 |
+
token_ids: list of token ids
|
| 2070 |
+
|
| 2071 |
+
Returns:
|
| 2072 |
+
embeddings tensor [L, H]
|
| 2073 |
+
"""
|
| 2074 |
+
if not token_ids:
|
| 2075 |
+
return torch.empty(0, self.m.config.hidden_size, device=self.m.device)
|
| 2076 |
+
tids = torch.tensor(token_ids, device=self.m.device)
|
| 2077 |
+
return self.m.model.embed_tokens(tids)
|
| 2078 |
+
|
| 2079 |
+
@torch.no_grad()
|
| 2080 |
+
def feed(self, embeds: torch.Tensor, return_logits: bool = False):
|
| 2081 |
+
"""
|
| 2082 |
+
embeds : [L, H] —— new embedding sequence fed into model at once
|
| 2083 |
+
"""
|
| 2084 |
+
L = embeds.size(0)
|
| 2085 |
+
device = embeds.device
|
| 2086 |
+
|
| 2087 |
+
past_len = self.get_cache_length()
|
| 2088 |
+
pos_ids = torch.arange(past_len, past_len + L, device=device).unsqueeze(0) # [1, L]
|
| 2089 |
+
|
| 2090 |
+
out = self.m(
|
| 2091 |
+
inputs_embeds=embeds.unsqueeze(0), # [1, L, H]
|
| 2092 |
+
position_ids=pos_ids,
|
| 2093 |
+
past_key_values=self.cache,
|
| 2094 |
+
# use_cache = True,
|
| 2095 |
+
return_dict=True,
|
| 2096 |
+
output_hidden_states=True,
|
| 2097 |
+
# attention_mask=attention_mask
|
| 2098 |
+
)
|
| 2099 |
+
self.cache = out.past_key_values
|
| 2100 |
+
|
| 2101 |
+
if return_logits:
|
| 2102 |
+
logits = self.m.lm_head(out.hidden_states[-1])[:, -1] # [1, vocab]
|
| 2103 |
+
return logits, out.hidden_states[-1]
|
| 2104 |
+
|
| 2105 |
+
@torch.no_grad()
|
| 2106 |
+
def decode(
|
| 2107 |
+
self,
|
| 2108 |
+
logits,
|
| 2109 |
+
mode: Literal["sampling", "greedy"] = "sampling",
|
| 2110 |
+
temperature=0.7,
|
| 2111 |
+
top_k=20,
|
| 2112 |
+
top_p=0.8,
|
| 2113 |
+
listen_top_k=None,
|
| 2114 |
+
listen_prob_scale=1.0,
|
| 2115 |
+
text_repetition_penalty=1.05,
|
| 2116 |
+
text_repetition_window_size=512,
|
| 2117 |
+
):
|
| 2118 |
+
"""
|
| 2119 |
+
Args:
|
| 2120 |
+
logits:
|
| 2121 |
+
mode: sampling or greedy
|
| 2122 |
+
temperature:
|
| 2123 |
+
top_k:
|
| 2124 |
+
top_p:
|
| 2125 |
+
listen_top_k: force listen_id to be in top-k to keep
|
| 2126 |
+
listen_prob_scale: multiply listen_id probability by a weight (<1 means decrease, >1 means increase)
|
| 2127 |
+
text_repetition_penalty: repetition penalty coefficient, >1.0 means decrease repetition, <1.0 means increase repetition
|
| 2128 |
+
text_repetition_window_size: repetition penalty window size
|
| 2129 |
+
|
| 2130 |
+
Sampling strategy:
|
| 2131 |
+
1. first sample all tokens with original logits (apply temperature)
|
| 2132 |
+
2. if sampled chunk_eos, return directly (keep the original model's decision of when to stop)
|
| 2133 |
+
3. if not sampled chunk_eos, mask it (set logit to -inf), continue sampling text tokens
|
| 2134 |
+
4. apply repetition penalty, top-k, top-p, etc. to the text tokens for the final sampling
|
| 2135 |
+
"""
|
| 2136 |
+
|
| 2137 |
+
logits = logits.clone()
|
| 2138 |
+
|
| 2139 |
+
# 0. independently check chunk_eos before sampling
|
| 2140 |
+
eos_id = self.chunk_eos_id
|
| 2141 |
+
|
| 2142 |
+
with torch.no_grad():
|
| 2143 |
+
if mode == "greedy":
|
| 2144 |
+
sampled_token = torch.argmax(logits[0]).item()
|
| 2145 |
+
else:
|
| 2146 |
+
original_probs = F.softmax(logits[0], dim=-1)
|
| 2147 |
+
sampled_token = torch.multinomial(original_probs, num_samples=1).item()
|
| 2148 |
+
|
| 2149 |
+
# if sampled chunk_eos, return directly
|
| 2150 |
+
if sampled_token == eos_id:
|
| 2151 |
+
next_token_id = torch.tensor([eos_id], device=logits.device)
|
| 2152 |
+
next_token_str = self.tokenizer.decode(next_token_id)
|
| 2153 |
+
|
| 2154 |
+
return next_token_id
|
| 2155 |
+
|
| 2156 |
+
# if not sampled chunk_eos, set its logit to -inf
|
| 2157 |
+
if self.forbidden_token_ids:
|
| 2158 |
+
logits[:, self.forbidden_token_ids] = float("-inf")
|
| 2159 |
+
|
| 2160 |
+
# 1. apply repetition penalty
|
| 2161 |
+
if text_repetition_penalty != 1.0 and len(self.generated_tokens) > 0:
|
| 2162 |
+
# get recent tokens (within window size) considering special tokens and normal tokens
|
| 2163 |
+
recent_tokens = self.generated_tokens[-text_repetition_window_size:]
|
| 2164 |
+
|
| 2165 |
+
# make it unique
|
| 2166 |
+
recent_tokens = list(set(recent_tokens))
|
| 2167 |
+
|
| 2168 |
+
# apply penalty to repeated tokens
|
| 2169 |
+
for token_id in recent_tokens:
|
| 2170 |
+
if token_id < logits.size(-1): # ensure token_id is in vocabulary range
|
| 2171 |
+
if text_repetition_penalty > 1.0:
|
| 2172 |
+
# penalize repetition: decrease logits
|
| 2173 |
+
logits[0, token_id] /= text_repetition_penalty
|
| 2174 |
+
else:
|
| 2175 |
+
# encourage repetition: increase logits
|
| 2176 |
+
logits[0, token_id] *= 1.0 / text_repetition_penalty
|
| 2177 |
+
|
| 2178 |
+
if listen_prob_scale != 1.0: # modify listen token logit separately
|
| 2179 |
+
logits[0, self.listen_id] *= listen_prob_scale
|
| 2180 |
+
|
| 2181 |
+
listen_rank = (logits[0] > logits[0, self.listen_id]).sum().item()
|
| 2182 |
+
|
| 2183 |
+
if listen_top_k is not None and listen_rank < listen_top_k: # listen_id is in top-k, return directly
|
| 2184 |
+
next_token_id = torch.tensor([self.listen_id], device=logits.device)
|
| 2185 |
+
next_token_str = self.tokenizer.decode(next_token_id)
|
| 2186 |
+
|
| 2187 |
+
if next_token_str == "<|listen|>":
|
| 2188 |
+
self.context += " "
|
| 2189 |
+
else:
|
| 2190 |
+
self.context += next_token_str
|
| 2191 |
+
|
| 2192 |
+
return next_token_id
|
| 2193 |
+
|
| 2194 |
+
if mode == "greedy":
|
| 2195 |
+
next_token_id = torch.argmax(logits, dim=-1)
|
| 2196 |
+
elif mode == "sampling":
|
| 2197 |
+
logits = logits / temperature
|
| 2198 |
+
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
| 2199 |
+
probs = F.softmax(logits, dim=-1)
|
| 2200 |
+
next_token_id = torch.multinomial(probs, num_samples=1).squeeze(1)
|
| 2201 |
+
else:
|
| 2202 |
+
raise ValueError(f"Unsupported decode mode: {mode}")
|
| 2203 |
+
|
| 2204 |
+
if next_token_id.item() not in self.special_token_ids:
|
| 2205 |
+
self.generated_tokens.append(next_token_id.item())
|
| 2206 |
+
else:
|
| 2207 |
+
self.generated_special_tokens.append(next_token_id.item())
|
| 2208 |
+
|
| 2209 |
+
return next_token_id
|
| 2210 |
+
|
| 2211 |
+
|
| 2212 |
+
def _download_url_to_tempfile(url: str, suffix: str = "", timeout: int = 60) -> str:
|
| 2213 |
+
"""
|
| 2214 |
+
Download a URL to a temporary file and return the path.
|
| 2215 |
+
|
| 2216 |
+
Args:
|
| 2217 |
+
url: HTTP/HTTPS URL to download
|
| 2218 |
+
suffix: File suffix (e.g., ".jpg", ".wav", ".mp4")
|
| 2219 |
+
timeout: Download timeout in seconds
|
| 2220 |
+
|
| 2221 |
+
Returns:
|
| 2222 |
+
Path to the downloaded temporary file
|
| 2223 |
+
"""
|
| 2224 |
+
import tempfile
|
| 2225 |
+
|
| 2226 |
+
import requests
|
| 2227 |
+
|
| 2228 |
+
response = requests.get(url, timeout=timeout)
|
| 2229 |
+
response.raise_for_status()
|
| 2230 |
+
|
| 2231 |
+
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as f:
|
| 2232 |
+
f.write(response.content)
|
| 2233 |
+
return f.name
|
| 2234 |
+
|
| 2235 |
+
|
| 2236 |
+
def _is_url(path: str) -> bool:
|
| 2237 |
+
return path.startswith(("http://", "https://"))
|
| 2238 |
+
|
| 2239 |
+
|
| 2240 |
+
def normalize_content_item(item) -> Union[str, Any, List[Any]]:
|
| 2241 |
+
"""Normalize structured content item to native format.
|
| 2242 |
+
|
| 2243 |
+
Supports:
|
| 2244 |
+
- Native format: str, PIL.Image, np.ndarray (pass through)
|
| 2245 |
+
- OpenAI structured format:
|
| 2246 |
+
- {"type": "text", "text": "..."} -> str
|
| 2247 |
+
- {"type": "image_url", "image_url": {"url": "..."}} -> PIL.Image
|
| 2248 |
+
- {"type": "audio_url", "audio_url": {"url": "..."}} -> np.ndarray
|
| 2249 |
+
- {"type": "video_url", "video_url": {"url": "...", ...}} -> List[Image, ndarray, ...]
|
| 2250 |
+
|
| 2251 |
+
URL formats supported:
|
| 2252 |
+
- Local file path: "/path/to/file.jpg"
|
| 2253 |
+
- HTTP/HTTPS URL: "https://example.com/image.jpg"
|
| 2254 |
+
|
| 2255 |
+
Args:
|
| 2256 |
+
item: Content item to normalize
|
| 2257 |
+
|
| 2258 |
+
Returns:
|
| 2259 |
+
Normalized item. For video_url, returns a tuple ("__video_contents__", list)
|
| 2260 |
+
that will be flattened by normalize_content().
|
| 2261 |
+
|
| 2262 |
+
Raises:
|
| 2263 |
+
ValueError: If content type is unknown or unsupported
|
| 2264 |
+
"""
|
| 2265 |
+
import os
|
| 2266 |
+
|
| 2267 |
+
import numpy as np
|
| 2268 |
+
from PIL import Image
|
| 2269 |
+
|
| 2270 |
+
if isinstance(item, str):
|
| 2271 |
+
return item
|
| 2272 |
+
if isinstance(item, Image.Image):
|
| 2273 |
+
return item
|
| 2274 |
+
if isinstance(item, np.ndarray):
|
| 2275 |
+
return item
|
| 2276 |
+
|
| 2277 |
+
if isinstance(item, dict):
|
| 2278 |
+
item_type = item.get("type")
|
| 2279 |
+
|
| 2280 |
+
if item_type == "text":
|
| 2281 |
+
return item.get("text", "")
|
| 2282 |
+
|
| 2283 |
+
elif item_type == "image_url":
|
| 2284 |
+
image_url_obj = item.get("image_url", {})
|
| 2285 |
+
url = image_url_obj.get("url", "") if isinstance(image_url_obj, dict) else image_url_obj
|
| 2286 |
+
|
| 2287 |
+
if _is_url(url):
|
| 2288 |
+
# Download to temp file
|
| 2289 |
+
temp_path = _download_url_to_tempfile(url, suffix=".jpg", timeout=30)
|
| 2290 |
+
img = Image.open(temp_path)
|
| 2291 |
+
os.unlink(temp_path)
|
| 2292 |
+
return img
|
| 2293 |
+
else:
|
| 2294 |
+
return Image.open(url)
|
| 2295 |
+
elif item_type == "audio_url":
|
| 2296 |
+
import librosa
|
| 2297 |
+
|
| 2298 |
+
audio_url_obj = item.get("audio_url", {})
|
| 2299 |
+
url = audio_url_obj.get("url", "") if isinstance(audio_url_obj, dict) else audio_url_obj
|
| 2300 |
+
|
| 2301 |
+
if _is_url(url):
|
| 2302 |
+
# Download to temp file
|
| 2303 |
+
temp_path = _download_url_to_tempfile(url, suffix=".wav", timeout=60)
|
| 2304 |
+
audio_np, _ = librosa.load(temp_path, sr=16000, mono=True)
|
| 2305 |
+
os.unlink(temp_path)
|
| 2306 |
+
return audio_np
|
| 2307 |
+
else:
|
| 2308 |
+
audio_np, _ = librosa.load(url, sr=16000, mono=True)
|
| 2309 |
+
return audio_np
|
| 2310 |
+
elif item_type == "video_url":
|
| 2311 |
+
# Video processing - returns a LIST of items (frames + audio segments)
|
| 2312 |
+
# Note: Unlike image_url/audio_url which return single items,
|
| 2313 |
+
# video_url returns a list that will be flattened into the content
|
| 2314 |
+
from minicpmo.utils import get_video_frame_audio_segments
|
| 2315 |
+
|
| 2316 |
+
video_url_obj = item.get("video_url", {})
|
| 2317 |
+
if isinstance(video_url_obj, dict):
|
| 2318 |
+
video_url = video_url_obj.get("url", "")
|
| 2319 |
+
# Get optional parameters from video_url object (OpenAI style)
|
| 2320 |
+
stack_frames = video_url_obj.get("stack_frames", 1)
|
| 2321 |
+
use_ffmpeg = video_url_obj.get("use_ffmpeg", False)
|
| 2322 |
+
use_audio = video_url_obj.get("use_audio", True)
|
| 2323 |
+
else:
|
| 2324 |
+
video_url = video_url_obj
|
| 2325 |
+
stack_frames = 1
|
| 2326 |
+
use_ffmpeg = False
|
| 2327 |
+
use_audio = True
|
| 2328 |
+
|
| 2329 |
+
# Handle HTTP/HTTPS URL - download to temp file
|
| 2330 |
+
temp_video_path = None
|
| 2331 |
+
if _is_url(video_url):
|
| 2332 |
+
temp_video_path = _download_url_to_tempfile(video_url, suffix=".mp4", timeout=120)
|
| 2333 |
+
video_path = temp_video_path
|
| 2334 |
+
else:
|
| 2335 |
+
video_path = video_url
|
| 2336 |
+
|
| 2337 |
+
# Extract frames and audio segments
|
| 2338 |
+
video_frames, audio_segments, stacked_frames = get_video_frame_audio_segments(
|
| 2339 |
+
video_path,
|
| 2340 |
+
stack_frames=stack_frames,
|
| 2341 |
+
use_ffmpeg=use_ffmpeg,
|
| 2342 |
+
use_audio=use_audio
|
| 2343 |
+
)
|
| 2344 |
+
|
| 2345 |
+
# Clean up temp file if downloaded
|
| 2346 |
+
if temp_video_path is not None:
|
| 2347 |
+
os.unlink(temp_video_path)
|
| 2348 |
+
|
| 2349 |
+
# Build omni_contents (interleaved frames and audio, or frames only)
|
| 2350 |
+
omni_contents = []
|
| 2351 |
+
for i in range(len(video_frames)):
|
| 2352 |
+
omni_contents.append(video_frames[i])
|
| 2353 |
+
if use_audio and audio_segments is not None:
|
| 2354 |
+
omni_contents.append(audio_segments[i])
|
| 2355 |
+
if stacked_frames is not None and i < len(stacked_frames) and stacked_frames[i] is not None:
|
| 2356 |
+
omni_contents.append(stacked_frames[i])
|
| 2357 |
+
|
| 2358 |
+
# Return as a special marker to be flattened later
|
| 2359 |
+
return "__video_contents__", omni_contents
|
| 2360 |
+
else:
|
| 2361 |
+
raise ValueError(f"Unknown content type: {item_type}")
|
| 2362 |
+
|
| 2363 |
+
raise ValueError(f"Cannot normalize content item of type: {type(item)}")
|
| 2364 |
+
|
| 2365 |
+
|
| 2366 |
+
def normalize_content(content) -> list:
|
| 2367 |
+
"""Normalize message content to list of native items.
|
| 2368 |
+
|
| 2369 |
+
Input formats:
|
| 2370 |
+
- str: "hello" -> ["hello"]
|
| 2371 |
+
- list of native items: [str, Image, np.ndarray] -> pass through with normalization
|
| 2372 |
+
- list of structured items: [{"type": "text", ...}] -> normalize each
|
| 2373 |
+
- video type: automatically expanded to omni_contents
|
| 2374 |
+
- mixed: works too
|
| 2375 |
+
|
| 2376 |
+
Args:
|
| 2377 |
+
content: Message content in any supported format
|
| 2378 |
+
|
| 2379 |
+
Returns:
|
| 2380 |
+
List of native items (str, PIL.Image, np.ndarray)
|
| 2381 |
+
|
| 2382 |
+
Examples:
|
| 2383 |
+
>>> normalize_content("hello")
|
| 2384 |
+
["hello"]
|
| 2385 |
+
|
| 2386 |
+
>>> normalize_content([{"type": "text", "text": "hi"}])
|
| 2387 |
+
["hi"]
|
| 2388 |
+
|
| 2389 |
+
>>> normalize_content([{"type": "video", "video": "/path/to/video.mp4"}])
|
| 2390 |
+
[<PIL.Image>, <np.ndarray>, <PIL.Image>, <np.ndarray>, ...]
|
| 2391 |
+
"""
|
| 2392 |
+
import numpy as np
|
| 2393 |
+
from PIL import Image
|
| 2394 |
+
|
| 2395 |
+
if isinstance(content, str):
|
| 2396 |
+
return [content]
|
| 2397 |
+
|
| 2398 |
+
if isinstance(content, list):
|
| 2399 |
+
result = []
|
| 2400 |
+
for item in content:
|
| 2401 |
+
normalized = normalize_content_item(item)
|
| 2402 |
+
# Handle video content (returns tuple with marker)
|
| 2403 |
+
if isinstance(normalized, tuple) and len(normalized) == 2 and normalized[0] == "__video_contents__":
|
| 2404 |
+
# Flatten video contents into result
|
| 2405 |
+
result.extend(normalized[1])
|
| 2406 |
+
else:
|
| 2407 |
+
result.append(normalized)
|
| 2408 |
+
return result
|
| 2409 |
+
|
| 2410 |
+
# Single non-list item (Image or np.ndarray)
|
| 2411 |
+
if isinstance(content, (Image.Image, np.ndarray)):
|
| 2412 |
+
return [content]
|
| 2413 |
+
|
| 2414 |
+
normalized = normalize_content_item(content)
|
| 2415 |
+
if isinstance(normalized, tuple) and len(normalized) == 2 and normalized[0] == "__video_contents__":
|
| 2416 |
+
return normalized[1]
|
| 2417 |
+
return [normalized]
|
vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|