andrevp commited on
Commit
33b4062
·
verified ·
1 Parent(s): af85fd9

Upload MiniCPM-o 4.5 MLX 4-bit quantized model

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ license_link: https://github.com/OpenBMB/MiniCPM-V/blob/main/LICENSE
4
+ base_model: openbmb/MiniCPM-o-4_5
5
+ tags:
6
+ - mlx
7
+ - vision
8
+ - multimodal
9
+ - vlm
10
+ - minicpm
11
+ - apple-silicon
12
+ - quantized
13
+ language:
14
+ - en
15
+ - zh
16
+ - id
17
+ - fr
18
+ - de
19
+ library_name: mlx
20
+ pipeline_tag: image-text-to-text
21
+ ---
22
+
23
+ # MiniCPM-o 4.5 — MLX 4-bit Quantized
24
+
25
+ 4-bit quantized [MLX](https://github.com/ml-explore/mlx) conversion of [openbmb/MiniCPM-o-4_5](https://huggingface.co/openbmb/MiniCPM-o-4_5) for fast inference on Apple Silicon (M1/M2/M3/M4).
26
+
27
+ ## Model Details
28
+
29
+ | | |
30
+ |---|---|
31
+ | **Base model** | [openbmb/MiniCPM-o-4_5](https://huggingface.co/openbmb/MiniCPM-o-4_5) |
32
+ | **Architecture** | SigLIP2 (27L) + Perceiver Resampler + Qwen3 LLM (36L) |
33
+ | **Parameters** | ~8B |
34
+ | **Quantization** | 4-bit (5.255 effective bits) — LLM quantized, vision encoder & resampler full precision |
35
+ | **Size on disk** | ~5.3 GB |
36
+ | **Framework** | [MLX](https://github.com/ml-explore/mlx) via [mlx-vlm](https://github.com/Blaizzy/mlx-vlm) |
37
+
38
+ ## Performance (M4 Pro, 24 GB RAM)
39
+
40
+ | Mode | Prompt Processing | Generation | Peak Memory |
41
+ |------|-------------------|------------|-------------|
42
+ | Text-only | ~100 tok/s | ~55 tok/s | ~5.8 GB |
43
+ | Image + Text | ~150 tok/s | ~51 tok/s | ~6.5 GB |
44
+
45
+ ## Capabilities
46
+
47
+ - Image understanding & description
48
+ - OCR / text extraction from images
49
+ - Chart & diagram analysis
50
+ - Math equation solving from images
51
+ - Visual reasoning & counting
52
+ - Code generation
53
+ - Multilingual (English, Chinese, Indonesian, French, German, etc.)
54
+
55
+ ## Requirements
56
+
57
+ - Apple Silicon Mac (M1 or later)
58
+ - Python 3.10+
59
+ - ~8 GB free RAM
60
+
61
+ ```bash
62
+ pip install mlx-vlm torch transformers Pillow
63
+ ```
64
+
65
+ ## Quick Start
66
+
67
+ ### Python API
68
+
69
+ ```python
70
+ from mlx_vlm import load
71
+ from mlx_vlm.generate import generate_step
72
+ import mlx.core as mx
73
+
74
+ model, processor = load("andrevp/MiniCPM-o-4_5-MLX", trust_remote_code=True)
75
+
76
+ # Text-only
77
+ text = "<|im_start|>user\nWhat is machine learning?<|im_end|>\n<|im_start|>assistant\n"
78
+ input_ids = mx.array(processor.tokenizer(text, return_tensors="np")["input_ids"])
79
+
80
+ tokens = []
81
+ for token, _ in generate_step(input_ids, model, None, None, temp=0.0):
82
+ tok_val = token.item()
83
+ tokens.append(tok_val)
84
+ if processor.tokenizer.decode([tok_val]) in ["<|im_end|>", "<|endoftext|>"]:
85
+ break
86
+
87
+ print(processor.tokenizer.decode(tokens, skip_special_tokens=True))
88
+ ```
89
+
90
+ ### Chat Script
91
+
92
+ A standalone `chat_minicpmo.py` script is available in the [conversion repository](https://github.com/andrevp):
93
+
94
+ ```bash
95
+ # Single-shot with image
96
+ python chat_minicpmo.py photo.jpg -p "What's in this image?"
97
+
98
+ # Single-shot text-only
99
+ python chat_minicpmo.py -p "Explain quantum computing briefly."
100
+
101
+ # Interactive mode
102
+ python chat_minicpmo.py
103
+
104
+ # Interactive with pre-loaded image
105
+ python chat_minicpmo.py photo.jpg
106
+ ```
107
+
108
+ Interactive commands: `/image <path>` | `/clear` | `/quit`
109
+
110
+ ## Quantization Details
111
+
112
+ - **LLM layers**: 4-bit quantized (group_size=64, affine mode)
113
+ - **Vision encoder (SigLIP2)**: Full precision (not quantized)
114
+ - **Perceiver Resampler**: Full precision (not quantized)
115
+ - **Weight breakdown**: 907 LLM keys (quantized) + 437 vision keys + 17 resampler keys (full precision)
116
+
117
+ ## Limitations
118
+
119
+ - **Vision-language only**: Audio input (Whisper encoder) and TTS output (CosyVoice2) from the original model are not included in this conversion.
120
+ - **Single image per turn**: Processes one image at a time.
121
+ - Quantization may slightly reduce output quality compared to the full-precision model.
122
+
123
+ ## License
124
+
125
+ This model is released under the **Apache-2.0** license, following the original [openbmb/MiniCPM-o-4_5](https://huggingface.co/openbmb/MiniCPM-o-4_5) license.
126
+
127
+ See the [original license](https://github.com/OpenBMB/MiniCPM-V/blob/main/LICENSE) for full terms.
128
+
129
+ ## Disclaimer
130
+
131
+ > As an LMM, MiniCPM-o 4.5 generates content by learning from a large amount of multimodal corpora, but it cannot comprehend, express personal opinions or make value judgments. Anything generated by MiniCPM-o 4.5 does not represent the views and positions of the model developers. We will not be liable for any problems arising from the use of the MiniCPM-o models, including but not limited to data security issues, risk of public opinion, or any risks and problems arising from the misdirection, misuse, dissemination or misuse of the model.
132
+
133
+ ## Credits
134
+
135
+ - **Original model**: [OpenBMB](https://github.com/OpenBMB) — [MiniCPM-o 4.5](https://huggingface.co/openbmb/MiniCPM-o-4_5)
136
+ - **MLX framework**: [Apple ML Explore](https://github.com/ml-explore/mlx)
137
+ - **mlx-vlm**: [Prince Canuma](https://github.com/Blaizzy/mlx-vlm)
added_tokens.json ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</answer>": 151686,
3
+ "</box>": 151674,
4
+ "</focus>": 151688,
5
+ "</image>": 151670,
6
+ "</image_id>": 151682,
7
+ "</image_save_to>": 151696,
8
+ "</line>": 151690,
9
+ "</perception>": 151692,
10
+ "</point>": 151678,
11
+ "</quad>": 151676,
12
+ "</ref>": 151672,
13
+ "</slice>": 151680,
14
+ "</source_image>": 151694,
15
+ "</think>": 151668,
16
+ "</tool_call>": 151658,
17
+ "</tool_response>": 151666,
18
+ "</unit>": 151684,
19
+ "<answer>": 151685,
20
+ "<box>": 151673,
21
+ "<focus>": 151687,
22
+ "<image>": 151669,
23
+ "<image_id>": 151681,
24
+ "<image_save_to>": 151695,
25
+ "<line>": 151689,
26
+ "<perception>": 151691,
27
+ "<point>": 151677,
28
+ "<quad>": 151675,
29
+ "<ref>": 151671,
30
+ "<slice>": 151679,
31
+ "<source_image>": 151693,
32
+ "<think>": 151667,
33
+ "<tool_call>": 151657,
34
+ "<tool_response>": 151665,
35
+ "<unit>": 151683,
36
+ "<|audio_end|>": 151699,
37
+ "<|audio_start|>": 151697,
38
+ "<|audio|>": 151698,
39
+ "<|box_end|>": 151649,
40
+ "<|box_start|>": 151648,
41
+ "<|emotion_end|>": 151711,
42
+ "<|emotion_start|>": 151710,
43
+ "<|endoftext|>": 151643,
44
+ "<|file_sep|>": 151664,
45
+ "<|fim_middle|>": 151660,
46
+ "<|fim_pad|>": 151662,
47
+ "<|fim_prefix|>": 151659,
48
+ "<|fim_suffix|>": 151661,
49
+ "<|im_end|>": 151645,
50
+ "<|im_start|>": 151644,
51
+ "<|image_pad|>": 151655,
52
+ "<|interrupt|>": 151707,
53
+ "<|listen|>": 151705,
54
+ "<|object_ref_end|>": 151647,
55
+ "<|object_ref_start|>": 151646,
56
+ "<|pitch_end|>": 151715,
57
+ "<|pitch_start|>": 151714,
58
+ "<|quad_end|>": 151651,
59
+ "<|quad_start|>": 151650,
60
+ "<|repo_name|>": 151663,
61
+ "<|speak|>": 151706,
62
+ "<|speed_end|>": 151713,
63
+ "<|speed_start|>": 151712,
64
+ "<|spk_bos|>": 151700,
65
+ "<|spk_eos|>": 151702,
66
+ "<|spk|>": 151701,
67
+ "<|turn_bos|>": 151716,
68
+ "<|timbre_10|>": 151726,
69
+ "<|timbre_11|>": 151727,
70
+ "<|timbre_12|>": 151728,
71
+ "<|timbre_13|>": 151729,
72
+ "<|timbre_14|>": 151730,
73
+ "<|timbre_15|>": 151731,
74
+ "<|timbre_16|>": 151732,
75
+ "<|timbre_17|>": 151733,
76
+ "<|timbre_18|>": 151734,
77
+ "<|timbre_19|>": 151735,
78
+ "<|turn_eos|>": 151717,
79
+ "<|timbre_20|>": 151736,
80
+ "<|timbre_21|>": 151737,
81
+ "<|timbre_22|>": 151738,
82
+ "<|timbre_23|>": 151739,
83
+ "<|timbre_24|>": 151740,
84
+ "<|timbre_25|>": 151741,
85
+ "<|timbre_26|>": 151742,
86
+ "<|timbre_27|>": 151743,
87
+ "<|timbre_28|>": 151744,
88
+ "<|timbre_29|>": 151745,
89
+ "<|chunk_eos|>": 151718,
90
+ "<|timbre_30|>": 151746,
91
+ "<|timbre_31|>": 151747,
92
+ "<|chunk_bos|>": 151719,
93
+ "<|chunk_tts_bos|>": 151720,
94
+ "<|chunk_tts_eos|>": 151721,
95
+ "<|tts_pad|>": 151722,
96
+ "<|timbre_7|>": 151723,
97
+ "<|timbre_8|>": 151724,
98
+ "<|timbre_9|>": 151725,
99
+ "<|tts_bos|>": 151703,
100
+ "<|tts_eos|>": 151704,
101
+ "<|vad_end|>": 151709,
102
+ "<|vad_start|>": 151708,
103
+ "<|video_pad|>": 151656,
104
+ "<|vision_end|>": 151653,
105
+ "<|vision_pad|>": 151654,
106
+ "<|vision_start|>": 151652
107
+ }
chat_template.jinja ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- if tools %}
2
+ {{- '<|im_start|>system\n' }}
3
+ {%- if messages[0].role == 'system' %}
4
+ {{- messages[0].content + '\n\n' }}
5
+ {%- endif %}
6
+ {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
7
+ {%- for tool in tools %}
8
+ {{- "\n" }}
9
+ {{- tool | tojson }}
10
+ {%- endfor %}
11
+ {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
12
+ {%- else %}
13
+ {%- if messages[0].role == 'system' %}
14
+ {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
15
+ {%- endif %}
16
+ {%- endif %}
17
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
18
+ {%- for message in messages[::-1] %}
19
+ {%- set index = (messages|length - 1) - loop.index0 %}
20
+ {%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
21
+ {%- set ns.multi_step_tool = false %}
22
+ {%- set ns.last_query_index = index %}
23
+ {%- endif %}
24
+ {%- endfor %}
25
+ {%- for message in messages %}
26
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
27
+ {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
28
+ {%- elif message.role == "assistant" %}
29
+ {%- set content = message.content %}
30
+ {%- set reasoning_content = '' %}
31
+ {%- if message.reasoning_content is defined and message.reasoning_content is not none %}
32
+ {%- set reasoning_content = message.reasoning_content %}
33
+ {%- else %}
34
+ {%- if '</think>' in message.content %}
35
+ {%- set content = message.content.split('</think>')[-1].lstrip('\n') %}
36
+ {%- set reasoning_content = message.content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
37
+ {%- endif %}
38
+ {%- endif %}
39
+ {%- if loop.index0 > ns.last_query_index %}
40
+ {%- if loop.last or (not loop.last and reasoning_content) %}
41
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
42
+ {%- else %}
43
+ {{- '<|im_start|>' + message.role + '\n' + content }}
44
+ {%- endif %}
45
+ {%- else %}
46
+ {{- '<|im_start|>' + message.role + '\n' + content }}
47
+ {%- endif %}
48
+ {%- if message.tool_calls %}
49
+ {%- for tool_call in message.tool_calls %}
50
+ {%- if (loop.first and content) or (not loop.first) %}
51
+ {{- '\n' }}
52
+ {%- endif %}
53
+ {%- if tool_call.function %}
54
+ {%- set tool_call = tool_call.function %}
55
+ {%- endif %}
56
+ {{- '<tool_call>\n{"name": "' }}
57
+ {{- tool_call.name }}
58
+ {{- '", "arguments": ' }}
59
+ {%- if tool_call.arguments is string %}
60
+ {{- tool_call.arguments }}
61
+ {%- else %}
62
+ {{- tool_call.arguments | tojson }}
63
+ {%- endif %}
64
+ {{- '}\n</tool_call>' }}
65
+ {%- endfor %}
66
+ {%- endif %}
67
+ {{- '<|im_end|>\n' }}
68
+ {%- elif message.role == "tool" %}
69
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
70
+ {{- '<|im_start|>user' }}
71
+ {%- endif %}
72
+ {{- '\n<tool_response>\n' }}
73
+ {{- message.content }}
74
+ {{- '\n</tool_response>' }}
75
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
76
+ {{- '<|im_end|>\n' }}
77
+ {%- endif %}
78
+ {%- endif %}
79
+ {%- endfor %}
80
+ {%- if add_generation_prompt %}
81
+ {{- '<|im_start|>assistant\n' }}
82
+ {%- if enable_thinking is defined and enable_thinking is false %}
83
+ {{- '<think>\n\n</think>\n\n' }}
84
+ {%- endif %}
85
+ {%- if use_tts_template is defined and use_tts_template is true %}
86
+ {{- '<|tts_bos|>' }}
87
+ {%- endif %}
88
+ {%- endif %}
config.json ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MiniCPMO"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "audio_chunk_length": 1.0,
8
+ "audio_config": {
9
+ "_attn_implementation_autoset": true,
10
+ "_name_or_path": "openai/whisper-medium",
11
+ "activation_dropout": 0.0,
12
+ "activation_function": "gelu",
13
+ "apply_spec_augment": false,
14
+ "architectures": [
15
+ "MiniCPMWhisperEncoder"
16
+ ],
17
+ "attention_dropout": 0.0,
18
+ "begin_suppress_tokens": [
19
+ 220,
20
+ 50257
21
+ ],
22
+ "bos_token_id": 50257,
23
+ "classifier_proj_size": 256,
24
+ "d_model": 1024,
25
+ "decoder_attention_heads": 16,
26
+ "decoder_ffn_dim": 4096,
27
+ "decoder_layerdrop": 0.0,
28
+ "decoder_layers": 24,
29
+ "decoder_start_token_id": 50258,
30
+ "dropout": 0.0,
31
+ "encoder_attention_heads": 16,
32
+ "encoder_ffn_dim": 4096,
33
+ "encoder_layerdrop": 0.0,
34
+ "encoder_layers": 24,
35
+ "eos_token_id": 50257,
36
+ "forced_decoder_ids": [
37
+ [
38
+ 1,
39
+ 50259
40
+ ],
41
+ [
42
+ 2,
43
+ 50359
44
+ ],
45
+ [
46
+ 3,
47
+ 50363
48
+ ]
49
+ ],
50
+ "init_std": 0.02,
51
+ "mask_feature_length": 10,
52
+ "mask_feature_min_masks": 0,
53
+ "mask_feature_prob": 0.0,
54
+ "mask_time_length": 10,
55
+ "mask_time_min_masks": 2,
56
+ "mask_time_prob": 0.05,
57
+ "max_length": 448,
58
+ "max_source_positions": 1500,
59
+ "max_target_positions": 448,
60
+ "median_filter_width": 7,
61
+ "model_type": "whisper",
62
+ "num_hidden_layers": 24,
63
+ "num_mel_bins": 80,
64
+ "pad_token_id": 50257,
65
+ "scale_embedding": false,
66
+ "suppress_tokens": [
67
+ 1,
68
+ 2,
69
+ 7,
70
+ 8,
71
+ 9,
72
+ 10,
73
+ 14,
74
+ 25,
75
+ 26,
76
+ 27,
77
+ 28,
78
+ 29,
79
+ 31,
80
+ 58,
81
+ 59,
82
+ 60,
83
+ 61,
84
+ 62,
85
+ 63,
86
+ 90,
87
+ 91,
88
+ 92,
89
+ 93,
90
+ 359,
91
+ 503,
92
+ 522,
93
+ 542,
94
+ 873,
95
+ 893,
96
+ 902,
97
+ 918,
98
+ 922,
99
+ 931,
100
+ 1350,
101
+ 1853,
102
+ 1982,
103
+ 2460,
104
+ 2627,
105
+ 3246,
106
+ 3253,
107
+ 3268,
108
+ 3536,
109
+ 3846,
110
+ 3961,
111
+ 4183,
112
+ 4667,
113
+ 6585,
114
+ 6647,
115
+ 7273,
116
+ 9061,
117
+ 9383,
118
+ 10428,
119
+ 10929,
120
+ 11938,
121
+ 12033,
122
+ 12331,
123
+ 12562,
124
+ 13793,
125
+ 14157,
126
+ 14635,
127
+ 15265,
128
+ 15618,
129
+ 16553,
130
+ 16604,
131
+ 18362,
132
+ 18956,
133
+ 20075,
134
+ 21675,
135
+ 22520,
136
+ 26130,
137
+ 26161,
138
+ 26435,
139
+ 28279,
140
+ 29464,
141
+ 31650,
142
+ 32302,
143
+ 32470,
144
+ 36865,
145
+ 42863,
146
+ 47425,
147
+ 49870,
148
+ 50254,
149
+ 50258,
150
+ 50358,
151
+ 50359,
152
+ 50360,
153
+ 50361,
154
+ 50362
155
+ ],
156
+ "torch_dtype": "float32",
157
+ "use_cache": true,
158
+ "use_weighted_layer_sum": false,
159
+ "vocab_size": 51865
160
+ },
161
+ "audio_pool_step": 5,
162
+ "auto_map": {
163
+ "AutoConfig": "configuration_minicpmo.MiniCPMOConfig",
164
+ "AutoModel": "modeling_minicpmo.MiniCPMO",
165
+ "AutoModelForCausalLM": "modeling_minicpmo.MiniCPMO"
166
+ },
167
+ "batch_vision_input": true,
168
+ "bos_token_id": 151643,
169
+ "drop_vision_last_layer": false,
170
+ "eos_token_id": [
171
+ 151645,
172
+ 151643
173
+ ],
174
+ "head_dim": 128,
175
+ "hidden_act": "silu",
176
+ "hidden_size": 4096,
177
+ "image_size": 448,
178
+ "init_audio": true,
179
+ "init_tts": true,
180
+ "init_vision": true,
181
+ "initializer_range": 0.02,
182
+ "intermediate_size": 12288,
183
+ "listen_speak_type": "asr",
184
+ "max_position_embeddings": 40960,
185
+ "max_window_layers": 36,
186
+ "model_type": "minicpmo",
187
+ "num_attention_heads": 32,
188
+ "num_hidden_layers": 36,
189
+ "num_key_value_heads": 8,
190
+ "patch_size": 14,
191
+ "quantization": {
192
+ "group_size": 64,
193
+ "bits": 4,
194
+ "mode": "affine"
195
+ },
196
+ "quantization_config": {
197
+ "group_size": 64,
198
+ "bits": 4,
199
+ "mode": "affine"
200
+ },
201
+ "query_num": 64,
202
+ "rms_norm_eps": 1e-06,
203
+ "rope_scaling": null,
204
+ "rope_theta": 1000000,
205
+ "slice_config": {
206
+ "max_slice_nums": 1,
207
+ "model_type": "minicpmv",
208
+ "patch_size": 14,
209
+ "scale_resolution": 448
210
+ },
211
+ "slice_mode": true,
212
+ "sliding_window": null,
213
+ "stream_input": true,
214
+ "tie_word_embeddings": false,
215
+ "transformers_version": "4.51.0",
216
+ "tts_config": {
217
+ "_attn_implementation_autoset": true,
218
+ "attention_type": "full_attention",
219
+ "attn_implementation": "sdpa",
220
+ "audio_bos_token_id": 151687,
221
+ "audio_tokenizer_sample_rate": 16000,
222
+ "audio_tokenizer_type": "s3tokenizer",
223
+ "aug_layer_loss_weight": false,
224
+ "aug_loss_weight": false,
225
+ "backbone_model": "llama",
226
+ "condition_type": "hidden_text_merge",
227
+ "cosyvoice_config_path": null,
228
+ "cosyvoice_model_dir": null,
229
+ "filter_tts_loss": false,
230
+ "hidden_act": "silu",
231
+ "hidden_size": 768,
232
+ "interleaved": false,
233
+ "intermediate_size": 3072,
234
+ "llm_dim": 4096,
235
+ "llm_dim_model_base": 256,
236
+ "llm_down_scale": false,
237
+ "llm_hidden_size": 4096,
238
+ "llm_intermediate_size": 768,
239
+ "long_weight": 0.1,
240
+ "max_position_embeddings": 4096,
241
+ "model_type": "minicpmtts",
242
+ "normalize_projected_hidden": true,
243
+ "num_attention_heads": 12,
244
+ "num_audio_tokens": 6562,
245
+ "num_hidden_layers": 20,
246
+ "num_key_value_heads": 12,
247
+ "num_mel_bins": 100,
248
+ "num_text_tokens": 152064,
249
+ "num_vq": 1,
250
+ "projector_type": "mlp",
251
+ "recomputed_chunks": 1,
252
+ "s3_stream_chunk_size": 25,
253
+ "s3_stream_generate": false,
254
+ "s3_stream_n_timesteps": 10,
255
+ "s3_stream_prelook_size": 3,
256
+ "short_weight": 0.1,
257
+ "streaming": false,
258
+ "streaming_audio_chunk_size": 50,
259
+ "streaming_sliding_window": false,
260
+ "streaming_sliding_window_audio_frame_rate": 50,
261
+ "streaming_sliding_window_audio_init_text_length": 10,
262
+ "streaming_sliding_window_audio_window_size": 300,
263
+ "streaming_sliding_window_average_speed": 5,
264
+ "streaming_sliding_window_fast_speed": 7,
265
+ "streaming_sliding_window_max_text_len": 500,
266
+ "streaming_sliding_window_slow_speed": 3,
267
+ "streaming_sliding_window_text_window_size": 50,
268
+ "streaming_text_chunk_max": 7,
269
+ "streaming_text_chunk_min": 3,
270
+ "streaming_text_reserved_len": 300,
271
+ "text_eos_token_id": 151692,
272
+ "tts_filter_loss_fix": false,
273
+ "use_llm_hidden_state": false,
274
+ "use_text": true,
275
+ "window_size": 2
276
+ },
277
+ "use_cache": true,
278
+ "use_image_id": true,
279
+ "use_sliding_window": false,
280
+ "version": "4.5",
281
+ "vision_batch_size": 16,
282
+ "vision_config": {
283
+ "_attn_implementation_autoset": true,
284
+ "attention_dropout": 0.0,
285
+ "hidden_act": "gelu_pytorch_tanh",
286
+ "hidden_size": 1152,
287
+ "image_size": 980,
288
+ "intermediate_size": 4304,
289
+ "layer_norm_eps": 1e-06,
290
+ "model_type": "siglip_vision_model",
291
+ "num_attention_heads": 16,
292
+ "num_channels": 3,
293
+ "num_hidden_layers": 27,
294
+ "patch_size": 14
295
+ },
296
+ "vocab_size": 151748
297
+ }
configuration_minicpmo.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright 2026 The OpenBMB Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import os
18
+ from typing import Union
19
+
20
+ from transformers import PretrainedConfig
21
+ from transformers import Qwen3Config
22
+ from transformers import WhisperConfig
23
+ from transformers.utils import logging
24
+
25
+ from .modeling_navit_siglip import SiglipVisionConfig
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ class MiniCPMVSliceConfig(PretrainedConfig):
31
+ model_type = "minicpmv"
32
+
33
+ def __init__(
34
+ self,
35
+ patch_size=14,
36
+ max_slice_nums=9,
37
+ scale_resolution=448,
38
+ **kwargs,
39
+ ):
40
+ super().__init__(**kwargs)
41
+ self.patch_size = patch_size
42
+ self.max_slice_nums = max_slice_nums
43
+ self.scale_resolution = scale_resolution
44
+
45
+ @classmethod
46
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
47
+ cls._set_token_in_kwargs(kwargs)
48
+
49
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
50
+
51
+ if config_dict.get("model_type") == "minicpmv":
52
+ config_dict = config_dict["slice_config"]
53
+
54
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
55
+ logger.warning(
56
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
57
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
58
+ )
59
+
60
+ return cls.from_dict(config_dict, **kwargs)
61
+
62
+
63
+ class MiniCPMTTSConfig(PretrainedConfig):
64
+ model_type = "minicpmtts"
65
+
66
+ def __init__(
67
+ self,
68
+ llm_dim: int = 2560,
69
+ llm_intermediate_size: int = 768,
70
+ llm_down_scale: bool = False,
71
+ llm_dim_model_base: int = 256,
72
+ projector_type: str = "mlp",
73
+ hidden_act: str = "silu",
74
+ aug_loss_weight: bool = False,
75
+ aug_layer_loss_weight: bool = False,
76
+ filter_tts_loss: bool = False,
77
+ tts_filter_loss_fix: bool = False,
78
+ long_weight: float = 0.1,
79
+ short_weight: float = 0.1,
80
+ hidden_size: int = 768,
81
+ intermediate_size: int = 3072,
82
+ num_attention_heads: int = 12,
83
+ num_hidden_layers: int = 20,
84
+ num_key_value_heads: int = 12,
85
+ max_position_embeddings: int = 4096,
86
+ num_audio_tokens: int = 4097,
87
+ num_text_tokens: int = 21178,
88
+ num_mel_bins: int = 100,
89
+ num_vq: int = 1,
90
+ use_llm_hidden_state: bool = False,
91
+ audio_bos_token_id: int = 21132,
92
+ text_eos_token_id: int = 21133,
93
+ use_text: bool = True,
94
+ streaming: bool = False,
95
+ streaming_text_chunk_min: int = 3,
96
+ streaming_text_chunk_max: int = 7,
97
+ streaming_text_reserved_len: int = 300,
98
+ streaming_audio_chunk_size: int = 50,
99
+ attn_implementation: str = "sdpa",
100
+ condition_type: str = "llm_hidden",
101
+ backbone_model: str = "llama",
102
+ audio_tokenizer_type: str = "wavtokenizer",
103
+ audio_tokenizer_sample_rate: int = 24000,
104
+ streaming_sliding_window: bool = False,
105
+ streaming_sliding_window_max_text_len: int = 500,
106
+ streaming_sliding_window_average_speed: int = 5,
107
+ streaming_sliding_window_fast_speed: int = 7,
108
+ streaming_sliding_window_slow_speed: int = 3,
109
+ streaming_sliding_window_audio_frame_rate: int = 50,
110
+ streaming_sliding_window_audio_init_text_length: int = 10,
111
+ streaming_sliding_window_audio_window_size: int = 300,
112
+ normalize_projected_hidden: bool = False,
113
+ interleaved: bool = False,
114
+ attention_type: str = "sliding_recompute",
115
+ recomputed_chunks: int = 1,
116
+ window_size: int = 2,
117
+ **kwargs,
118
+ ):
119
+ super().__init__(**kwargs)
120
+
121
+ self.llm_dim = llm_dim
122
+ self.llm_hidden_size = llm_dim
123
+ self.llm_intermediate_size = llm_intermediate_size
124
+ self.llm_down_scale = llm_down_scale
125
+ self.llm_dim_model_base = llm_dim_model_base
126
+ self.projector_type = projector_type
127
+ self.aug_loss_weight = aug_loss_weight
128
+ self.aug_layer_loss_weight = aug_layer_loss_weight
129
+ self.tts_filter_loss_fix = tts_filter_loss_fix
130
+ self.filter_tts_loss = filter_tts_loss
131
+ self.long_weight = long_weight
132
+ self.short_weight = short_weight
133
+ self.hidden_act = hidden_act
134
+
135
+ self.hidden_size = hidden_size
136
+ self.intermediate_size = intermediate_size
137
+ self.num_attention_heads = num_attention_heads
138
+ self.num_hidden_layers = num_hidden_layers
139
+ self.num_key_value_heads = num_key_value_heads
140
+ self.max_position_embeddings = max_position_embeddings
141
+ self.num_audio_tokens = num_audio_tokens
142
+ self.num_text_tokens = num_text_tokens
143
+ self.num_mel_bins = num_mel_bins
144
+ self.num_vq = num_vq
145
+ self.use_llm_hidden_state = use_llm_hidden_state
146
+ self.audio_bos_token_id = audio_bos_token_id
147
+ self.text_eos_token_id = text_eos_token_id
148
+ self.use_text = use_text
149
+ self.streaming = streaming
150
+ self.streaming_text_chunk_min = streaming_text_chunk_min
151
+ self.streaming_text_chunk_max = streaming_text_chunk_max
152
+ self.streaming_text_reserved_len = streaming_text_reserved_len
153
+ self.streaming_audio_chunk_size = streaming_audio_chunk_size
154
+ self.attn_implementation = attn_implementation
155
+ self.condition_type = condition_type
156
+ self.backbone_model = backbone_model
157
+ self.audio_tokenizer_type = audio_tokenizer_type
158
+ self.audio_tokenizer_sample_rate = audio_tokenizer_sample_rate
159
+
160
+ self.streaming_sliding_window = streaming_sliding_window
161
+ self.streaming_sliding_window_max_text_len = streaming_sliding_window_max_text_len
162
+ self.streaming_sliding_window_average_speed = streaming_sliding_window_average_speed
163
+ self.streaming_sliding_window_fast_speed = streaming_sliding_window_fast_speed
164
+ self.streaming_sliding_window_slow_speed = streaming_sliding_window_slow_speed
165
+ self.streaming_sliding_window_audio_frame_rate = streaming_sliding_window_audio_frame_rate
166
+ self.streaming_sliding_window_audio_init_text_length = streaming_sliding_window_audio_init_text_length
167
+ self.streaming_sliding_window_audio_window_size = streaming_sliding_window_audio_window_size
168
+
169
+ self.normalize_projected_hidden = normalize_projected_hidden
170
+
171
+ self.interleaved = interleaved
172
+ self.attention_type = attention_type
173
+ self.recomputed_chunks = recomputed_chunks
174
+ self.window_size = window_size
175
+
176
+
177
+ class MiniCPMOConfig(Qwen3Config):
178
+ model_type = "minicpmo"
179
+ keys_to_ignore_at_inference = ["past_key_values"]
180
+
181
+ default_vision_config = {
182
+ "hidden_size": 1152,
183
+ "image_size": 980,
184
+ "intermediate_size": 4304,
185
+ "model_type": "siglip",
186
+ "num_attention_heads": 16,
187
+ "num_hidden_layers": 27,
188
+ "patch_size": 14,
189
+ }
190
+
191
+ def __init__(
192
+ self,
193
+ use_cache=True,
194
+ query_num=64,
195
+ image_size=448,
196
+ drop_vision_last_layer=True,
197
+ batch_vision_input=True,
198
+ slice_config=None,
199
+ vision_config=None,
200
+ audio_config=None,
201
+ tts_config=None,
202
+ use_image_id=True,
203
+ vision_batch_size=16,
204
+ audio_pool_step=5,
205
+ audio_chunk_length=1.0,
206
+ stream_input=False,
207
+ listen_speak_type="asr",
208
+ init_vision=True,
209
+ init_audio=True,
210
+ init_tts=True,
211
+ **kwargs,
212
+ ):
213
+ self.use_cache = use_cache
214
+ self.query_num = query_num
215
+ self.image_size = image_size
216
+ self.drop_vision_last_layer = drop_vision_last_layer
217
+ self.batch_vision_input = batch_vision_input
218
+ self.use_image_id = use_image_id
219
+ self.vision_batch_size = vision_batch_size
220
+ self.audio_pool_step = audio_pool_step
221
+ self.audio_chunk_length = audio_chunk_length
222
+ self.stream_input = stream_input
223
+ self.listen_speak_type = listen_speak_type
224
+
225
+ self.init_vision = init_vision
226
+ self.init_audio = init_audio
227
+ self.init_tts = init_tts
228
+
229
+ if slice_config is None:
230
+ self.slice_config = MiniCPMVSliceConfig(max_slice_nums=1)
231
+ else:
232
+ self.slice_config = MiniCPMVSliceConfig(**slice_config)
233
+ self.slice_mode = True
234
+
235
+ # same as HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit add tgt_sizes
236
+ if vision_config is None:
237
+ self.vision_config = SiglipVisionConfig(**self.default_vision_config)
238
+ logger.info("vision_config is None, using default vision config")
239
+ elif isinstance(vision_config, dict):
240
+ self.vision_config = SiglipVisionConfig(**vision_config)
241
+ elif isinstance(vision_config, SiglipVisionConfig):
242
+ self.vision_config = vision_config
243
+
244
+ if audio_config is None:
245
+ self.audio_config = WhisperConfig()
246
+ elif isinstance(audio_config, dict):
247
+ self.audio_config = WhisperConfig(**audio_config)
248
+ elif isinstance(audio_config, WhisperConfig):
249
+ self.audio_config = audio_config
250
+
251
+ if tts_config is None:
252
+ self.tts_config = MiniCPMTTSConfig()
253
+ elif isinstance(tts_config, dict):
254
+ self.tts_config = MiniCPMTTSConfig(**tts_config)
255
+ elif isinstance(tts_config, MiniCPMTTSConfig):
256
+ self.tts_config = tts_config
257
+
258
+ self.patch_size = self.vision_config.patch_size
259
+
260
+ super().__init__(**kwargs)
generation_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "do_sample": true,
4
+ "eos_token_id": [
5
+ 151645,
6
+ 151643
7
+ ],
8
+ "pad_token_id": 151643,
9
+ "temperature": 0.6,
10
+ "top_k": 20,
11
+ "top_p": 0.95
12
+ }
model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5e0d2dbc6eacf34177bcf3badf36ead8719f48e1b4c8adb2a2754be5a73218e6
3
+ size 5092993723
model-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b35835d4c370c057a6c11d87cb91565f54e973c09a9c6b8ee5564140ce34d2a
3
+ size 527444905
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_minicpmo.py ADDED
The diff for this file is too large to render. See raw diff
 
modeling_navit_siglip.py ADDED
@@ -0,0 +1,981 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Google AI and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch Siglip model."""
16
+ # Copied from HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit and add tgt_sizes
17
+
18
+
19
+ import math
20
+ import os
21
+ import warnings
22
+ from dataclasses import dataclass
23
+ from typing import Optional
24
+ from typing import Tuple
25
+ from typing import Union
26
+
27
+ import numpy as np
28
+ import torch
29
+ import torch.nn.functional as F
30
+ import torch.utils.checkpoint
31
+ from torch import nn
32
+ from torch.nn.init import _calculate_fan_in_and_fan_out
33
+ from transformers.activations import ACT2FN
34
+ from transformers.configuration_utils import PretrainedConfig
35
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
36
+ from transformers.modeling_outputs import BaseModelOutput
37
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
38
+ from transformers.modeling_utils import PreTrainedModel
39
+ from transformers.utils import add_start_docstrings
40
+ from transformers.utils import add_start_docstrings_to_model_forward
41
+ from transformers.utils import is_flash_attn_2_available
42
+ from transformers.utils import logging
43
+ from transformers.utils import ModelOutput
44
+ from transformers.utils import replace_return_docstrings
45
+
46
+ logger = logging.get_logger(__name__)
47
+
48
+
49
+ class SiglipVisionConfig(PretrainedConfig):
50
+ r"""
51
+ This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
52
+ Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
53
+ configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
54
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
55
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
56
+ documentation from [`PretrainedConfig`] for more information.
57
+ Args:
58
+ hidden_size (`int`, *optional*, defaults to 768):
59
+ Dimensionality of the encoder layers and the pooler layer.
60
+ intermediate_size (`int`, *optional*, defaults to 3072):
61
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
62
+ num_hidden_layers (`int`, *optional*, defaults to 12):
63
+ Number of hidden layers in the Transformer encoder.
64
+ num_attention_heads (`int`, *optional*, defaults to 12):
65
+ Number of attention heads for each attention layer in the Transformer encoder.
66
+ num_channels (`int`, *optional*, defaults to 3):
67
+ Number of channels in the input images.
68
+ image_size (`int`, *optional*, defaults to 224):
69
+ The size (resolution) of each image.
70
+ patch_size (`int`, *optional*, defaults to 16):
71
+ The size (resolution) of each patch.
72
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
73
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
74
+ `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
75
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
76
+ The epsilon used by the layer normalization layers.
77
+ attention_dropout (`float`, *optional*, defaults to 0.0):
78
+ The dropout ratio for the attention probabilities.
79
+ Example:
80
+ ```python
81
+ >>> from transformers import SiglipVisionConfig, SiglipVisionModel
82
+ >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
83
+ >>> configuration = SiglipVisionConfig()
84
+ >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
85
+ >>> model = SiglipVisionModel(configuration)
86
+ >>> # Accessing the model configuration
87
+ >>> configuration = model.config
88
+ ```"""
89
+
90
+ model_type = "siglip_vision_model"
91
+
92
+ def __init__(
93
+ self,
94
+ hidden_size=768,
95
+ intermediate_size=3072,
96
+ num_hidden_layers=12,
97
+ num_attention_heads=12,
98
+ num_channels=3,
99
+ image_size=224,
100
+ patch_size=16,
101
+ hidden_act="gelu_pytorch_tanh",
102
+ layer_norm_eps=1e-6,
103
+ attention_dropout=0.0,
104
+ **kwargs,
105
+ ):
106
+ super().__init__(**kwargs)
107
+
108
+ self.hidden_size = hidden_size
109
+ self.intermediate_size = intermediate_size
110
+ self.num_hidden_layers = num_hidden_layers
111
+ self.num_attention_heads = num_attention_heads
112
+ self.num_channels = num_channels
113
+ self.patch_size = patch_size
114
+ self.image_size = image_size
115
+ self.attention_dropout = attention_dropout
116
+ self.layer_norm_eps = layer_norm_eps
117
+ self.hidden_act = hidden_act
118
+
119
+ @classmethod
120
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
121
+ cls._set_token_in_kwargs(kwargs)
122
+
123
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
124
+
125
+ # get the vision config dict if we are loading from SiglipConfig
126
+ if config_dict.get("model_type") == "siglip":
127
+ config_dict = config_dict["vision_config"]
128
+
129
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
130
+ logger.warning(
131
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
132
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
133
+ )
134
+
135
+ return cls.from_dict(config_dict, **kwargs)
136
+
137
+
138
+ _CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
139
+
140
+ SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
141
+ "google/siglip-base-patch16-224",
142
+ # See all SigLIP models at https://huggingface.co/models?filter=siglip
143
+ ]
144
+
145
+ if is_flash_attn_2_available():
146
+ from flash_attn import flash_attn_func
147
+ from flash_attn import flash_attn_varlen_func
148
+ from flash_attn.bert_padding import index_first_axis # noqa
149
+ from flash_attn.bert_padding import pad_input
150
+ from flash_attn.bert_padding import unpad_input
151
+
152
+
153
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
154
+ def _get_unpad_data(attention_mask):
155
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
156
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
157
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
158
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
159
+ return (
160
+ indices,
161
+ cu_seqlens,
162
+ max_seqlen_in_batch,
163
+ )
164
+
165
+
166
+ def _trunc_normal_(tensor, mean, std, a, b):
167
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
168
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
169
+ def norm_cdf(x):
170
+ # Computes standard normal cumulative distribution function
171
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
172
+
173
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
174
+ warnings.warn(
175
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
176
+ "The distribution of values may be incorrect.",
177
+ stacklevel=2,
178
+ )
179
+
180
+ # Values are generated by using a truncated uniform distribution and
181
+ # then using the inverse CDF for the normal distribution.
182
+ # Get upper and lower cdf values
183
+ l = norm_cdf((a - mean) / std)
184
+ u = norm_cdf((b - mean) / std)
185
+
186
+ # Uniformly fill tensor with values from [l, u], then translate to
187
+ # [2l-1, 2u-1].
188
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
189
+
190
+ # Use inverse cdf transform for normal distribution to get truncated
191
+ # standard normal
192
+ if tensor.dtype in [torch.float16, torch.bfloat16]:
193
+ # The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu
194
+ og_dtype = tensor.dtype
195
+ tensor = tensor.to(torch.float32)
196
+ tensor.erfinv_()
197
+ tensor = tensor.to(og_dtype)
198
+ else:
199
+ tensor.erfinv_()
200
+
201
+ # Transform to proper mean, std
202
+ tensor.mul_(std * math.sqrt(2.0))
203
+ tensor.add_(mean)
204
+
205
+ # Clamp to ensure it's in the proper range
206
+ if tensor.dtype == torch.float16:
207
+ # The `clamp_` op is not (yet?) defined in float16+cpu
208
+ tensor = tensor.to(torch.float32)
209
+ tensor.clamp_(min=a, max=b)
210
+ tensor = tensor.to(torch.float16)
211
+ else:
212
+ tensor.clamp_(min=a, max=b)
213
+
214
+
215
+ def trunc_normal_tf_(
216
+ tensor: torch.Tensor,
217
+ mean: float = 0.0,
218
+ std: float = 1.0,
219
+ a: float = -2.0,
220
+ b: float = 2.0,
221
+ ) -> torch.Tensor:
222
+ """Fills the input Tensor with values drawn from a truncated
223
+ normal distribution. The values are effectively drawn from the
224
+ normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
225
+ with values outside :math:`[a, b]` redrawn until they are within
226
+ the bounds. The method used for generating the random values works
227
+ best when :math:`a \\leq \text{mean} \\leq b`.
228
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
229
+ bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
230
+ and the result is subsquently scaled and shifted by the mean and std args.
231
+ Args:
232
+ tensor: an n-dimensional `torch.Tensor`
233
+ mean: the mean of the normal distribution
234
+ std: the standard deviation of the normal distribution
235
+ a: the minimum cutoff value
236
+ b: the maximum cutoff value
237
+ """
238
+ with torch.no_grad():
239
+ _trunc_normal_(tensor, 0, 1.0, a, b)
240
+ tensor.mul_(std).add_(mean)
241
+
242
+
243
+ def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
244
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
245
+ if mode == "fan_in":
246
+ denom = fan_in
247
+ elif mode == "fan_out":
248
+ denom = fan_out
249
+ elif mode == "fan_avg":
250
+ denom = (fan_in + fan_out) / 2
251
+
252
+ variance = scale / denom
253
+
254
+ if distribution == "truncated_normal":
255
+ # constant is stddev of standard normal truncated to (-2, 2)
256
+ trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
257
+ elif distribution == "normal":
258
+ with torch.no_grad():
259
+ tensor.normal_(std=math.sqrt(variance))
260
+ elif distribution == "uniform":
261
+ bound = math.sqrt(3 * variance)
262
+ with torch.no_grad():
263
+ tensor.uniform_(-bound, bound)
264
+ else:
265
+ raise ValueError(f"invalid distribution {distribution}")
266
+
267
+
268
+ def lecun_normal_(tensor):
269
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
270
+
271
+
272
+ def default_flax_embed_init(tensor):
273
+ variance_scaling_(tensor, mode="fan_in", distribution="normal")
274
+
275
+
276
+ @dataclass
277
+ # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip
278
+ class SiglipVisionModelOutput(ModelOutput):
279
+ """
280
+ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
281
+ Args:
282
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
283
+ The image embeddings obtained by applying the projection layer to the pooler_output.
284
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
285
+ Sequence of hidden-states at the output of the last layer of the model.
286
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
287
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
288
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
289
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
290
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
291
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
292
+ sequence_length)`.
293
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
294
+ heads.
295
+ """
296
+
297
+ image_embeds: Optional[torch.FloatTensor] = None
298
+ last_hidden_state: torch.FloatTensor = None
299
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
300
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
301
+
302
+
303
+ class SiglipVisionEmbeddings(nn.Module):
304
+ def __init__(self, config: SiglipVisionConfig):
305
+ super().__init__()
306
+ self.config = config
307
+ self.embed_dim = config.hidden_size
308
+ self.image_size = config.image_size
309
+ self.patch_size = config.patch_size
310
+
311
+ self.patch_embedding = nn.Conv2d(
312
+ in_channels=config.num_channels,
313
+ out_channels=self.embed_dim,
314
+ kernel_size=self.patch_size,
315
+ stride=self.patch_size,
316
+ padding="valid",
317
+ )
318
+
319
+ self.num_patches_per_side = self.image_size // self.patch_size
320
+ self.num_patches = self.num_patches_per_side**2
321
+ self.num_positions = self.num_patches
322
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
323
+
324
+ def forward(
325
+ self,
326
+ pixel_values: torch.FloatTensor,
327
+ patch_attention_mask: torch.BoolTensor,
328
+ tgt_sizes: Optional[torch.IntTensor] = None,
329
+ ) -> torch.Tensor:
330
+ batch_size = pixel_values.size(0)
331
+
332
+ patch_embeds = self.patch_embedding(pixel_values)
333
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
334
+
335
+ max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
336
+ max_nb_patches_h, max_nb_patches_w = (
337
+ max_im_h // self.patch_size,
338
+ max_im_w // self.patch_size,
339
+ )
340
+ boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side)
341
+ position_ids = torch.full(
342
+ size=(
343
+ batch_size,
344
+ max_nb_patches_h * max_nb_patches_w,
345
+ ),
346
+ fill_value=0,
347
+ )
348
+
349
+ for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
350
+ if tgt_sizes is not None:
351
+ nb_patches_h = tgt_sizes[batch_idx][0]
352
+ nb_patches_w = tgt_sizes[batch_idx][1]
353
+ else:
354
+ nb_patches_h = p_attn_mask[:, 0].sum()
355
+ nb_patches_w = p_attn_mask[0].sum()
356
+
357
+ fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
358
+ fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
359
+
360
+ bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
361
+ bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
362
+
363
+ pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten()
364
+ position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
365
+
366
+ position_ids = position_ids.to(self.position_embedding.weight.device)
367
+
368
+ embeddings = embeddings + self.position_embedding(position_ids)
369
+ return embeddings
370
+
371
+
372
+ class SiglipAttention(nn.Module):
373
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
374
+
375
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
376
+ def __init__(self, config):
377
+ super().__init__()
378
+ self.config = config
379
+ self.embed_dim = config.hidden_size
380
+ self.num_heads = config.num_attention_heads
381
+ self.head_dim = self.embed_dim // self.num_heads
382
+ if self.head_dim * self.num_heads != self.embed_dim:
383
+ raise ValueError(
384
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
385
+ f" {self.num_heads})."
386
+ )
387
+ self.scale = self.head_dim**-0.5
388
+ self.dropout = config.attention_dropout
389
+
390
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
391
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
392
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
393
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
394
+
395
+ def forward(
396
+ self,
397
+ hidden_states: torch.Tensor,
398
+ attention_mask: Optional[torch.Tensor] = None,
399
+ output_attentions: Optional[bool] = False,
400
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
401
+ """Input shape: Batch x Time x Channel"""
402
+
403
+ batch_size, q_len, _ = hidden_states.size()
404
+
405
+ query_states = self.q_proj(hidden_states)
406
+ key_states = self.k_proj(hidden_states)
407
+ value_states = self.v_proj(hidden_states)
408
+
409
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
410
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
411
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
412
+
413
+ k_v_seq_len = key_states.shape[-2]
414
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
415
+
416
+ if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
417
+ raise ValueError(
418
+ f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
419
+ f" {attn_weights.size()}"
420
+ )
421
+
422
+ if attention_mask is not None:
423
+ if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
424
+ raise ValueError(
425
+ f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
426
+ )
427
+ attn_weights = attn_weights + attention_mask
428
+
429
+ # upcast attention to fp32
430
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
431
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
432
+ attn_output = torch.matmul(attn_weights, value_states)
433
+
434
+ if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
435
+ raise ValueError(
436
+ f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
437
+ f" {attn_output.size()}"
438
+ )
439
+
440
+ attn_output = attn_output.transpose(1, 2).contiguous()
441
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
442
+
443
+ attn_output = self.out_proj(attn_output)
444
+
445
+ return attn_output, attn_weights
446
+
447
+
448
+ class SiglipFlashAttention2(SiglipAttention):
449
+ """
450
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
451
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
452
+ flash attention and deal with padding tokens in case the input contains any of them.
453
+ """
454
+
455
+ def __init__(self, *args, **kwargs):
456
+ super().__init__(*args, **kwargs)
457
+ self.is_causal = False # Hack to make sure we don't use a causal mask
458
+
459
+ def forward(
460
+ self,
461
+ hidden_states: torch.Tensor,
462
+ attention_mask: Optional[torch.LongTensor] = None,
463
+ position_ids: Optional[torch.LongTensor] = None,
464
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
465
+ output_attentions: bool = False,
466
+ use_cache: bool = False,
467
+ **kwargs,
468
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
469
+ output_attentions = False
470
+
471
+ bsz, q_len, _ = hidden_states.size()
472
+
473
+ query_states = self.q_proj(hidden_states)
474
+ key_states = self.k_proj(hidden_states)
475
+ value_states = self.v_proj(hidden_states)
476
+
477
+ # Flash attention requires the input to have the shape
478
+ # batch_size x seq_length x head_dim x hidden_dim
479
+ # therefore we just need to keep the original shape
480
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
481
+ key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
482
+ value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
483
+
484
+ kv_seq_len = key_states.shape[-2]
485
+ if past_key_value is not None:
486
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
487
+ # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
488
+ # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
489
+
490
+ # if past_key_value is not None:
491
+ # cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
492
+ # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
493
+
494
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
495
+ # to be able to avoid many of these transpose/reshape/view.
496
+ query_states = query_states.transpose(1, 2)
497
+ key_states = key_states.transpose(1, 2)
498
+ value_states = value_states.transpose(1, 2)
499
+
500
+ dropout_rate = self.dropout if self.training else 0.0
501
+
502
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
503
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
504
+ # cast them back in the correct dtype just to be sure everything works as expected.
505
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
506
+ # in fp32. (LlamaRMSNorm handles it correctly)
507
+
508
+ input_dtype = query_states.dtype
509
+ if input_dtype == torch.float32:
510
+ if torch.is_autocast_enabled():
511
+ target_dtype = torch.get_autocast_gpu_dtype()
512
+ # Handle the case where the model is quantized
513
+ elif hasattr(self.config, "_pre_quantization_dtype"):
514
+ target_dtype = self.config._pre_quantization_dtype
515
+ else:
516
+ target_dtype = self.q_proj.weight.dtype
517
+
518
+ logger.warning_once(
519
+ "The input hidden states seems to be silently casted in float32, this might be related to the fact"
520
+ " you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
521
+ f" {target_dtype}."
522
+ )
523
+
524
+ query_states = query_states.to(target_dtype)
525
+ key_states = key_states.to(target_dtype)
526
+ value_states = value_states.to(target_dtype)
527
+
528
+ attn_output = self._flash_attention_forward(
529
+ query_states,
530
+ key_states,
531
+ value_states,
532
+ attention_mask,
533
+ q_len,
534
+ dropout=dropout_rate,
535
+ )
536
+
537
+ attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous()
538
+ attn_output = self.out_proj(attn_output)
539
+
540
+ if not output_attentions:
541
+ attn_weights = None
542
+
543
+ return attn_output, attn_weights
544
+
545
+ def _flash_attention_forward(
546
+ self,
547
+ query_states,
548
+ key_states,
549
+ value_states,
550
+ attention_mask,
551
+ query_length,
552
+ dropout=0.0,
553
+ softmax_scale=None,
554
+ ):
555
+ """
556
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
557
+ first unpad the input, then computes the attention scores and pad the final attention scores.
558
+ Args:
559
+ query_states (`torch.Tensor`):
560
+ Input query states to be passed to Flash Attention API
561
+ key_states (`torch.Tensor`):
562
+ Input key states to be passed to Flash Attention API
563
+ value_states (`torch.Tensor`):
564
+ Input value states to be passed to Flash Attention API
565
+ attention_mask (`torch.Tensor`):
566
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
567
+ position of padding tokens and 1 for the position of non-padding tokens.
568
+ dropout (`int`, *optional*):
569
+ Attention dropout
570
+ softmax_scale (`float`, *optional*):
571
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
572
+ """
573
+
574
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
575
+ causal = self.is_causal and query_length != 1
576
+
577
+ # Contains at least one padding token in the sequence
578
+ if attention_mask is not None:
579
+ batch_size = query_states.shape[0]
580
+ (
581
+ query_states,
582
+ key_states,
583
+ value_states,
584
+ indices_q,
585
+ cu_seq_lens,
586
+ max_seq_lens,
587
+ ) = self._upad_input(query_states, key_states, value_states, attention_mask, query_length)
588
+
589
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
590
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
591
+
592
+ attn_output_unpad = flash_attn_varlen_func(
593
+ query_states,
594
+ key_states,
595
+ value_states,
596
+ cu_seqlens_q=cu_seqlens_q,
597
+ cu_seqlens_k=cu_seqlens_k,
598
+ max_seqlen_q=max_seqlen_in_batch_q,
599
+ max_seqlen_k=max_seqlen_in_batch_k,
600
+ dropout_p=dropout,
601
+ softmax_scale=softmax_scale,
602
+ causal=causal,
603
+ )
604
+
605
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
606
+ else:
607
+ attn_output = flash_attn_func(
608
+ query_states,
609
+ key_states,
610
+ value_states,
611
+ dropout,
612
+ softmax_scale=softmax_scale,
613
+ causal=causal,
614
+ )
615
+
616
+ return attn_output
617
+
618
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
619
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
620
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
621
+
622
+ key_layer = index_first_axis(
623
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
624
+ indices_k,
625
+ )
626
+ value_layer = index_first_axis(
627
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
628
+ indices_k,
629
+ )
630
+ if query_length == kv_seq_len:
631
+ query_layer = index_first_axis(
632
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
633
+ indices_k,
634
+ )
635
+ cu_seqlens_q = cu_seqlens_k
636
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
637
+ indices_q = indices_k
638
+ elif query_length == 1:
639
+ max_seqlen_in_batch_q = 1
640
+ cu_seqlens_q = torch.arange(
641
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
642
+ ) # There is a memcpy here, that is very bad.
643
+ indices_q = cu_seqlens_q[:-1]
644
+ query_layer = query_layer.squeeze(1)
645
+ else:
646
+ # The -q_len: slice assumes left padding.
647
+ attention_mask = attention_mask[:, -query_length:]
648
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
649
+
650
+ return (
651
+ query_layer,
652
+ key_layer,
653
+ value_layer,
654
+ indices_q,
655
+ (cu_seqlens_q, cu_seqlens_k),
656
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
657
+ )
658
+
659
+
660
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
661
+ class SiglipMLP(nn.Module):
662
+ def __init__(self, config):
663
+ super().__init__()
664
+ self.config = config
665
+ self.activation_fn = ACT2FN[config.hidden_act]
666
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
667
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
668
+
669
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
670
+ hidden_states = self.fc1(hidden_states)
671
+ hidden_states = self.activation_fn(hidden_states)
672
+ hidden_states = self.fc2(hidden_states)
673
+ return hidden_states
674
+
675
+
676
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip
677
+ class SiglipEncoderLayer(nn.Module):
678
+ def __init__(self, config: SiglipVisionConfig):
679
+ super().__init__()
680
+ self.embed_dim = config.hidden_size
681
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
682
+ self.self_attn = SiglipAttention(config) if not self._use_flash_attention_2 else SiglipFlashAttention2(config)
683
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
684
+ self.mlp = SiglipMLP(config)
685
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
686
+
687
+ def forward(
688
+ self,
689
+ hidden_states: torch.Tensor,
690
+ attention_mask: torch.Tensor,
691
+ output_attentions: Optional[bool] = False,
692
+ ) -> Tuple[torch.FloatTensor]:
693
+ """
694
+ Args:
695
+ hidden_states (`torch.FloatTensor`):
696
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
697
+ attention_mask (`torch.FloatTensor`):
698
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
699
+ output_attentions (`bool`, *optional*, defaults to `False`):
700
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
701
+ returned tensors for more detail.
702
+ """
703
+ residual = hidden_states
704
+
705
+ hidden_states = self.layer_norm1(hidden_states)
706
+ hidden_states, attn_weights = self.self_attn(
707
+ hidden_states=hidden_states,
708
+ attention_mask=attention_mask,
709
+ output_attentions=output_attentions,
710
+ )
711
+ hidden_states = residual + hidden_states
712
+
713
+ residual = hidden_states
714
+ hidden_states = self.layer_norm2(hidden_states)
715
+ hidden_states = self.mlp(hidden_states)
716
+ hidden_states = residual + hidden_states
717
+
718
+ outputs = (hidden_states,)
719
+
720
+ if output_attentions:
721
+ outputs += (attn_weights,)
722
+
723
+ return outputs
724
+
725
+
726
+ class SiglipPreTrainedModel(PreTrainedModel):
727
+ """
728
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
729
+ models.
730
+ """
731
+
732
+ config_class = SiglipVisionConfig
733
+ base_model_prefix = "siglip"
734
+ supports_gradient_checkpointing = True
735
+
736
+ def _init_weights(self, module):
737
+ """Initialize the weights"""
738
+
739
+ if isinstance(module, SiglipVisionEmbeddings):
740
+ width = self.config.hidden_size
741
+ nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
742
+ elif isinstance(module, nn.Embedding):
743
+ default_flax_embed_init(module.weight)
744
+ elif isinstance(module, SiglipAttention):
745
+ nn.init.normal_(module.q_proj.weight)
746
+ nn.init.normal_(module.k_proj.weight)
747
+ nn.init.normal_(module.v_proj.weight)
748
+ nn.init.normal_(module.out_proj.weight)
749
+ nn.init.zeros_(module.q_proj.bias)
750
+ nn.init.zeros_(module.k_proj.bias)
751
+ nn.init.zeros_(module.v_proj.bias)
752
+ nn.init.zeros_(module.out_proj.bias)
753
+ elif isinstance(module, SiglipMLP):
754
+ nn.init.normal_(module.fc1.weight)
755
+ nn.init.normal_(module.fc2.weight)
756
+ nn.init.normal_(module.fc1.bias, std=1e-6)
757
+ nn.init.normal_(module.fc2.bias, std=1e-6)
758
+ elif isinstance(module, (nn.Linear, nn.Conv2d)):
759
+ lecun_normal_(module.weight)
760
+ if module.bias is not None:
761
+ nn.init.zeros_(module.bias)
762
+ elif isinstance(module, nn.LayerNorm):
763
+ module.bias.data.zero_()
764
+ module.weight.data.fill_(1.0)
765
+
766
+
767
+ SIGLIP_START_DOCSTRING = r"""
768
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
769
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
770
+ etc.)
771
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
772
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
773
+ and behavior.
774
+ Parameters:
775
+ config ([`SiglipVisionConfig`]): Model configuration class with all the parameters of the model.
776
+ Initializing with a config file does not load the weights associated with the model, only the
777
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
778
+ """
779
+
780
+
781
+ SIGLIP_VISION_INPUTS_DOCSTRING = r"""
782
+ Args:
783
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
784
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
785
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
786
+ output_attentions (`bool`, *optional*):
787
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
788
+ tensors for more detail.
789
+ output_hidden_states (`bool`, *optional*):
790
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
791
+ more detail.
792
+ return_dict (`bool`, *optional*):
793
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
794
+ """
795
+
796
+
797
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip
798
+ class SiglipEncoder(nn.Module):
799
+ """
800
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
801
+ [`SiglipEncoderLayer`].
802
+ Args:
803
+ config: SiglipConfig
804
+ """
805
+
806
+ def __init__(self, config: SiglipVisionConfig):
807
+ super().__init__()
808
+ self.config = config
809
+ self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
810
+ self.gradient_checkpointing = False
811
+
812
+ # Ignore copy
813
+ def forward(
814
+ self,
815
+ inputs_embeds,
816
+ attention_mask: Optional[torch.Tensor] = None,
817
+ output_attentions: Optional[bool] = None,
818
+ output_hidden_states: Optional[bool] = None,
819
+ return_dict: Optional[bool] = None,
820
+ ) -> Union[Tuple, BaseModelOutput]:
821
+ r"""
822
+ Args:
823
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
824
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
825
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
826
+ than the model's internal embedding lookup matrix.
827
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
828
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
829
+ - 1 for tokens that are **not masked**,
830
+ - 0 for tokens that are **masked**.
831
+ [What are attention masks?](../glossary#attention-mask)
832
+ output_attentions (`bool`, *optional*):
833
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
834
+ returned tensors for more detail.
835
+ output_hidden_states (`bool`, *optional*):
836
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
837
+ for more detail.
838
+ return_dict (`bool`, *optional*):
839
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
840
+ """
841
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
842
+ output_hidden_states = (
843
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
844
+ )
845
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
846
+
847
+ encoder_states = () if output_hidden_states else None
848
+ all_attentions = () if output_attentions else None
849
+
850
+ hidden_states = inputs_embeds
851
+ for encoder_layer in self.layers:
852
+ if output_hidden_states:
853
+ encoder_states = encoder_states + (hidden_states,)
854
+ if self.gradient_checkpointing and self.training:
855
+ layer_outputs = self._gradient_checkpointing_func(
856
+ encoder_layer.__call__,
857
+ hidden_states,
858
+ attention_mask,
859
+ output_attentions,
860
+ )
861
+ else:
862
+ layer_outputs = encoder_layer(
863
+ hidden_states,
864
+ attention_mask,
865
+ output_attentions=output_attentions,
866
+ )
867
+
868
+ hidden_states = layer_outputs[0]
869
+
870
+ if output_attentions:
871
+ all_attentions = all_attentions + (layer_outputs[1],)
872
+
873
+ if output_hidden_states:
874
+ encoder_states = encoder_states + (hidden_states,)
875
+
876
+ if not return_dict:
877
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
878
+ return BaseModelOutput(
879
+ last_hidden_state=hidden_states,
880
+ hidden_states=encoder_states,
881
+ attentions=all_attentions,
882
+ )
883
+
884
+
885
+ @add_start_docstrings(
886
+ """The vision model from SigLIP without any head or projection on top.""",
887
+ SIGLIP_START_DOCSTRING,
888
+ )
889
+ class SiglipVisionTransformer(SiglipPreTrainedModel):
890
+ config_class = SiglipVisionConfig
891
+ main_input_name = "pixel_values"
892
+ _supports_flash_attn_2 = True
893
+ _no_split_modules = []
894
+
895
+ def __init__(self, config: SiglipVisionConfig):
896
+ super().__init__(config)
897
+ self.config = config
898
+ embed_dim = config.hidden_size
899
+
900
+ self.embeddings = SiglipVisionEmbeddings(config)
901
+ self.encoder = SiglipEncoder(config)
902
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
903
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
904
+
905
+ # Initialize weights and apply final processing
906
+ self.post_init()
907
+
908
+ def get_input_embeddings(self) -> nn.Module:
909
+ return self.embeddings.patch_embedding
910
+
911
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
912
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig)
913
+ def forward(
914
+ self,
915
+ pixel_values,
916
+ patch_attention_mask: Optional[torch.BoolTensor] = None,
917
+ tgt_sizes: Optional[torch.IntTensor] = None,
918
+ output_attentions: Optional[bool] = None,
919
+ output_hidden_states: Optional[bool] = None,
920
+ return_dict: Optional[bool] = None,
921
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
922
+ r"""
923
+ Returns:
924
+ """
925
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
926
+ output_hidden_states = (
927
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
928
+ )
929
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
930
+
931
+ batch_size = pixel_values.size(0)
932
+ if patch_attention_mask is None:
933
+ patch_attention_mask = torch.ones(
934
+ size=(
935
+ batch_size,
936
+ pixel_values.size(2) // self.config.patch_size,
937
+ pixel_values.size(3) // self.config.patch_size,
938
+ ),
939
+ dtype=torch.bool,
940
+ device=pixel_values.device,
941
+ )
942
+
943
+ hidden_states = self.embeddings(
944
+ pixel_values=pixel_values,
945
+ patch_attention_mask=patch_attention_mask,
946
+ tgt_sizes=tgt_sizes,
947
+ )
948
+
949
+ patch_attention_mask = patch_attention_mask.view(batch_size, -1)
950
+ # The call to `_upad_input` in `_flash_attention_forward` is expensive
951
+ # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
952
+ # avoiding passing the attention_mask, which is equivalent to attending to the full sequence
953
+ if not torch.any(~patch_attention_mask):
954
+ attention_mask = None
955
+ else:
956
+ attention_mask = (
957
+ _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
958
+ if not self._use_flash_attention_2
959
+ else patch_attention_mask
960
+ )
961
+
962
+ encoder_outputs = self.encoder(
963
+ inputs_embeds=hidden_states,
964
+ attention_mask=attention_mask,
965
+ output_attentions=output_attentions,
966
+ output_hidden_states=output_hidden_states,
967
+ return_dict=return_dict,
968
+ )
969
+
970
+ last_hidden_state = encoder_outputs[0]
971
+ last_hidden_state = self.post_layernorm(last_hidden_state)
972
+
973
+ if not return_dict:
974
+ return (last_hidden_state, None) + encoder_outputs[1:]
975
+
976
+ return BaseModelOutputWithPooling(
977
+ last_hidden_state=last_hidden_state,
978
+ pooler_output=None,
979
+ hidden_states=encoder_outputs.hidden_states,
980
+ attentions=encoder_outputs.attentions,
981
+ )
preprocessor_config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "image_processor_type": "MiniCPMVImageProcessor",
3
+ "feature_extractor_type": "MiniCPMAAudioProcessor",
4
+ "auto_map": {
5
+ "AutoProcessor": "processing_minicpmo.MiniCPMOProcessor",
6
+ "AutoImageProcessor": "processing_minicpmo.MiniCPMVImageProcessor",
7
+ "AutoFeatureExtractor": "processing_minicpmo.MiniCPMAAudioProcessor"
8
+ },
9
+ "processor_class": "MiniCPMOProcessor",
10
+ "max_slice_nums": 9,
11
+ "scale_resolution": 448,
12
+ "patch_size": 14,
13
+ "use_image_id": true,
14
+ "image_feature_size": 64,
15
+ "im_start": "<image>",
16
+ "im_end": "</image>",
17
+ "slice_start": "<slice>",
18
+ "slice_end": "</slice>",
19
+ "unk": "<unk>",
20
+ "im_id_start": "<image_id>",
21
+ "im_id_end": "</image_id>",
22
+ "slice_mode": true,
23
+ "audio_pool_step": 5,
24
+ "norm_mean": [
25
+ 0.5,
26
+ 0.5,
27
+ 0.5
28
+ ],
29
+ "norm_std": [
30
+ 0.5,
31
+ 0.5,
32
+ 0.5
33
+ ],
34
+ "version": 4.5
35
+ }
processing_minicpmo.py ADDED
@@ -0,0 +1,1665 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright 2026 The OpenBMB Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import copy
18
+ import math
19
+ import re
20
+ from typing import Any
21
+ from typing import Dict
22
+ from typing import List
23
+ from typing import Optional
24
+ from typing import Tuple
25
+ from typing import Union
26
+
27
+ import numpy as np
28
+ import torch
29
+ from PIL import Image
30
+ from transformers import AutoImageProcessor
31
+ from transformers.audio_utils import spectrogram
32
+ from transformers.audio_utils import window_function
33
+ from transformers.image_processing_utils import BaseImageProcessor
34
+ from transformers.image_processing_utils import BatchFeature
35
+ from transformers.image_transforms import to_channel_dimension_format
36
+ from transformers.image_utils import ChannelDimension
37
+ from transformers.image_utils import ImageInput
38
+ from transformers.image_utils import infer_channel_dimension_format
39
+ from transformers.image_utils import is_torch_tensor
40
+ from transformers.image_utils import to_numpy_array
41
+ from transformers.image_utils import valid_images
42
+ from transformers.models.whisper.feature_extraction_whisper import WhisperFeatureExtractor
43
+ from transformers.processing_utils import ProcessorMixin
44
+ from transformers.tokenization_utils_base import PreTokenizedInput
45
+ from transformers.tokenization_utils_base import TextInput
46
+ from transformers.utils import is_torch_device
47
+ from transformers.utils import is_torch_dtype
48
+ from transformers.utils import requires_backends
49
+ from transformers.utils import TensorType
50
+
51
+
52
+ def recursive_converter(converter, value):
53
+ if isinstance(value, list):
54
+ new_value = []
55
+ for v in value:
56
+ new_value += [recursive_converter(converter, v)]
57
+ return new_value
58
+ else:
59
+ return converter(value)
60
+
61
+
62
+ class MiniCPMOBatchFeature(BatchFeature):
63
+ """Extend from BatchFeature for supporting various image size"""
64
+
65
+ def __init__(self, data: Optional[Dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None):
66
+ super().__init__(data)
67
+ self.convert_to_tensors(tensor_type=tensor_type)
68
+
69
+ def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None, **kwargs):
70
+ if tensor_type is None:
71
+ return self
72
+
73
+ is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type)
74
+
75
+ def converter(value):
76
+ try:
77
+ if not is_tensor(value):
78
+ tensor = as_tensor(value)
79
+ return tensor
80
+ except: # noqa E722
81
+ if key == "overflowing_values":
82
+ raise ValueError("Unable to create tensor returning overflowing values of different lengths. ")
83
+ raise ValueError(
84
+ "Unable to create tensor, you should probably activate padding "
85
+ "with 'padding=True' to have batched tensors with the same length."
86
+ )
87
+
88
+ for key, value in self.items():
89
+ self[key] = recursive_converter(converter, value)
90
+ return self
91
+
92
+ def to(self, *args, **kwargs) -> "MiniCPMOBatchFeature":
93
+ requires_backends(self, ["torch"])
94
+ import torch
95
+
96
+ def cast_tensor(v):
97
+ if not torch.is_tensor(v):
98
+ return v
99
+
100
+ if torch.is_floating_point(v):
101
+ return v.to(*args, **kwargs)
102
+ elif device is not None:
103
+ return v.to(device=device)
104
+ else:
105
+ return v
106
+
107
+ new_data = {}
108
+ device = kwargs.get("device")
109
+ if device is None and len(args) > 0:
110
+ arg = args[0]
111
+ if is_torch_dtype(arg):
112
+ pass
113
+ elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int):
114
+ device = arg
115
+ else:
116
+ raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
117
+
118
+ # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
119
+ for k, v in self.items():
120
+ new_data[k] = recursive_converter(cast_tensor, v)
121
+ self.data = new_data
122
+ return self
123
+
124
+
125
+ class MiniCPMVImageProcessor(BaseImageProcessor):
126
+ model_input_names = ["pixel_values"]
127
+
128
+ def __init__(self, max_slice_nums=9, scale_resolution=448, patch_size=14, **kwargs):
129
+ super().__init__(**kwargs)
130
+ self.max_slice_nums = max_slice_nums
131
+ self.scale_resolution = scale_resolution
132
+ self.patch_size = patch_size
133
+ self.use_image_id = kwargs.pop("use_image_id", False)
134
+ self.image_feature_size = kwargs.pop("image_feature_size", 64)
135
+ self.im_start_token = kwargs.pop("im_start", "<image>")
136
+ self.im_end_token = kwargs.pop("im_end", "</image>")
137
+ self.slice_start_token = kwargs.pop("slice_start", "<slice>")
138
+ self.slice_end_token = kwargs.pop("slice_end", "</slice>")
139
+ self.unk_token = kwargs.pop("unk", "<unk>")
140
+ self.im_id_start = kwargs.pop("im_id_start", "<image_id>")
141
+ self.im_id_end = kwargs.pop("im_id_end", "</image_id>")
142
+ self.slice_mode = kwargs.pop("slice_mode", True)
143
+
144
+ self.mean = np.array(kwargs.pop("norm_mean", [0.5, 0.5, 0.5]))
145
+ self.std = np.array(kwargs.pop("norm_std", [0.5, 0.5, 0.5]))
146
+ self.version = kwargs.pop("version", 2.0)
147
+
148
+ @staticmethod
149
+ def ensure_divide(length, patch_size):
150
+ return max(round(length / patch_size) * patch_size, patch_size)
151
+
152
+ def find_best_resize(self, original_size, scale_resolution, patch_size, allow_upscale=False):
153
+ width, height = original_size
154
+ if (width * height > scale_resolution * scale_resolution) or allow_upscale:
155
+ r = width / height
156
+ height = int(scale_resolution / math.sqrt(r))
157
+ width = int(height * r)
158
+ best_width = self.ensure_divide(width, patch_size)
159
+ best_height = self.ensure_divide(height, patch_size)
160
+ return best_width, best_height
161
+
162
+ def get_refine_size(self, original_size, grid, scale_resolution, patch_size, allow_upscale=False):
163
+ width, height = original_size
164
+ grid_x, grid_y = grid
165
+
166
+ refine_width = self.ensure_divide(width, grid_x)
167
+ refine_height = self.ensure_divide(height, grid_y)
168
+
169
+ grid_width = refine_width / grid_x
170
+ grid_height = refine_height / grid_y
171
+
172
+ best_grid_size = self.find_best_resize(
173
+ (grid_width, grid_height), scale_resolution, patch_size, allow_upscale=allow_upscale
174
+ )
175
+ refine_size = (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y)
176
+ return refine_size
177
+
178
+ @staticmethod
179
+ def split_to_patches(image, grid):
180
+ patches = []
181
+ width, height = image.size
182
+ grid_x = int(width / grid[0])
183
+ grid_y = int(height / grid[1])
184
+ for i in range(0, height, grid_y):
185
+ images = []
186
+ for j in range(0, width, grid_x):
187
+ box = (j, i, j + grid_x, i + grid_y)
188
+ patch = image.crop(box)
189
+ images.append(patch)
190
+ patches.append(images)
191
+ return patches
192
+
193
+ def slice_image(self, image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False):
194
+ original_size = image.size
195
+ source_image = None
196
+ best_grid = self.get_sliced_grid(original_size, max_slice_nums, never_split)
197
+ patches = []
198
+
199
+ if best_grid is None:
200
+ # dont need to slice, upsample
201
+ best_size = self.find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=True)
202
+ source_image = image.resize(best_size, resample=Image.Resampling.BICUBIC)
203
+ else:
204
+ # source image, down-sampling and ensure divided by patch_size
205
+ best_resize = self.find_best_resize(original_size, scale_resolution, patch_size)
206
+ source_image = image.copy().resize(best_resize, resample=Image.Resampling.BICUBIC)
207
+ refine_size = self.get_refine_size(
208
+ original_size, best_grid, scale_resolution, patch_size, allow_upscale=True
209
+ )
210
+ refine_image = image.resize(refine_size, resample=Image.Resampling.BICUBIC)
211
+ patches = self.split_to_patches(refine_image, best_grid)
212
+
213
+ return source_image, patches, best_grid
214
+
215
+ def get_grid_placeholder(self, grid):
216
+ if grid is None:
217
+ return ""
218
+ slice_image_placeholder = (
219
+ self.slice_start_token + self.unk_token * self.image_feature_size + self.slice_end_token
220
+ )
221
+
222
+ cols = grid[0]
223
+ rows = grid[1]
224
+ slices = []
225
+ for i in range(rows):
226
+ lines = []
227
+ for j in range(cols):
228
+ lines.append(slice_image_placeholder)
229
+ slices.append("".join(lines))
230
+
231
+ slice_placeholder = "\n".join(slices)
232
+ return slice_placeholder
233
+
234
+ def get_image_id_placeholder(self, idx=0):
235
+ return f"{self.im_id_start}{idx}{self.im_id_end}"
236
+
237
+ def get_sliced_images(self, image, max_slice_nums=None):
238
+ slice_images = []
239
+
240
+ if not self.slice_mode:
241
+ return [image]
242
+
243
+ max_slice_nums = self.max_slice_nums if max_slice_nums is None else int(max_slice_nums)
244
+ assert max_slice_nums > 0
245
+ source_image, patches, sliced_grid = self.slice_image(
246
+ image, max_slice_nums, self.scale_resolution, self.patch_size # default: 9 # default: 448 # default: 14
247
+ )
248
+
249
+ slice_images.append(source_image)
250
+ if len(patches) > 0:
251
+ for i in range(len(patches)):
252
+ for j in range(len(patches[0])):
253
+ slice_images.append(patches[i][j])
254
+ return slice_images
255
+
256
+ def get_sliced_grid(self, image_size, max_slice_nums, nerver_split=False):
257
+ original_width, original_height = image_size
258
+ log_ratio = math.log(original_width / original_height)
259
+ ratio = original_width * original_height / (self.scale_resolution * self.scale_resolution)
260
+ multiple = min(math.ceil(ratio), max_slice_nums)
261
+ if multiple <= 1 or nerver_split:
262
+ return None
263
+ candidate_split_grids_nums = []
264
+ for i in [multiple - 1, multiple, multiple + 1]:
265
+ if i == 1 or i > max_slice_nums:
266
+ continue
267
+ candidate_split_grids_nums.append(i)
268
+
269
+ candidate_grids = []
270
+ for split_grids_nums in candidate_split_grids_nums:
271
+ m = 1
272
+ while m <= split_grids_nums:
273
+ if split_grids_nums % m == 0:
274
+ candidate_grids.append([m, split_grids_nums // m])
275
+ m += 1
276
+
277
+ best_grid = [1, 1]
278
+ min_error = float("inf")
279
+ for grid in candidate_grids:
280
+ error = abs(log_ratio - math.log(grid[0] / grid[1]))
281
+ if error < min_error:
282
+ best_grid = grid
283
+ min_error = error
284
+
285
+ return best_grid
286
+
287
+ def get_slice_image_placeholder(self, image_size, image_idx=0, max_slice_nums=None, use_image_id=None):
288
+ max_slice_nums = self.max_slice_nums if max_slice_nums is None else int(max_slice_nums)
289
+ assert max_slice_nums > 0
290
+ grid = self.get_sliced_grid(image_size=image_size, max_slice_nums=max_slice_nums)
291
+
292
+ image_placeholder = self.im_start_token + self.unk_token * self.image_feature_size + self.im_end_token
293
+ use_image_id = self.use_image_id if use_image_id is None else bool(use_image_id)
294
+ if use_image_id:
295
+ final_placeholder = self.get_image_id_placeholder(image_idx) + image_placeholder
296
+ else:
297
+ final_placeholder = image_placeholder
298
+
299
+ if self.slice_mode:
300
+ final_placeholder = final_placeholder + self.get_grid_placeholder(grid=grid)
301
+ return final_placeholder
302
+
303
+ @staticmethod
304
+ def to_pil_image(image, rescale=None) -> Image.Image:
305
+ """Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back
306
+ as the last axis if needed.
307
+
308
+ Args:
309
+ image (`Image.Image` or `numpy.ndarray` or `torch.Tensor`):
310
+ The image to convert to the PIL Image format.
311
+ rescale (`bool`, *optional*):
312
+ whether to apply the scaling factor (to make pixel values integers between 0 and 255). Will
313
+ default to `True` if the image type is a floating type, `False` otherwise.
314
+ """
315
+ if isinstance(image, Image.Image):
316
+ return image
317
+ if is_torch_tensor(image):
318
+ image = image.numpy()
319
+
320
+ if isinstance(image, np.ndarray):
321
+ if rescale is None:
322
+ # rescale default to the array being of floating type.
323
+ rescale = isinstance(image.flat[0], np.floating)
324
+ # If the channel as been moved to first dim, we put it back at the end.
325
+ if image.ndim == 3 and image.shape[0] in [1, 3]:
326
+ image = image.transpose(1, 2, 0)
327
+ if rescale:
328
+ image = image * 255
329
+ image = image.astype(np.uint8)
330
+ return Image.fromarray(image)
331
+ return image
332
+
333
+ def reshape_by_patch(self, image):
334
+ image = torch.from_numpy(image)
335
+ patch_size = self.patch_size
336
+ patches = torch.nn.functional.unfold(image, (patch_size, patch_size), stride=(patch_size, patch_size))
337
+
338
+ patches = patches.reshape(image.size(0), patch_size, patch_size, -1)
339
+ patches = patches.permute(0, 1, 3, 2).reshape(image.size(0), patch_size, -1)
340
+ return patches.numpy()
341
+
342
+ def preprocess(
343
+ self,
344
+ images: Union[Image.Image, List[Image.Image], List[List[Image.Image]]],
345
+ do_pad: Optional[bool] = True,
346
+ max_slice_nums: int = None,
347
+ return_tensors: Optional[Union[str, TensorType]] = None,
348
+ **kwargs,
349
+ ) -> MiniCPMOBatchFeature:
350
+ if isinstance(images, Image.Image):
351
+ images_list = [[images]]
352
+ elif isinstance(images[0], Image.Image):
353
+ images_list = [images]
354
+ else:
355
+ images_list = images
356
+
357
+ new_images_list = []
358
+ image_sizes_list = []
359
+ tgt_sizes_list = []
360
+
361
+ for _images in images_list:
362
+ if _images is None or len(_images) == 0:
363
+ new_images_list.append([])
364
+ image_sizes_list.append([])
365
+ tgt_sizes_list.append([])
366
+ continue
367
+ if not valid_images(_images):
368
+ raise ValueError(
369
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
370
+ "torch.Tensor, tf.Tensor or jax.ndarray."
371
+ )
372
+
373
+ _images = [self.to_pil_image(image).convert("RGB") for image in _images]
374
+ input_data_format = infer_channel_dimension_format(np.array(_images[0]))
375
+
376
+ new_images = []
377
+ image_sizes = [image.size for image in _images]
378
+ tgt_sizes = []
379
+ for image in _images:
380
+ image_patches = self.get_sliced_images(image, max_slice_nums)
381
+ image_patches = [to_numpy_array(image).astype(np.float32) / 255 for image in image_patches]
382
+ image_patches = [
383
+ self.normalize(image=image, mean=self.mean, std=self.std, input_data_format=input_data_format)
384
+ for image in image_patches
385
+ ]
386
+ image_patches = [
387
+ to_channel_dimension_format(image, ChannelDimension.FIRST, input_channel_dim=input_data_format)
388
+ for image in image_patches
389
+ ]
390
+ for slice_image in image_patches:
391
+ new_images.append(self.reshape_by_patch(slice_image))
392
+ tgt_sizes.append(
393
+ np.array((slice_image.shape[1] // self.patch_size, slice_image.shape[2] // self.patch_size))
394
+ )
395
+
396
+ if tgt_sizes:
397
+ tgt_sizes = np.vstack(tgt_sizes)
398
+
399
+ new_images_list.append(new_images)
400
+ image_sizes_list.append(image_sizes)
401
+ tgt_sizes_list.append(tgt_sizes)
402
+ return MiniCPMOBatchFeature(
403
+ data={"pixel_values": new_images_list, "image_sizes": image_sizes_list, "tgt_sizes": tgt_sizes_list},
404
+ tensor_type=return_tensors,
405
+ )
406
+
407
+
408
+ AutoImageProcessor.register("MiniCPMVImageProcessor", MiniCPMVImageProcessor)
409
+
410
+
411
+ def chunk_audio(audio: np.ndarray, max_duration_seconds: int = 30, sample_rate: int = 16000) -> List[np.ndarray]:
412
+ """split long audio into chunks
413
+
414
+ Args:
415
+ audio:
416
+ max_duration_seconds:
417
+ sample_rate:
418
+
419
+ Returns:
420
+ chunks
421
+ """
422
+ max_len = int(max_duration_seconds * sample_rate)
423
+
424
+ if len(audio) <= max_len:
425
+ return [audio]
426
+
427
+ chunks = []
428
+ for i in range(0, len(audio), max_len):
429
+ chunk = audio[i : i + max_len]
430
+ chunks.append(chunk)
431
+
432
+ return chunks
433
+
434
+
435
+ def process_audio_batch(
436
+ audios: Union[np.ndarray, List[np.ndarray], List[List[np.ndarray]]],
437
+ feature_extractor,
438
+ sampling_rate: int = 16000,
439
+ max_duration_seconds: int = 30,
440
+ return_attention_mask: bool = True,
441
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
442
+ """extract audio mel features
443
+
444
+ Args:
445
+ audios:
446
+ feature_extractor: WhisperFeatureExtractor
447
+ sampling_rate:
448
+ max_duration_seconds:
449
+ return_attention_mask:
450
+
451
+ Returns:
452
+ (audio_features, audio_feature_lens)
453
+ audio_features: [batch_size, n_mels, max_frames]
454
+ audio_feature_lens:
455
+ """
456
+ if isinstance(audios, np.ndarray):
457
+ audios_list = [[audios]]
458
+ elif len(audios) > 0 and isinstance(audios[0], np.ndarray):
459
+ audios_list = [audios]
460
+ else:
461
+ audios_list = audios
462
+
463
+ audio_features_all = []
464
+ audio_feature_lens_list = []
465
+
466
+ for batch_audios in audios_list:
467
+ batch_lens = []
468
+
469
+ for audio in batch_audios:
470
+ chunks = chunk_audio(audio, max_duration_seconds, sampling_rate)
471
+
472
+ for chunk in chunks:
473
+ audio_input = feature_extractor(
474
+ chunk,
475
+ sampling_rate=sampling_rate,
476
+ return_tensors="pt",
477
+ padding="max_length",
478
+ return_attention_mask=return_attention_mask,
479
+ )
480
+
481
+ audio_feature = audio_input["input_features"] # [1, 80, frames]
482
+
483
+ if return_attention_mask:
484
+ actual_len = audio_input["attention_mask"].sum(dim=1) # Tensor([frames])
485
+ audio_feature = audio_feature[:, :, : actual_len[0]]
486
+ batch_lens.append(actual_len[0])
487
+ else:
488
+ batch_lens.append(torch.tensor(audio_feature.shape[2]))
489
+
490
+ audio_features_all.append(audio_feature.squeeze(0)) # [80, frames]
491
+
492
+ if len(batch_lens) > 0:
493
+ audio_feature_lens_list.append(torch.hstack(batch_lens))
494
+ else:
495
+ audio_feature_lens_list.append(torch.tensor([]))
496
+
497
+ # pad to same length
498
+ if audio_features_all:
499
+ audio_features = torch.nn.utils.rnn.pad_sequence(
500
+ [feat.transpose(0, 1) for feat in audio_features_all], batch_first=True, padding_value=0.0
501
+ ).transpose(
502
+ 1, 2
503
+ ) # [batch, 80, max_frames]
504
+ else:
505
+ audio_features = torch.tensor([])
506
+
507
+ return audio_features, audio_feature_lens_list
508
+
509
+
510
+ def regroup_audio_features(
511
+ audio_features: torch.Tensor, audio_feature_lens: List[torch.Tensor], regroup_seconds: int, fps: int = 100
512
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
513
+ """regroup audio features to fixed duration
514
+
515
+ Args:
516
+ audio_features: [batch, n_mels, frames]
517
+ audio_feature_lens: each batch's actual length
518
+ regroup_seconds: regroup duration (seconds)
519
+ fps: frames per second
520
+
521
+ Returns:
522
+ (regrouped_features, regrouped_lens)
523
+ """
524
+ # flatten to continuous frames sequence
525
+ all_lens = []
526
+ for lens in audio_feature_lens:
527
+ if isinstance(lens, torch.Tensor):
528
+ all_lens.extend(lens.tolist())
529
+ elif isinstance(lens, list):
530
+ all_lens.extend([int(x) for x in lens])
531
+
532
+ if len(all_lens) == 0:
533
+ return torch.tensor([]), []
534
+
535
+ # concatenate all valid features
536
+ flat_slices = [audio_features[i, :, :L] for i, L in enumerate(all_lens)] # [n_mels, L]
537
+
538
+ if len(flat_slices) == 1:
539
+ full_feat = flat_slices[0]
540
+ else:
541
+ full_feat = torch.cat(flat_slices, dim=1) # [n_mels, total_frames]
542
+
543
+ # split to fixed frames
544
+ frames_per_seg = int(regroup_seconds * fps)
545
+ segments = []
546
+
547
+ for start in range(0, full_feat.size(1), frames_per_seg):
548
+ seg = full_feat[:, start : start + frames_per_seg]
549
+ if seg.size(1) > 0:
550
+ segments.append(seg)
551
+
552
+ if len(segments) == 0:
553
+ return torch.tensor([]), []
554
+
555
+ # pad and convert to batch
556
+ seg_lens = [s.size(1) for s in segments]
557
+ segs_transposed = [s.transpose(0, 1) for s in segments]
558
+
559
+ padded = torch.nn.utils.rnn.pad_sequence(segs_transposed, batch_first=True, padding_value=0.0) # [N, max_T, n_mels]
560
+
561
+ padded = padded.transpose(1, 2) # [N, n_mels, max_T]
562
+ lens_tensor = torch.tensor(seg_lens, dtype=torch.int32, device=padded.device)
563
+
564
+ return padded, [lens_tensor]
565
+
566
+
567
+ class MiniCPMAAudioProcessor(WhisperFeatureExtractor):
568
+ """
569
+ On top of WhisperFeatureExtractor:
570
+ - support dynamic_log_norm (original max-8dB, adjustable dynamic_range_db)
571
+ - or fixed log_floor_db (e.g. -10dB)
572
+ - this is because we need to do streaming scheme, in which we can't do dynamic setting
573
+ - this can be modified in the middle, through set_dynamic_log_norm
574
+ Two paths (torch / numpy) keep consistent clipping and scaling order:
575
+ log10 -> (dynamic/fixed lower limit clipping) -> (+4)/4
576
+ """
577
+
578
+ def __init__(
579
+ self,
580
+ *args,
581
+ dynamic_log_norm: bool = True,
582
+ dynamic_range_db: float = 8.0,
583
+ log_floor_db: float = -10.0,
584
+ **kwargs,
585
+ ):
586
+ super().__init__(*args, **kwargs)
587
+ self.dynamic_log_norm = bool(dynamic_log_norm)
588
+ self.dynamic_range_db = float(dynamic_range_db)
589
+ self.log_floor_db = float(log_floor_db)
590
+
591
+ def set_spac_log_norm(
592
+ self,
593
+ dynamic_range_db: Optional[float] = None,
594
+ log_floor_db: Optional[float] = None,
595
+ *,
596
+ inplace: bool = True,
597
+ ) -> "MiniCPMAAudioProcessor":
598
+ """Hot update dynamic/fixed lower limit strategy.
599
+
600
+ Args:
601
+ enabled: True=use dynamic threshold (max - dynamic_range_db), False=use fixed lower limit log_floor_db.
602
+ None means keep unchanged.
603
+ dynamic_range_db: dynamic range (dB), only effective when enabled=True. None means keep unchanged.
604
+ log_floor_db: fixed log floor (dB, usually <= 0), only effective when enabled=False. None means keep unchanged.
605
+ inplace: True directly modify current instance; False return a shallow copy and modify on it.
606
+
607
+ Returns:
608
+ self or new instance (when inplace=False).
609
+ """
610
+
611
+ target = self if inplace else copy.copy(self)
612
+
613
+ if dynamic_range_db is not None:
614
+ val = float(dynamic_range_db)
615
+ if val < 0:
616
+ raise ValueError("dynamic_range_db must be >= 0.")
617
+ target.dynamic_log_norm = True # explicitly set the value to dynamic mode
618
+ target.dynamic_range_db = val
619
+
620
+ if log_floor_db is not None:
621
+ val = float(log_floor_db)
622
+ # usually log10(mel) maximum is not more than ~0dB, floor should be <= 0; here do loose validation
623
+ if val > 0:
624
+ raise ValueError("log_floor_db should be <= 0 (log10 scale).")
625
+ target.dynamic_log_norm = False # explicitly set the value to fixed lower limit mode
626
+ target.log_floor_db = val
627
+
628
+ return target
629
+
630
+ def _np_extract_fbank_features(self, waveform_batch: np.ndarray, device: str) -> np.ndarray:
631
+ """NumPy version consistent with upstream, but replace max-8dB with configurable dynamic/fixed lower limit clipping."""
632
+ if device != "cpu":
633
+ raise ValueError(
634
+ f"Got device `{device}` for feature extraction, but feature extraction on CUDA accelerator "
635
+ "devices requires torch. Set device='cpu' or install torch."
636
+ )
637
+
638
+ log_spec_batch: List[np.ndarray] = []
639
+ for waveform in waveform_batch:
640
+ # generate log10 Mel
641
+ log_spec = spectrogram(
642
+ waveform,
643
+ window_function(self.n_fft, "hann"),
644
+ frame_length=self.n_fft,
645
+ hop_length=self.hop_length,
646
+ power=2.0,
647
+ dither=self.dither,
648
+ mel_filters=self.mel_filters,
649
+ log_mel="log10",
650
+ )
651
+ # consistent with upstream: remove the last frame
652
+ log_spec = log_spec[:, :-1]
653
+
654
+ # dynamic/fixed clipping
655
+ if self.dynamic_log_norm:
656
+ threshold = log_spec.max() - self.dynamic_range_db
657
+ log_spec = np.maximum(log_spec, threshold)
658
+ else:
659
+ log_spec = np.maximum(log_spec, self.log_floor_db)
660
+
661
+ # consistent with Whisper linear scaling
662
+ log_spec = (log_spec + 4.0) / 4.0
663
+
664
+ log_spec_batch.append(log_spec)
665
+
666
+ return np.array(log_spec_batch)
667
+
668
+ def _torch_extract_fbank_features(self, waveform: np.ndarray, device: str = "cpu") -> np.ndarray:
669
+ if torch is None:
670
+ raise RuntimeError("PyTorch is not installed, cannot compute STFT on GPU.")
671
+
672
+ waveform = torch.from_numpy(waveform).to(device, torch.float32)
673
+ window = torch.hann_window(self.n_fft, device=device)
674
+
675
+ if self.dither != 0.0:
676
+ waveform = waveform + self.dither * torch.randn_like(waveform)
677
+
678
+ stft = torch.stft(waveform, n_fft=self.n_fft, hop_length=self.hop_length, window=window, return_complex=True)
679
+ magnitudes = stft[..., :-1].abs() ** 2
680
+
681
+ mel_filters = torch.from_numpy(self.mel_filters).to(device, torch.float32) # [n_mels, 1+n_fft//2]
682
+ mel_spec = mel_filters.T @ magnitudes # [..., n_mels, T]
683
+
684
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10() # <= 0
685
+
686
+ if self.dynamic_log_norm:
687
+ if waveform.dim() == 2:
688
+ max_val_t = log_spec.max(dim=2, keepdim=True)[0] # over T
689
+ max_val_bt = max_val_t.max(dim=1, keepdim=True)[0] # over mel
690
+ threshold = max_val_bt - self.dynamic_range_db
691
+ log_spec = torch.maximum(log_spec, threshold)
692
+ else:
693
+ threshold = log_spec.max() - self.dynamic_range_db
694
+ log_spec = torch.maximum(log_spec, threshold)
695
+ else:
696
+ floor_tensor = torch.tensor(self.log_floor_db, dtype=log_spec.dtype, device=log_spec.device)
697
+ log_spec = torch.maximum(log_spec, floor_tensor)
698
+
699
+ log_spec = (log_spec + 4.0) / 4.0
700
+
701
+ if device != "cpu":
702
+ log_spec = log_spec.detach().cpu()
703
+ return log_spec.numpy()
704
+
705
+ def process(self, *args, **kwargs):
706
+ """Alias of __call__ for convenience."""
707
+ return self.__call__(*args, **kwargs)
708
+
709
+
710
+ class StreamingMelProcessorExact:
711
+ """Strictly offline equivalent streaming Mel processor.
712
+
713
+ - accumulate all historical audio into buffer; use the same feature_extractor to calculate the entire mel after each addition.
714
+ - only output "stable" frames: the frame center does not depend on future (right) context, i.e. center + n_fft//2 <= current buffer length.
715
+ - output the last batch of frames at the end (flush), ensuring complete consistency with offline full-calculation.
716
+
717
+ Cost: Each call performs feature extraction on the accumulated buffer (can be optimized to incremental if needed).
718
+ """
719
+
720
+ def __init__(
721
+ self,
722
+ feature_extractor: MiniCPMAAudioProcessor,
723
+ chunk_ms: int = 100,
724
+ first_chunk_ms: Optional[int] = None,
725
+ sample_rate: int = 16000,
726
+ n_fft: int = 400,
727
+ hop_length: int = 160,
728
+ n_mels: int = 80,
729
+ cnn_redundancy_ms: int = 10, # (given in ms, usually 10ms=1 frame)
730
+ # sliding window parameters
731
+ enable_sliding_window: bool = False, # whether to enable sliding window
732
+ slide_trigger_seconds: float = 30.0, # trigger threshold for sliding window in seconds
733
+ slide_stride_seconds: float = 10.0, # stride for sliding window in seconds
734
+ ):
735
+ self.feature_extractor = feature_extractor
736
+ self.chunk_ms = chunk_ms
737
+ self.first_chunk_ms = first_chunk_ms if first_chunk_ms is not None else chunk_ms
738
+ self.sample_rate = sample_rate
739
+ self.n_fft = n_fft
740
+ self.hop_length = hop_length
741
+ self.n_mels = n_mels
742
+
743
+ self.chunk_samples = int(round(chunk_ms * sample_rate / 1000))
744
+ self.chunk_frames = self.chunk_samples // hop_length
745
+ # align to hop_length to avoid frame boundary issues
746
+ hop = self.hop_length
747
+ raw_first_samples = int(round(self.first_chunk_ms * sample_rate / 1000))
748
+ aligned_first = max(hop, (raw_first_samples // hop) * hop)
749
+ self.first_chunk_samples = aligned_first
750
+ self.half_window = n_fft // 2 # required right context
751
+
752
+ # redundancy frames (in frames), <=1 frame: 10ms → 1 frame
753
+ self.cnn_redundancy_ms = cnn_redundancy_ms
754
+ self.cnn_redundancy_samples = int(cnn_redundancy_ms * sample_rate / 1000)
755
+ self.cnn_redundancy_frames = max(0, self.cnn_redundancy_samples // hop_length)
756
+
757
+ # sliding window configuration (Trigger mode)
758
+ self.enable_sliding_window = enable_sliding_window
759
+ self.trigger_seconds = slide_trigger_seconds
760
+ self.slide_seconds = slide_stride_seconds
761
+
762
+ # shift/base (global frame coordinates)
763
+ self.left_samples_dropped = 0 # samples dropped from the left
764
+ self.base_T = 0 # index of the "global frame" corresponding to mel_full[:, :, 0]
765
+
766
+ self.reset()
767
+
768
+ def reset(self):
769
+ self.buffer = np.zeros(0, dtype=np.float32)
770
+ self.last_emitted_T = 0
771
+ self.total_samples_processed = 0
772
+ self.chunk_count = 0
773
+ self.is_first = True
774
+ self.left_samples_dropped = 0
775
+ self.base_T = 0
776
+
777
+ def get_chunk_size(self) -> int:
778
+ return self.first_chunk_samples if self.is_first else self.chunk_samples
779
+
780
+ def get_expected_output_frames(self) -> int:
781
+ raise NotImplementedError("get_expected_output_frames is not implemented")
782
+
783
+ def _extract_full(self) -> torch.Tensor:
784
+ # when buffer length is less than n_fft, Whisper's internal STFT will raise an error in center=True and pad mode
785
+ # (pad is greater than input length). At this time, there is no stable frame to output, so return empty features directly.
786
+ if len(self.buffer) < self.n_fft:
787
+ raise ValueError(f"buffer length is shorter than n_fft {len(self.buffer)} < {self.n_fft}")
788
+ # if buffer length is less than 5s, use set_spac_log_norm(log_floor_db=-10) or the last cached result
789
+ if len(self.buffer) < 5 * self.sample_rate:
790
+ # TODO: here the best is to do some experiments to choose the best one, now this is selected through experience, can see MiniCPMAAudioProcessor's main implementation
791
+ self.feature_extractor.set_spac_log_norm(log_floor_db=-10)
792
+ # if buffer length is greater than 5s, use set_spac_log_norm(dynamic_range_db=8)
793
+ else:
794
+ self.feature_extractor.set_spac_log_norm(dynamic_range_db=8)
795
+ feats = self.feature_extractor(
796
+ self.buffer,
797
+ sampling_rate=self.sample_rate,
798
+ return_tensors="pt",
799
+ padding=False,
800
+ )
801
+ return feats.input_features # [1, 80, T]
802
+
803
+ def _stable_frames_count(self) -> int:
804
+ # number of stable frames = floor((len(buffer) - half_window) / hop) + 1, minimum is 0
805
+ L = int(self.buffer.shape[0])
806
+ if L <= 0:
807
+ return 0
808
+ if L < self.half_window:
809
+ return 0
810
+ return max(0, (L - self.half_window) // self.hop_length + 1)
811
+
812
+ def _maybe_slide_buffer(self):
813
+ """Trigger mode sliding window: when the buffer reaches the trigger threshold, slide a fixed length window."""
814
+ if not self.enable_sliding_window:
815
+ return
816
+
817
+ sr = self.sample_rate
818
+ hop = self.hop_length
819
+ L = len(self.buffer)
820
+
821
+ # convert seconds to samples
822
+ trigger_samples = int(self.trigger_seconds * sr)
823
+ stride_samples = int(self.slide_seconds * sr)
824
+
825
+ # check if the trigger threshold is reached
826
+ if L < trigger_samples:
827
+ return
828
+
829
+ # calculate the number of samples to drop (fixed sliding stride_samples)
830
+ drop = stride_samples
831
+
832
+ # cannot drop the left context that is still needed for subsequent emission
833
+ # in trigger mode, we only need to protect the minimum necessary data
834
+ # i.e. ensure that we do not discard frames that may be needed in the future
835
+ last_emitted_local = self.last_emitted_T - self.base_T
836
+
837
+ # only protect necessary context (e.g. the most recent 1 second data)
838
+ min_keep_seconds = 1.0 # keep at least 1 second of data to ensure continuity
839
+ min_keep_samples = int(min_keep_seconds * sr)
840
+
841
+ # guard_samples are the minimum samples we must keep
842
+ guard_samples = min(min_keep_samples, L - drop)
843
+
844
+ # limit: do not exceed the safe boundary; and align hop
845
+ max_allowed_drop = max(0, L - guard_samples)
846
+ drop = min(drop, max_allowed_drop)
847
+ drop = (drop // hop) * hop
848
+
849
+ if drop <= 0:
850
+ return
851
+
852
+ # truly drop & update base
853
+ self.buffer = self.buffer[drop:]
854
+ self.left_samples_dropped += drop
855
+ self.base_T += drop // hop
856
+
857
+ def process(self, audio_chunk: np.ndarray, is_last_chunk: bool = False) -> Tuple[torch.Tensor, Dict]:
858
+ self.chunk_count += 1
859
+ # append to buffer
860
+ if len(self.buffer) == 0:
861
+ self.buffer = audio_chunk.astype(np.float32, copy=True)
862
+ else:
863
+ self.buffer = np.concatenate([self.buffer, audio_chunk.astype(np.float32, copy=True)])
864
+
865
+ # sliding window processing
866
+ self._maybe_slide_buffer()
867
+
868
+ # full extraction (for the current window)
869
+ mel_full = self._extract_full()
870
+ T_full = mel_full.shape[-1] # local frames in the current window
871
+ stable_T = min(T_full, self._stable_frames_count()) # local stable frames
872
+ stable_T_global = self.base_T + stable_T # map to global frame coordinates
873
+
874
+ # plan the core frames for the current emission (global coordinates)
875
+ core_start_g = self.last_emitted_T
876
+ core_end_g = core_start_g + self.chunk_frames
877
+ required_stable_g = core_end_g + self.cnn_redundancy_frames
878
+
879
+ if stable_T_global >= required_stable_g or is_last_chunk:
880
+ emit_start_g = max(0, core_start_g - self.cnn_redundancy_frames)
881
+ emit_end_g = core_end_g + self.cnn_redundancy_frames
882
+
883
+ # global -> local index
884
+ emit_start = max(0, emit_start_g - self.base_T)
885
+ emit_end = emit_end_g - self.base_T
886
+ emit_start = max(0, min(emit_start, T_full))
887
+ emit_end = max(emit_start, min(emit_end, T_full))
888
+
889
+ mel_output = mel_full[:, :, emit_start:emit_end]
890
+ self.last_emitted_T = core_end_g # only advance the core frame pointer (global)
891
+ else:
892
+ mel_output = mel_full[:, :, 0:0]
893
+
894
+ self.total_samples_processed += len(audio_chunk)
895
+ self.is_first = False
896
+
897
+ info = {
898
+ "type": "exact_chunk",
899
+ "chunk_number": self.chunk_count,
900
+ "emitted_frames": mel_output.shape[-1],
901
+ "stable_T": stable_T,
902
+ "T_full": T_full,
903
+ "base_T": self.base_T,
904
+ "stable_T_global": stable_T_global,
905
+ "buffer_len_samples": int(self.buffer.shape[0]),
906
+ "left_samples_dropped": self.left_samples_dropped,
907
+ "core_start": core_start_g, # if keep the original field name, use the global value here
908
+ "core_end": core_end_g, # same as above
909
+ }
910
+ return mel_output, info
911
+
912
+ def flush(self) -> torch.Tensor:
913
+ """Called when the stream ends, output the remaining unemitted frames, ensuring consistency with offline (calculated by global coordinates)."""
914
+ if len(self.buffer) == 0:
915
+ return torch.zeros(1, 80, 0)
916
+
917
+ mel_full = self._extract_full()
918
+ T_local = mel_full.shape[-1]
919
+ T_global = self.base_T + T_local
920
+
921
+ if self.last_emitted_T < T_global:
922
+ start_l = max(0, self.last_emitted_T - self.base_T)
923
+ tail = mel_full[:, :, start_l:]
924
+ self.last_emitted_T = T_global
925
+ return tail
926
+ return mel_full[:, :, 0:0]
927
+
928
+ def get_config(self) -> Dict:
929
+ return {
930
+ "chunk_ms": self.chunk_ms,
931
+ "first_chunk_ms": self.first_chunk_ms,
932
+ "effective_first_chunk_ms": self.first_chunk_samples / self.sample_rate * 1000.0,
933
+ "sample_rate": self.sample_rate,
934
+ "n_fft": self.n_fft,
935
+ "hop_length": self.hop_length,
936
+ "cnn_redundancy_ms": self.cnn_redundancy_ms,
937
+ "cnn_redundancy_frames": self.cnn_redundancy_frames,
938
+ "enable_sliding_window": self.enable_sliding_window,
939
+ "trigger_seconds": self.trigger_seconds,
940
+ "slide_seconds": self.slide_seconds,
941
+ }
942
+
943
+ def get_state(self) -> Dict:
944
+ return {
945
+ "chunk_count": self.chunk_count,
946
+ "last_emitted_T": self.last_emitted_T,
947
+ "total_samples_processed": self.total_samples_processed,
948
+ "buffer_len": int(self.buffer.shape[0]),
949
+ "base_T": self.base_T,
950
+ "left_samples_dropped": self.left_samples_dropped,
951
+ }
952
+
953
+ def get_snapshot(self) -> Dict:
954
+ """Get a complete state snapshot (including buffer), used for recovery from a fast start.
955
+
956
+ Returns:
957
+ A dictionary containing the complete state, which can be used to restore the snapshot
958
+ """
959
+ buffer_copy = self.buffer.copy()
960
+ snapshot = {
961
+ "chunk_count": self.chunk_count,
962
+ "last_emitted_T": self.last_emitted_T,
963
+ "total_samples_processed": self.total_samples_processed,
964
+ "buffer": buffer_copy,
965
+ "base_T": self.base_T,
966
+ "left_samples_dropped": self.left_samples_dropped,
967
+ "is_first": self.is_first,
968
+ # save the state of the feature_extractor (key: ensure determinism of mel feature extraction)
969
+ "fe_dynamic_log_norm": getattr(self.feature_extractor, "dynamic_log_norm", None),
970
+ "fe_dynamic_range_db": getattr(self.feature_extractor, "dynamic_range_db", None),
971
+ "fe_log_floor_db": getattr(self.feature_extractor, "log_floor_db", None),
972
+ }
973
+
974
+ return snapshot
975
+
976
+ def restore_snapshot(self, snapshot: Dict) -> None:
977
+ """Restore state from a snapshot
978
+
979
+ Args:
980
+ snapshot: the snapshot dictionary returned by get_snapshot
981
+ """
982
+ # record the state before restoration
983
+ prev_state = {
984
+ "chunk_count": self.chunk_count,
985
+ "last_emitted_T": self.last_emitted_T,
986
+ "buffer_len": len(self.buffer),
987
+ }
988
+
989
+ # restore state
990
+ self.chunk_count = snapshot["chunk_count"]
991
+ self.last_emitted_T = snapshot["last_emitted_T"]
992
+ self.total_samples_processed = snapshot["total_samples_processed"]
993
+ self.buffer = snapshot["buffer"].copy() # copy buffer
994
+ self.base_T = snapshot["base_T"]
995
+ self.left_samples_dropped = snapshot["left_samples_dropped"]
996
+ self.is_first = snapshot["is_first"]
997
+
998
+ # restore the state of the feature_extractor (key: ensure determinism of mel feature extraction)
999
+ if snapshot.get("fe_dynamic_log_norm") is not None:
1000
+ self.feature_extractor.dynamic_log_norm = snapshot["fe_dynamic_log_norm"]
1001
+ if snapshot.get("fe_dynamic_range_db") is not None:
1002
+ self.feature_extractor.dynamic_range_db = snapshot["fe_dynamic_range_db"]
1003
+ if snapshot.get("fe_log_floor_db") is not None:
1004
+ self.feature_extractor.log_floor_db = snapshot["fe_log_floor_db"]
1005
+
1006
+
1007
+ class MiniCPMOProcessor(ProcessorMixin):
1008
+ attributes = ["image_processor", "audio_processor", "tokenizer"]
1009
+ audio_processor_class = "AutoFeatureExtractor"
1010
+ image_processor_class = "AutoImageProcessor"
1011
+ tokenizer_class = "AutoTokenizer"
1012
+
1013
+ def __init__(self, image_processor=None, audio_processor=None, tokenizer=None, **kwargs):
1014
+ super().__init__(image_processor, audio_processor, tokenizer)
1015
+
1016
+ self.version = image_processor.version if image_processor else None
1017
+ # audio feature pooling step, needs to be consistent with config.audio_pool_step
1018
+ self.pool_step = kwargs.get("audio_pool_step", 5)
1019
+
1020
+ # initialize the streaming audio processor
1021
+ self._streaming_mel_processor = None
1022
+ if audio_processor is not None:
1023
+ self._init_streaming_processor()
1024
+
1025
+ def get_audio_placeholder(
1026
+ self,
1027
+ audio_lens: int,
1028
+ chunk_input: bool = True,
1029
+ chunk_length: int = 1,
1030
+ ) -> str:
1031
+ """
1032
+ Public method to get audio placeholder string for vLLM integration.
1033
+
1034
+ Args:
1035
+ audio_lens: Length of audio in samples
1036
+ chunk_input: Whether to use chunked processing
1037
+ chunk_length: Chunk length in seconds
1038
+
1039
+ Returns:
1040
+ Audio placeholder string
1041
+ """
1042
+ pool_step = self.pool_step
1043
+ feature_lens = math.ceil(audio_lens / self.audio_processor.hop_length)
1044
+
1045
+ feature_lens = (feature_lens - 1) // 2 + 1
1046
+ output_lens = (feature_lens - pool_step) // pool_step + 1
1047
+
1048
+ if chunk_input:
1049
+ fbank_feat_in_chunk = int(chunk_length * 100)
1050
+ cnn_feat_in_chunk = (fbank_feat_in_chunk - 1) // 2 + 1
1051
+ audio_embeds_in_chunk = (cnn_feat_in_chunk - pool_step) // pool_step + 1
1052
+ num_audio_chunks = (output_lens + audio_embeds_in_chunk - 1) // audio_embeds_in_chunk
1053
+
1054
+ place_holders = ""
1055
+ total_unk_len = 0
1056
+ for _ in range(num_audio_chunks):
1057
+ unk_len = min(audio_embeds_in_chunk, output_lens - total_unk_len)
1058
+ place_holders += self.tokenizer.audio_start + "<unk>" * unk_len + self.tokenizer.audio_end
1059
+ total_unk_len += unk_len
1060
+ audio_placeholder = place_holders
1061
+ else:
1062
+ audio_placeholder = self.tokenizer.audio_start + "<unk>" * output_lens + self.tokenizer.audio_end
1063
+
1064
+ return audio_placeholder
1065
+
1066
+ def _init_streaming_processor(
1067
+ self,
1068
+ chunk_ms: int = 100,
1069
+ cnn_redundancy_ms: int = 0,
1070
+ *,
1071
+ mode: str = "exact",
1072
+ first_chunk_ms: Optional[int] = None,
1073
+ enable_sliding_window: bool = False,
1074
+ slide_trigger_seconds: float = 30.0,
1075
+ slide_stride_seconds: float = 10.0,
1076
+ ):
1077
+ """Initialize the streaming processor
1078
+
1079
+ Args:
1080
+ chunk_ms: Chunk size in milliseconds, also the sliding step.
1081
+ cnn_redundancy_ms: CNN boundary redundancy in milliseconds (before and after), 0 means standard mode.
1082
+ mode: streaming processing mode, currently only supports "exact"
1083
+ first_chunk_ms: the size of the first chunk (milliseconds), if not specified, it is the same as chunk_ms
1084
+ enable_sliding_window: whether to enable sliding window (trigger mode)
1085
+ slide_trigger_seconds: trigger threshold for sliding window in seconds
1086
+ slide_stride_seconds: stride for sliding window in seconds
1087
+ """
1088
+ if mode == "exact":
1089
+ self._streaming_mel_processor = StreamingMelProcessorExact(
1090
+ feature_extractor=self.audio_processor,
1091
+ chunk_ms=chunk_ms,
1092
+ first_chunk_ms=first_chunk_ms,
1093
+ sample_rate=16000,
1094
+ cnn_redundancy_ms=cnn_redundancy_ms,
1095
+ enable_sliding_window=enable_sliding_window,
1096
+ slide_trigger_seconds=slide_trigger_seconds,
1097
+ slide_stride_seconds=slide_stride_seconds,
1098
+ )
1099
+ else:
1100
+ raise ValueError(f"Unsupported mode: {mode}, only 'exact' is supported")
1101
+ self._streaming_mode = mode if mode in ["exact"] else ("exact")
1102
+
1103
+ def set_streaming_mode(
1104
+ self,
1105
+ mode: str = "exact",
1106
+ chunk_ms: int = 100,
1107
+ cnn_redundancy_ms: int = 0,
1108
+ *,
1109
+ first_chunk_ms: Optional[int] = None,
1110
+ enable_sliding_window: bool = False,
1111
+ slide_trigger_seconds: float = 30.0,
1112
+ slide_stride_seconds: float = 10.0,
1113
+ ):
1114
+ """Set streaming processing mode
1115
+
1116
+ Args:
1117
+ mode: streaming processing mode, currently only supports "exact"
1118
+ chunk_ms: chunk size in milliseconds, also the sliding step.
1119
+ cnn_redundancy_ms: CNN boundary redundancy in milliseconds (before and after), 0 means standard mode.
1120
+ first_chunk_ms: the size of the first chunk (milliseconds), if not specified, it is the same as chunk_ms
1121
+ enable_sliding_window: whether to enable sliding window (trigger mode)
1122
+ slide_trigger_seconds: trigger threshold for sliding window in seconds
1123
+ slide_stride_seconds: stride for sliding window in seconds
1124
+ """
1125
+ if self.audio_processor is None:
1126
+ raise ValueError("audio_processor is not set, cannot initialize the streaming processor")
1127
+ self._init_streaming_processor(
1128
+ chunk_ms=chunk_ms,
1129
+ cnn_redundancy_ms=cnn_redundancy_ms,
1130
+ mode=mode,
1131
+ first_chunk_ms=first_chunk_ms,
1132
+ enable_sliding_window=enable_sliding_window,
1133
+ slide_trigger_seconds=slide_trigger_seconds,
1134
+ slide_stride_seconds=slide_stride_seconds,
1135
+ )
1136
+
1137
+ def process_image(
1138
+ self,
1139
+ images: Optional[ImageInput] = None,
1140
+ do_pad: bool = True,
1141
+ max_slice_nums: int = 1,
1142
+ return_tensors: str = "pt",
1143
+ ) -> MiniCPMOBatchFeature:
1144
+ """Process image data
1145
+
1146
+ Args:
1147
+ images: input images
1148
+ do_pad: whether to pad
1149
+ max_slice_nums: maximum number of slices
1150
+ return_tensors: return tensor type
1151
+ Returns:
1152
+ MiniCPMOBatchFeature object
1153
+ """
1154
+ if images is None:
1155
+ return MiniCPMOBatchFeature(data={"pixel_values": [[]], "image_sizes": [[]], "tgt_sizes": [[]]})
1156
+
1157
+ result = self.image_processor(
1158
+ images, do_pad=do_pad, max_slice_nums=max_slice_nums, return_tensors=return_tensors
1159
+ )
1160
+
1161
+ model_inputs = {
1162
+ "pixel_values": result.get("pixel_values", [[]]),
1163
+ "image_sizes": result.get("image_sizes", [[]]),
1164
+ "tgt_sizes": result.get("tgt_sizes", [[]]),
1165
+ }
1166
+
1167
+ return MiniCPMOBatchFeature(data=model_inputs)
1168
+
1169
+ def process_audio(
1170
+ self,
1171
+ audios: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
1172
+ sampling_rate: int = 16000,
1173
+ regroup_to_seconds: Optional[int] = None,
1174
+ fps: int = 100,
1175
+ ) -> MiniCPMOBatchFeature:
1176
+ """Process audio data in batch
1177
+
1178
+ Args:
1179
+ audios: audio data
1180
+ sampling_rate: sampling rate
1181
+ regroup_to_seconds: regroup duration in seconds
1182
+ fps: frames per second
1183
+ Returns:
1184
+ MiniCPMOBatchFeature object
1185
+ """
1186
+ if audios is None:
1187
+ return MiniCPMOBatchFeature(data={"audio_features": [], "audio_feature_lens": []})
1188
+
1189
+ audio_features, audio_feature_lens = process_audio_batch(
1190
+ audios=audios,
1191
+ feature_extractor=self.audio_processor,
1192
+ sampling_rate=sampling_rate,
1193
+ max_duration_seconds=30,
1194
+ return_attention_mask=True,
1195
+ )
1196
+
1197
+ if regroup_to_seconds is not None and len(audio_features) > 0:
1198
+ audio_features, audio_feature_lens = regroup_audio_features(
1199
+ audio_features=audio_features,
1200
+ audio_feature_lens=audio_feature_lens,
1201
+ regroup_seconds=regroup_to_seconds,
1202
+ fps=fps,
1203
+ )
1204
+
1205
+ model_inputs = {"audio_features": audio_features, "audio_feature_lens": audio_feature_lens}
1206
+
1207
+ return MiniCPMOBatchFeature(data=model_inputs)
1208
+
1209
+ def process_audio_streaming(
1210
+ self,
1211
+ audio_chunk: np.ndarray,
1212
+ reset: bool = False,
1213
+ return_batch_feature: bool = False,
1214
+ is_last_chunk: bool = False,
1215
+ ) -> Union[Tuple[torch.Tensor, dict], MiniCPMOBatchFeature]:
1216
+ """Process audio chunk in streaming
1217
+
1218
+ Args:
1219
+ audio_chunk: audio data chunk (any audio, e.g. first process 125ms, then process 100ms)
1220
+ reset: whether to reset the processor state
1221
+ return_batch_feature: whether to return MiniCPMOBatchFeature format (consistent with process_audio)
1222
+ Returns:
1223
+ If return_batch_feature=False:
1224
+ (audio_features, info)
1225
+ - audio_features: [1, 80, n_frames] mel features
1226
+ - info: processing information dictionary
1227
+ If return_batch_feature=True:
1228
+ MiniCPMOBatchFeature object, containing:
1229
+ - audio_features: [1, 80, n_frames] mel features
1230
+ - audio_feature_lens: [tensor([n_frames])]
1231
+ - info: processing information (as an extra attribute)
1232
+ """
1233
+ if self._streaming_mel_processor is None:
1234
+ raise ValueError("Streaming processor not initialized, please ensure audio_processor is set")
1235
+
1236
+ if reset:
1237
+ self._streaming_mel_processor.reset()
1238
+
1239
+ # process chunk
1240
+ mel_features, info = self._streaming_mel_processor.process(audio_chunk, is_last_chunk=is_last_chunk)
1241
+
1242
+ # determine the return format based on the parameters
1243
+ if return_batch_feature:
1244
+ # return the format consistent with process_audio
1245
+ # note: info returns emitted_frames, which represents the actual output frames
1246
+ n_frames = info.get("emitted_frames", mel_features.shape[-1])
1247
+ model_inputs = {
1248
+ "audio_features": mel_features,
1249
+ "audio_feature_lens": [torch.tensor([n_frames])],
1250
+ "streaming_info": info, # add streaming processing information
1251
+ }
1252
+ return MiniCPMOBatchFeature(data=model_inputs)
1253
+ else:
1254
+ return mel_features, info
1255
+
1256
+ def reset_streaming(self):
1257
+ if self._streaming_mel_processor is not None:
1258
+ self._streaming_mel_processor.reset()
1259
+
1260
+ def get_streaming_chunk_size(self) -> int:
1261
+ if self._streaming_mel_processor is None:
1262
+ raise ValueError("Streaming processor not initialized")
1263
+ return self._streaming_mel_processor.get_chunk_size()
1264
+
1265
+ def configure_streaming(
1266
+ self,
1267
+ chunk_ms: int = 100,
1268
+ enable_sliding_window: bool = False,
1269
+ slide_trigger_seconds: float = 30.0,
1270
+ slide_stride_seconds: float = 10.0,
1271
+ ):
1272
+ """Configure streaming processor parameters
1273
+
1274
+ Args:
1275
+ chunk_ms: chunk size in milliseconds
1276
+ enable_sliding_window: whether to enable sliding window (trigger mode)
1277
+ slide_trigger_seconds: trigger threshold for sliding window in seconds
1278
+ slide_stride_seconds: stride for sliding window in seconds
1279
+ """
1280
+ if self.audio_processor is None:
1281
+ raise ValueError("audio_processor is not set")
1282
+
1283
+ self._init_streaming_processor(
1284
+ chunk_ms=chunk_ms,
1285
+ enable_sliding_window=enable_sliding_window,
1286
+ slide_trigger_seconds=slide_trigger_seconds,
1287
+ slide_stride_seconds=slide_stride_seconds,
1288
+ )
1289
+
1290
+ def get_streaming_config(self) -> dict:
1291
+ if self._streaming_mel_processor is None:
1292
+ return {}
1293
+ return self._streaming_mel_processor.get_config()
1294
+
1295
+ def get_streaming_state(self) -> dict:
1296
+ if self._streaming_mel_processor is None:
1297
+ return {}
1298
+ return self._streaming_mel_processor.get_state()
1299
+
1300
+ def get_streaming_snapshot(self) -> dict:
1301
+ if self._streaming_mel_processor is None:
1302
+ return {}
1303
+ return self._streaming_mel_processor.get_snapshot()
1304
+
1305
+ def restore_streaming_snapshot(self, snapshot: dict) -> None:
1306
+ if self._streaming_mel_processor is None:
1307
+ return
1308
+ if not snapshot:
1309
+ return
1310
+ self._streaming_mel_processor.restore_snapshot(snapshot)
1311
+
1312
+ def __call__(
1313
+ self,
1314
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
1315
+ images: ImageInput = None,
1316
+ audios: Union[np.ndarray, List[np.ndarray], List[List[np.ndarray]]] = None,
1317
+ audio_parts: Optional[list] = None,
1318
+ max_length: Optional[int] = None,
1319
+ do_pad: Optional[bool] = True,
1320
+ max_slice_nums: int = None,
1321
+ use_image_id: bool = True,
1322
+ stream_input: bool = False,
1323
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
1324
+ sampling_rate: Optional[int] = 16000,
1325
+ online_streaming: bool = False,
1326
+ audio_chunk_idx: int = 0,
1327
+ is_last_chunk: bool = False,
1328
+ **kwargs,
1329
+ ) -> MiniCPMOBatchFeature:
1330
+ if images is not None:
1331
+ image_inputs = self.process_image(
1332
+ images=images, do_pad=do_pad, max_slice_nums=max_slice_nums, return_tensors=return_tensors
1333
+ )
1334
+ else:
1335
+ image_inputs = None
1336
+
1337
+ audio_features, audio_feature_lens, audio_phs = self.audio_feature_extract(
1338
+ audios,
1339
+ audio_parts,
1340
+ stream_input,
1341
+ sampling_rate,
1342
+ online_streaming=online_streaming,
1343
+ is_last_chunk=is_last_chunk,
1344
+ )
1345
+
1346
+ model_inputs = self._convert_omni_to_inputs(
1347
+ image_inputs,
1348
+ audio_phs,
1349
+ text,
1350
+ max_slice_nums=max_slice_nums,
1351
+ use_image_id=use_image_id,
1352
+ max_length=max_length,
1353
+ **kwargs,
1354
+ )
1355
+
1356
+ model_inputs["audio_features"] = audio_features
1357
+ model_inputs["audio_feature_lens"] = audio_feature_lens
1358
+
1359
+ result = MiniCPMOBatchFeature(data={**model_inputs})
1360
+
1361
+ if online_streaming:
1362
+ result.use_extra_context = True
1363
+ result.prefix_extra_frames = 0 if audio_chunk_idx == 0 else 2
1364
+ result.suffix_extra_frames = 2
1365
+ result.chunk_idx = audio_chunk_idx
1366
+
1367
+ return result
1368
+
1369
+ def audio_feature_extract(
1370
+ self,
1371
+ audios: Union[np.ndarray, List[np.ndarray], List[List[np.ndarray]], None] = None,
1372
+ audio_parts: Optional[list] = None,
1373
+ stream_input: Optional[bool] = False,
1374
+ sampling_rate: Optional[int] = None,
1375
+ chunk_length: Optional[int] = 1,
1376
+ online_streaming: bool = False,
1377
+ is_last_chunk: bool = False,
1378
+ **kwargs,
1379
+ ):
1380
+ if audios is None:
1381
+ return [], [], []
1382
+
1383
+ if isinstance(audios, np.ndarray):
1384
+ audios_list = [[audios]]
1385
+ elif isinstance(audios[0], np.ndarray):
1386
+ audios_list = [audios]
1387
+ else:
1388
+ audios_list = audios
1389
+
1390
+ if audio_parts is not None:
1391
+ assert len(audio_parts) == len(audios_list)
1392
+ for parts, audios in zip(audio_parts, audios_list):
1393
+ assert len(parts) == len(audios)
1394
+
1395
+ audio_feature_lens_list = []
1396
+ audio_ph_list = []
1397
+ audio_features_all = []
1398
+
1399
+ # audio placeholder not dependent on audio_parts
1400
+ for audios in audios_list:
1401
+ if audios:
1402
+ audio_ph_list.append(
1403
+ [
1404
+ self.get_audio_placeholder(len(a), chunk_input=stream_input, chunk_length=chunk_length)
1405
+ for a in audios
1406
+ ]
1407
+ )
1408
+ else:
1409
+ audio_ph_list.append([])
1410
+
1411
+ for idx, audios in enumerate(audios_list):
1412
+ if audio_parts is not None:
1413
+ # same audio part merge
1414
+ audio_part = audio_parts[idx]
1415
+ merge_audio = []
1416
+ cur_audio = []
1417
+ for aid, (part, audio) in enumerate(zip(audio_part, audios)):
1418
+ if aid == 0 or audio_part[aid] == audio_part[aid - 1]:
1419
+ cur_audio.append(audio)
1420
+ else:
1421
+ merge_audio.append(np.hstack(cur_audio))
1422
+ cur_audio = [audio]
1423
+ if cur_audio:
1424
+ merge_audio.append(np.hstack(cur_audio))
1425
+ else:
1426
+ merge_audio = audios
1427
+
1428
+ # If the audio exceeds 30 seconds, split it into chunks every 30 seconds.
1429
+ final_merge_audio = []
1430
+ max_audio_inp_len = 30 * sampling_rate
1431
+ for audio in merge_audio:
1432
+ if len(audio) <= max_audio_inp_len:
1433
+ final_merge_audio.append(audio)
1434
+ else:
1435
+ for i in range(math.ceil(len(audio) / max_audio_inp_len)):
1436
+ final_merge_audio.append(audio[i * max_audio_inp_len : (i + 1) * max_audio_inp_len])
1437
+
1438
+ audio_feature_lens = []
1439
+
1440
+ if audios:
1441
+ if online_streaming:
1442
+ # online streaming: only support single audio, directly use process_audio_streaming return format
1443
+ assert (
1444
+ len(final_merge_audio) == 1
1445
+ ), f"online streaming mode only supports single audio, currently there are {len(final_merge_audio)}"
1446
+ audio = final_merge_audio[0]
1447
+ result = self.process_audio_streaming(
1448
+ audio, reset=False, return_batch_feature=True, is_last_chunk=is_last_chunk
1449
+ )
1450
+ audio_features_all.append(
1451
+ result["audio_features"].squeeze(0)
1452
+ ) # [1, 80, T] -> [80, T], keep consistent with batch processing
1453
+ audio_feature_lens_list.append(result["audio_feature_lens"][0])
1454
+ else:
1455
+ # batch processing
1456
+ audio_inputs = self.audio_processor(
1457
+ final_merge_audio,
1458
+ sampling_rate=sampling_rate,
1459
+ return_attention_mask=True,
1460
+ padding="max_length",
1461
+ return_tensors="pt",
1462
+ **kwargs,
1463
+ )
1464
+ audio_feature = audio_inputs["input_features"]
1465
+ actual_lens = audio_inputs["attention_mask"].sum(dim=1)
1466
+
1467
+ for feat, lens in zip(audio_feature, actual_lens):
1468
+ audio_features_all.append(feat[:, :lens])
1469
+ audio_feature_lens.append(lens)
1470
+
1471
+ audio_feature_lens = torch.hstack(audio_feature_lens)
1472
+ audio_feature_lens_list.append(audio_feature_lens)
1473
+ else:
1474
+ audio_feature_lens_list.append([])
1475
+
1476
+ if audio_features_all:
1477
+ audio_features = [i.permute(1, 0) for i in audio_features_all]
1478
+ audio_features = torch.nn.utils.rnn.pad_sequence(
1479
+ audio_features, batch_first=True, padding_value=0.0
1480
+ ).permute(0, 2, 1)
1481
+ else:
1482
+ audio_features = []
1483
+
1484
+ return audio_features, audio_feature_lens_list, audio_ph_list
1485
+
1486
+ def _convert(self, input_str, max_inp_length: Optional[int] = None):
1487
+ old_input_ids = self.tokenizer.encode(input_str)
1488
+
1489
+ listen_token_id = self.tokenizer.convert_tokens_to_ids("<|listen|>")
1490
+ input_ids = []
1491
+ for token in old_input_ids:
1492
+ if token != listen_token_id:
1493
+ input_ids.append(token)
1494
+
1495
+ if max_inp_length is not None:
1496
+ input_ids = input_ids[:max_inp_length]
1497
+ input_ids = torch.tensor(input_ids, dtype=torch.int32)
1498
+
1499
+ ## image bound
1500
+ start_cond = (input_ids == self.tokenizer.im_start_id) | (input_ids == self.tokenizer.slice_start_id)
1501
+ end_cond = (input_ids == self.tokenizer.im_end_id) | (input_ids == self.tokenizer.slice_end_id)
1502
+
1503
+ image_start_idx = torch.where(start_cond)[0]
1504
+ image_start_idx += 1
1505
+ image_end_idx = torch.where(end_cond)[0]
1506
+
1507
+ valid_image_nums = max(len(image_start_idx), len(image_end_idx))
1508
+
1509
+ image_bounds = torch.hstack(
1510
+ [
1511
+ image_start_idx[:valid_image_nums].unsqueeze(-1),
1512
+ image_end_idx[:valid_image_nums].unsqueeze(-1),
1513
+ ]
1514
+ )
1515
+
1516
+ ## audio bound
1517
+ audio_start_idx = torch.where(input_ids == self.tokenizer.audio_start_id)[0]
1518
+ audio_end_idx = torch.where(input_ids == self.tokenizer.audio_end_id)[0]
1519
+ assert len(audio_start_idx) == len(audio_end_idx)
1520
+ audio_bounds = torch.hstack([(audio_start_idx + 1).unsqueeze(-1), audio_end_idx.unsqueeze(-1)])
1521
+
1522
+ spk_start_idx = torch.where(input_ids == self.tokenizer.spk_start_id)[0]
1523
+ spk_end_idx = torch.where(input_ids == self.tokenizer.spk_end_id)[0]
1524
+ assert len(spk_start_idx) == len(spk_end_idx)
1525
+ spk_bounds = torch.hstack([(spk_start_idx + 1).unsqueeze(-1), spk_end_idx.unsqueeze(-1)])
1526
+
1527
+ return input_ids, image_bounds, audio_bounds, spk_bounds
1528
+
1529
+ def _convert_omni_to_inputs(
1530
+ self,
1531
+ images,
1532
+ audio_phs,
1533
+ texts: Union[str, List[str]],
1534
+ truncation=None,
1535
+ max_length=None,
1536
+ max_slice_nums=None,
1537
+ use_image_id=None,
1538
+ return_tensors=None,
1539
+ **kwargs,
1540
+ ):
1541
+ if images is None and audio_phs is None:
1542
+ model_inputs = self.tokenizer(
1543
+ texts, return_tensors=return_tensors, truncation=truncation, max_length=max_length, **kwargs
1544
+ )
1545
+ return MiniCPMOBatchFeature(data={**model_inputs})
1546
+
1547
+ image_pattern = "<image>./</image>"
1548
+ audio_pattern = "<audio>./</audio>"
1549
+ split_pattern = f"({image_pattern}|{audio_pattern})"
1550
+
1551
+ if isinstance(texts, str):
1552
+ texts = [texts]
1553
+
1554
+ bs = len(texts)
1555
+ if images is not None:
1556
+ images, image_sizes, tgt_sizes = images["pixel_values"], images["image_sizes"], images["tgt_sizes"]
1557
+ else:
1558
+ images, image_sizes, tgt_sizes = [[]] * bs, [[]] * bs, [[]] * bs
1559
+
1560
+ input_ids_list = []
1561
+ image_bounds_list = []
1562
+ audio_bounds_list = []
1563
+ spk_bounds_list = []
1564
+
1565
+ for index, text in enumerate(texts):
1566
+ text_chunks = re.split(split_pattern, text)
1567
+
1568
+ image_tags = re.findall(image_pattern, text)
1569
+ audio_tags = re.findall(audio_pattern, text)
1570
+
1571
+ if image_tags:
1572
+ assert images is not None
1573
+ assert len(image_tags) == len(image_sizes[index])
1574
+ if audio_tags:
1575
+ assert audio_phs is not None
1576
+ assert len(audio_tags) == len(audio_phs[index])
1577
+
1578
+ image_id = 0
1579
+ audio_id = 0
1580
+ for i, chunk in enumerate(text_chunks):
1581
+ if chunk == image_pattern:
1582
+ image_placeholder = self.image_processor.get_slice_image_placeholder(
1583
+ image_sizes[index][image_id], image_id, max_slice_nums, use_image_id
1584
+ )
1585
+ image_id += 1
1586
+ text_chunks[i] = image_placeholder
1587
+ elif chunk == audio_pattern:
1588
+ audio_placeholder = audio_phs[index][audio_id]
1589
+ audio_id += 1
1590
+ text_chunks[i] = audio_placeholder
1591
+
1592
+ final_text = "".join(text_chunks)
1593
+ input_ids, image_bounds, audio_bounds, spk_bounds = self._convert(final_text, max_length)
1594
+
1595
+ input_ids_list.append(input_ids)
1596
+ image_bounds_list.append(image_bounds)
1597
+ audio_bounds_list.append(audio_bounds)
1598
+ spk_bounds_list.append(spk_bounds)
1599
+
1600
+ padded_input_ids, padding_lengths = self.pad(input_ids_list, padding_side="left")
1601
+ attention_mask = torch.ones_like(padded_input_ids, dtype=torch.bool)
1602
+ for i, length in enumerate(padding_lengths):
1603
+ image_bounds_list[i] = image_bounds_list[i] + length
1604
+ audio_bounds_list[i] = audio_bounds_list[i] + length
1605
+ spk_bounds_list[i] = spk_bounds_list[i] + length
1606
+ attention_mask[i, :length] = False
1607
+
1608
+ data = {
1609
+ "input_ids": padded_input_ids,
1610
+ "attention_mask": attention_mask,
1611
+ "pixel_values": images,
1612
+ "image_sizes": image_sizes,
1613
+ "image_bound": image_bounds_list,
1614
+ "tgt_sizes": tgt_sizes,
1615
+ "audio_bounds": audio_bounds_list,
1616
+ "spk_bounds": spk_bounds_list,
1617
+ }
1618
+
1619
+ return data
1620
+
1621
+ def pad(self, inputs, max_length=None, padding_value=0, padding_side="left"):
1622
+ items = []
1623
+ if isinstance(inputs[0], list):
1624
+ assert isinstance(inputs[0][0], torch.Tensor)
1625
+ for it in inputs:
1626
+ for tr in it:
1627
+ items.append(tr)
1628
+ else:
1629
+ assert isinstance(inputs[0], torch.Tensor)
1630
+ items = inputs
1631
+
1632
+ batch_size = len(items)
1633
+ shape = items[0].shape
1634
+ dim = len(shape)
1635
+ assert dim <= 2
1636
+ if max_length is None:
1637
+ max_length = 0
1638
+ max_length = max(max_length, max(item.shape[-1] for item in items))
1639
+ min_length = min(item.shape[-1] for item in items)
1640
+ dtype = items[0].dtype
1641
+
1642
+ if dim == 0:
1643
+ return torch.stack([item for item in items], dim=0), [0]
1644
+ elif dim == 1:
1645
+ if max_length == min_length:
1646
+ return torch.stack([item for item in items], dim=0), [0] * batch_size
1647
+ tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
1648
+ else:
1649
+ tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value
1650
+
1651
+ padding_length = []
1652
+ for i, item in enumerate(items):
1653
+ if dim == 1:
1654
+ if padding_side == "left":
1655
+ tensor[i, -len(item) :] = item.clone()
1656
+ else:
1657
+ tensor[i, : len(item)] = item.clone()
1658
+ elif dim == 2:
1659
+ if padding_side == "left":
1660
+ tensor[i, -len(item) :, :] = item.clone()
1661
+ else:
1662
+ tensor[i, : len(item), :] = item.clone()
1663
+ padding_length.append(tensor.shape[-1] - len(item))
1664
+
1665
+ return tensor, padding_length
processor_config.json ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "audio_processor": {
3
+ "audio_pool_step": 5,
4
+ "auto_map": {
5
+ "AutoFeatureExtractor": "processing_minicpmo.MiniCPMAAudioProcessor",
6
+ "AutoImageProcessor": "processing_minicpmo.MiniCPMVImageProcessor",
7
+ "AutoProcessor": "processing_minicpmo.MiniCPMOProcessor"
8
+ },
9
+ "chunk_length": 30,
10
+ "dither": 0.0,
11
+ "dynamic_log_norm": true,
12
+ "dynamic_range_db": 8.0,
13
+ "feature_extractor_type": "MiniCPMAAudioProcessor",
14
+ "feature_size": 80,
15
+ "hop_length": 160,
16
+ "im_end": "</image>",
17
+ "im_id_end": "</image_id>",
18
+ "im_id_start": "<image_id>",
19
+ "im_start": "<image>",
20
+ "image_feature_size": 64,
21
+ "image_processor_type": "MiniCPMVImageProcessor",
22
+ "log_floor_db": -10.0,
23
+ "max_slice_nums": 9,
24
+ "n_fft": 400,
25
+ "n_samples": 480000,
26
+ "nb_max_frames": 3000,
27
+ "norm_mean": [
28
+ 0.5,
29
+ 0.5,
30
+ 0.5
31
+ ],
32
+ "norm_std": [
33
+ 0.5,
34
+ 0.5,
35
+ 0.5
36
+ ],
37
+ "padding_side": "right",
38
+ "padding_value": 0.0,
39
+ "patch_size": 14,
40
+ "return_attention_mask": false,
41
+ "sampling_rate": 16000,
42
+ "scale_resolution": 448,
43
+ "slice_end": "</slice>",
44
+ "slice_mode": true,
45
+ "slice_start": "<slice>",
46
+ "unk": "<unk>",
47
+ "use_image_id": true,
48
+ "version": 4.5
49
+ },
50
+ "auto_map": {
51
+ "AutoProcessor": "processing_minicpmo.MiniCPMOProcessor"
52
+ },
53
+ "image_processor": {
54
+ "audio_pool_step": 5,
55
+ "auto_map": {
56
+ "AutoFeatureExtractor": "processing_minicpmo.MiniCPMAAudioProcessor",
57
+ "AutoImageProcessor": "processing_minicpmo.MiniCPMVImageProcessor",
58
+ "AutoProcessor": "processing_minicpmo.MiniCPMOProcessor"
59
+ },
60
+ "im_end": "</image>",
61
+ "im_end_token": "</image>",
62
+ "im_id_end": "</image_id>",
63
+ "im_id_start": "<image_id>",
64
+ "im_start": "<image>",
65
+ "im_start_token": "<image>",
66
+ "image_feature_size": 64,
67
+ "image_processor_type": "MiniCPMVImageProcessor",
68
+ "max_slice_nums": 9,
69
+ "mean": [
70
+ 0.5,
71
+ 0.5,
72
+ 0.5
73
+ ],
74
+ "norm_mean": [
75
+ 0.5,
76
+ 0.5,
77
+ 0.5
78
+ ],
79
+ "norm_std": [
80
+ 0.5,
81
+ 0.5,
82
+ 0.5
83
+ ],
84
+ "patch_size": 14,
85
+ "scale_resolution": 448,
86
+ "slice_end": "</slice>",
87
+ "slice_end_token": "</slice>",
88
+ "slice_mode": true,
89
+ "slice_start": "<slice>",
90
+ "slice_start_token": "<slice>",
91
+ "std": [
92
+ 0.5,
93
+ 0.5,
94
+ 0.5
95
+ ],
96
+ "unk": "<unk>",
97
+ "unk_token": "<unk>",
98
+ "use_image_id": true,
99
+ "version": 4.5
100
+ },
101
+ "processor_class": "MiniCPMOProcessor"
102
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<unk>",
4
+ "<image>",
5
+ "</image>",
6
+ "<ref>",
7
+ "</ref>",
8
+ "<box>",
9
+ "</box>",
10
+ "<quad>",
11
+ "</quad>",
12
+ "<point>",
13
+ "</point>",
14
+ "<slice>",
15
+ "</slice>",
16
+ "<image_id>",
17
+ "</image_id>",
18
+ "<unit>",
19
+ "</unit>",
20
+ "<answer>",
21
+ "</answer>",
22
+ "<focus>",
23
+ "</focus>",
24
+ "<line>",
25
+ "</line>",
26
+ "<perception>",
27
+ "</perception>",
28
+ "<source_image>",
29
+ "</source_image>",
30
+ "<image_save_to>",
31
+ "</image_save_to>",
32
+ "<|audio_start|>",
33
+ "<|audio|>",
34
+ "<|audio_end|>",
35
+ "<|spk_bos|>",
36
+ "<|spk|>",
37
+ "<|spk_eos|>",
38
+ "<|tts_bos|>",
39
+ "<|tts_eos|>",
40
+ "<|listen|>",
41
+ "<|speak|>",
42
+ "<|interrupt|>",
43
+ "<|vad_start|>",
44
+ "<|vad_end|>",
45
+ "<|emotion_start|>",
46
+ "<|emotion_end|>",
47
+ "<|speed_start|>",
48
+ "<|speed_end|>",
49
+ "<|pitch_start|>",
50
+ "<|pitch_end|>",
51
+ "<|turn_bos|>",
52
+ "<|turn_eos|>",
53
+ "<|chunk_eos|>",
54
+ "<|chunk_bos|>",
55
+ "<|chunk_tts_bos|>",
56
+ "<|chunk_tts_eos|>",
57
+ "<|tts_pad|>",
58
+ "<|timbre_7|>",
59
+ "<|timbre_8|>",
60
+ "<|timbre_9|>",
61
+ "<|timbre_10|>",
62
+ "<|timbre_11|>",
63
+ "<|timbre_12|>",
64
+ "<|timbre_13|>",
65
+ "<|timbre_14|>",
66
+ "<|timbre_15|>",
67
+ "<|timbre_16|>",
68
+ "<|timbre_17|>",
69
+ "<|timbre_18|>",
70
+ "<|timbre_19|>",
71
+ "<|timbre_20|>",
72
+ "<|timbre_21|>",
73
+ "<|timbre_22|>",
74
+ "<|timbre_23|>",
75
+ "<|timbre_24|>",
76
+ "<|timbre_25|>",
77
+ "<|timbre_26|>",
78
+ "<|timbre_27|>",
79
+ "<|timbre_28|>",
80
+ "<|timbre_29|>",
81
+ "<|timbre_30|>",
82
+ "<|timbre_31|>"
83
+ ],
84
+ "bos_token": "<|im_start|>",
85
+ "eos_token": "<|im_end|>",
86
+ "pad_token": "<|endoftext|>",
87
+ "unk_token": "<unk>"
88
+ }
tokenization_minicpmo_fast.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright 2026 The OpenBMB Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from typing import List
18
+
19
+ from transformers import Qwen2TokenizerFast
20
+
21
+
22
+ class MiniCPMOTokenizerFast(Qwen2TokenizerFast):
23
+ def __init__(self, **kwargs):
24
+ self._bad_token_ids = kwargs.pop("bad_token_ids", [])
25
+
26
+ super().__init__(**kwargs)
27
+
28
+ # image
29
+ self.im_start = "<image>"
30
+ self.im_end = "</image>"
31
+ self.ref_start = "<ref>"
32
+ self.ref_end = "</ref>"
33
+ self.box_start = "<box>"
34
+ self.box_end = "</box>"
35
+ self.quad_start = "<quad>"
36
+ self.quad_end = "</quad>"
37
+ self.slice_start = "<slice>"
38
+ self.slice_end = "</slice>"
39
+ self.im_id_start = "<image_id>"
40
+ self.im_id_end = "</image_id>"
41
+
42
+ # audio
43
+ self.audio_start = "<|audio_start|>"
44
+ self.audio_end = "<|audio_end|>"
45
+ self.spk_start = "<|spk_bos|>"
46
+ self.spk_end = "<|spk_eos|>"
47
+ self.tts_start = "<|tts_bos|>"
48
+ self.tts_end = "<|tts_eos|>"
49
+
50
+ @property
51
+ def eos_id(self):
52
+ return self.eos_token_id
53
+
54
+ @property
55
+ def bos_id(self):
56
+ return self.bos_token_id
57
+
58
+ @property
59
+ def unk_id(self):
60
+ return self.unk_token_id
61
+
62
+ @property
63
+ def im_start_id(self):
64
+ return self.convert_tokens_to_ids(self.im_start)
65
+
66
+ @property
67
+ def im_end_id(self):
68
+ return self.convert_tokens_to_ids(self.im_end)
69
+
70
+ @property
71
+ def slice_start_id(self):
72
+ return self.convert_tokens_to_ids(self.slice_start)
73
+
74
+ @property
75
+ def slice_end_id(self):
76
+ return self.convert_tokens_to_ids(self.slice_end)
77
+
78
+ @property
79
+ def im_id_start_id(self):
80
+ return self.convert_tokens_to_ids(self.im_id_start)
81
+
82
+ @property
83
+ def im_id_end_id(self):
84
+ return self.convert_tokens_to_ids(self.im_id_end)
85
+
86
+ @property
87
+ def audio_start_id(self):
88
+ return self.convert_tokens_to_ids(self.audio_start)
89
+
90
+ @property
91
+ def audio_end_id(self):
92
+ return self.convert_tokens_to_ids(self.audio_end)
93
+
94
+ @property
95
+ def spk_start_id(self):
96
+ return self.convert_tokens_to_ids(self.spk_start)
97
+
98
+ @property
99
+ def spk_end_id(self):
100
+ return self.convert_tokens_to_ids(self.spk_end)
101
+
102
+ @property
103
+ def tts_start_id(self):
104
+ return self.convert_tokens_to_ids(self.tts_start)
105
+
106
+ @property
107
+ def tts_end_id(self):
108
+ return self.convert_tokens_to_ids(self.tts_end)
109
+
110
+ @staticmethod
111
+ def escape(text: str) -> str:
112
+ return text
113
+
114
+ @staticmethod
115
+ def unescape(text: str) -> str:
116
+ return text
117
+
118
+ @property
119
+ def bad_token_ids(self) -> List[int]:
120
+ return self._bad_token_ids
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:66664f87759d9e829e7ef0ded96976727374dcd7ca6f3ae9bfe89bbda541e5af
3
+ size 11437708
tokenizer_config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "auto_map": {
4
+ "AutoProcessor": "processing_minicpmo.MiniCPMOProcessor",
5
+ "AutoTokenizer": [
6
+ null,
7
+ "tokenization_minicpmo_fast.MiniCPMOTokenizerFast"
8
+ ]
9
+ },
10
+ "backend": "tokenizers",
11
+ "bos_token": "<|im_start|>",
12
+ "clean_up_tokenization_spaces": false,
13
+ "eos_token": "<|im_end|>",
14
+ "errors": "replace",
15
+ "is_local": true,
16
+ "model_max_length": 131072,
17
+ "pad_token": "<|endoftext|>",
18
+ "processor_class": "MiniCPMOProcessor",
19
+ "split_special_tokens": false,
20
+ "tokenizer_class": "MiniCPMOTokenizer",
21
+ "unk_token": "<unk>"
22
+ }
utils.py ADDED
@@ -0,0 +1,2417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright 2026 The OpenBMB Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import logging
18
+ from dataclasses import dataclass
19
+ from typing import Any
20
+ from typing import Dict
21
+ from typing import List
22
+ from typing import Literal
23
+ from typing import Optional
24
+ from typing import Tuple
25
+ from typing import Union
26
+
27
+ import torch
28
+ import torch.nn.functional as F
29
+ import torch.nn.utils.parametrize as P
30
+ from transformers.cache_utils import DynamicCache
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ # text
36
+ @dataclass
37
+ class GenerateChunkOutput:
38
+ chunk_token_ids: torch.Tensor
39
+ current_inputs_embeds: torch.Tensor
40
+ input_last_hidden_states: Optional[torch.Tensor] # for tts use_speaker_embedding
41
+ last_hidden_states: Optional[torch.Tensor] # for tts input feature (projector_semantic)
42
+ past_key_values: Optional[torch.Tensor]
43
+ finished: bool
44
+
45
+
46
+ class ChunkPrefillChunkGenerate:
47
+ def __init__(self, model, tokenizer, terminators):
48
+ self.tokenizer = tokenizer
49
+ self.model = model
50
+ self.terminators = terminators
51
+ self.terminators_ids = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
52
+ self.embedding_layer = self.model.get_input_embeddings()
53
+
54
+ self.forbidden_tokens = [
55
+ ":",
56
+ ":",
57
+ ";",
58
+ "#",
59
+ "“",
60
+ "”",
61
+ "‘",
62
+ "’",
63
+ "@",
64
+ "*",
65
+ "【",
66
+ "】",
67
+ "「",
68
+ "」",
69
+ "(",
70
+ ")",
71
+ "(",
72
+ ")",
73
+ "[",
74
+ "]",
75
+ "&",
76
+ "/",
77
+ "$",
78
+ ]
79
+
80
+ self.forbidden_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in self.forbidden_tokens]
81
+ bad_token_ids = getattr(tokenizer, "bad_token_ids", [])
82
+ if bad_token_ids:
83
+ self.forbidden_token_ids.extend(bad_token_ids)
84
+
85
+ @staticmethod
86
+ def prepare_generation_config(do_sample, max_new_tokens=50, min_new_tokens=0, **kwargs):
87
+ num_beams = kwargs.get("num_beams", 3)
88
+ generation_config = {
89
+ "num_beams": num_beams,
90
+ "top_p": 0.8,
91
+ "top_k": 100,
92
+ "temperature": 0.7,
93
+ "do_sample": True,
94
+ "repetition_penalty": 1.05,
95
+ }
96
+
97
+ if do_sample:
98
+ generation_config.update(
99
+ {
100
+ "top_p": 0.8,
101
+ "top_k": 100,
102
+ "temperature": 0.7,
103
+ "do_sample": True,
104
+ "repetition_penalty": 1.05,
105
+ }
106
+ )
107
+ elif num_beams > 1:
108
+ generation_config.update({"num_beams": num_beams, "repetition_penalty": 1.2, "do_sample": False})
109
+ else:
110
+ generation_config.update({"do_sample": False, "repetition_penalty": 1.05})
111
+
112
+ generation_config.update((k, kwargs[k]) for k in generation_config.keys() & kwargs.keys())
113
+ generation_config["min_new_tokens"] = min_new_tokens
114
+ generation_config["max_new_tokens"] = max_new_tokens
115
+
116
+ return generation_config
117
+
118
+ def chunk_generate(
119
+ self,
120
+ inputs_embeds: torch.Tensor,
121
+ past_key_values,
122
+ is_first_generate_chunk: bool,
123
+ chunk_size: int,
124
+ return_hidden_states: bool,
125
+ do_sample: bool,
126
+ temperature: float,
127
+ top_p: float,
128
+ top_k: int,
129
+ repetition_penalty: float = 1.05,
130
+ length_penalty: float = 1.0,
131
+ all_input_ids: Optional[torch.Tensor] = None,
132
+ ) -> GenerateChunkOutput:
133
+ """
134
+ Args:
135
+ inputs_embeds: [1, seq_len, hidden_dim], Input embeddings of current chunk.
136
+ past_key_values: [num_layers, 2, batch_size, num_heads, seq_len, head_dim], Past key values for llm.
137
+ is_first_generate_chunk: bool, Whether this is the first generate chunk.
138
+ chunk_size: int, The size of the current chunk, default is 10, and it is fixed during training.
139
+ return_hidden_states: bool Whether to return the hidden states, default is True.
140
+ do_sample: bool Whether to sample from the model, default is True.
141
+ temperature: float The temperature for the model, default is 0.7.
142
+ top_p: float The top-p for the model, default is 0.8.
143
+ top_k: int The top-k for the model, default is 100.
144
+ repetition_penalty: float, The repetition penalty for the model, default is 1.05.
145
+ length_penalty: float, The length penalty for the model, default is 1.0. Higher value means more detailed generation.
146
+ all_input_ids: Optional[torch.Tensor], The input ids for the current chunk.
147
+ """
148
+
149
+ finished = False
150
+ current_inputs_embeds = inputs_embeds.clone()
151
+ input_last_hidden_states = []
152
+ last_hidden_states = []
153
+ generated_tokens = []
154
+
155
+ for token_idx in range(chunk_size):
156
+ if is_first_generate_chunk and token_idx == 0:
157
+ # first generate chunk, prefill inputs_embeds
158
+ model_inputs = {
159
+ "inputs_embeds": current_inputs_embeds,
160
+ "past_key_values": past_key_values,
161
+ "use_cache": True,
162
+ "output_hidden_states": return_hidden_states,
163
+ }
164
+ else: # for all other cases: prefill the latest generated token
165
+ model_inputs = {
166
+ "inputs_embeds": current_inputs_embeds[:, -1:, :],
167
+ "past_key_values": past_key_values,
168
+ "use_cache": True,
169
+ "output_hidden_states": return_hidden_states,
170
+ }
171
+
172
+ with torch.no_grad():
173
+ outputs = self.model(**model_inputs)
174
+
175
+ # last token's logits
176
+ logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=inputs_embeds.device)
177
+
178
+ # forbid specific tokens decoding = model.generate@suppress_tokens
179
+ if self.forbidden_token_ids:
180
+ logits[:, self.forbidden_token_ids] = float("-inf")
181
+
182
+ past_key_values = outputs.past_key_values
183
+
184
+ PENALTY_WINDOW_SIZE = 128
185
+
186
+ # apply repetition penalty
187
+ if repetition_penalty != 1.0:
188
+ # get token ids for repetition penalty
189
+ if all_input_ids is not None:
190
+ # use global input ids (including original input and generated part)
191
+ if len(generated_tokens) > 0:
192
+ generated_token_ids = torch.cat(generated_tokens, dim=1)
193
+ current_sequence = torch.cat(
194
+ [
195
+ all_input_ids[:, -PENALTY_WINDOW_SIZE:],
196
+ generated_token_ids,
197
+ ],
198
+ dim=1,
199
+ )
200
+ else:
201
+ current_sequence = all_input_ids[:, -PENALTY_WINDOW_SIZE:]
202
+ unique_token_ids = torch.unique(current_sequence.squeeze(0))
203
+ elif len(generated_tokens) > 0:
204
+ # revert to original logic: only use generated tokens
205
+ generated_token_ids = torch.cat(generated_tokens, dim=1).squeeze(0)
206
+ unique_token_ids = torch.unique(generated_token_ids)
207
+ else:
208
+ unique_token_ids = torch.tensor([], dtype=torch.long, device=logits.device)
209
+
210
+ # apply repetition penalty
211
+ for token_id in unique_token_ids:
212
+ if logits[0, token_id] > 0:
213
+ logits[0, token_id] = logits[0, token_id] / repetition_penalty
214
+ else:
215
+ logits[0, token_id] = logits[0, token_id] * repetition_penalty
216
+
217
+ # apply length penalty, higher value means more detailed generation
218
+ if length_penalty != 1.0:
219
+ for eos_token_id in self.terminators_ids:
220
+ if logits[0, eos_token_id] > 0:
221
+ logits[0, eos_token_id] = logits[0, eos_token_id] / length_penalty
222
+ else:
223
+ logits[0, eos_token_id] = logits[0, eos_token_id] * length_penalty
224
+
225
+ # apply temperature
226
+ if temperature != 1.0:
227
+ logits = logits / temperature
228
+
229
+ if do_sample:
230
+ # Top-k filtering
231
+ if top_k > 0:
232
+ top_k_logits, top_k_indices = torch.topk(logits, min(top_k, logits.size(-1)))
233
+ logits_filtered = torch.full_like(logits, float("-inf"))
234
+ logits_filtered.scatter_(1, top_k_indices, top_k_logits)
235
+ logits = logits_filtered
236
+
237
+ # Top-p filtering
238
+ if top_p < 1.0:
239
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
240
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
241
+
242
+ # remove tokens with cumulative probability greater than top_p
243
+ sorted_indices_to_remove = cumulative_probs > top_p
244
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
245
+ sorted_indices_to_remove[..., 0] = 0
246
+
247
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
248
+ logits[indices_to_remove] = float("-inf")
249
+
250
+ # sampling
251
+ probs = F.softmax(logits, dim=-1)
252
+ next_token = torch.multinomial(probs, num_samples=1)
253
+ else:
254
+ next_token = torch.argmax(logits, dim=-1, keepdim=True)
255
+
256
+ if return_hidden_states:
257
+ if is_first_generate_chunk and token_idx == 0:
258
+ input_last_hidden_states.append(outputs.hidden_states[-1])
259
+ else:
260
+ last_hidden_states.append(outputs.hidden_states[-1])
261
+
262
+ # if terminator token, stop generating
263
+ if next_token.item() in self.terminators_ids:
264
+ finished = True
265
+ break
266
+
267
+ generated_tokens.append(next_token)
268
+
269
+ # convert new token to embeddings and concatenate
270
+ next_token_embed = self.embedding_layer(next_token)
271
+
272
+ # update inputs_embeds, add one
273
+ current_inputs_embeds = torch.cat([current_inputs_embeds, next_token_embed], dim=1)
274
+
275
+ if len(generated_tokens) > 0:
276
+ chunk_token_ids = torch.cat(generated_tokens, dim=1)
277
+ else:
278
+ # special case: if last chunk and first predict is eos token, return last token of previous chunk. return a tensor with shape (1, 0)
279
+ if finished:
280
+ chunk_token_ids = torch.zeros((1, 0), dtype=torch.long, device=current_inputs_embeds.device)
281
+ else:
282
+ raise Exception("this should not happen")
283
+
284
+ if len(last_hidden_states) > 0:
285
+ last_hidden_states = torch.cat(last_hidden_states, dim=1)
286
+ else:
287
+ # special case: if last chunk, return last token of previous chunk.
288
+ if finished:
289
+ last_hidden_states = torch.cat(last_hidden_states, dim=1)
290
+ else:
291
+ raise Exception("this should not happen")
292
+
293
+ if len(input_last_hidden_states) > 0:
294
+ input_last_hidden_states = torch.cat(input_last_hidden_states, dim=1)
295
+ else:
296
+ input_last_hidden_states = None
297
+
298
+ return GenerateChunkOutput(
299
+ chunk_token_ids=chunk_token_ids,
300
+ current_inputs_embeds=current_inputs_embeds,
301
+ input_last_hidden_states=input_last_hidden_states,
302
+ last_hidden_states=last_hidden_states,
303
+ past_key_values=past_key_values,
304
+ finished=finished,
305
+ )
306
+
307
+
308
+ def streaming_token_decoder(token_iterator, tokenizer, skip_special_tokens=False):
309
+ """
310
+ Incrementally decode tokens from an iterator, handling partial multi-byte characters.
311
+
312
+ When streaming tokens, multi-byte characters (like Chinese) may be split across multiple
313
+ tokens. Decoding partial tokens results in replacement characters (U+FFFD). This function
314
+ buffers tokens and only yields complete characters.
315
+
316
+ Args:
317
+ token_iterator: An iterator yielding (token_ids, is_finished) tuples.
318
+ token_ids can be torch.Tensor or any iterable of integers.
319
+ tokenizer: The tokenizer to use for decoding.
320
+ skip_special_tokens: Whether to skip special tokens during decoding.
321
+
322
+ Yields:
323
+ (decoded_text, is_finished) tuples where decoded_text is the new text since last yield.
324
+ """
325
+ accumulated_token_ids = []
326
+ yielded_text_len = 0
327
+
328
+ for token_ids, is_finished in token_iterator:
329
+ # Accumulate token IDs
330
+ if torch.is_tensor(token_ids):
331
+ accumulated_token_ids.extend(token_ids.reshape(-1).tolist())
332
+ else:
333
+ accumulated_token_ids.extend(list(token_ids) if hasattr(token_ids, "__iter__") else [token_ids])
334
+
335
+ # Decode all accumulated tokens
336
+ full_decoded = tokenizer.decode(accumulated_token_ids, skip_special_tokens=skip_special_tokens)
337
+
338
+ if is_finished:
339
+ # Final chunk - yield all remaining text
340
+ new_text = full_decoded[yielded_text_len:]
341
+ yield new_text, is_finished
342
+ else:
343
+ # Find safe prefix without incomplete multi-byte characters
344
+ # The replacement character '�' (U+FFFD) indicates incomplete decoding
345
+ new_text = full_decoded[yielded_text_len:]
346
+
347
+ # Hold back text ending with replacement character (incomplete UTF-8 sequence)
348
+ safe_end = len(new_text)
349
+ while safe_end > 0 and new_text[safe_end - 1] == "\ufffd":
350
+ safe_end -= 1
351
+
352
+ safe_text = new_text[:safe_end] if safe_end > 0 else ""
353
+ yielded_text_len += len(safe_text)
354
+ yield safe_text, is_finished
355
+
356
+
357
+ def torch_clone_recursive(obj):
358
+ """Recursively clone nested containers of torch.Tensors.
359
+
360
+ Supported container types: dict, list, tuple. Non-container non-Tensor
361
+ objects are returned as-is.
362
+ """
363
+ if torch.is_tensor(obj):
364
+ return obj.clone()
365
+ elif isinstance(obj, dict):
366
+ return {k: torch_clone_recursive(v) for k, v in obj.items()}
367
+ elif isinstance(obj, list):
368
+ return [torch_clone_recursive(v) for v in obj]
369
+ elif isinstance(obj, tuple):
370
+ return tuple(torch_clone_recursive(v) for v in obj)
371
+ else:
372
+ raise ValueError(f"Unsupported type: {type(obj)}")
373
+
374
+
375
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
376
+ """Rotate half the hidden dims of the input for RoPE."""
377
+ dim = x.shape[-1]
378
+ x1 = x[..., : dim // 2]
379
+ x2 = x[..., dim // 2 :]
380
+ return torch.cat((-x2, x1), dim=-1)
381
+
382
+
383
+ @dataclass
384
+ class SpeculativeSnapshot:
385
+ """Speculative snapshot for VAD speculative rollback.
386
+
387
+ Used in VAD speculative execution: creates a snapshot after streaming_prefill
388
+ and before streaming_generate. If speculation fails (user continues speaking),
389
+ the state can be restored to continue streaming_prefill.
390
+
391
+ Implementation:
392
+ - LLM KV Cache: only record length, restore by truncation (zero extra VRAM)
393
+ - Audio KV Cache: requires cloning, as generate sets it to None
394
+ - Mel processor: save full state snapshot (including buffer)
395
+ """
396
+
397
+ # KV Cache length (for truncation recovery)
398
+ llm_cache_length: int
399
+ audio_cache_length: int
400
+
401
+ # session state
402
+ new_user_msg: bool
403
+ llm_generated: bool
404
+ llm_generate_completed: bool
405
+
406
+ # Round management
407
+ next_round_id: int
408
+ pending_round_id: Optional[int]
409
+ omni_chunk_history_length: int
410
+
411
+ # TTS state (requires cloning, but usually small)
412
+ tts_last_turn_tokens: Optional[torch.Tensor]
413
+
414
+ # Streaming processor state
415
+ audio_chunk_idx: int
416
+
417
+ # Mel processor state snapshot (including buffer)
418
+ mel_processor_snapshot: Optional[dict] = None
419
+
420
+ # Audio encoder KV cache (requires cloning to ensure determinism after recovery)
421
+ audio_past_key_values: Optional[tuple] = None
422
+
423
+ # timestamp (for debugging)
424
+ timestamp: float = 0.0
425
+
426
+ # debug field: for verifying correctness of recovery
427
+ llm_cache_checksum: Optional[float] = None # LLM KV Cache first layer K sum
428
+ audio_cache_checksum: Optional[float] = None # Audio KV Cache first layer K sum
429
+ mel_buffer_checksum: Optional[float] = None # Mel buffer sum
430
+
431
+ # RNG state (key: for ensuring determinism of dithering etc. after recovery)
432
+ rng_state_cpu: Optional[torch.Tensor] = None # torch CPU RNG state
433
+ rng_state_cuda: Optional[torch.Tensor] = None # torch CUDA RNG state (if on GPU)
434
+
435
+ def summary(self) -> str:
436
+ mel_buf_len = 0
437
+ if self.mel_processor_snapshot:
438
+ buf = self.mel_processor_snapshot.get("buffer")
439
+ if buf is not None:
440
+ mel_buf_len = len(buf)
441
+ return (
442
+ f"llm_cache={self.llm_cache_length}, "
443
+ f"audio_cache={self.audio_cache_length}, "
444
+ f"audio_chunk_idx={self.audio_chunk_idx}, "
445
+ f"mel_buffer={mel_buf_len}, "
446
+ f"history_len={self.omni_chunk_history_length}, "
447
+ f"new_user_msg={self.new_user_msg}, "
448
+ f"llm_generated={self.llm_generated}"
449
+ )
450
+
451
+
452
+ # tts
453
+ @dataclass
454
+ class TTSSamplingParams:
455
+ top_p: float = 0.85
456
+ min_p: float = 0.01
457
+ top_k: int = 25
458
+ repetition_penalty: float = 1.05
459
+ temperature: float = 0.8
460
+ win_size: int = 16
461
+ tau_r: float = 0.1
462
+
463
+
464
+ class TTSStreamingGenerator:
465
+ """
466
+ Streaming generator for TTS that processes chunks and yields audio tokens in real-time.
467
+
468
+ Supported attention types:
469
+ - full_attention: Full attention, all tokens can attend to each other
470
+ - sliding_window: Sliding window attention, KV cache is truncated to fixed size (token_window_size)
471
+ - sliding_recompute: Sliding recompute, only keep previous chunk and recompute with current chunk
472
+ - reindex: Keep first chunk as sink, reindex sliding window positions via RoPE rotation
473
+ """
474
+
475
+ def __init__(
476
+ self,
477
+ model,
478
+ temperature: float,
479
+ eos_token: Union[int, torch.Tensor],
480
+ chunk_size: int = 25, # s3tokenizer 1s = 25token
481
+ tts_last_turn_tokens: torch.Tensor = None,
482
+ logits_processors=None,
483
+ logits_warpers=None,
484
+ ):
485
+ self.tts = model
486
+ self.device = model.device
487
+ self.temperature = torch.tensor([temperature], dtype=torch.float, device=self.device)
488
+ self.eos_token = (
489
+ torch.tensor(eos_token, device=self.device) if isinstance(eos_token, int) else eos_token.to(self.device)
490
+ )
491
+
492
+ self.num_vq = model.num_vq
493
+ self.num_audio_tokens = model.num_audio_tokens
494
+ self.recomputed_chunks = model.recomputed_chunks
495
+ self.emb_code = model.emb_code
496
+ self.head_code = model.head_code
497
+
498
+ # Attention type and window sizes
499
+ self.attention_type = model.attention_type # "full_attention", "sliding_window", "sliding_recompute", "reindex"
500
+ self.chunk_window_size = model.chunk_window_size # chunk-level window for sliding_recompute (default 2)
501
+ self.token_window_size = model.token_window_size # token-level window for sliding_window/reindex (default 300)
502
+
503
+ # RoPE config (for reindex mode)
504
+ self.rope_theta = model.model.config.rope_theta
505
+ self.head_dim = model.model.config.hidden_size // model.model.config.num_attention_heads
506
+
507
+ # Logits processors
508
+ self.logits_processors = logits_processors if logits_processors is not None else []
509
+ # Logits warpers (like TopP/TopK), separate from processors
510
+ self.logits_warpers = logits_warpers if logits_warpers is not None else []
511
+
512
+ # initialize state
513
+ self.past_key_values = None
514
+ self.text_start_pos = 0
515
+ self.idx = -1 # start from -1, become 0 when first called
516
+ self.all_conditions = []
517
+ self.all_generated_tokens = []
518
+ self.tts_last_turn_tokens = tts_last_turn_tokens
519
+ self.spk_emb = None
520
+
521
+ audio_bos = [self.tts.audio_bos_token_id]
522
+ audio_bos = torch.Tensor(audio_bos).to(self.tts.emb_text.weight.device, dtype=torch.long)
523
+
524
+ self.audio_bos_embeds = self.tts.emb_text(audio_bos).unsqueeze(0)
525
+ self.text_eos_embed = self.tts.emb_text(
526
+ torch.tensor(
527
+ [self.tts.config.text_eos_token_id],
528
+ device=self.tts.emb_text.weight.device,
529
+ dtype=torch.long,
530
+ )
531
+ ).unsqueeze(0)
532
+
533
+ # buffer related, used to fill up chunk_size and yield to outside
534
+ self.chunk_size = chunk_size
535
+ self._token_buffer: List[torch.Tensor] = []
536
+
537
+ # Chunk info tracking for sliding_recompute and reindex
538
+ self._chunk_info: List[dict] = []
539
+ self._total_seq_len = 0
540
+
541
+ # Reindex mode: track sink (first chunk) length
542
+ self._sink_kv_len = 0
543
+
544
+ def _build_recompute_inputs(self, current_condition: torch.Tensor) -> torch.Tensor:
545
+ """Build recompute inputs for sliding_recompute mode."""
546
+ if len(self._chunk_info) == 0:
547
+ return current_condition
548
+
549
+ prev_chunk = self._chunk_info[-1]
550
+ prev_condition = prev_chunk["condition"]
551
+ prev_audio_tokens = prev_chunk["audio_tokens"]
552
+
553
+ recompute_list = [prev_condition]
554
+ if len(prev_audio_tokens) > 0:
555
+ prev_audio_embeds = torch.cat([self.emb_code[0](tok) for tok in prev_audio_tokens], dim=1)
556
+ recompute_list.append(prev_audio_embeds)
557
+
558
+ recompute_list.append(current_condition)
559
+ return torch.cat(recompute_list, dim=1)
560
+
561
+ def _truncate_kv_cache_sliding_window(self):
562
+ """Truncate KV cache for sliding_window mode."""
563
+ if self.past_key_values is None:
564
+ return
565
+
566
+ if hasattr(self.past_key_values, "get_seq_length"):
567
+ current_kv_len = self.past_key_values.get_seq_length()
568
+ else:
569
+ current_kv_len = self.past_key_values[0][0].shape[2]
570
+
571
+ if current_kv_len <= self.token_window_size:
572
+ return
573
+
574
+ new_cache = DynamicCache()
575
+ num_layers = (
576
+ len(self.past_key_values.key_cache)
577
+ if hasattr(self.past_key_values, "key_cache")
578
+ else len(self.past_key_values)
579
+ )
580
+
581
+ for layer_idx in range(num_layers):
582
+ if hasattr(self.past_key_values, "key_cache"):
583
+ key = self.past_key_values.key_cache[layer_idx][:, :, -self.token_window_size :, :]
584
+ value = self.past_key_values.value_cache[layer_idx][:, :, -self.token_window_size :, :]
585
+ else:
586
+ key = self.past_key_values[layer_idx][0][:, :, -self.token_window_size :, :]
587
+ value = self.past_key_values[layer_idx][1][:, :, -self.token_window_size :, :]
588
+ new_cache.update(key, value, layer_idx)
589
+
590
+ self.past_key_values = new_cache
591
+
592
+ @staticmethod
593
+ def _apply_rope_rotation(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
594
+ """Apply RoPE rotation to tensor."""
595
+ return x * cos + rotate_half(x) * sin
596
+
597
+ def _compute_rope_cos_sin(self, positions: torch.Tensor, device: torch.device, dtype: torch.dtype):
598
+ """Compute RoPE cos and sin for given positions."""
599
+ dim_half = self.head_dim // 2
600
+ freq_seq = torch.arange(0, dim_half, dtype=torch.float32, device=device)
601
+ inv_freq = 1.0 / (self.rope_theta ** (freq_seq / dim_half))
602
+
603
+ # positions: [seq_len]
604
+ angles = positions.float().unsqueeze(-1) * inv_freq.unsqueeze(0) # [seq_len, dim_half]
605
+ angles = torch.cat([angles, angles], dim=-1) # [seq_len, head_dim]
606
+
607
+ cos = angles.cos().to(dtype)
608
+ sin = angles.sin().to(dtype)
609
+ return cos, sin
610
+
611
+ def _reindex_kv_cache(self):
612
+ """
613
+ Reindex KV cache for reindex mode:
614
+ 1. Keep first chunk as attention sink
615
+ 2. Keep last chunk
616
+ 3. Discard middle chunks
617
+ 4. Reindex the last chunk's key positions to be right after sink via RoPE rotation
618
+ """
619
+ if self.past_key_values is None or len(self._chunk_info) < 2:
620
+ return
621
+
622
+ # Get current KV cache length
623
+ if hasattr(self.past_key_values, "get_seq_length"):
624
+ current_kv_len = self.past_key_values.get_seq_length()
625
+ else:
626
+ current_kv_len = self.past_key_values[0][0].shape[2]
627
+
628
+ # Calculate sink length (first chunk)
629
+ sink_len = self._chunk_info[0]["condition_len"] + self._chunk_info[0]["audio_token_count"]
630
+
631
+ # Last chunk length
632
+ last_chunk = self._chunk_info[-1]
633
+ last_chunk_len = last_chunk["condition_len"] + last_chunk["audio_token_count"]
634
+
635
+ keep_len = sink_len + last_chunk_len
636
+
637
+ # Get device and dtype
638
+ device = self.past_key_values.key_cache[0].device
639
+ dtype = self.past_key_values.key_cache[0].dtype
640
+
641
+ if current_kv_len <= keep_len:
642
+ last_chunk_kv_len = current_kv_len - sink_len
643
+ if last_chunk_kv_len <= 0:
644
+ return
645
+ self.text_start_pos = current_kv_len
646
+ return
647
+
648
+ # Step 1: Truncate KV cache - keep sink and last chunk
649
+ new_cache = DynamicCache()
650
+ num_layers = len(self.past_key_values.key_cache)
651
+
652
+ original_start_pos = current_kv_len - last_chunk_len
653
+ new_start_pos = sink_len
654
+ delta = new_start_pos - original_start_pos # This is a scalar constant
655
+ delta_positions = torch.full((last_chunk_len,), delta, dtype=torch.float32, device=device)
656
+
657
+ # Compute rotation cos/sin
658
+ cos, sin = self._compute_rope_cos_sin(delta_positions, device, dtype)
659
+ cos = cos.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, head_dim]
660
+ sin = sin.unsqueeze(0).unsqueeze(0)
661
+
662
+ for layer_idx in range(num_layers):
663
+ key_full = self.past_key_values.key_cache[layer_idx]
664
+ value_full = self.past_key_values.value_cache[layer_idx]
665
+
666
+ # Extract sink and last chunk
667
+ key_sink = key_full[:, :, :sink_len, :]
668
+ value_sink = value_full[:, :, :sink_len, :]
669
+ key_last = key_full[:, :, -last_chunk_len:, :]
670
+ value_last = value_full[:, :, -last_chunk_len:, :]
671
+
672
+ # Apply RoPE rotation to reindex key positions
673
+ key_last_reindexed = self._apply_rope_rotation(key_last, cos, sin)
674
+
675
+ # Concatenate sink and reindexed last chunk
676
+ key = torch.cat([key_sink, key_last_reindexed], dim=2)
677
+ value = torch.cat([value_sink, value_last], dim=2)
678
+
679
+ new_cache.update(key, value, layer_idx)
680
+
681
+ self.past_key_values = new_cache
682
+
683
+ # Update text_start_pos to reflect new positions
684
+ self.text_start_pos = sink_len + last_chunk_len
685
+
686
+ @torch.inference_mode()
687
+ def generate_with_buffer(
688
+ self,
689
+ condition: torch.Tensor,
690
+ text_finished: bool = False,
691
+ max_new_token: int = 500,
692
+ ):
693
+ """input a condition embedding chunk, generate audio token each time,
694
+ and accumulate to buffer, only yield when buffer satisfies chunk_size.
695
+
696
+ Yields:
697
+ torch.Tensor of shape [chunk_size] (2D: [1, chunk_size])
698
+ """
699
+ self.idx += 1
700
+ self.device = self.tts.device
701
+
702
+ # if text finished, first concatenate Text EOS
703
+ if text_finished:
704
+ condition = torch.cat([condition, self.text_eos_embed], dim=1)
705
+
706
+ # always concatenate Audio BOS
707
+ condition = torch.cat([condition, self.audio_bos_embeds], dim=1).to(self.device)
708
+
709
+ self.all_conditions.append(condition)
710
+
711
+ # Initialize current chunk info
712
+ current_chunk_info = {
713
+ "condition_len": condition.shape[1],
714
+ "audio_token_count": 0,
715
+ "condition": condition.clone(),
716
+ "audio_tokens": [],
717
+ }
718
+
719
+ # Handle different attention types
720
+ if self.attention_type == "sliding_recompute" and self.idx >= 1:
721
+ # sliding_recompute: discard KV cache, recompute with previous + current chunk
722
+ self.past_key_values = None
723
+ current_condition = self._build_recompute_inputs(condition)
724
+ self.text_start_pos = 0
725
+ elif self.attention_type == "reindex" and self.idx >= 1:
726
+ # reindex: truncate KV cache keeping sink + last chunk, reindex positions via RoPE
727
+ self._reindex_kv_cache()
728
+ current_condition = condition
729
+ # Always update text_start_pos based on actual KV cache length (like reference code)
730
+ if self.past_key_values is not None:
731
+ if hasattr(self.past_key_values, "get_seq_length"):
732
+ kv_len = self.past_key_values.get_seq_length()
733
+ else:
734
+ kv_len = self.past_key_values[0][0].shape[2]
735
+ self.text_start_pos = kv_len
736
+ else:
737
+ current_condition = condition
738
+
739
+ condition_length = current_condition.shape[1]
740
+ prefill_len = condition_length
741
+ finished = torch.zeros(1, dtype=torch.bool, device=self.device)
742
+ chunk_generated_tokens = []
743
+
744
+ for t in range(max_new_token):
745
+ if t == 0:
746
+ inputs_embeds = current_condition
747
+ pos_ids = torch.arange(
748
+ self.text_start_pos,
749
+ self.text_start_pos + condition_length,
750
+ dtype=torch.long,
751
+ device=self.device,
752
+ ).unsqueeze(0)
753
+ else:
754
+ last = self.all_generated_tokens[-1]
755
+ # last: [1,1], directly as code id
756
+ inputs_embeds = self.emb_code[0](last)
757
+ pos_ids = torch.tensor(
758
+ [self.text_start_pos + prefill_len + t - 1],
759
+ dtype=torch.long,
760
+ device=self.device,
761
+ ).unsqueeze(0)
762
+
763
+ outputs = self.tts.model(
764
+ position_ids=pos_ids,
765
+ past_key_values=self.past_key_values,
766
+ inputs_embeds=inputs_embeds,
767
+ use_cache=True,
768
+ )
769
+ hidden_states = outputs.last_hidden_state
770
+
771
+ # Handle KV cache based on attention type
772
+ if self.attention_type == "sliding_window":
773
+ self.past_key_values = outputs.past_key_values
774
+ self._truncate_kv_cache_sliding_window()
775
+ else:
776
+ self.past_key_values = outputs.past_key_values
777
+
778
+ with P.cached():
779
+ logits = torch.empty(
780
+ hidden_states.size(0),
781
+ hidden_states.size(1),
782
+ self.num_audio_tokens,
783
+ self.num_vq,
784
+ dtype=torch.float,
785
+ device=self.device,
786
+ )
787
+ for num_vq_iter in range(self.num_vq):
788
+ x: torch.Tensor = self.head_code[num_vq_iter](hidden_states)
789
+ logits[..., num_vq_iter] = x
790
+ del x
791
+
792
+ del hidden_states
793
+
794
+ logits = logits[:, -1].float()
795
+
796
+ logits = logits.permute(0, 2, 1)
797
+ logits = logits.reshape(-1, logits.size(2))
798
+
799
+ logits /= self.temperature
800
+
801
+ audio_bos = len(self.all_generated_tokens) == 0 and t == 0
802
+
803
+ if not audio_bos:
804
+ # use generated tokens (current chunk) as input for processor/warper (align with modeling_minicpmo)
805
+ all_generated_tokens = torch.cat(self.all_generated_tokens, dim=1).to(self.device) # [1, T]
806
+ for processor in self.logits_processors:
807
+ logits = processor(all_generated_tokens, logits)
808
+
809
+ for warper in self.logits_warpers:
810
+ logits = warper(all_generated_tokens, logits)
811
+ del all_generated_tokens
812
+
813
+ # sample next token (only use first codebook, same as generate)
814
+ scores = F.softmax(logits, dim=-1)
815
+ idx_next = torch.multinomial(scores, num_samples=1) # [(B*num_vq), 1]
816
+ next_id = idx_next.view(-1, self.num_vq)[:, 0:1] # only take first codebook → [B, 1]
817
+ del scores
818
+
819
+ if next_id.eq(
820
+ self.eos_token
821
+ ).any(): # generated audio eos token, means this chunk is finished, no longer generate new tokens
822
+ finished[:] = True
823
+ else: # eos token cannot be added to buffer, he does not speak.
824
+ # convert next_id to correct shape [1, 1], no num_vq dimension
825
+ if next_id.dim() == 0: # if scalar
826
+ next_tok = next_id.unsqueeze(0).unsqueeze(0) # [1, 1]
827
+ elif next_id.dim() == 1: # if 1D [1]
828
+ next_tok = next_id.unsqueeze(0) # [1, 1]
829
+ else:
830
+ next_tok = next_id
831
+
832
+ self.all_generated_tokens.append(next_tok)
833
+ chunk_generated_tokens.append(next_tok)
834
+
835
+ # Update chunk info for sliding_recompute
836
+ current_chunk_info["audio_tokens"].append(next_tok.clone())
837
+ current_chunk_info["audio_token_count"] += 1
838
+
839
+ self._token_buffer.append(next_tok)
840
+
841
+ if len(self._token_buffer) == 0:
842
+ # case 1: if last text chunk, yield None
843
+ if text_finished:
844
+ yield torch.empty(1, 0, dtype=torch.long, device=self.device), True
845
+ break
846
+ # case 2: if not last text chunk, break directly
847
+ else:
848
+ break
849
+ else: # buffer has something
850
+ # case 1: if buffer is larger/equal to chunk_size, yield out
851
+ if len(self._token_buffer) >= self.chunk_size:
852
+ batch = torch.cat(self._token_buffer[: self.chunk_size], dim=1) # [1, chunk_size]
853
+ yield batch, False # → [1, chunk_size]
854
+ # discard yielded part
855
+ self._token_buffer = self._token_buffer[self.chunk_size :]
856
+
857
+ # case 2: if buffer is smaller than chunk_size
858
+ else:
859
+ # if generation finished, and is the last text chunk, yield all remaining tokens, then break
860
+ if finished.all():
861
+ if text_finished:
862
+ batch = torch.cat(self._token_buffer, dim=1) # [1, chunk_size]
863
+ yield batch, True # → [1, chunk_size]
864
+ self._token_buffer = []
865
+ break
866
+ else:
867
+ # not the last text chunk, need to wait for next text chunk to fill up buffer, then this call ends
868
+ break
869
+ else: # generation of this audio chunk is not finished, continue generating
870
+ continue
871
+
872
+ # Save current chunk info for sliding_recompute and reindex
873
+ self._chunk_info.append(current_chunk_info)
874
+ self._total_seq_len += condition.shape[1] + len(chunk_generated_tokens)
875
+
876
+ # Update text_start_pos based on attention type
877
+ if self.attention_type == "sliding_recompute":
878
+ # sliding_recompute: will be reset at next chunk start, update normally here
879
+ self.text_start_pos += prefill_len + len(chunk_generated_tokens)
880
+ elif self.attention_type == "reindex":
881
+ # reindex: position based on actual KV cache length (positions have been reindexed to be continuous)
882
+ if self.past_key_values is not None:
883
+ if hasattr(self.past_key_values, "get_seq_length"):
884
+ self.text_start_pos = self.past_key_values.get_seq_length()
885
+ else:
886
+ self.text_start_pos = self.past_key_values[0][0].shape[2]
887
+ else:
888
+ self.text_start_pos += condition.shape[1] + len(chunk_generated_tokens)
889
+ else:
890
+ self.text_start_pos += condition.shape[1] + len(chunk_generated_tokens)
891
+ # note: remaining tokens in buffer will be kept, and accumulated next time
892
+
893
+
894
+ # sliding window
895
+ @dataclass
896
+ class StreamingWindowConfig:
897
+ text_window_high_tokens: int = 8000
898
+ text_window_low_tokens: int = 6000
899
+
900
+
901
+ @dataclass
902
+ class DuplexWindowConfig:
903
+ """duplex sliding window configuration
904
+
905
+ sliding window mode:
906
+ - "off": disable sliding window
907
+ - "basic": basic sliding window (trigger by cache length)
908
+ - "context": sliding window with context (trigger by unit number, preserve generated text to previous)
909
+ """
910
+
911
+ # sliding window mode
912
+ sliding_window_mode: str = "off" # "off" / "basic" / "context"
913
+
914
+ # basic sliding window parameters
915
+ basic_window_high_tokens: int = 8000 # high watermark: trigger sliding window when exceeded
916
+ basic_window_low_tokens: int = 6000 # low watermark: keep to this value after sliding window
917
+
918
+ # context sliding window parameters
919
+ context_previous_max_tokens: int = 500 # previous maximum token number
920
+ context_max_units: int = 24 # maximum unit number (trigger sliding window when exceeded)
921
+
922
+ # verification mode (for comparison test)
923
+ verify_mode: bool = False # whether to enable verification log
924
+
925
+
926
+ def as_dynamic_cache(past_key_values):
927
+ """Convert legacy tuple cache to DynamicCache if needed."""
928
+ if isinstance(past_key_values, DynamicCache):
929
+ return past_key_values
930
+
931
+ if isinstance(past_key_values, tuple):
932
+ return DynamicCache.from_legacy_cache(past_key_values)
933
+
934
+ return past_key_values
935
+
936
+
937
+ def get_kv_cache_length(cache) -> int:
938
+ """Get the sequence length of a KV cache.
939
+
940
+ Args:
941
+ cache: DynamicCache or tuple-based cache
942
+
943
+ Returns:
944
+ The number of tokens in the cache
945
+ """
946
+ if cache is None:
947
+ return 0
948
+
949
+ if isinstance(cache, DynamicCache):
950
+ if not cache.key_cache or not cache.key_cache[0].numel():
951
+ return 0
952
+ return cache.key_cache[0].shape[-2]
953
+
954
+ if isinstance(cache, tuple):
955
+ return cache[0][0].shape[2]
956
+
957
+ return 0
958
+
959
+
960
+ def get_rotary_cos_sin(
961
+ head_dim: int,
962
+ positions: torch.Tensor,
963
+ device: torch.device,
964
+ dtype: torch.dtype,
965
+ rope_theta: float = 10000.0,
966
+ inv_freq_cache: Optional[Dict[Tuple, torch.Tensor]] = None,
967
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
968
+ """Compute RoPE cos and sin components for given positions.
969
+
970
+ Args:
971
+ head_dim: Dimension of each attention head
972
+ positions: Position indices tensor
973
+ device: Target device
974
+ dtype: Target dtype
975
+ rope_theta: RoPE base frequency (default 10000.0)
976
+ inv_freq_cache: Optional cache dict for inverse frequencies
977
+
978
+ Returns:
979
+ Tuple of (cos, sin) tensors with shape [1, 1, seq_len, head_dim]
980
+ """
981
+ cache_key = (head_dim, device)
982
+
983
+ inv_freq = inv_freq_cache.get(cache_key) if inv_freq_cache is not None else None
984
+ if inv_freq is None or inv_freq.device != device or inv_freq.shape[0] != head_dim // 2:
985
+ exponent = torch.arange(0, head_dim, 2, device=device, dtype=torch.float32) / head_dim
986
+ inv_freq = 1.0 / (rope_theta**exponent)
987
+ if inv_freq_cache is not None:
988
+ inv_freq_cache[cache_key] = inv_freq
989
+
990
+ positions = positions.to(device=device, dtype=torch.float32)
991
+ angles = torch.einsum("i,j->ij", positions, inv_freq)
992
+ cos = torch.cos(angles)
993
+ sin = torch.sin(angles)
994
+
995
+ # Use cat instead of repeat_interleave, consistent with model's original RotaryEmbedding
996
+ # Original: emb = torch.cat((freqs, freqs), dim=-1) -> [f0, f1, ..., f_{d/2}, f0, f1, ..., f_{d/2}]
997
+ cos_full = torch.cat([cos, cos], dim=-1).to(dtype=dtype)
998
+ sin_full = torch.cat([sin, sin], dim=-1).to(dtype=dtype)
999
+ cos_full = cos_full.unsqueeze(0).unsqueeze(0)
1000
+ sin_full = sin_full.unsqueeze(0).unsqueeze(0)
1001
+ return cos_full, sin_full
1002
+
1003
+
1004
+ def realign_rotary_suffix(
1005
+ suffix_keys: torch.Tensor,
1006
+ old_positions: torch.Tensor,
1007
+ new_positions: torch.Tensor,
1008
+ rope_theta: float = 10000.0,
1009
+ inv_freq_cache: Optional[Dict[Tuple, torch.Tensor]] = None,
1010
+ ) -> torch.Tensor:
1011
+ """Realign RoPE position encoding after cache eviction.
1012
+
1013
+ When tokens are dropped from the middle of a cache, the suffix tokens
1014
+ need their RoPE embeddings recalculated with new position indices.
1015
+
1016
+ Args:
1017
+ suffix_keys: Key tensor to realign, shape [batch, heads, seq_len, head_dim]
1018
+ old_positions: Original position indices
1019
+ new_positions: New position indices after eviction
1020
+ rope_theta: RoPE base frequency
1021
+ inv_freq_cache: Optional cache dict for inverse frequencies
1022
+
1023
+ Returns:
1024
+ Realigned key tensor with same shape as input
1025
+ """
1026
+ if suffix_keys.numel() == 0:
1027
+ return suffix_keys
1028
+
1029
+ head_dim = suffix_keys.shape[-1]
1030
+ device = suffix_keys.device
1031
+ dtype = suffix_keys.dtype
1032
+
1033
+ # Compute old position cos/sin
1034
+ cos_old, sin_old = get_rotary_cos_sin(head_dim, old_positions, device, dtype, rope_theta, inv_freq_cache)
1035
+
1036
+ # Inverse transform: recover original key
1037
+ base = cos_old * suffix_keys - sin_old * rotate_half(suffix_keys)
1038
+
1039
+ # Compute new position cos/sin
1040
+ cos_new, sin_new = get_rotary_cos_sin(head_dim, new_positions, device, dtype, rope_theta, inv_freq_cache)
1041
+
1042
+ # Forward transform: re-encode with new positions
1043
+ return cos_new * base + sin_new * rotate_half(base)
1044
+
1045
+
1046
+ def drop_tokens_from_cache(
1047
+ cache: Optional[DynamicCache | Tuple],
1048
+ length: int,
1049
+ preserve: int,
1050
+ position_offset: int,
1051
+ rope_theta: float = 10000.0,
1052
+ inv_freq_cache: Optional[Dict[Tuple, torch.Tensor]] = None,
1053
+ ) -> Tuple[Optional[DynamicCache], int, bool]:
1054
+ """Drop tokens from a KV cache while preserving system prompt.
1055
+
1056
+ Removes tokens in the range [preserve, preserve + length) from the cache,
1057
+ realigning RoPE embeddings for the suffix.
1058
+
1059
+ Args:
1060
+ cache: DynamicCache or tuple-based cache (will be converted to DynamicCache)
1061
+ length: Number of tokens to drop
1062
+ preserve: Number of tokens to preserve at the start (system prompt)
1063
+ position_offset: Current position offset for RoPE calculation
1064
+ rope_theta: RoPE base frequency
1065
+ inv_freq_cache: Optional cache dict for inverse frequencies
1066
+
1067
+ Returns:
1068
+ Tuple of (cache, new_position_offset, success)
1069
+ Note: Tuple cache will be converted to DynamicCache. Modification is in-place.
1070
+ """
1071
+ if cache is None or length <= 0:
1072
+ return cache, position_offset, False
1073
+
1074
+ cache = as_dynamic_cache(cache)
1075
+
1076
+ total_len = get_kv_cache_length(cache)
1077
+ if total_len <= 0:
1078
+ return cache, position_offset, False
1079
+
1080
+ preserve = min(preserve, total_len)
1081
+ available = total_len - preserve
1082
+
1083
+ if available < length:
1084
+ logger.warning(
1085
+ "Cannot drop %d tokens: only %d available (total=%d, preserve=%d)",
1086
+ length,
1087
+ available,
1088
+ total_len,
1089
+ preserve,
1090
+ )
1091
+ return cache, position_offset, False
1092
+
1093
+ suffix_len = total_len - preserve - length
1094
+ # note: after RoPE reindex, the position of cache has been compressed (from preserve start)
1095
+ # so here should not add position_offset, but use the actual layout of current cache
1096
+ suffix_offset = preserve + length # suffix current position in cache
1097
+ prefix_offset = preserve # suffix new position (follow preserve)
1098
+
1099
+ # Prepare position tensors for RoPE realignment
1100
+ old_positions = None
1101
+ new_positions = None
1102
+ if suffix_len > 0:
1103
+ device = cache.key_cache[0].device
1104
+ old_positions = torch.arange(
1105
+ suffix_offset,
1106
+ suffix_offset + suffix_len,
1107
+ device=device,
1108
+ dtype=torch.long,
1109
+ )
1110
+ new_positions = torch.arange(
1111
+ prefix_offset,
1112
+ prefix_offset + suffix_len,
1113
+ device=device,
1114
+ dtype=torch.long,
1115
+ )
1116
+
1117
+ keep_len = total_len - length
1118
+
1119
+ # Process each layer (in-place modification)
1120
+ for layer_idx in range(len(cache.key_cache)):
1121
+ key_tensor = cache.key_cache[layer_idx]
1122
+ value_tensor = cache.value_cache[layer_idx]
1123
+
1124
+ if not key_tensor.numel():
1125
+ continue
1126
+
1127
+ # Preserve prefix (system prompt)
1128
+ prefix_keys = key_tensor[:, :, :preserve, :]
1129
+ prefix_values = value_tensor[:, :, :preserve, :]
1130
+
1131
+ if suffix_len > 0:
1132
+ # Keep and realign suffix
1133
+ suffix_keys = key_tensor[:, :, preserve + length :, :]
1134
+ suffix_values = value_tensor[:, :, preserve + length :, :]
1135
+
1136
+ if old_positions is not None and new_positions is not None and suffix_keys.numel():
1137
+ suffix_keys = realign_rotary_suffix(
1138
+ suffix_keys,
1139
+ old_positions,
1140
+ new_positions,
1141
+ rope_theta,
1142
+ inv_freq_cache,
1143
+ )
1144
+
1145
+ cache.key_cache[layer_idx] = torch.cat([prefix_keys, suffix_keys], dim=-2).contiguous()
1146
+ cache.value_cache[layer_idx] = torch.cat([prefix_values, suffix_values], dim=-2).contiguous()
1147
+ else:
1148
+ cache.key_cache[layer_idx] = prefix_keys.contiguous()
1149
+ cache.value_cache[layer_idx] = prefix_values.contiguous()
1150
+
1151
+ cache.crop(keep_len)
1152
+ cache._seen_tokens = max(keep_len, 0)
1153
+
1154
+ new_offset = position_offset + length
1155
+ logger.debug("Dropped %d tokens from cache, new length=%d", length, keep_len)
1156
+
1157
+ return cache, new_offset, True
1158
+
1159
+
1160
+ # stream decoder
1161
+ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float("inf")):
1162
+ logits = logits.clone()
1163
+
1164
+ # Top-k filtering
1165
+ if top_k > 0:
1166
+ top_k = min(top_k, logits.size(-1))
1167
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
1168
+ logits[indices_to_remove] = filter_value
1169
+
1170
+ # Top-p (nucleus) filtering
1171
+ if top_p > 0.0:
1172
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
1173
+ probs = F.softmax(sorted_logits, dim=-1)
1174
+ cumulative_probs = torch.cumsum(probs, dim=-1)
1175
+
1176
+ sorted_indices_to_remove = cumulative_probs > top_p
1177
+ # keep the first token that exceeds top_p
1178
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
1179
+ sorted_indices_to_remove[..., 0] = 0
1180
+
1181
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
1182
+ logits[0, indices_to_remove] = filter_value
1183
+
1184
+ return logits
1185
+
1186
+
1187
+ class StreamDecoder:
1188
+ def __init__(self, llm, tokenizer, special_token_ids=None, forbidden_token_ids=None):
1189
+ self.m = llm
1190
+ self.tokenizer = tokenizer
1191
+ self.listen_id = self.tokenizer.eos_token_id
1192
+
1193
+ self.chunk_eos_id = self.tokenizer.convert_tokens_to_ids("<|chunk_eos|>")
1194
+ self.chunk_tts_eos_id = self.tokenizer.convert_tokens_to_ids("<|chunk_tts_eos|>")
1195
+ self.turn_eos_id = self.tokenizer.convert_tokens_to_ids("<|turn_eos|>")
1196
+ self.speak_id = self.tokenizer.convert_tokens_to_ids("<|speak|>")
1197
+
1198
+ self.special_token_ids = special_token_ids if special_token_ids is not None else []
1199
+
1200
+ # cache special tokens (used for context sliding window filtering)
1201
+ self._all_special_ids = set()
1202
+ self._all_special_tokens_text = set()
1203
+ if self.tokenizer:
1204
+ if hasattr(self.tokenizer, "all_special_ids"):
1205
+ self._all_special_ids = set(self.tokenizer.all_special_ids)
1206
+ if hasattr(self.tokenizer, "all_special_tokens"):
1207
+ self._all_special_tokens_text = set(self.tokenizer.all_special_tokens)
1208
+
1209
+ custom_special_tokens = [
1210
+ "<unit>",
1211
+ "</unit>",
1212
+ "<image>",
1213
+ "</image>",
1214
+ "<slice>",
1215
+ "</slice>",
1216
+ "<|listen|>",
1217
+ "<|speak|>",
1218
+ "<|tts_bos|>",
1219
+ "<|tts_eos|>",
1220
+ "<|audio_start|>",
1221
+ "<|audio_end|>",
1222
+ "<|chunk_eos|>",
1223
+ "<|chunk_tts_eos|>",
1224
+ "<|turn_eos|>",
1225
+ "<|audio_start|>",
1226
+ "<|audio_end|>",
1227
+ ]
1228
+ self._all_special_tokens_text.update(custom_special_tokens)
1229
+ for token in custom_special_tokens:
1230
+ token_id = self.tokenizer.convert_tokens_to_ids(token)
1231
+ if token_id is not None and token_id != self.tokenizer.unk_token_id:
1232
+ self._all_special_ids.add(token_id)
1233
+
1234
+ if forbidden_token_ids is None:
1235
+ self.forbidden_token_ids = []
1236
+ elif isinstance(forbidden_token_ids, int):
1237
+ self.forbidden_token_ids = [self.forbidden_token_ids]
1238
+ else:
1239
+ self.forbidden_token_ids = forbidden_token_ids
1240
+ self.forbidden_token_ids.append(self.chunk_eos_id)
1241
+
1242
+ assert isinstance(self.forbidden_token_ids, list)
1243
+
1244
+ self.cache = None
1245
+ self.context = ""
1246
+ self.generated_tokens = [] # track generated tokens
1247
+ self.generated_special_tokens = [] # track generated special tokens
1248
+ self.reset()
1249
+ self.embeds = None
1250
+ self.system_embeds = None
1251
+
1252
+ # sliding window related states
1253
+ self._unit_history: List[Dict[str, Any]] = []
1254
+ self._next_unit_id: int = 0
1255
+ self._pending_unit_id: Optional[int] = None
1256
+ self._pending_unit_start_cache_len: int = 0
1257
+ self._system_preserve_length: int = 0
1258
+ self._position_offset: int = 0
1259
+ self._window_config = DuplexWindowConfig()
1260
+ self._window_enabled: bool = True
1261
+ self._rope_inv_freq_cache: Dict[Tuple, torch.Tensor] = {}
1262
+
1263
+ # context preserving sliding window states
1264
+ # initial cache layout: [prefix] [suffix] [units...]
1265
+ # after first sliding window: [prefix] [previous_marker + content] [suffix] [units...]
1266
+ # fixed dynamic sliding region fixed
1267
+ self._preserve_prefix_length: int = 0 # original prefix length (fixed)
1268
+ self._previous_content_length: int = 0 # previous content length (dynamic, including marker)
1269
+ self._suffix_token_ids: List[int] = [] # suffix token ids (e.g. <|im_end|>)
1270
+
1271
+ # previous marker (added dynamically after first sliding window)
1272
+ self._previous_marker: str = "\n\nprevious: " # fixed prefix marker
1273
+ self._previous_marker_token_ids: List[int] = [] # marker token ids (initialized)
1274
+ self._has_previous: bool = False # whether previous marker has been added
1275
+
1276
+ # previous content
1277
+ self._previous_text: str = "" # accumulated generated text (without marker)
1278
+ self._previous_token_ids: List[int] = [] # previous full token ids (including marker)
1279
+
1280
+ # validation statistics
1281
+ self._sliding_event_count: int = 0 # sliding window trigger count
1282
+ self._total_dropped_tokens: int = 0 # total dropped token count
1283
+ self._total_dropped_units: int = 0 # total dropped unit count
1284
+
1285
+ def sliding_embeds(self):
1286
+ # tmp = system_embeds
1287
+ # tmp +-》 embeds after 5s
1288
+ # reset
1289
+ # feed
1290
+ pass
1291
+
1292
+ def reset(self):
1293
+ self.context = ""
1294
+ self.cache = None
1295
+ self.generated_tokens = []
1296
+ self.generated_special_tokens = []
1297
+ self.embeds = None
1298
+ self.system_embeds = None
1299
+
1300
+ # sliding window state reset
1301
+ old_unit_count = len(self._unit_history) if hasattr(self, "_unit_history") else 0
1302
+ self._unit_history = []
1303
+ self._next_unit_id = 0
1304
+ self._pending_unit_id = None
1305
+ self._pending_unit_start_cache_len = 0
1306
+ self._system_preserve_length = 0
1307
+ self._position_offset = 0
1308
+ self._rope_inv_freq_cache = {}
1309
+
1310
+ # context preserving sliding window state reset
1311
+ self._preserve_prefix_length = 0
1312
+ self._previous_content_length = 0
1313
+ self._suffix_token_ids = []
1314
+ self._previous_marker = "\n\nprevious: "
1315
+ self._previous_marker_token_ids = []
1316
+ self._has_previous = False
1317
+ self._previous_text = ""
1318
+ self._previous_token_ids = []
1319
+
1320
+ # validation statistics
1321
+ self._sliding_event_count = 0 # sliding window trigger count
1322
+ self._total_dropped_tokens = 0 # total dropped token count
1323
+ self._total_dropped_units = 0 # total dropped unit count
1324
+
1325
+ def get_cache_length(self) -> int:
1326
+ if self.cache is None:
1327
+ return 0
1328
+ if isinstance(self.cache, DynamicCache):
1329
+ if len(self.cache.key_cache) > 0 and self.cache.key_cache[0].numel() > 0:
1330
+ return self.cache.key_cache[0].shape[2]
1331
+ return 0
1332
+ # Tuple cache format
1333
+ return self.cache[0][0].shape[2]
1334
+
1335
+ def get_total_generated_tokens(self) -> int:
1336
+ return sum(len(u.get("generated_tokens", [])) for u in self._unit_history)
1337
+
1338
+ def register_unit_start(self) -> int:
1339
+ self._pending_unit_id = self._next_unit_id
1340
+ self._pending_unit_start_cache_len = self.get_cache_length()
1341
+ return self._pending_unit_id
1342
+
1343
+ def register_unit_end(
1344
+ self,
1345
+ input_type: str,
1346
+ generated_tokens: Optional[List[int]] = None,
1347
+ is_listen: bool = False,
1348
+ generated_text: Optional[str] = None,
1349
+ ):
1350
+ """Call when unit ends, record unit information
1351
+
1352
+ Should be called after feeding </unit> token
1353
+
1354
+ Args:
1355
+ input_type: "audio" / "video" / "omni" / "system"
1356
+ generated_tokens: tokens generated by the unit (token ids)
1357
+ is_listen: whether the unit is in listen state
1358
+ generated_text: text generated by the unit (used for context preserving mode)
1359
+ """
1360
+ if self._pending_unit_id is None:
1361
+ logger.warning("register_unit_end called without register_unit_start")
1362
+ return
1363
+
1364
+ # calculate the length of the unit
1365
+ current_cache_len = self.get_cache_length()
1366
+ unit_len = current_cache_len - self._pending_unit_start_cache_len
1367
+
1368
+ if unit_len > 0:
1369
+ entry = {
1370
+ "unit_id": self._pending_unit_id,
1371
+ "length": unit_len,
1372
+ "type": input_type,
1373
+ "generated_tokens": generated_tokens or [],
1374
+ "generated_text": generated_text or "", # used for context preserving mode
1375
+ "is_listen": is_listen,
1376
+ }
1377
+ self._unit_history.append(entry)
1378
+
1379
+ self._pending_unit_id = None
1380
+ self._pending_unit_start_cache_len = 0
1381
+ self._next_unit_id += 1
1382
+
1383
+ def register_system_prompt(self):
1384
+ """Call after system prompt prefill, record preserve length"""
1385
+ self._system_preserve_length = self.get_cache_length()
1386
+
1387
+ # sliding window core methods
1388
+
1389
+ def _get_rope_theta(self) -> float:
1390
+ """get model rope_theta configuration"""
1391
+ return float(getattr(self.m.config, "rope_theta", 10000.0))
1392
+
1393
+ def _drop_tokens_from_cache(self, length: int) -> bool:
1394
+ """remove specified number of tokens from cache (protect system prompt)
1395
+
1396
+ remove tokens in the range [preserve, preserve + length)
1397
+ supports DynamicCache and tuple cache formats
1398
+ """
1399
+ if self.cache is None or length <= 0:
1400
+ return False
1401
+
1402
+ cache_type = "DynamicCache" if isinstance(self.cache, DynamicCache) else "TupleCache"
1403
+ cache_len_before = self.get_cache_length()
1404
+ offset_before = self._position_offset
1405
+
1406
+ new_cache, new_offset, success = drop_tokens_from_cache(
1407
+ cache=self.cache,
1408
+ length=length,
1409
+ preserve=self._system_preserve_length,
1410
+ position_offset=self._position_offset,
1411
+ rope_theta=self._get_rope_theta(),
1412
+ inv_freq_cache=self._rope_inv_freq_cache,
1413
+ )
1414
+ if success:
1415
+ self.cache = new_cache # For DynamicCache this is the same object (in-place)
1416
+ self._position_offset = new_offset
1417
+
1418
+ return success
1419
+
1420
+ def _drop_unit(self, unit_id: int) -> bool:
1421
+ """remove specified unit"""
1422
+ entries = [u for u in self._unit_history if u["unit_id"] == unit_id]
1423
+ if not entries:
1424
+ return False
1425
+
1426
+ total_len = sum(e["length"] for e in entries)
1427
+ if total_len <= 0:
1428
+ for e in entries:
1429
+ self._unit_history.remove(e)
1430
+ return False
1431
+
1432
+ if not self._drop_tokens_from_cache(total_len):
1433
+ return False
1434
+
1435
+ for e in entries:
1436
+ self._unit_history.remove(e)
1437
+
1438
+ return True
1439
+
1440
+ def _drop_next_unit(self) -> bool:
1441
+ """remove the earliest non-system unit"""
1442
+ for entry in self._unit_history:
1443
+ unit_id = entry.get("unit_id")
1444
+ if unit_id is None:
1445
+ continue
1446
+ # skip system type
1447
+ if entry.get("type") == "system":
1448
+ continue
1449
+ if self._drop_unit(unit_id):
1450
+ return True
1451
+ return False
1452
+
1453
+ def enforce_window(self) -> bool:
1454
+ """enforce sliding window strategy (same as single-mode, only look at cache length)
1455
+
1456
+ when cache length exceeds high water line, loop to remove the earliest unit,
1457
+ until cache length drops below the low water line.
1458
+ """
1459
+ if not self._window_enabled:
1460
+ return False
1461
+
1462
+ cfg = self._window_config
1463
+ cache_len_before = self.get_cache_length()
1464
+
1465
+ if cache_len_before <= cfg.basic_window_high_tokens:
1466
+ return False # not above high water line, no trigger
1467
+
1468
+ dropped_count = 0
1469
+ cache_len = cache_len_before
1470
+ while cache_len > cfg.basic_window_low_tokens:
1471
+ if not self._drop_next_unit():
1472
+ break
1473
+ dropped_count += 1
1474
+ cache_len = self.get_cache_length()
1475
+
1476
+ if dropped_count > 0:
1477
+ # update statistics counters
1478
+ self._sliding_event_count += 1
1479
+ self._total_dropped_tokens += cache_len_before - cache_len
1480
+ self._total_dropped_units += dropped_count
1481
+
1482
+ # consistency check
1483
+ expected = self._system_preserve_length + sum(u["length"] for u in self._unit_history)
1484
+ is_consistent = expected == cache_len
1485
+ if not is_consistent:
1486
+ logger.error(
1487
+ "CONSISTENCY ERROR! preserve=%d + sum(units)=%d != cache=%d, offset=%d",
1488
+ self._system_preserve_length,
1489
+ sum(u["length"] for u in self._unit_history),
1490
+ cache_len,
1491
+ self._position_offset,
1492
+ )
1493
+
1494
+ return dropped_count > 0
1495
+
1496
+ # context preserving sliding window methods
1497
+
1498
+ def register_system_prompt_with_context(
1499
+ self,
1500
+ suffix_token_ids: Optional[List[int]] = None,
1501
+ context_previous_marker: str = "\n\nprevious: ",
1502
+ ):
1503
+ """register system prompt (with context preserving mode)
1504
+
1505
+ initial cache layout: [prefix] [suffix] [units...]
1506
+ after first sliding window: [prefix] [context_previous_marker + content] [suffix] [units...]
1507
+
1508
+ when calling this method, cache should only have prefix (without previous marker)
1509
+ suffix will be fed in later
1510
+
1511
+ Args:
1512
+ suffix_token_ids: suffix token ids (e.g. id of <|im_end|>)
1513
+ context_previous_marker: previous marker prefix, e.g. "\\n\\nprevious: "
1514
+ """
1515
+ # prefix = current cache content (fixed, without previous marker)
1516
+ self._preserve_prefix_length = self.get_cache_length()
1517
+ self._previous_content_length = 0 # initially no previous content
1518
+ self._suffix_token_ids = suffix_token_ids or []
1519
+ # total preserve length = prefix + suffix (initially no previous)
1520
+ self._system_preserve_length = self._preserve_prefix_length + len(self._suffix_token_ids)
1521
+
1522
+ # initialize previous related states
1523
+ self._previous_marker = context_previous_marker
1524
+ self._previous_marker_token_ids = (
1525
+ self.tokenizer.encode(context_previous_marker, add_special_tokens=False) if self.tokenizer else []
1526
+ )
1527
+ self._has_previous = False
1528
+ self._previous_text = ""
1529
+ self._previous_token_ids = []
1530
+
1531
+ def _extract_generated_text(self, units: List[Dict[str, Any]]) -> Tuple[str, List[int]]:
1532
+ """extract generated text and token ids from units
1533
+
1534
+ Args:
1535
+ units: list of units to extract
1536
+
1537
+ Returns:
1538
+ (text, token_ids): concatenated text and token ids (filtered out special tokens)
1539
+ """
1540
+ text_parts = []
1541
+ token_ids = []
1542
+
1543
+ for u in units:
1544
+ # only keep generated content of non-listen units
1545
+ if u.get("is_listen", False):
1546
+ continue
1547
+ gen_text = u.get("generated_text", "")
1548
+ gen_tokens = u.get("generated_tokens", [])
1549
+
1550
+ # filter out special tokens from text
1551
+ if gen_text:
1552
+ clean_text = gen_text
1553
+ for st in self._all_special_tokens_text:
1554
+ clean_text = clean_text.replace(st, "")
1555
+ if clean_text.strip():
1556
+ text_parts.append(clean_text)
1557
+
1558
+ # filter out special tokens
1559
+ if gen_tokens:
1560
+ filtered_tokens = [t for t in gen_tokens if t not in self._all_special_ids]
1561
+ token_ids.extend(filtered_tokens)
1562
+
1563
+ return "".join(text_parts), token_ids
1564
+
1565
+ def _rebuild_cache_with_previous(
1566
+ self,
1567
+ new_previous_tokens: List[int],
1568
+ units_to_keep_len: Optional[int] = None,
1569
+ ) -> bool:
1570
+ """rebuild cache, insert new previous content between prefix and suffix
1571
+
1572
+ cache layout change:
1573
+ [prefix] [old_prev] [suffix] [old_units] → [prefix] [new_prev] [suffix] [remaining_units]
1574
+
1575
+ Args:
1576
+ new_previous_tokens: new previous token ids
1577
+ units_to_keep_len: length of units to keep (from cache end backwards)
1578
+ if None, calculate based on unit_history
1579
+
1580
+ Returns:
1581
+ whether successful rebuild
1582
+ """
1583
+ if self.cache is None:
1584
+ return False
1585
+
1586
+ old_previous_len = self._previous_content_length
1587
+ new_previous_len = len(new_previous_tokens)
1588
+ suffix_len = len(self._suffix_token_ids)
1589
+ total_cache_len = self.get_cache_length()
1590
+
1591
+ # calculate length of units to keep
1592
+ if units_to_keep_len is None:
1593
+ units_to_keep_len = sum(u["length"] for u in self._unit_history)
1594
+
1595
+ # special case: if previous is unchanged (new and old are empty), no need to rebuild prefix+suffix part of cache
1596
+ # but still need to reindex units RoPE (because a unit was deleted, position changed)
1597
+ if new_previous_len == 0 and old_previous_len == 0:
1598
+ # cache layout: [prefix(7)] [suffix(1)] [units...]
1599
+ # only keep prefix + suffix + remaining_units
1600
+ preserve_len = self._preserve_prefix_length + suffix_len
1601
+
1602
+ # simply slice cache: [prefix+suffix] + [remaining_units]
1603
+ # remaining_units in cache end
1604
+ if units_to_keep_len > 0:
1605
+ # [0:preserve_len] + [total-units_to_keep_len:total]
1606
+ prefix_suffix_cache = self._slice_cache(0, preserve_len)
1607
+ units_cache = self._slice_cache(total_cache_len - units_to_keep_len, None)
1608
+
1609
+ # calculate number of dropped tokens
1610
+ dropped_tokens = total_cache_len - preserve_len - units_to_keep_len
1611
+
1612
+ # reindex units RoPE: position from (preserve_len + dropped_tokens) to preserve_len
1613
+ # note: no position_offset, because cache position has been compressed (from 0 start)
1614
+ if dropped_tokens > 0:
1615
+ old_start = preserve_len + dropped_tokens
1616
+ new_start = preserve_len
1617
+ units_cache = self._reindex_rope_for_cache(units_cache, old_start, new_start, units_to_keep_len)
1618
+
1619
+ self.cache = self._concat_caches(prefix_suffix_cache, units_cache)
1620
+ else:
1621
+ self.cache = self._slice_cache(0, preserve_len)
1622
+
1623
+ return True
1624
+
1625
+ # 1. get prefix cache (fixed)
1626
+ prefix_end = self._preserve_prefix_length
1627
+ prefix_cache = self._slice_cache(0, prefix_end)
1628
+
1629
+ # 2. get units cache to keep (from end)
1630
+ units_start_in_old_cache = total_cache_len - units_to_keep_len
1631
+ units_cache = None
1632
+ if units_to_keep_len > 0:
1633
+ units_cache = self._slice_cache(units_start_in_old_cache, None)
1634
+
1635
+ # 3. calculate new previous + suffix cache (needs forward)
1636
+ # merge previous tokens and suffix tokens
1637
+ prev_suffix_tokens = new_previous_tokens + self._suffix_token_ids
1638
+ prev_suffix_len = len(prev_suffix_tokens)
1639
+
1640
+ new_prefix_prev_suffix_cache = prefix_cache
1641
+ if prev_suffix_len > 0:
1642
+ # Embed tokens
1643
+ prev_suffix_embeds = self.embed_tokens(prev_suffix_tokens)
1644
+ # calculate start position (after prefix)
1645
+ start_pos = self._preserve_prefix_length + self._position_offset
1646
+
1647
+ # forward calculate KV cache
1648
+ with torch.no_grad():
1649
+ device = prev_suffix_embeds.device
1650
+ position_ids = torch.arange(
1651
+ start_pos,
1652
+ start_pos + prev_suffix_len,
1653
+ device=device,
1654
+ ).unsqueeze(0)
1655
+
1656
+ # use prefix cache as past_key_values
1657
+ outputs = self.m(
1658
+ inputs_embeds=(
1659
+ prev_suffix_embeds.unsqueeze(0) if prev_suffix_embeds.dim() == 2 else prev_suffix_embeds
1660
+ ),
1661
+ position_ids=position_ids,
1662
+ past_key_values=prefix_cache,
1663
+ use_cache=True,
1664
+ return_dict=True,
1665
+ )
1666
+ # new cache contains prefix + new_previous + suffix
1667
+ new_prefix_prev_suffix_cache = outputs.past_key_values
1668
+
1669
+ # 4. adjust units cache RoPE
1670
+ # new layout: [prefix] [new_prev] [suffix] [units]
1671
+ # note: no position_offset, because cache position has been compressed (from 0 start)
1672
+ new_system_total = prefix_end + new_previous_len + suffix_len
1673
+ if units_cache is not None and self._get_cache_len(units_cache) > 0:
1674
+ old_start = units_start_in_old_cache
1675
+ new_start = new_system_total
1676
+
1677
+ if old_start != new_start:
1678
+ units_cache = self._reindex_rope_for_cache(units_cache, old_start, new_start, units_to_keep_len)
1679
+
1680
+ # 5. concatenate new cache
1681
+ if units_cache is not None and self._get_cache_len(units_cache) > 0:
1682
+ self.cache = self._concat_caches(new_prefix_prev_suffix_cache, units_cache)
1683
+ else:
1684
+ self.cache = new_prefix_prev_suffix_cache
1685
+
1686
+ # 6. update length
1687
+ self._previous_content_length = new_previous_len
1688
+ # total preserve length = prefix + previous + suffix
1689
+ self._system_preserve_length = prefix_end + new_previous_len + suffix_len
1690
+
1691
+ # print detailed cache layout information
1692
+ prev_text_preview = self._previous_text[:50] + "..." if len(self._previous_text) > 50 else self._previous_text
1693
+ suffix_preview = self.tokenizer.decode(self._suffix_token_ids) if self._suffix_token_ids else ""
1694
+ return True
1695
+
1696
+ def _slice_cache(self, start: int, end: Optional[int], clone: bool = True):
1697
+ """slice cache
1698
+
1699
+ Args:
1700
+ start: start position
1701
+ end: end position (None means to end)
1702
+ clone: whether to clone (default True, to prevent shared memory issues)
1703
+ """
1704
+ if self.cache is None:
1705
+ return None
1706
+ if isinstance(self.cache, DynamicCache):
1707
+ # DynamicCache
1708
+ new_key_cache = [
1709
+ k[:, :, start:end, :].clone() if clone else k[:, :, start:end, :] for k in self.cache.key_cache
1710
+ ]
1711
+ new_value_cache = [
1712
+ v[:, :, start:end, :].clone() if clone else v[:, :, start:end, :] for v in self.cache.value_cache
1713
+ ]
1714
+ new_cache = DynamicCache()
1715
+ new_cache.key_cache = new_key_cache
1716
+ new_cache.value_cache = new_value_cache
1717
+ return new_cache
1718
+ else:
1719
+ # Tuple cache
1720
+ if clone:
1721
+ return tuple(
1722
+ (layer[0][:, :, start:end, :].clone(), layer[1][:, :, start:end, :].clone()) for layer in self.cache
1723
+ )
1724
+ else:
1725
+ return tuple((layer[0][:, :, start:end, :], layer[1][:, :, start:end, :]) for layer in self.cache)
1726
+
1727
+ @staticmethod
1728
+ def _get_cache_len(cache) -> int:
1729
+ if cache is None:
1730
+ return 0
1731
+ if isinstance(cache, DynamicCache):
1732
+ if len(cache.key_cache) > 0 and cache.key_cache[0].numel() > 0:
1733
+ return cache.key_cache[0].shape[2]
1734
+ return 0
1735
+
1736
+ if cache and cache[0] and cache[0][0] is not None:
1737
+ return cache[0][0].shape[2]
1738
+ return 0
1739
+
1740
+ @staticmethod
1741
+ def _concat_caches(cache1, cache2):
1742
+ if cache1 is None:
1743
+ return cache2
1744
+ if cache2 is None:
1745
+ return cache1
1746
+
1747
+ if isinstance(cache1, DynamicCache):
1748
+ new_cache = DynamicCache()
1749
+ new_cache.key_cache = [torch.cat([k1, k2], dim=2) for k1, k2 in zip(cache1.key_cache, cache2.key_cache)]
1750
+ new_cache.value_cache = [
1751
+ torch.cat([v1, v2], dim=2) for v1, v2 in zip(cache1.value_cache, cache2.value_cache)
1752
+ ]
1753
+ return new_cache
1754
+ else:
1755
+ return tuple(
1756
+ (
1757
+ torch.cat([layer1[0], layer2[0]], dim=2),
1758
+ torch.cat([layer1[1], layer2[1]], dim=2),
1759
+ )
1760
+ for layer1, layer2 in zip(cache1, cache2)
1761
+ )
1762
+
1763
+ def _reindex_rope_for_cache(self, cache, old_start: int, new_start: int, length: int):
1764
+ """reindex RoPE position for cache"""
1765
+ if cache is None or length <= 0:
1766
+ return cache
1767
+
1768
+ if isinstance(cache, DynamicCache):
1769
+ device = cache.key_cache[0].device if cache.key_cache else None
1770
+ else:
1771
+ device = cache[0][0].device if cache and cache[0] else None
1772
+
1773
+ if device is None:
1774
+ return cache
1775
+
1776
+ old_positions = torch.arange(old_start, old_start + length, device=device, dtype=torch.long)
1777
+ new_positions = torch.arange(new_start, new_start + length, device=device, dtype=torch.long)
1778
+
1779
+ rope_theta = self._get_rope_theta()
1780
+
1781
+ if isinstance(cache, DynamicCache):
1782
+ new_key_cache = []
1783
+ for k in cache.key_cache:
1784
+ new_k = realign_rotary_suffix(k, old_positions, new_positions, rope_theta, self._rope_inv_freq_cache)
1785
+ new_key_cache.append(new_k)
1786
+ cache.key_cache = new_key_cache
1787
+ return cache
1788
+ else:
1789
+ new_cache = []
1790
+ for layer in cache:
1791
+ new_k = realign_rotary_suffix(
1792
+ layer[0], old_positions, new_positions, rope_theta, self._rope_inv_freq_cache
1793
+ )
1794
+ new_cache.append((new_k, layer[1]))
1795
+ return tuple(new_cache)
1796
+
1797
+ def _update_previous(
1798
+ self,
1799
+ new_text: str,
1800
+ new_tokens: List[int],
1801
+ max_tokens: int,
1802
+ ) -> None:
1803
+ """update previous context (also update cache)
1804
+
1805
+ when first sliding window, dynamically add marker + text, subsequent sliding window append text
1806
+ when content exceeds max_tokens, truncate content (keep marker)
1807
+ rebuild cache to maintain consistency
1808
+
1809
+ Args:
1810
+ new_text: new text
1811
+ new_tokens: new token ids
1812
+ max_tokens: previous content maximum token count (without marker)
1813
+ """
1814
+ marker_len = len(self._previous_marker_token_ids)
1815
+ tokens_to_drop = 0
1816
+
1817
+ # if no new content, do not add marker, but still need to rebuild cache
1818
+ if not new_tokens and not new_text:
1819
+ # still need to rebuild cache (because a unit was deleted)
1820
+ self._rebuild_cache_with_previous(self._previous_token_ids)
1821
+ return
1822
+
1823
+ if not self._has_previous:
1824
+ # when first has actual content: add marker + text
1825
+ self._previous_text = new_text
1826
+ self._previous_token_ids = self._previous_marker_token_ids.copy() + new_tokens
1827
+ self._has_previous = True
1828
+ else:
1829
+ # subsequent sliding window: append text to previous
1830
+ self._previous_text += new_text
1831
+ self._previous_token_ids.extend(new_tokens)
1832
+
1833
+ # calculate token count of content (without marker)
1834
+ content_token_count = len(self._previous_token_ids) - marker_len
1835
+
1836
+ # check if need to truncate content (keep marker)
1837
+ if content_token_count > max_tokens:
1838
+ # truncate left content, keep marker + latest max_tokens content
1839
+ tokens_to_drop = content_token_count - max_tokens
1840
+ old_text = self._previous_text
1841
+ # keep marker + truncated content
1842
+ content_tokens = self._previous_token_ids[marker_len + tokens_to_drop :]
1843
+ self._previous_token_ids = self._previous_marker_token_ids.copy() + content_tokens
1844
+ # redecode text (only decode content part)
1845
+ try:
1846
+ self._previous_text = self.tokenizer.decode(
1847
+ content_tokens,
1848
+ skip_special_tokens=True,
1849
+ )
1850
+ except Exception as e:
1851
+ logger.warning("_update_previous: decode failed: %s", e)
1852
+
1853
+ # rebuild cache
1854
+ self._rebuild_cache_with_previous(self._previous_token_ids)
1855
+
1856
+ def _drop_unit_with_context(
1857
+ self,
1858
+ unit_id: int,
1859
+ max_previous_tokens: int,
1860
+ ) -> Tuple[bool, str, List[int]]:
1861
+ """remove specified unit and return its generated content (for context preserving)
1862
+
1863
+ process:
1864
+ 1. extract generated content of unit
1865
+ 2. remove unit from cache (without prefix+previous)
1866
+ 3. append generated content to previous
1867
+ 4. rebuild cache (in _update_previous)
1868
+
1869
+ Args:
1870
+ unit_id: unit ID to remove
1871
+ max_previous_tokens: previous maximum token count
1872
+
1873
+ Returns:
1874
+ (success, extracted_text, extracted_tokens): whether successful, extracted text and tokens
1875
+ """
1876
+ entries = [u for u in self._unit_history if u["unit_id"] == unit_id]
1877
+ if not entries:
1878
+ return False, "", []
1879
+
1880
+ # extract generated content
1881
+ extracted_text, extracted_tokens = self._extract_generated_text(entries)
1882
+
1883
+ # calculate total length
1884
+ total_len = sum(e["length"] for e in entries)
1885
+ if total_len <= 0:
1886
+ for e in entries:
1887
+ self._unit_history.remove(e)
1888
+ return False, extracted_text, extracted_tokens
1889
+
1890
+ cache_before = self.get_cache_length()
1891
+
1892
+ # remove from unit_history (record for later processing)
1893
+ for e in entries:
1894
+ self._unit_history.remove(e)
1895
+
1896
+ # note: here no longer call _drop_tokens_from_cache
1897
+ # because _update_previous will rebuild the entire cache
1898
+
1899
+ # update previous (also rebuild cache)
1900
+ self._update_previous(extracted_text, extracted_tokens, max_previous_tokens)
1901
+
1902
+ return True, extracted_text, extracted_tokens
1903
+
1904
+ def _drop_next_unit_with_context(self, max_previous_tokens: int) -> bool:
1905
+ """remove the earliest non-system unit (with context preserving)"""
1906
+ for entry in self._unit_history:
1907
+ unit_id = entry.get("unit_id")
1908
+ if unit_id is None:
1909
+ continue
1910
+ if entry.get("type") == "system":
1911
+ continue
1912
+ success, _, _ = self._drop_unit_with_context(unit_id, max_previous_tokens)
1913
+ if success:
1914
+ return True
1915
+ return False
1916
+
1917
+ def enforce_window_with_context(self) -> bool:
1918
+ """context preserving sliding window execution
1919
+
1920
+ when unit count exceeds max_units, remove the earliest unit,
1921
+ and accumulate its generated content to previous.
1922
+ Cache will be automatically rebuilt in _update_previous.
1923
+
1924
+ Returns:
1925
+ whether sliding window is executed
1926
+ """
1927
+ if not self._window_enabled:
1928
+ return False
1929
+
1930
+ cfg = self._window_config
1931
+
1932
+ if cfg.sliding_window_mode != "context":
1933
+ # if not context mode, fallback to basic sliding window
1934
+ return self.enforce_window()
1935
+
1936
+ cache_len_before = self.get_cache_length()
1937
+ units_before = len(self._unit_history)
1938
+
1939
+ # context preserving mode: only check if unit count exceeds limit
1940
+ # (previous exceeds limit in _update_previous will automatically truncate left)
1941
+ if units_before <= cfg.context_max_units:
1942
+ return False
1943
+
1944
+ # sliding window loop: remove unit until count ≤ max_units
1945
+ dropped_count = 0
1946
+ while len(self._unit_history) > cfg.context_max_units:
1947
+ if not self._drop_next_unit_with_context(cfg.context_previous_max_tokens):
1948
+ break
1949
+
1950
+ dropped_count += 1
1951
+
1952
+ cache_len_after = self.get_cache_length()
1953
+
1954
+ if dropped_count > 0:
1955
+ # update statistics counter
1956
+ self._sliding_event_count += 1
1957
+ self._total_dropped_tokens += cache_len_before - cache_len_after
1958
+ self._total_dropped_units += dropped_count
1959
+
1960
+ # consistency check
1961
+ expected = self._system_preserve_length + sum(u["length"] for u in self._unit_history)
1962
+
1963
+ return dropped_count > 0
1964
+
1965
+ def get_previous_context(self) -> Tuple[str, List[int]]:
1966
+ """get current accumulated previous context
1967
+
1968
+ Returns:
1969
+ (previous_text, previous_token_ids): current accumulated text and token ids
1970
+ """
1971
+ return self._previous_text, self._previous_token_ids.copy()
1972
+
1973
+ def get_window_stats(self) -> Dict[str, Any]:
1974
+ """get sliding window statistics"""
1975
+ unit_lengths = [u["length"] for u in self._unit_history]
1976
+ return {
1977
+ "cache_length": self.get_cache_length(),
1978
+ "unit_count": len(self._unit_history),
1979
+ "unit_lengths": unit_lengths,
1980
+ "unit_total_length": sum(unit_lengths),
1981
+ "system_preserve_length": self._system_preserve_length,
1982
+ "position_offset": self._position_offset,
1983
+ "window_enabled": self._window_enabled,
1984
+ "total_generated_tokens": self.get_total_generated_tokens(),
1985
+ "pending_unit_id": self._pending_unit_id,
1986
+ "next_unit_id": self._next_unit_id,
1987
+ "config": {
1988
+ "sliding_window_mode": self._window_config.sliding_window_mode,
1989
+ "basic_window_high_tokens": self._window_config.basic_window_high_tokens,
1990
+ "basic_window_low_tokens": self._window_config.basic_window_low_tokens,
1991
+ "context_previous_max_tokens": self._window_config.context_previous_max_tokens,
1992
+ "context_max_units": self._window_config.context_max_units,
1993
+ },
1994
+ # context preserving related
1995
+ "preserve_prefix_length": self._preserve_prefix_length,
1996
+ "previous_content_length": self._previous_content_length,
1997
+ "suffix_token_count": len(self._suffix_token_ids),
1998
+ "previous_text_length": len(self._previous_text),
1999
+ "previous_token_count": len(self._previous_token_ids),
2000
+ "has_system_template": self._system_prompt_template is not None,
2001
+ }
2002
+
2003
+ def _verify_consistency(self) -> bool:
2004
+ """verify unit history and cache length consistency"""
2005
+ expected = self._system_preserve_length + sum(u["length"] for u in self._unit_history)
2006
+ actual = self.get_cache_length()
2007
+ return expected == actual
2008
+
2009
+ def print_verification_summary(self) -> Dict[str, Any]:
2010
+ """print verification summary (for comparing off/basic/context mode)
2011
+
2012
+ Returns:
2013
+ dictionary containing key verification data
2014
+ """
2015
+ cfg = self._window_config
2016
+
2017
+ # collect all generated text
2018
+ all_generated_text = []
2019
+ all_generated_tokens = []
2020
+ for u in self._unit_history:
2021
+ if not u.get("is_listen", False):
2022
+ gen_text = u.get("generated_text", "")
2023
+ gen_tokens = u.get("generated_tokens", [])
2024
+ if gen_text:
2025
+ all_generated_text.append(gen_text)
2026
+ if gen_tokens:
2027
+ all_generated_tokens.extend(gen_tokens)
2028
+
2029
+ combined_text = "".join(all_generated_text)
2030
+
2031
+ summary = {
2032
+ "mode": cfg.sliding_window_mode,
2033
+ "final_cache_length": self.get_cache_length(),
2034
+ "final_unit_count": len(self._unit_history),
2035
+ "sliding_event_count": self._sliding_event_count,
2036
+ "total_dropped_tokens": self._total_dropped_tokens,
2037
+ "total_dropped_units": self._total_dropped_units,
2038
+ "total_generated_tokens": len(all_generated_tokens),
2039
+ "generated_text": combined_text,
2040
+ "previous_text": self._previous_text,
2041
+ "previous_token_count": len(self._previous_token_ids),
2042
+ "position_offset": self._position_offset,
2043
+ "system_preserve_length": self._system_preserve_length,
2044
+ }
2045
+
2046
+ return summary
2047
+
2048
+ def set_window_config(self, config: DuplexWindowConfig) -> None:
2049
+ """set sliding window configuration"""
2050
+ self._window_config = config
2051
+
2052
+ def set_window_enabled(self, enabled: bool) -> None:
2053
+ """enable/disable sliding window"""
2054
+ old_enabled = self._window_enabled
2055
+ self._window_enabled = enabled
2056
+
2057
+ def get_context(self):
2058
+ return self.context
2059
+
2060
+ def embed_token(self, tid):
2061
+ if isinstance(tid, int):
2062
+ tid = torch.tensor([tid], device=self.m.device)
2063
+ return self.m.model.embed_tokens(tid)
2064
+
2065
+ def embed_tokens(self, token_ids: List[int]) -> torch.Tensor:
2066
+ """batch embed multiple tokens
2067
+
2068
+ Args:
2069
+ token_ids: list of token ids
2070
+
2071
+ Returns:
2072
+ embeddings tensor [L, H]
2073
+ """
2074
+ if not token_ids:
2075
+ return torch.empty(0, self.m.config.hidden_size, device=self.m.device)
2076
+ tids = torch.tensor(token_ids, device=self.m.device)
2077
+ return self.m.model.embed_tokens(tids)
2078
+
2079
+ @torch.no_grad()
2080
+ def feed(self, embeds: torch.Tensor, return_logits: bool = False):
2081
+ """
2082
+ embeds : [L, H] —— new embedding sequence fed into model at once
2083
+ """
2084
+ L = embeds.size(0)
2085
+ device = embeds.device
2086
+
2087
+ past_len = self.get_cache_length()
2088
+ pos_ids = torch.arange(past_len, past_len + L, device=device).unsqueeze(0) # [1, L]
2089
+
2090
+ out = self.m(
2091
+ inputs_embeds=embeds.unsqueeze(0), # [1, L, H]
2092
+ position_ids=pos_ids,
2093
+ past_key_values=self.cache,
2094
+ # use_cache = True,
2095
+ return_dict=True,
2096
+ output_hidden_states=True,
2097
+ # attention_mask=attention_mask
2098
+ )
2099
+ self.cache = out.past_key_values
2100
+
2101
+ if return_logits:
2102
+ logits = self.m.lm_head(out.hidden_states[-1])[:, -1] # [1, vocab]
2103
+ return logits, out.hidden_states[-1]
2104
+
2105
+ @torch.no_grad()
2106
+ def decode(
2107
+ self,
2108
+ logits,
2109
+ mode: Literal["sampling", "greedy"] = "sampling",
2110
+ temperature=0.7,
2111
+ top_k=20,
2112
+ top_p=0.8,
2113
+ listen_top_k=None,
2114
+ listen_prob_scale=1.0,
2115
+ text_repetition_penalty=1.05,
2116
+ text_repetition_window_size=512,
2117
+ ):
2118
+ """
2119
+ Args:
2120
+ logits:
2121
+ mode: sampling or greedy
2122
+ temperature:
2123
+ top_k:
2124
+ top_p:
2125
+ listen_top_k: force listen_id to be in top-k to keep
2126
+ listen_prob_scale: multiply listen_id probability by a weight (<1 means decrease, >1 means increase)
2127
+ text_repetition_penalty: repetition penalty coefficient, >1.0 means decrease repetition, <1.0 means increase repetition
2128
+ text_repetition_window_size: repetition penalty window size
2129
+
2130
+ Sampling strategy:
2131
+ 1. first sample all tokens with original logits (apply temperature)
2132
+ 2. if sampled chunk_eos, return directly (keep the original model's decision of when to stop)
2133
+ 3. if not sampled chunk_eos, mask it (set logit to -inf), continue sampling text tokens
2134
+ 4. apply repetition penalty, top-k, top-p, etc. to the text tokens for the final sampling
2135
+ """
2136
+
2137
+ logits = logits.clone()
2138
+
2139
+ # 0. independently check chunk_eos before sampling
2140
+ eos_id = self.chunk_eos_id
2141
+
2142
+ with torch.no_grad():
2143
+ if mode == "greedy":
2144
+ sampled_token = torch.argmax(logits[0]).item()
2145
+ else:
2146
+ original_probs = F.softmax(logits[0], dim=-1)
2147
+ sampled_token = torch.multinomial(original_probs, num_samples=1).item()
2148
+
2149
+ # if sampled chunk_eos, return directly
2150
+ if sampled_token == eos_id:
2151
+ next_token_id = torch.tensor([eos_id], device=logits.device)
2152
+ next_token_str = self.tokenizer.decode(next_token_id)
2153
+
2154
+ return next_token_id
2155
+
2156
+ # if not sampled chunk_eos, set its logit to -inf
2157
+ if self.forbidden_token_ids:
2158
+ logits[:, self.forbidden_token_ids] = float("-inf")
2159
+
2160
+ # 1. apply repetition penalty
2161
+ if text_repetition_penalty != 1.0 and len(self.generated_tokens) > 0:
2162
+ # get recent tokens (within window size) considering special tokens and normal tokens
2163
+ recent_tokens = self.generated_tokens[-text_repetition_window_size:]
2164
+
2165
+ # make it unique
2166
+ recent_tokens = list(set(recent_tokens))
2167
+
2168
+ # apply penalty to repeated tokens
2169
+ for token_id in recent_tokens:
2170
+ if token_id < logits.size(-1): # ensure token_id is in vocabulary range
2171
+ if text_repetition_penalty > 1.0:
2172
+ # penalize repetition: decrease logits
2173
+ logits[0, token_id] /= text_repetition_penalty
2174
+ else:
2175
+ # encourage repetition: increase logits
2176
+ logits[0, token_id] *= 1.0 / text_repetition_penalty
2177
+
2178
+ if listen_prob_scale != 1.0: # modify listen token logit separately
2179
+ logits[0, self.listen_id] *= listen_prob_scale
2180
+
2181
+ listen_rank = (logits[0] > logits[0, self.listen_id]).sum().item()
2182
+
2183
+ if listen_top_k is not None and listen_rank < listen_top_k: # listen_id is in top-k, return directly
2184
+ next_token_id = torch.tensor([self.listen_id], device=logits.device)
2185
+ next_token_str = self.tokenizer.decode(next_token_id)
2186
+
2187
+ if next_token_str == "<|listen|>":
2188
+ self.context += " "
2189
+ else:
2190
+ self.context += next_token_str
2191
+
2192
+ return next_token_id
2193
+
2194
+ if mode == "greedy":
2195
+ next_token_id = torch.argmax(logits, dim=-1)
2196
+ elif mode == "sampling":
2197
+ logits = logits / temperature
2198
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
2199
+ probs = F.softmax(logits, dim=-1)
2200
+ next_token_id = torch.multinomial(probs, num_samples=1).squeeze(1)
2201
+ else:
2202
+ raise ValueError(f"Unsupported decode mode: {mode}")
2203
+
2204
+ if next_token_id.item() not in self.special_token_ids:
2205
+ self.generated_tokens.append(next_token_id.item())
2206
+ else:
2207
+ self.generated_special_tokens.append(next_token_id.item())
2208
+
2209
+ return next_token_id
2210
+
2211
+
2212
+ def _download_url_to_tempfile(url: str, suffix: str = "", timeout: int = 60) -> str:
2213
+ """
2214
+ Download a URL to a temporary file and return the path.
2215
+
2216
+ Args:
2217
+ url: HTTP/HTTPS URL to download
2218
+ suffix: File suffix (e.g., ".jpg", ".wav", ".mp4")
2219
+ timeout: Download timeout in seconds
2220
+
2221
+ Returns:
2222
+ Path to the downloaded temporary file
2223
+ """
2224
+ import tempfile
2225
+
2226
+ import requests
2227
+
2228
+ response = requests.get(url, timeout=timeout)
2229
+ response.raise_for_status()
2230
+
2231
+ with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as f:
2232
+ f.write(response.content)
2233
+ return f.name
2234
+
2235
+
2236
+ def _is_url(path: str) -> bool:
2237
+ return path.startswith(("http://", "https://"))
2238
+
2239
+
2240
+ def normalize_content_item(item) -> Union[str, Any, List[Any]]:
2241
+ """Normalize structured content item to native format.
2242
+
2243
+ Supports:
2244
+ - Native format: str, PIL.Image, np.ndarray (pass through)
2245
+ - OpenAI structured format:
2246
+ - {"type": "text", "text": "..."} -> str
2247
+ - {"type": "image_url", "image_url": {"url": "..."}} -> PIL.Image
2248
+ - {"type": "audio_url", "audio_url": {"url": "..."}} -> np.ndarray
2249
+ - {"type": "video_url", "video_url": {"url": "...", ...}} -> List[Image, ndarray, ...]
2250
+
2251
+ URL formats supported:
2252
+ - Local file path: "/path/to/file.jpg"
2253
+ - HTTP/HTTPS URL: "https://example.com/image.jpg"
2254
+
2255
+ Args:
2256
+ item: Content item to normalize
2257
+
2258
+ Returns:
2259
+ Normalized item. For video_url, returns a tuple ("__video_contents__", list)
2260
+ that will be flattened by normalize_content().
2261
+
2262
+ Raises:
2263
+ ValueError: If content type is unknown or unsupported
2264
+ """
2265
+ import os
2266
+
2267
+ import numpy as np
2268
+ from PIL import Image
2269
+
2270
+ if isinstance(item, str):
2271
+ return item
2272
+ if isinstance(item, Image.Image):
2273
+ return item
2274
+ if isinstance(item, np.ndarray):
2275
+ return item
2276
+
2277
+ if isinstance(item, dict):
2278
+ item_type = item.get("type")
2279
+
2280
+ if item_type == "text":
2281
+ return item.get("text", "")
2282
+
2283
+ elif item_type == "image_url":
2284
+ image_url_obj = item.get("image_url", {})
2285
+ url = image_url_obj.get("url", "") if isinstance(image_url_obj, dict) else image_url_obj
2286
+
2287
+ if _is_url(url):
2288
+ # Download to temp file
2289
+ temp_path = _download_url_to_tempfile(url, suffix=".jpg", timeout=30)
2290
+ img = Image.open(temp_path)
2291
+ os.unlink(temp_path)
2292
+ return img
2293
+ else:
2294
+ return Image.open(url)
2295
+ elif item_type == "audio_url":
2296
+ import librosa
2297
+
2298
+ audio_url_obj = item.get("audio_url", {})
2299
+ url = audio_url_obj.get("url", "") if isinstance(audio_url_obj, dict) else audio_url_obj
2300
+
2301
+ if _is_url(url):
2302
+ # Download to temp file
2303
+ temp_path = _download_url_to_tempfile(url, suffix=".wav", timeout=60)
2304
+ audio_np, _ = librosa.load(temp_path, sr=16000, mono=True)
2305
+ os.unlink(temp_path)
2306
+ return audio_np
2307
+ else:
2308
+ audio_np, _ = librosa.load(url, sr=16000, mono=True)
2309
+ return audio_np
2310
+ elif item_type == "video_url":
2311
+ # Video processing - returns a LIST of items (frames + audio segments)
2312
+ # Note: Unlike image_url/audio_url which return single items,
2313
+ # video_url returns a list that will be flattened into the content
2314
+ from minicpmo.utils import get_video_frame_audio_segments
2315
+
2316
+ video_url_obj = item.get("video_url", {})
2317
+ if isinstance(video_url_obj, dict):
2318
+ video_url = video_url_obj.get("url", "")
2319
+ # Get optional parameters from video_url object (OpenAI style)
2320
+ stack_frames = video_url_obj.get("stack_frames", 1)
2321
+ use_ffmpeg = video_url_obj.get("use_ffmpeg", False)
2322
+ use_audio = video_url_obj.get("use_audio", True)
2323
+ else:
2324
+ video_url = video_url_obj
2325
+ stack_frames = 1
2326
+ use_ffmpeg = False
2327
+ use_audio = True
2328
+
2329
+ # Handle HTTP/HTTPS URL - download to temp file
2330
+ temp_video_path = None
2331
+ if _is_url(video_url):
2332
+ temp_video_path = _download_url_to_tempfile(video_url, suffix=".mp4", timeout=120)
2333
+ video_path = temp_video_path
2334
+ else:
2335
+ video_path = video_url
2336
+
2337
+ # Extract frames and audio segments
2338
+ video_frames, audio_segments, stacked_frames = get_video_frame_audio_segments(
2339
+ video_path,
2340
+ stack_frames=stack_frames,
2341
+ use_ffmpeg=use_ffmpeg,
2342
+ use_audio=use_audio
2343
+ )
2344
+
2345
+ # Clean up temp file if downloaded
2346
+ if temp_video_path is not None:
2347
+ os.unlink(temp_video_path)
2348
+
2349
+ # Build omni_contents (interleaved frames and audio, or frames only)
2350
+ omni_contents = []
2351
+ for i in range(len(video_frames)):
2352
+ omni_contents.append(video_frames[i])
2353
+ if use_audio and audio_segments is not None:
2354
+ omni_contents.append(audio_segments[i])
2355
+ if stacked_frames is not None and i < len(stacked_frames) and stacked_frames[i] is not None:
2356
+ omni_contents.append(stacked_frames[i])
2357
+
2358
+ # Return as a special marker to be flattened later
2359
+ return "__video_contents__", omni_contents
2360
+ else:
2361
+ raise ValueError(f"Unknown content type: {item_type}")
2362
+
2363
+ raise ValueError(f"Cannot normalize content item of type: {type(item)}")
2364
+
2365
+
2366
+ def normalize_content(content) -> list:
2367
+ """Normalize message content to list of native items.
2368
+
2369
+ Input formats:
2370
+ - str: "hello" -> ["hello"]
2371
+ - list of native items: [str, Image, np.ndarray] -> pass through with normalization
2372
+ - list of structured items: [{"type": "text", ...}] -> normalize each
2373
+ - video type: automatically expanded to omni_contents
2374
+ - mixed: works too
2375
+
2376
+ Args:
2377
+ content: Message content in any supported format
2378
+
2379
+ Returns:
2380
+ List of native items (str, PIL.Image, np.ndarray)
2381
+
2382
+ Examples:
2383
+ >>> normalize_content("hello")
2384
+ ["hello"]
2385
+
2386
+ >>> normalize_content([{"type": "text", "text": "hi"}])
2387
+ ["hi"]
2388
+
2389
+ >>> normalize_content([{"type": "video", "video": "/path/to/video.mp4"}])
2390
+ [<PIL.Image>, <np.ndarray>, <PIL.Image>, <np.ndarray>, ...]
2391
+ """
2392
+ import numpy as np
2393
+ from PIL import Image
2394
+
2395
+ if isinstance(content, str):
2396
+ return [content]
2397
+
2398
+ if isinstance(content, list):
2399
+ result = []
2400
+ for item in content:
2401
+ normalized = normalize_content_item(item)
2402
+ # Handle video content (returns tuple with marker)
2403
+ if isinstance(normalized, tuple) and len(normalized) == 2 and normalized[0] == "__video_contents__":
2404
+ # Flatten video contents into result
2405
+ result.extend(normalized[1])
2406
+ else:
2407
+ result.append(normalized)
2408
+ return result
2409
+
2410
+ # Single non-list item (Image or np.ndarray)
2411
+ if isinstance(content, (Image.Image, np.ndarray)):
2412
+ return [content]
2413
+
2414
+ normalized = normalize_content_item(content)
2415
+ if isinstance(normalized, tuple) and len(normalized) == 2 and normalized[0] == "__video_contents__":
2416
+ return normalized[1]
2417
+ return [normalized]
vocab.json ADDED
The diff for this file is too large to render. See raw diff