Files changed (2) hide show
  1. README.md +192 -226
  2. modeling_minicpm.py +1026 -31
README.md CHANGED
@@ -1,226 +1,192 @@
1
- ---
2
- license: apache-2.0
3
- language:
4
- - zh
5
- - en
6
- pipeline_tag: text-generation
7
- library_name: transformers
8
- ---
9
- <div align="center">
10
- <img src="https://github.com/OpenBMB/MiniCPM/blob/main/assets/minicpm_logo.png?raw=true" width="500em" ></img>
11
- </div>
12
-
13
- <p align="center">
14
- <a href="https://github.com/OpenBMB/MiniCPM/" target="_blank">GitHub Repo</a> |
15
- <a href="https://github.com/OpenBMB/MiniCPM/tree/main/report/MiniCPM_4_Technical_Report.pdf" target="_blank">Technical Report</a>
16
- </p>
17
- <p align="center">
18
- 👋 Join us on <a href="https://discord.gg/3cGQn9b3YM" target="_blank">Discord</a> and <a href="https://github.com/OpenBMB/MiniCPM/blob/main/assets/wechat.jpg" target="_blank">WeChat</a>
19
- </p>
20
-
21
- ## What's New
22
- - [2025.06.06] **MiniCPM4** series are released! This model achieves ultimate efficiency improvements while maintaining optimal performance at the same scale! It can achieve over 5x generation acceleration on typical end-side chips! You can find technical report [here](https://github.com/OpenBMB/MiniCPM/tree/main/report/MiniCPM_4_Technical_Report.pdf).🔥🔥🔥
23
-
24
- ## MiniCPM4 Series
25
- MiniCPM4 series are highly efficient large language models (LLMs) designed explicitly for end-side devices, which achieves this efficiency through systematic innovation in four key dimensions: model architecture, training data, training algorithms, and inference systems.
26
- - [MiniCPM4-8B](https://huggingface.co/openbmb/MiniCPM4-8B): The flagship of MiniCPM4, with 8B parameters, trained on 8T tokens.
27
- - [MiniCPM4-0.5B](https://huggingface.co/openbmb/MiniCPM4-0.5B): The small version of MiniCPM4, with 0.5B parameters, trained on 1T tokens. (**<-- you are here**)
28
- - [MiniCPM4-8B-Eagle-FRSpec](https://huggingface.co/openbmb/MiniCPM4-8B-Eagle-FRSpec): Eagle head for FRSpec, accelerating speculative inference for MiniCPM4-8B.
29
- - [MiniCPM4-8B-Eagle-FRSpec-QAT-cpmcu](https://huggingface.co/openbmb/MiniCPM4-8B-Eagle-FRSpec-QAT-cpmcu): Eagle head trained with QAT for FRSpec, efficiently integrate speculation and quantization to achieve ultra acceleration for MiniCPM4-8B.
30
- - [MiniCPM4-8B-Eagle-vLLM](https://huggingface.co/openbmb/MiniCPM4-8B-Eagle-vLLM): Eagle head in vLLM format, accelerating speculative inference for MiniCPM4-8B.
31
- - [MiniCPM4-8B-marlin-Eagle-vLLM](https://huggingface.co/openbmb/MiniCPM4-8B-marlin-Eagle-vLLM): Quantized Eagle head for vLLM format, accelerating speculative inference for MiniCPM4-8B.
32
- - [BitCPM4-0.5B](https://huggingface.co/openbmb/BitCPM4-0.5B): Extreme ternary quantization applied to MiniCPM4-0.5B compresses model parameters into ternary values, achieving a 90% reduction in bit width.
33
- - [BitCPM4-1B](https://huggingface.co/openbmb/BitCPM4-1B): Extreme ternary quantization applied to MiniCPM3-1B compresses model parameters into ternary values, achieving a 90% reduction in bit width.
34
- - [MiniCPM4-Survey](https://huggingface.co/openbmb/MiniCPM4-Survey): Based on MiniCPM4-8B, accepts users' quiries as input and autonomously generate trustworthy, long-form survey papers.
35
- - [MiniCPM4-MCP](https://huggingface.co/openbmb/MiniCPM4-MCP): Based on MiniCPM4-8B, accepts users' queries and available MCP tools as input and autonomously calls relevant MCP tools to satisfy users' requirements.
36
-
37
- ## Introduction
38
- MiniCPM 4 is an extremely efficient edge-side large model that has undergone efficient optimization across four dimensions: model architecture, learning algorithms, training data, and inference systems, achieving ultimate efficiency improvements.
39
-
40
- - 🏗️ **Efficient Model Architecture:**
41
- - InfLLM v2 -- Trainable Sparse Attention Mechanism: Adopts a trainable sparse attention mechanism architecture where each token only needs to compute relevance with less than 5% of tokens in 128K long text processing, significantly reducing computational overhead for long texts
42
-
43
- - 🧠 **Efficient Learning Algorithms:**
44
- - Model Wind Tunnel 2.0 -- Efficient Predictable Scaling: Introduces scaling prediction methods for performance of downstream tasks, enabling more precise model training configuration search
45
- - BitCPM -- Ultimate Ternary Quantization: Compresses model parameter bit-width to 3 values, achieving 90% extreme model bit-width reduction
46
- - Efficient Training Engineering Optimization: Adopts FP8 low-precision computing technology combined with Multi-token Prediction training strategy
47
-
48
- - 📚 **High-Quality Training Data:**
49
- - UltraClean -- High-quality Pre-training Data Filtering and Generation: Builds iterative data cleaning strategies based on efficient data verification, open-sourcing high-quality Chinese and English pre-training dataset [UltraFinweb](https://huggingface.co/datasets/openbmb/Ultra-FineWeb)
50
- - UltraChat v2 -- High-quality Supervised Fine-tuning Data Generation: Constructs large-scale high-quality supervised fine-tuning datasets covering multiple dimensions including knowledge-intensive data, reasoning-intensive data, instruction-following data, long text understanding data, and tool calling data
51
-
52
- - ⚡ **Efficient Inference System:**
53
- - CPM.cu -- Lightweight and Efficient CUDA Inference Framework: Integrates sparse attention, model quantization, and speculative sampling to achieve efficient prefilling and decoding
54
- - ArkInfer -- Cross-platform Deployment System: Supports efficient deployment across multiple backend environments, providing flexible cross-platform adaptation capabilities
55
-
56
- ## Usage
57
- ### Inference with Transformers
58
- ```python
59
- from transformers import AutoModelForCausalLM, AutoTokenizer
60
- import torch
61
- torch.manual_seed(0)
62
-
63
- path = 'openbmb/MiniCPM4-0.5B'
64
- device = "cuda"
65
- tokenizer = AutoTokenizer.from_pretrained(path)
66
- model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16, device_map=device, trust_remote_code=True)
67
-
68
- # User can directly use the chat interface
69
- responds, history = model.chat(tokenizer, "Write an article about Artificial Intelligence.", temperature=0.7, top_p=0.7)
70
- print(responds)
71
-
72
- # User can also use the generate interface
73
- # messages = [
74
- # {"role": "user", "content": "Write an article about Artificial Intelligence."},
75
- # ]
76
- # prompt_text = tokenizer.apply_chat_template(
77
- # messages,
78
- # tokenize=False,
79
- # add_generation_prompt=True,
80
- # )
81
- # model_inputs = tokenizer([prompt_text], return_tensors="pt").to(device)
82
-
83
- # model_outputs = model.generate(
84
- # **model_inputs,
85
- # max_new_tokens=1024,
86
- # top_p=0.7,
87
- # temperature=0.7
88
- # )
89
- # output_token_ids = [
90
- # model_outputs[i][len(model_inputs[i]):] for i in range(len(model_inputs['input_ids']))
91
- # ]
92
-
93
- # responses = tokenizer.batch_decode(output_token_ids, skip_special_tokens=True)[0]
94
- # print(responses)
95
- ```
96
-
97
- ### Inference with [SGLang](https://github.com/sgl-project/sglang)
98
-
99
- For now, you need to install our forked version of SGLang.
100
- ```bash
101
- git clone -b openbmb https://github.com/OpenBMB/sglang.git
102
- cd sglang
103
-
104
- pip install --upgrade pip
105
- pip install -e "python[all]"
106
- ```
107
-
108
- You can start the inference server by running the following command:
109
- ```bash
110
- python -m sglang.launch_server --model openbmb/MiniCPM4-0.5B --trust-remote-code --port 30000 --chat-template chatml
111
- ```
112
-
113
- Then you can use the chat interface by running the following command:
114
- ```python
115
- import openai
116
-
117
- client = openai.Client(base_url=f"http://localhost:30000/v1", api_key="None")
118
-
119
- response = client.chat.completions.create(
120
- model="openbmb/MiniCPM4-0.5B",
121
- messages=[
122
- {"role": "user", "content": "Write an article about Artificial Intelligence."},
123
- ],
124
- temperature=0.7,
125
- max_tokens=1024,
126
- )
127
-
128
- print(response.choices[0].message.content)
129
- ```
130
-
131
- ### Inference with [vLLM](https://github.com/vllm-project/vllm)
132
- For now, you need to install the latest version of vLLM.
133
- ```
134
- pip install -U vllm \
135
- --pre \
136
- --extra-index-url https://wheels.vllm.ai/nightly
137
- ```
138
-
139
- Then you can inference MiniCPM4-0.5B with vLLM:
140
- ```python
141
- from transformers import AutoTokenizer
142
- from vllm import LLM, SamplingParams
143
-
144
- model_name = "openbmb/MiniCPM4-0.5B"
145
- prompt = [{"role": "user", "content": "Please recommend 5 tourist attractions in Beijing. "}]
146
-
147
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
148
- input_text = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)
149
-
150
- llm = LLM(
151
- model=model_name,
152
- trust_remote_code=True,
153
- max_num_batched_tokens=32768,
154
- dtype="bfloat16",
155
- gpu_memory_utilization=0.8,
156
- )
157
- sampling_params = SamplingParams(top_p=0.7, temperature=0.7, max_tokens=1024, repetition_penalty=1.02)
158
-
159
- outputs = llm.generate(prompts=input_text, sampling_params=sampling_params)
160
-
161
- print(outputs[0].outputs[0].text)
162
- ```
163
-
164
- Also, you can start the inference server by running the following command:
165
- > **Note**: In vLLM's chat API, `add_special_tokens` is `False` by default. This means important special tokens—such as the beginning-of-sequence (BOS) token—will not be added automatically. To ensure the input prompt is correctly formatted for the model, you should explicitly set `extra_body={"add_special_tokens": True}`.
166
-
167
- ```bash
168
- vllm serve openbmb/MiniCPM4-0.5B
169
- ```
170
-
171
- Then you can use the chat interface by running the following code:
172
-
173
- ```python
174
- import openai
175
-
176
- client = openai.Client(base_url="http://localhost:8000/v1", api_key="EMPTY")
177
-
178
- response = client.chat.completions.create(
179
- model="openbmb/MiniCPM4-0.5B",
180
- messages=[
181
- {"role": "user", "content": "Write an article about Artificial Intelligence."},
182
- ],
183
- temperature=0.7,
184
- max_tokens=1024,
185
- extra_body=dict(add_special_tokens=True), # Ensures special tokens are added for chat template
186
-
187
- )
188
-
189
- print(response.choices[0].message.content)
190
- ```
191
-
192
-
193
- ## Evaluation Results
194
- On two typical end-side chips, Jetson AGX Orin and RTX 4090, MiniCPM4 demonstrates significantly faster processing speed compared to similar-size models in long text processing tasks. As text length increases, MiniCPM4's efficiency advantage becomes more pronounced. On the Jetson AGX Orin platform, compared to Qwen3-8B, MiniCPM4 achieves approximately 7x decoding speed improvement.
195
-
196
- ![benchmark](https://github.com/OpenBMB/MiniCPM/blob/main/assets/minicpm4/efficiency.png?raw=true)
197
-
198
- #### Comprehensive Evaluation
199
- MiniCPM4 launches end-side versions with 8B and 0.5B parameter scales, both achieving best-in-class performance in their respective categories.
200
-
201
- ![benchmark](https://github.com/OpenBMB/MiniCPM/blob/main/assets/minicpm4/benchmark.png?raw=true)
202
-
203
- #### Long Text Evaluation
204
- MiniCPM4 is pre-trained on 32K long texts and achieves length extension through YaRN technology. In the 128K long text needle-in-a-haystack task, MiniCPM4 demonstrates outstanding performance.
205
-
206
- ![long-niah](https://github.com/OpenBMB/MiniCPM/blob/main/assets/minicpm4/128k-niah.png?raw=true)
207
-
208
- ## Statement
209
- - As a language model, MiniCPM generates content by learning from a vast amount of text.
210
- - However, it does not possess the ability to comprehend or express personal opinions or value judgments.
211
- - Any content generated by MiniCPM does not represent the viewpoints or positions of the model developers.
212
- - Therefore, when using content generated by MiniCPM, users should take full responsibility for evaluating and verifying it on their own.
213
-
214
- ## LICENSE
215
- - This repository and MiniCPM models are released under the [Apache-2.0](https://github.com/OpenBMB/MiniCPM/blob/main/LICENSE) License.
216
-
217
- ## Citation
218
- - Please cite our [paper](https://github.com/OpenBMB/MiniCPM/tree/main/report/MiniCPM_4_Technical_Report.pdf) if you find our work valuable.
219
-
220
- ```bibtex
221
- @article{minicpm4,
222
- title={{MiniCPM4}: Ultra-Efficient LLMs on End Devices},
223
- author={MiniCPM Team},
224
- year={2025}
225
- }
226
- ```
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - zh
5
+ - en
6
+ pipeline_tag: text-generation
7
+ library_name: transformers
8
+ ---
9
+ <div align="center">
10
+ <img src="https://github.com/OpenBMB/MiniCPM/blob/main/assets/minicpm_logo.png?raw=true" width="500em" ></img>
11
+ </div>
12
+
13
+ <p align="center">
14
+ <a href="https://github.com/OpenBMB/MiniCPM/" target="_blank">GitHub Repo</a> |
15
+ <a href="https://github.com/OpenBMB/MiniCPM/tree/main/report/MiniCPM_4_Technical_Report.pdf" target="_blank">Technical Report</a>
16
+ </p>
17
+ <p align="center">
18
+ 👋 Join us on <a href="https://discord.gg/3cGQn9b3YM" target="_blank">Discord</a> and <a href="https://github.com/OpenBMB/MiniCPM/blob/main/assets/wechat.jpg" target="_blank">WeChat</a>
19
+ </p>
20
+
21
+ ## What's New
22
+ - [2025.06.06] **MiniCPM4** series are released! This model achieves ultimate efficiency improvements while maintaining optimal performance at the same scale! It can achieve over 5x generation acceleration on typical end-side chips! You can find technical report [here](https://github.com/OpenBMB/MiniCPM/tree/main/report/MiniCPM_4_Technical_Report.pdf).🔥🔥🔥
23
+
24
+ ## MiniCPM4 Series
25
+ MiniCPM4 series are highly efficient large language models (LLMs) designed explicitly for end-side devices, which achieves this efficiency through systematic innovation in four key dimensions: model architecture, training data, training algorithms, and inference systems.
26
+ - [MiniCPM4-8B](https://huggingface.co/openbmb/MiniCPM4-8B): The flagship of MiniCPM4, with 8B parameters, trained on 8T tokens.
27
+ - [MiniCPM4-0.5B](https://huggingface.co/openbmb/MiniCPM4-0.5B): The small version of MiniCPM4, with 0.5B parameters, trained on 1T tokens. (**<-- you are here**)
28
+ - [MiniCPM4-8B-Eagle-FRSpec](https://huggingface.co/openbmb/MiniCPM4-8B-Eagle-FRSpec): Eagle head for FRSpec, accelerating speculative inference for MiniCPM4-8B.
29
+ - [MiniCPM4-8B-Eagle-FRSpec-QAT-cpmcu](https://huggingface.co/openbmb/MiniCPM4-8B-Eagle-FRSpec-QAT-cpmcu): Eagle head trained with QAT for FRSpec, efficiently integrate speculation and quantization to achieve ultra acceleration for MiniCPM4-8B.
30
+ - [MiniCPM4-8B-Eagle-vLLM](https://huggingface.co/openbmb/MiniCPM4-8B-Eagle-vLLM): Eagle head in vLLM format, accelerating speculative inference for MiniCPM4-8B.
31
+ - [MiniCPM4-8B-marlin-Eagle-vLLM](https://huggingface.co/openbmb/MiniCPM4-8B-marlin-Eagle-vLLM): Quantized Eagle head for vLLM format, accelerating speculative inference for MiniCPM4-8B.
32
+ - [BitCPM4-0.5B](https://huggingface.co/openbmb/BitCPM4-0.5B): Extreme ternary quantization applied to MiniCPM4-0.5B compresses model parameters into ternary values, achieving a 90% reduction in bit width.
33
+ - [BitCPM4-1B](https://huggingface.co/openbmb/BitCPM4-1B): Extreme ternary quantization applied to MiniCPM3-1B compresses model parameters into ternary values, achieving a 90% reduction in bit width.
34
+ - [MiniCPM4-Survey](https://huggingface.co/openbmb/MiniCPM4-Survey): Based on MiniCPM4-8B, accepts users' quiries as input and autonomously generate trustworthy, long-form survey papers.
35
+ - [MiniCPM4-MCP](https://huggingface.co/openbmb/MiniCPM4-MCP): Based on MiniCPM4-8B, accepts users' queries and available MCP tools as input and autonomously calls relevant MCP tools to satisfy users' requirements.
36
+
37
+ ## Introduction
38
+ MiniCPM 4 is an extremely efficient edge-side large model that has undergone efficient optimization across four dimensions: model architecture, learning algorithms, training data, and inference systems, achieving ultimate efficiency improvements.
39
+
40
+ - 🏗️ **Efficient Model Architecture:**
41
+ - InfLLM v2 -- Trainable Sparse Attention Mechanism: Adopts a trainable sparse attention mechanism architecture where each token only needs to compute relevance with less than 5% of tokens in 128K long text processing, significantly reducing computational overhead for long texts
42
+
43
+ - 🧠 **Efficient Learning Algorithms:**
44
+ - Model Wind Tunnel 2.0 -- Efficient Predictable Scaling: Introduces scaling prediction methods for performance of downstream tasks, enabling more precise model training configuration search
45
+ - BitCPM -- Ultimate Ternary Quantization: Compresses model parameter bit-width to 3 values, achieving 90% extreme model bit-width reduction
46
+ - Efficient Training Engineering Optimization: Adopts FP8 low-precision computing technology combined with Multi-token Prediction training strategy
47
+
48
+ - 📚 **High-Quality Training Data:**
49
+ - UltraClean -- High-quality Pre-training Data Filtering and Generation: Builds iterative data cleaning strategies based on efficient data verification, open-sourcing high-quality Chinese and English pre-training dataset [UltraFinweb](https://huggingface.co/datasets/openbmb/Ultra-FineWeb)
50
+ - UltraChat v2 -- High-quality Supervised Fine-tuning Data Generation: Constructs large-scale high-quality supervised fine-tuning datasets covering multiple dimensions including knowledge-intensive data, reasoning-intensive data, instruction-following data, long text understanding data, and tool calling data
51
+
52
+ - ⚡ **Efficient Inference System:**
53
+ - CPM.cu -- Lightweight and Efficient CUDA Inference Framework: Integrates sparse attention, model quantization, and speculative sampling to achieve efficient prefilling and decoding
54
+ - ArkInfer -- Cross-platform Deployment System: Supports efficient deployment across multiple backend environments, providing flexible cross-platform adaptation capabilities
55
+
56
+ ## Usage
57
+ ### Inference with Transformers
58
+ ```python
59
+ from transformers import AutoModelForCausalLM, AutoTokenizer
60
+ import torch
61
+ torch.manual_seed(0)
62
+
63
+ path = 'openbmb/MiniCPM4-0.5B'
64
+ device = "cuda"
65
+ tokenizer = AutoTokenizer.from_pretrained(path)
66
+ model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16, device_map=device, trust_remote_code=True)
67
+
68
+ # User can directly use the chat interface
69
+ responds, history = model.chat(tokenizer, "Write an article about Artificial Intelligence.", temperature=0.7, top_p=0.7)
70
+ print(responds)
71
+
72
+ # User can also use the generate interface
73
+ # messages = [
74
+ # {"role": "user", "content": "Write an article about Artificial Intelligence."},
75
+ # ]
76
+ # model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(device)
77
+
78
+ # model_outputs = model.generate(
79
+ # model_inputs,
80
+ # max_new_tokens=1024,
81
+ # top_p=0.7,
82
+ # temperature=0.7
83
+ # )
84
+ # output_token_ids = [
85
+ # model_outputs[i][len(model_inputs[i]):] for i in range(len(model_inputs))
86
+ # ]
87
+
88
+ # responses = tokenizer.batch_decode(output_token_ids, skip_special_tokens=True)[0]
89
+ # print(responses)
90
+ ```
91
+
92
+ ### Inference with [SGLang](https://github.com/sgl-project/sglang)
93
+
94
+ For now, you need to install our forked version of SGLang.
95
+ ```bash
96
+ git clone -b openbmb https://github.com/OpenBMB/sglang.git
97
+ cd sglang
98
+
99
+ pip install --upgrade pip
100
+ pip install -e "python[all]"
101
+ ```
102
+
103
+ You can start the inference server by running the following command:
104
+ ```bash
105
+ python -m sglang.launch_server --model openbmb/MiniCPM4-8B --trust-remote-code --port 30000 --chat-template chatml
106
+ ```
107
+
108
+ Then you can use the chat interface by running the following command:
109
+ ```python
110
+ import openai
111
+
112
+ client = openai.Client(base_url=f"http://localhost:30000/v1", api_key="None")
113
+
114
+ response = client.chat.completions.create(
115
+ model="openbmb/MiniCPM4-8B",
116
+ messages=[
117
+ {"role": "user", "content": "Write an article about Artificial Intelligence."},
118
+ ],
119
+ temperature=0.7,
120
+ max_tokens=1024,
121
+ )
122
+
123
+ print(response.choices[0].message.content)
124
+ ```
125
+
126
+ ### Inference with [vLLM](https://github.com/vllm-project/vllm)
127
+ For now, you need to install the latest version of vLLM.
128
+ ```
129
+ pip install -U vllm \
130
+ --pre \
131
+ --extra-index-url https://wheels.vllm.ai/nightly
132
+ ```
133
+
134
+ Then you can inference MiniCPM4-8B with vLLM:
135
+ ```python
136
+ from transformers import AutoTokenizer
137
+ from vllm import LLM, SamplingParams
138
+
139
+ model_name = "openbmb/MiniCPM4-8B"
140
+ prompt = [{"role": "user", "content": "Please recommend 5 tourist attractions in Beijing. "}]
141
+
142
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
143
+ input_text = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)
144
+
145
+ llm = LLM(
146
+ model=model_name,
147
+ trust_remote_code=True,
148
+ max_num_batched_tokens=32768,
149
+ dtype="bfloat16",
150
+ gpu_memory_utilization=0.8,
151
+ )
152
+ sampling_params = SamplingParams(top_p=0.7, temperature=0.7, max_tokens=1024, repetition_penalty=1.02)
153
+
154
+ outputs = llm.generate(prompts=input_text, sampling_params=sampling_params)
155
+
156
+ print(outputs[0].outputs[0].text)
157
+ ```
158
+
159
+ ## Evaluation Results
160
+ On two typical end-side chips, Jetson AGX Orin and RTX 4090, MiniCPM4 demonstrates significantly faster processing speed compared to similar-size models in long text processing tasks. As text length increases, MiniCPM4's efficiency advantage becomes more pronounced. On the Jetson AGX Orin platform, compared to Qwen3-8B, MiniCPM4 achieves approximately 7x decoding speed improvement.
161
+
162
+ ![benchmark](https://github.com/OpenBMB/MiniCPM/blob/main/assets/minicpm4/efficiency.png?raw=true)
163
+
164
+ #### Comprehensive Evaluation
165
+ MiniCPM4 launches end-side versions with 8B and 0.5B parameter scales, both achieving best-in-class performance in their respective categories.
166
+
167
+ ![benchmark](https://github.com/OpenBMB/MiniCPM/blob/main/assets/minicpm4/benchmark.png?raw=true)
168
+
169
+ #### Long Text Evaluation
170
+ MiniCPM4 is pre-trained on 32K long texts and achieves length extension through YaRN technology. In the 128K long text needle-in-a-haystack task, MiniCPM4 demonstrates outstanding performance.
171
+
172
+ ![long-niah](https://github.com/OpenBMB/MiniCPM/blob/main/assets/minicpm4/128k-niah.png?raw=true)
173
+
174
+ ## Statement
175
+ - As a language model, MiniCPM generates content by learning from a vast amount of text.
176
+ - However, it does not possess the ability to comprehend or express personal opinions or value judgments.
177
+ - Any content generated by MiniCPM does not represent the viewpoints or positions of the model developers.
178
+ - Therefore, when using content generated by MiniCPM, users should take full responsibility for evaluating and verifying it on their own.
179
+
180
+ ## LICENSE
181
+ - This repository and MiniCPM models are released under the [Apache-2.0](https://github.com/OpenBMB/MiniCPM/blob/main/LICENSE) License.
182
+
183
+ ## Citation
184
+ - Please cite our [paper](https://github.com/OpenBMB/MiniCPM/tree/main/report/MiniCPM_4_Technical_Report.pdf) if you find our work valuable.
185
+
186
+ ```bibtex
187
+ @article{minicpm4,
188
+ title={{MiniCPM4}: Ultra-Efficient LLMs on End Devices},
189
+ author={MiniCPM Team},
190
+ year={2025}
191
+ }
192
+ ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling_minicpm.py CHANGED
@@ -24,7 +24,7 @@ import torch.utils.checkpoint
24
  from torch import nn
25
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
26
  from transformers.activations import ACT2FN
27
- from transformers.cache_utils import Cache, DynamicCache, CacheLayerMixin, DynamicLayer
28
  from transformers.modeling_attn_mask_utils import (
29
  AttentionMaskConverter,
30
  _prepare_4d_attention_mask,
@@ -52,9 +52,493 @@ from .configuration_minicpm import MiniCPMConfig
52
  try:
53
  from flash_attn import flash_attn_func, flash_attn_varlen_func
54
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
 
 
 
 
 
 
55
  except:
56
  pass
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
 
60
  # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
@@ -83,6 +567,22 @@ def _get_unpad_data(attention_mask):
83
  )
84
 
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
 
88
  # @torch.jit.script # type: ignore
@@ -296,21 +796,6 @@ class MiniCPMMLP(nn.Module):
296
 
297
  return down_proj
298
 
299
- def _unpad_one_tensor(hidden_states, attention_mask):
300
- # Unpad the hidden states using the indices
301
- indices, cu_seqlens, max_seqlen_in_batch = _get_unpad_data(attention_mask)
302
- batch_size, seq_len = hidden_states.shape[:2]
303
-
304
- # Get the remaining dimensions
305
- remaining_dims = hidden_states.shape[2:]
306
-
307
- # Reshape to (batch_size * seq_len, *remaining_dims)
308
- reshaped_states = hidden_states.reshape(batch_size * seq_len, *remaining_dims)
309
-
310
- # Apply unpadding using indices
311
- unpadded_states = index_first_axis(reshaped_states, indices)
312
-
313
- return unpadded_states, indices, cu_seqlens, max_seqlen_in_batch
314
 
315
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
316
  """
@@ -442,7 +927,15 @@ class MiniCPMAttention(nn.Module):
442
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
443
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
444
 
445
- kv_seq_len = position_ids.max().item() + 1
 
 
 
 
 
 
 
 
446
  cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len)
447
 
448
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
@@ -544,7 +1037,9 @@ class MiniCPMFlashAttention2(MiniCPMAttention):
544
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
545
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
546
 
547
- kv_seq_len = position_ids.max().item() + 1
 
 
548
  cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len)
549
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
550
 
@@ -692,6 +1187,504 @@ class MiniCPMFlashAttention2(MiniCPMAttention):
692
  )
693
 
694
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
695
  class MiniCPMSdpaAttention(MiniCPMAttention):
696
  """
697
  MiniCPM attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
@@ -734,7 +1727,9 @@ class MiniCPMSdpaAttention(MiniCPMAttention):
734
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
735
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
736
 
737
- kv_seq_len = position_ids.max().item() + 1
 
 
738
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
739
 
740
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
@@ -788,7 +1783,10 @@ class MiniCPMDecoderLayer(nn.Module):
788
  def __init__(self, config: MiniCPMConfig, layer_idx: int):
789
  super().__init__()
790
  self.hidden_size = config.hidden_size
791
- self.self_attn = MINICPM_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
 
 
 
792
 
793
  self.mlp = MiniCPMMLP(config)
794
  self.input_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -1054,10 +2052,11 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
1054
  raise ValueError(
1055
  'You must use the new past_key_values format, such as the Cache class, instead of the old tuple format.'
1056
  )
 
1057
 
1058
- # Calculate the usable length of past key values
1059
- past_key_values_length = past_key_values.get_seq_length() if isinstance(past_key_values, Cache) else 0
1060
-
1061
 
1062
  if position_ids is None:
1063
  device = input_ids.device if input_ids is not None else inputs_embeds.device
@@ -1283,16 +2282,12 @@ class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
1283
  ):
1284
  if past_key_values is not None:
1285
  if isinstance(past_key_values, Cache):
1286
- # Use the new Cache class methods
1287
  cache_length = past_key_values.get_seq_length()
1288
-
1289
-
1290
- past_length = cache_length
1291
- max_cache_length = None
1292
  else:
1293
- raise ValueError(
1294
- 'You must use the new past_key_values format, such as the Cache class, instead of the old tuple format.'
1295
- )
1296
 
1297
  # Keep only the unprocessed tokens:
1298
  # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
 
24
  from torch import nn
25
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
26
  from transformers.activations import ACT2FN
27
+ from transformers.cache_utils import Cache, DynamicCache
28
  from transformers.modeling_attn_mask_utils import (
29
  AttentionMaskConverter,
30
  _prepare_4d_attention_mask,
 
52
  try:
53
  from flash_attn import flash_attn_func, flash_attn_varlen_func
54
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
55
+ from infllm_v2 import (
56
+ infllmv2_attn_stage1,
57
+ infllmv2_attn_varlen_func,
58
+ infllmv2_attn_with_kvcache,
59
+ max_pooling_1d,
60
+ )
61
  except:
62
  pass
63
 
64
+ from functools import lru_cache
65
+
66
+
67
+ def compressed_attention(
68
+ q: torch.Tensor,
69
+ k: torch.Tensor,
70
+ v: torch.Tensor,
71
+ kernel_size: int,
72
+ kernel_stride: int,
73
+ block_size: int,
74
+ topk: int,
75
+ cu_seqlens_q: torch.Tensor,
76
+ cu_seqlens_k: torch.Tensor,
77
+ max_seqlen_q: int,
78
+ max_seqlen_k: int,
79
+ sm_scale: float = None,
80
+ init_blocks: int = 1,
81
+ local_blocks: int = 2,
82
+ parallel_topk_compute: Union[str, bool] = 'auto',
83
+ total_seq_lens=-1,
84
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
85
+ """Attention between query and compressed key and value. Compute attention output and topk block idx used in topk_sparse_attention.
86
+
87
+ Args:
88
+ q (torch.Tensor): shape [total_q_len, num_q_heads, head_dim]
89
+ k (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]
90
+ v (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]
91
+ kernel_size (int): kernel size in compress_key_value
92
+ kernel_stride (int): stride of compress_key_value
93
+ block_size (int): key value block size for topk sparse attention.
94
+ topk (int): number of blocks for each query.
95
+ cu_seqlens_q (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
96
+ cu_seqlens_k (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_k in flash_attn_func_varlen.
97
+ max_seqlen_q (int): max q len of the batch.
98
+ max_seqlen_k (int): max k len of the batch.
99
+ sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim).
100
+ init_blocks (int, optional): Number of init blocks for each query. Defaults to 1.
101
+ local_blocks (int, optional): Number of local blocks for each query. Defaults to 2.
102
+ parallel_topk_compute (str, optional): Only set it to False when the sequence length is too long. This can avoid a current bug.
103
+ We'll fix this issue later. Defaults to auto, it will be set to False when the sequence length is greater than 32k and True otherwise.
104
+
105
+ Returns:
106
+ Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention
107
+ """
108
+ with torch.no_grad():
109
+ cache_len = 0
110
+ batch_size = cu_seqlens_q.shape[0] - 1
111
+ if total_seq_lens == -1:
112
+ total_seq_lens = max_seqlen_q
113
+ q_idx = torch.cat(
114
+ [
115
+ torch.arange(cu_seqlens_q[i + 1] - cu_seqlens_q[i], device=q.device) + total_seq_lens - (cu_seqlens_q[i + 1] - cu_seqlens_q[i])
116
+ for i in range(batch_size)
117
+ ],
118
+ dim=0,
119
+ )
120
+ q_idx = q_idx // block_size
121
+
122
+ else:
123
+ cache_len = total_seq_lens - max_seqlen_q
124
+ assert batch_size == 1, 'batch_size must be 1 when total_seq_lens is set'
125
+ q_idx = torch.tensor([total_seq_lens - 1], device=q.device, dtype=torch.int32) // block_size
126
+
127
+ score = infllmv2_attn_stage1(
128
+ q.contiguous(),
129
+ k.contiguous(),
130
+ v.contiguous(),
131
+ cu_seqlens_q=cu_seqlens_q,
132
+ cu_seqlens_k=cu_seqlens_k,
133
+ max_seqlen_q=max_seqlen_q,
134
+ max_seqlen_k=max_seqlen_k,
135
+ causal=q_idx.shape[0] > 1)
136
+ score = score[:, :q_idx.shape[0], :]
137
+
138
+ # Replace transform_score with max_pooling_1d
139
+ block_score = max_pooling_1d(
140
+ score.contiguous(),
141
+ cache_len=cache_len,
142
+ local_blocks=local_blocks,
143
+ init_blocks=init_blocks,
144
+ block_size=block_size,
145
+ stride=kernel_stride,
146
+ )
147
+ # get topk
148
+ topk = min(topk, block_score.shape[-1])
149
+ topk_idx = block_score.topk(topk, dim=-1).indices.sort(-1).values
150
+ topk_idx[topk_idx >= q_idx[None, :, None]] = -1
151
+ topk_idx = topk_idx.to(torch.int32)
152
+
153
+ return topk_idx
154
+
155
+
156
+ @lru_cache(maxsize=16)
157
+ def calc_chunks_with_stride(cu_seqlen, chunk_size, kernel_stride):
158
+ """
159
+ Compute the chunks that require Sparse attention, with stride support.
160
+
161
+ Args:
162
+ cu_seqlen (torch.Tensor): Cumulative sequence lengths for each sample.
163
+ chunk_size (int): Chunk size used for Sparse attention.
164
+ kernel_stride (int): Stride size when sliding over the sequence.
165
+
166
+ Returns:
167
+ filtered_indices (torch.Tensor): Indices used to directly index into the key/value tensors.
168
+ cu_seqlens_compressed (torch.Tensor): Cumulative sequence lengths after compression.
169
+ """
170
+ # 1. Compute the length of each sequence
171
+ batch_sizes = cu_seqlen[1:] - cu_seqlen[:-1]
172
+
173
+ # 2. Compute the start positions of chunks for each sequence (with stride)
174
+ max_seq_len = torch.max(batch_sizes)
175
+ max_num_chunks_per_seq = (max_seq_len - chunk_size) // kernel_stride + 1
176
+ chunk_start_offsets = torch.arange(0, max_num_chunks_per_seq * kernel_stride, kernel_stride, device=cu_seqlen.device)
177
+ seq_starts = cu_seqlen[:-1]
178
+ chunk_start_in_seq = seq_starts[:, None] + chunk_start_offsets[None, :] # [batch_size, max_num_chunks_per_seq]
179
+
180
+ # 3. Filter out chunks that exceed sequence length or are smaller than the full chunk size
181
+ chunk_end_in_seq = chunk_start_in_seq + chunk_size
182
+ valid_chunk_mask = (chunk_end_in_seq <= (seq_starts[:, None] + batch_sizes[:, None]))
183
+
184
+ # 4. Filter valid chunk start positions using the valid_chunk_mask
185
+ valid_chunk_starts = chunk_start_in_seq[valid_chunk_mask] # [num_valid_chunks]
186
+ del chunk_start_in_seq
187
+ # 5. Generate filtered_indices
188
+ chunk_indices = torch.arange(
189
+ 0, chunk_size, device=cu_seqlen.device
190
+ )[None, :] # [1, chunk_size]
191
+ filtered_indices = valid_chunk_starts[:, None] + chunk_indices # [num_valid_chunks, chunk_size]
192
+ filtered_indices = filtered_indices.view(-1) # Flatten to 1D indices
193
+
194
+ # 6. Compute compressed cumulative sequence lengths
195
+ num_filtered_chunks_per_batch = valid_chunk_mask.sum(dim=1) # Number of valid chunks per batch
196
+ cu_seqlens_compressed = torch.zeros(
197
+ len(cu_seqlen), dtype=torch.int32, device=cu_seqlen.device
198
+ )
199
+ cu_seqlens_compressed[1:] = num_filtered_chunks_per_batch.cumsum(dim=0)
200
+ del num_filtered_chunks_per_batch, chunk_start_offsets, seq_starts, chunk_end_in_seq, valid_chunk_mask, chunk_indices
201
+ return filtered_indices, cu_seqlens_compressed
202
+
203
+
204
+ class CompressK(torch.nn.Module):
205
+ def __init__(self, head_num_k, head_dim, kernel_size, kernel_stride=16):
206
+ """
207
+ Module for compressing key (K) representations.
208
+
209
+ Args:
210
+ head_num_k (int): Number of key attention heads.
211
+ head_dim (int): Dimension of each attention head.
212
+ kernel_size (int): Size of each chunk used for compression.
213
+ kernel_stride (int, optional): Stride used when dividing input into chunks. Default is 16.
214
+ """
215
+ super().__init__()
216
+ self.kernel_size = kernel_size
217
+ self.head_num_k = head_num_k
218
+ self.head_dim = head_dim
219
+ self.kernel_stride = kernel_stride
220
+
221
+ def forward(self, k: torch.Tensor, cu_seqlens):
222
+ """
223
+ Forward pass for compressing the key (K) tensor.
224
+
225
+ Args:
226
+ k (torch.Tensor): Input key tensor of shape (total_seq_len, num_heads, head_dim).
227
+ cu_seqlens (torch.Tensor): Cumulative sequence lengths for each sample in the batch, typically used for handling variable-length sequences.
228
+
229
+ Returns:
230
+ compress_k (torch.Tensor): Compressed key tensor.
231
+ cu_seqlens_compressed (torch.Tensor): Updated cumulative sequence lengths after compression.
232
+
233
+ """
234
+ # Compute chunk-related metadata, with stride support
235
+ filtered_k_indices, cu_seqlens_compressed = calc_chunks_with_stride(
236
+ cu_seqlens, self.kernel_size, self.kernel_stride
237
+ )
238
+
239
+ # Extract filtered key vectors
240
+ filtered_k = k.index_select(0, filtered_k_indices.view(-1))
241
+
242
+ # split
243
+ filtered_k = filtered_k.view(filtered_k.shape[0] // self.kernel_size, self.kernel_size, self.head_num_k, self.head_dim) # [l, block_size,h,d]
244
+
245
+ compressed_k = filtered_k.mean(dim=1)
246
+ return compressed_k, cu_seqlens_compressed
247
+
248
+
249
+ class DynamicCacheQKV(DynamicCache):
250
+ """
251
+ A cache that grows dynamically as more tokens are generated. This is the default for generative models.
252
+
253
+ It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
254
+ `[batch_size, num_heads, seq_len, head_dim]`.
255
+
256
+ Example:
257
+ ```python
258
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
259
+
260
+ >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
261
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
262
+
263
+ >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
264
+
265
+ >>> # Prepare a cache class and pass it to model's forward
266
+ >>> past_key_values = DynamicCache()
267
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
268
+ >>> outputs.past_key_values # access cache filled with key/values from generation
269
+ DynamicCache()
270
+ ```
271
+ """
272
+ def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
273
+ super().__init__()
274
+ if num_hidden_layers is None:
275
+ self.key_cache: List[torch.Tensor] = []
276
+ self.value_cache: List[torch.Tensor] = []
277
+ self.compress_k_cache: List[torch.Tensor] = []
278
+ self.no_compress_k_cache: List[torch.Tensor] = []
279
+ self.cached_compressed_cu_seqlens: List[torch.Tensor] = []
280
+ self.no_rope_key_cache: List[torch.Tensor] = []
281
+ else:
282
+ self.key_cache: List[torch.Tensor] = [[] for _ in range(num_hidden_layers)]
283
+ self.value_cache: List[torch.Tensor] = [[] for _ in range(num_hidden_layers)]
284
+ self.compress_k_cache: List[torch.Tensor] = [[] for _ in range(num_hidden_layers)]
285
+ self.no_compress_k_cache: List[torch.Tensor] = [[] for _ in range(num_hidden_layers)]
286
+ self.cached_compressed_cu_seqlens: List[torch.Tensor] = [[] for _ in range(num_hidden_layers)]
287
+ self.no_rope_key_cache: List[torch.Tensor] = [[] for _ in range(num_hidden_layers)]
288
+ self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
289
+
290
+ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
291
+ """
292
+ Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
293
+ sequence length.
294
+ """
295
+ if layer_idx < len(self):
296
+ return (self.key_cache[layer_idx], self.value_cache[layer_idx])
297
+ else:
298
+ raise KeyError(f'Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}')
299
+
300
+ def __iter__(self):
301
+ """
302
+ Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
303
+ keys and values
304
+ """
305
+ for layer_idx in range(len(self)):
306
+ yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
307
+
308
+ def __len__(self):
309
+ """
310
+ Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
311
+ to the number of layers in the model.
312
+ """
313
+ return len(self.key_cache)
314
+
315
+ def update(
316
+ self,
317
+ key_states: torch.Tensor,
318
+ value_states: torch.Tensor,
319
+ layer_idx: int,
320
+ cache_kwargs: Optional[Dict[str, Any]] = None
321
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
322
+ """
323
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
324
+
325
+ Parameters:
326
+ key_states (`torch.Tensor`):
327
+ The new key states to cache.
328
+ value_states (`torch.Tensor`):
329
+ The new value states to cache.
330
+ layer_idx (`int`):
331
+ The index of the layer to cache the states for.
332
+ cache_kwargs (`Dict[str, Any]`, `optional`):
333
+ Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
334
+
335
+ Return:
336
+ A tuple containing the updated key and value states.
337
+ """
338
+ # Update the number of seen tokens
339
+ if layer_idx == 0:
340
+ self._seen_tokens += key_states.shape[-2]
341
+
342
+ # Update the cache
343
+ if len(self.key_cache) <= layer_idx:
344
+ self.key_cache.append(key_states)
345
+ self.value_cache.append(value_states)
346
+
347
+ # content on layer cache can be a tensor and checking not tensor causes errors
348
+ # so we explicitly check for the empty list
349
+ elif self.key_cache[layer_idx] == []:
350
+ self.key_cache[layer_idx] = key_states
351
+ self.value_cache[layer_idx] = value_states
352
+
353
+ else:
354
+ self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
355
+ self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
356
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
357
+
358
+ def update_no_rope_key(
359
+ self,
360
+ key_states: torch.Tensor,
361
+ layer_idx: int,
362
+ cache_kwargs: Optional[Dict[str, Any]] = None):
363
+
364
+ # Update the cache
365
+ if len(self.no_rope_key_cache) <= layer_idx:
366
+ self.no_rope_key_cache.append(key_states)
367
+
368
+ # content on layer cache can be a tensor and checking not tensor causes errors
369
+ # so we explicitly check for the empty list
370
+ elif self.no_rope_key_cache[layer_idx] == []:
371
+ self.no_rope_key_cache[layer_idx] = key_states
372
+ else:
373
+ self.no_rope_key_cache[layer_idx] = torch.cat([self.no_rope_key_cache[layer_idx], key_states], dim=1)
374
+ return self.no_rope_key_cache[layer_idx]
375
+
376
+ def update_compress_k(
377
+ self,
378
+ key_states: torch.Tensor,
379
+ layer_idx: int,
380
+ cache_kwargs: Optional[Dict[str, Any]] = None
381
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
382
+ """
383
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
384
+
385
+ Parameters:
386
+ key_states (`torch.Tensor`):
387
+ The new key states to cache.
388
+ value_states (`torch.Tensor`):
389
+ The new value states to cache.
390
+ layer_idx (`int`):
391
+ The index of the layer to cache the states for.
392
+ cache_kwargs (`Dict[str, Any]`, `optional`):
393
+ Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
394
+
395
+ Return:
396
+ A tuple containing the updated key and value states.
397
+ """
398
+
399
+ # Update the cache
400
+ if len(self.compress_k_cache) <= layer_idx:
401
+ self.compress_k_cache.append(key_states)
402
+
403
+ # content on layer cache can be a tensor and checking not tensor causes errors
404
+ # so we explicitly check for the empty list
405
+ elif self.compress_k_cache[layer_idx] == []:
406
+ self.compress_k_cache[layer_idx] = key_states
407
+ else:
408
+ self.compress_k_cache[layer_idx] = torch.cat([self.compress_k_cache[layer_idx], key_states], dim=0)
409
+ return self.compress_k_cache[layer_idx]
410
+
411
+ def update_no_compress_k(
412
+ self,
413
+ key_states: torch.Tensor,
414
+ layer_idx: int,
415
+ kernel_size: int = 32,
416
+ kernel_stride: int = 16,
417
+ cache_kwargs: Optional[Dict[str, Any]] = None
418
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
419
+ """
420
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
421
+
422
+ Parameters:
423
+ key_states (`torch.Tensor`):
424
+ The new key states to cache.
425
+ value_states (`torch.Tensor`):
426
+ The new value states to cache.
427
+ layer_idx (`int`):
428
+ The index of the layer to cache the states for.
429
+ cache_kwargs (`Dict[str, Any]`, `optional`):
430
+ Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
431
+
432
+ Return:
433
+ A tuple containing the updated key and value states.
434
+ """
435
+ # Update the cache
436
+ if len(self.no_compress_k_cache) <= layer_idx:
437
+ self.no_compress_k_cache.append(key_states)
438
+
439
+ # content on layer cache can be a tensor and checking not tensor causes errors
440
+ # so we explicitly check for the empty list
441
+ elif self.no_compress_k_cache[layer_idx] == []:
442
+ self.no_compress_k_cache[layer_idx] = key_states
443
+ else:
444
+ self.no_compress_k_cache[layer_idx] = torch.cat([self.no_compress_k_cache[layer_idx], key_states], dim=0)
445
+
446
+ current_len = self.no_compress_k_cache[layer_idx].shape[0]
447
+
448
+ if current_len >= kernel_size:
449
+ k_chunk = self.no_compress_k_cache[layer_idx][:kernel_size]
450
+ self.no_compress_k_cache[layer_idx] = self.no_compress_k_cache[layer_idx][kernel_stride:]
451
+ return k_chunk
452
+ else:
453
+ return None
454
+
455
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
456
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
457
+ # TODO: deprecate this function in favor of `cache_position`
458
+ if len(self.key_cache) <= layer_idx or (len(self.key_cache) > layer_idx and self.key_cache[layer_idx] == []):
459
+ return 0
460
+ return self.key_cache[layer_idx].shape[-2]
461
+
462
+ def get_max_length(self) -> Optional[int]:
463
+ """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
464
+ return None
465
+
466
+ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
467
+ """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for
468
+ backward compatibility."""
469
+ legacy_cache = ()
470
+ for layer_idx in range(len(self)):
471
+ legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
472
+ return legacy_cache
473
+
474
+ # @classmethod
475
+ # def from_legacy_cache(
476
+ # cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_hidden_layers: int = None
477
+ # ) -> "DynamicCacheQKV":
478
+ # """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
479
+ # backward compatibility."""
480
+ # cache = cls(num_hidden_layers)
481
+ # if past_key_values is not None:
482
+ # for layer_idx in range(len(past_key_values)):
483
+ # key_states, value_states, query_status = past_key_values[layer_idx]
484
+ # cache.update(key_states, value_states, query_status,layer_idx)
485
+ # return cache
486
+
487
+ def crop(self, max_length: int):
488
+ """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
489
+ negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search."""
490
+ # In case it is negative
491
+ if max_length < 0:
492
+ max_length = self.get_seq_length() - abs(max_length)
493
+
494
+ if self.get_seq_length() <= max_length:
495
+ return
496
+
497
+ self._seen_tokens = max_length
498
+ for idx in range(len(self.key_cache)):
499
+ if self.key_cache[idx] != []:
500
+ self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
501
+ self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
502
+
503
+ def batch_split(self, full_batch_size: int, split_size: int, num_hidden_layers: int) -> List['DynamicCacheQKV']:
504
+ """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
505
+ `_split_model_inputs()` in `generation.utils`"""
506
+ out = []
507
+ for i in range(0, full_batch_size, split_size):
508
+ current_split = DynamicCacheQKV(num_hidden_layers)
509
+ current_split._seen_tokens = self._seen_tokens
510
+ current_split.key_cache = [tensor[i: i + split_size] for tensor in self.key_cache]
511
+ current_split.value_cache = [tensor[i: i + split_size] for tensor in self.value_cache]
512
+ out.append(current_split)
513
+ return out
514
+
515
+ @classmethod
516
+ def from_batch_splits(cls, splits: List['DynamicCacheQKV'], num_hidden_layers: int) -> 'DynamicCacheQKV':
517
+ """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
518
+ `generation.utils`"""
519
+ cache = cls(num_hidden_layers)
520
+ for idx in range(len(splits[0])):
521
+ key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
522
+ value_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
523
+ query_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
524
+ if key_cache != []:
525
+ layer_keys = torch.cat(key_cache, dim=0)
526
+ layer_values = torch.cat(value_cache, dim=0)
527
+ layer_query = torch.cat(query_cache, dim=0)
528
+ cache.update(layer_keys, layer_values, idx, query_states=layer_query)
529
+ return cache
530
+
531
+ def batch_repeat_interleave(self, repeats: int):
532
+ """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
533
+ for layer_idx in range(len(self)):
534
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0)
535
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0)
536
+
537
+ def batch_select_indices(self, indices: torch.Tensor):
538
+ """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
539
+ for layer_idx in range(len(self)):
540
+ self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
541
+ self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
542
 
543
 
544
  # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
 
567
  )
568
 
569
 
570
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
571
+ warnings.warn(
572
+ 'Calling `transformers.models.minicpm.modeling_minicpm._prepare_4d_attention_mask` is deprecated and will be removed in v4.37. Use `transformers.modeling_attn_mask_utils._prepare_4d_attention_mask'
573
+ )
574
+ return _prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
575
+
576
+
577
+ def _make_causal_mask(
578
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
579
+ ):
580
+ warnings.warn(
581
+ 'Calling `transformers.models.minicpm.modeling_minicpm._make_causal_mask` is deprecated and will be removed in v4.37. Use `transformers.models.minicpm.modeling_minicpm.AttentionMaskConverter._make_causal_mask'
582
+ )
583
+ return AttentionMaskConverter._make_causal_mask(
584
+ input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length
585
+ )
586
 
587
 
588
  # @torch.jit.script # type: ignore
 
796
 
797
  return down_proj
798
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
799
 
800
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
801
  """
 
927
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
928
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
929
 
930
+ kv_seq_len = key_states.shape[-2]
931
+ if past_key_value is not None:
932
+ if self.layer_idx is None:
933
+ raise ValueError(
934
+ f'The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} '
935
+ 'for auto-regressive decoding with k/v caching, please make sure to initialize the attention class '
936
+ 'with a layer index.'
937
+ )
938
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
939
  cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len)
940
 
941
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
 
1037
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1038
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1039
 
1040
+ kv_seq_len = key_states.shape[-2]
1041
+ if past_key_value is not None:
1042
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
1043
  cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len)
1044
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
1045
 
 
1187
  )
1188
 
1189
 
1190
+ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1191
+ """
1192
+ MiniCPM flash attention module. This module inherits from `MiniCPMAttention` as the weights of the module stays
1193
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
1194
+ flash attention and deal with padding tokens in case the input contains any of them.
1195
+ """
1196
+
1197
+ def __init__(self, *args, **kwargs):
1198
+ super().__init__(*args, **kwargs)
1199
+ assert self.config._attn_implementation == 'flash_attention_2', 'Only flash_attention_2 is supported for sparse attention'
1200
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
1201
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
1202
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
1203
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
1204
+
1205
+ # -------sparse-------
1206
+ self.kernel_size = self.config.sparse_config.get('kernel_size', 32)
1207
+ self.kernel_stride = self.config.sparse_config.get('kernel_stride', 16)
1208
+ self.init_blocks = self.config.sparse_config.get('init_blocks', 1)
1209
+ self.block_size = self.config.sparse_config.get('block_size', 64)
1210
+ self.window_size = self.config.sparse_config.get('window_size', 2048)
1211
+ self.dense_len = self.config.sparse_config.get('dense_len', 8192)
1212
+
1213
+ self.local_blocks = self.window_size // self.block_size # local_blocks
1214
+ self.topk = self.config.sparse_config.get('topk', 64)
1215
+ self.use_nope = self.config.sparse_config.get('use_nope', False)
1216
+ self.compress_k = CompressK(self.num_key_value_heads, self.head_dim, kernel_size=self.kernel_size, kernel_stride=self.kernel_stride)
1217
+
1218
+ def forward(
1219
+ self,
1220
+ hidden_states: torch.Tensor,
1221
+ attention_mask: Optional[torch.LongTensor] = None,
1222
+ position_ids: Optional[torch.LongTensor] = None,
1223
+ past_key_value: Optional[Cache] = None,
1224
+ output_attentions: bool = False,
1225
+ use_cache: bool = False,
1226
+ **kwargs,
1227
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
1228
+ # MiniCPMFlashAttention2 attention does not support output_attentions
1229
+ if 'padding_mask' in kwargs:
1230
+ warnings.warn(
1231
+ 'Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`'
1232
+ )
1233
+
1234
+ # overwrite attention_mask with padding_mask
1235
+ attention_mask = kwargs.pop('padding_mask')
1236
+
1237
+ output_attentions = False
1238
+
1239
+ bsz, q_len, _ = hidden_states.size()
1240
+ assert bsz == 1, 'Only batch_size=1 is supported at the moment.'
1241
+
1242
+ query_states = self.q_proj(hidden_states)
1243
+ key_states = self.k_proj(hidden_states)
1244
+ value_states = self.v_proj(hidden_states)
1245
+
1246
+ # !save no rope
1247
+ if self.use_nope:
1248
+ query_states_no_rope = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
1249
+ key_states_no_rope = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
1250
+
1251
+ # Flash attention requires the input to have the shape
1252
+ # batch_size x seq_length x head_dim x hidden_dim
1253
+ # therefore we just need to keep the original shape
1254
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
1255
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1256
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1257
+
1258
+ kv_seq_len = key_states.shape[-2]
1259
+ if past_key_value is not None:
1260
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
1261
+ cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len)
1262
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
1263
+
1264
+ if past_key_value is not None:
1265
+ cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models
1266
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
1267
+
1268
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
1269
+ # to be able to avoid many of these transpose/reshape/view.
1270
+ query_states = query_states.transpose(1, 2)
1271
+ key_states = key_states.transpose(1, 2)
1272
+ value_states = value_states.transpose(1, 2)
1273
+ if self.use_nope:
1274
+ no_rope_param = {
1275
+ 'key_states_no_rope': key_states_no_rope,
1276
+ 'query_states_no_rope': query_states_no_rope,
1277
+ }
1278
+ if kv_seq_len <= self.dense_len:
1279
+ past_key_value.update_no_rope_key(key_states_no_rope, self.layer_idx)
1280
+ else:
1281
+ no_rope_param = None
1282
+
1283
+ dropout_rate = self.attention_dropout if self.training else 0.0
1284
+
1285
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
1286
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
1287
+ # cast them back in the correct dtype just to be sure everything works as expected.
1288
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
1289
+ # in fp32. (MiniCPMRMSNorm handles it correctly)
1290
+
1291
+ input_dtype = query_states.dtype
1292
+ if input_dtype == torch.float32:
1293
+ # Handle the case where the model is quantized
1294
+ if hasattr(self.config, '_pre_quantization_dtype'):
1295
+ target_dtype = self.config._pre_quantization_dtype
1296
+ else:
1297
+ target_dtype = self.q_proj.weight.dtype
1298
+
1299
+ logger.warning_once(
1300
+ f'The input hidden states seems to be silently casted in float32, this might be related to'
1301
+ f' the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in'
1302
+ f' {target_dtype}.'
1303
+ )
1304
+
1305
+ query_states = query_states.to(target_dtype)
1306
+ key_states = key_states.to(target_dtype)
1307
+ value_states = value_states.to(target_dtype)
1308
+ if kv_seq_len < self.dense_len:
1309
+ attn_output = self._flash_attention_forward_dense(
1310
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate)
1311
+ elif past_key_value is None or q_len != 1: # prefilling
1312
+ attn_output = self._flash_attention_forward(
1313
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate,
1314
+ no_rope_param=no_rope_param, # if past_key_value is not None else None,
1315
+ past_key_value=past_key_value)
1316
+ else:
1317
+ attn_output = self._flash_attention_forward_with_kv_cache(
1318
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate, no_rope_param=no_rope_param, past_key_value=past_key_value)
1319
+
1320
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
1321
+ attn_output = self.o_proj(attn_output)
1322
+
1323
+ if not output_attentions:
1324
+ attn_weights = None
1325
+
1326
+ return attn_output, attn_weights, past_key_value
1327
+
1328
+ def _flash_attention_forward(
1329
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None, no_rope_param=None, past_key_value=None
1330
+ ):
1331
+ """
1332
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
1333
+ first unpad the input, then computes the attention scores and pad the final attention scores.
1334
+
1335
+ Args:
1336
+ query_states (`torch.Tensor`):
1337
+ Input query states to be passed to Flash Attention API
1338
+ key_states (`torch.Tensor`):
1339
+ Input key states to be passed to Flash Attention API
1340
+ value_states (`torch.Tensor`):
1341
+ Input value states to be passed to Flash Attention API
1342
+ attention_mask (`torch.Tensor`):
1343
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
1344
+ position of padding tokens and 1 for the position of non-padding tokens.
1345
+ dropout (`int`, *optional*):
1346
+ Attention dropout
1347
+ softmax_scale (`float`, *optional*):
1348
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
1349
+ """
1350
+ if not self._flash_attn_uses_top_left_mask:
1351
+ causal = self.is_causal
1352
+ else:
1353
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in MiniCPMFlashAttention2 __init__.
1354
+ causal = self.is_causal and query_length != 1
1355
+ # Contains at least one padding token in the sequence
1356
+ if attention_mask is not None:
1357
+ batch_size = query_states.shape[0]
1358
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
1359
+ query_states, key_states, value_states, attention_mask, query_length
1360
+ )
1361
+ if no_rope_param is not None:
1362
+ # nope unpad
1363
+ no_rope_param['query_states_no_rope'] = no_rope_param['query_states_no_rope'].squeeze(0)
1364
+ no_rope_param['key_states_no_rope'] = no_rope_param['key_states_no_rope'].squeeze(0)
1365
+
1366
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
1367
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
1368
+ attn_output_unpad = self.sparse_forward(
1369
+ query_states,
1370
+ key_states,
1371
+ value_states,
1372
+ cu_seqlens_q,
1373
+ cu_seqlens_k,
1374
+ max_seqlen_in_batch_q,
1375
+ max_seqlen_in_batch_k,
1376
+ no_rope_param=no_rope_param,
1377
+ past_key_value=past_key_value,
1378
+ )
1379
+
1380
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
1381
+ else:
1382
+ raise ValueError('Need attention mask')
1383
+
1384
+ return attn_output
1385
+
1386
+ def _flash_attention_forward_with_kv_cache(
1387
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None, no_rope_param=None, past_key_value=None
1388
+ ):
1389
+ """
1390
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
1391
+ first unpad the input, then computes the attention scores and pad the final attention scores.
1392
+
1393
+ Args:
1394
+ query_states (`torch.Tensor`):
1395
+ Input query states to be passed to Flash Attention API
1396
+ key_states (`torch.Tensor`):
1397
+ Input key states to be passed to Flash Attention API
1398
+ value_states (`torch.Tensor`):
1399
+ Input value states to be passed to Flash Attention API
1400
+ attention_mask (`torch.Tensor`):
1401
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
1402
+ position of padding tokens and 1 for the position of non-padding tokens.
1403
+ dropout (`int`, *optional*):
1404
+ Attention dropout
1405
+ softmax_scale (`float`, *optional*):
1406
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
1407
+ """
1408
+ if not self._flash_attn_uses_top_left_mask:
1409
+ causal = self.is_causal
1410
+ else:
1411
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in MiniCPMFlashAttention2 __init__.
1412
+ causal = self.is_causal and query_length != 1
1413
+ # Contains at least one padding token in the sequence
1414
+ if attention_mask is not None:
1415
+
1416
+ batch_size = query_states.shape[0]
1417
+
1418
+ # query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
1419
+ # query_states, key_states, value_states, attention_mask, query_length=query_length
1420
+ # )
1421
+
1422
+ assert batch_size == 1, 'Only batch_size=1 is supported at the moment.'
1423
+ # prepare past kv ,new kv
1424
+ new_q = query_states
1425
+
1426
+ new_k = key_states[:, -1:, :, :].contiguous()
1427
+ new_v = value_states[:, -1:, :, :].contiguous()
1428
+
1429
+ past_k = key_states[:, :-1, :, :].contiguous()
1430
+ past_v = value_states[:, :-1, :, :].contiguous()
1431
+ if no_rope_param is not None:
1432
+ # nope unpad
1433
+ no_rope_param['query_states_no_rope'] = no_rope_param['query_states_no_rope'].squeeze(0)
1434
+ no_rope_param['key_states_no_rope'] = no_rope_param['key_states_no_rope'].squeeze(0)
1435
+
1436
+ attn_output = self.sparse_forward_with_kv_cache(
1437
+ past_k=past_k, past_v=past_v, new_k=new_k, new_v=new_v, new_q=new_q, batch_size=batch_size, no_rope_param=no_rope_param, past_key_value=past_key_value)
1438
+
1439
+ # attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
1440
+ else:
1441
+ raise ValueError('need attention mask')
1442
+
1443
+ return attn_output
1444
+
1445
+ def sparse_forward(self,
1446
+ query_layer,
1447
+ key_layer,
1448
+ value_layer,
1449
+ cu_seqlens_q,
1450
+ cu_seqlens_k,
1451
+ max_seqlen_in_batch_q,
1452
+ max_seqlen_in_batch_k,
1453
+ no_rope_param=None,
1454
+ past_key_value=None):
1455
+ stage1_k = key_layer if no_rope_param is None else no_rope_param['key_states_no_rope']
1456
+ compressed_k, compressed_cu_seqlens = self.compress_k(stage1_k, cu_seqlens_k)
1457
+ compressed_v = compressed_k.clone()
1458
+ if past_key_value is not None:
1459
+ # Compute the start indices of keys (k) that were not compressed, Only batch_size=1 is supported at the moment.
1460
+ no_compress_k_start = compressed_k.shape[0] * self.kernel_stride
1461
+ past_key_value.update_compress_k(
1462
+ compressed_k, self.layer_idx
1463
+ )
1464
+ past_key_value.update_no_compress_k(
1465
+ key_layer[no_compress_k_start:], self.layer_idx, no_compress_k_start)
1466
+ past_key_value.cached_compressed_cu_seqlens.append(compressed_cu_seqlens)
1467
+ compressed_seqlens = compressed_cu_seqlens[1:] - compressed_cu_seqlens[:-1]
1468
+ topk_idx = compressed_attention(
1469
+ query_layer if no_rope_param is None else no_rope_param['query_states_no_rope'],
1470
+ compressed_k,
1471
+ compressed_v,
1472
+ self.kernel_size,
1473
+ self.kernel_stride,
1474
+ self.block_size,
1475
+ self.topk,
1476
+ cu_seqlens_q,
1477
+ compressed_cu_seqlens,
1478
+ max_seqlen_in_batch_q,
1479
+ compressed_seqlens.max().item(),
1480
+ None,
1481
+ init_blocks=self.init_blocks,
1482
+ local_blocks=self.local_blocks,
1483
+ )
1484
+
1485
+ topk_attn_output = infllmv2_attn_varlen_func(
1486
+ query_layer,
1487
+ key_layer,
1488
+ value_layer,
1489
+ cu_seqlens_q,
1490
+ cu_seqlens_k,
1491
+ max_seqlen_in_batch_q,
1492
+ max_seqlen_in_batch_k,
1493
+ dropout_p=0.0,
1494
+ deterministic=False,
1495
+ softmax_scale=None,
1496
+ causal=True,
1497
+ return_attn_probs=False,
1498
+ block_window_size=self.window_size // self.block_size,
1499
+ topk_idx=topk_idx
1500
+ )
1501
+
1502
+ return topk_attn_output
1503
+
1504
+ def sparse_forward_with_kv_cache(self, past_k=None, past_v=None, new_k=None, new_v=None, new_q=None, batch_size=None, no_rope_param=None, past_key_value=None):
1505
+
1506
+ # stage1_k = new_k.squeeze(0) if no_rope_param is None else no_rope_param['key_states_no_rope']
1507
+ if past_k.shape[1] + new_k.shape[1] == self.dense_len and (past_key_value.compress_k_cache == [] or len(past_key_value.compress_k_cache) < self.layer_idx + 1 or past_key_value.compress_k_cache[self.layer_idx] == []):
1508
+ if no_rope_param is not None:
1509
+ stage1_k = past_key_value.no_rope_key_cache[self.layer_idx].squeeze(0).contiguous() # just batch_size ==1
1510
+ else:
1511
+ stage1_k = torch.cat([past_k, new_k], dim=1).contiguous().squeeze(0).contiguous() # just batch_size ==1
1512
+ compressed_k, compressed_cu_seqlens = self.compress_k(stage1_k, torch.tensor([0, stage1_k.shape[0]], device=stage1_k.device, dtype=torch.int32)) # just batch_size ==1
1513
+
1514
+ # Compute the start indices of keys (k) that were not compressed, Only batch_size=1 is supported at the moment.
1515
+ no_compress_k_start = compressed_k.shape[0] * self.kernel_stride
1516
+ past_key_value.update_compress_k(
1517
+ compressed_k, self.layer_idx
1518
+ )
1519
+ past_key_value.update_no_compress_k(
1520
+ stage1_k[no_compress_k_start:], self.layer_idx, no_compress_k_start)
1521
+ past_key_value.cached_compressed_cu_seqlens.append(compressed_cu_seqlens)
1522
+
1523
+ else:
1524
+ stage1_k = new_k.squeeze(0) if no_rope_param is None else no_rope_param['key_states_no_rope']
1525
+ no_compress_k = past_key_value.update_no_compress_k(
1526
+ stage1_k, self.layer_idx, kernel_stride=self.kernel_stride, kernel_size=self.kernel_size)
1527
+ if no_compress_k is not None:
1528
+ compressed_k = no_compress_k.mean(dim=0, keepdim=True) # [1, n_heads_k, head_dim]
1529
+
1530
+ compressed_k = past_key_value.update_compress_k(
1531
+ compressed_k, self.layer_idx) # [seqlen, nheads_k, head_dim]
1532
+
1533
+ past_key_value.cached_compressed_cu_seqlens[self.layer_idx][-1] += 1 # !Increment the last entry in sequence lengths by 1; currently supports only batch_size = 1
1534
+ compressed_cu_seqlens = past_key_value.cached_compressed_cu_seqlens[self.layer_idx]
1535
+ else:
1536
+ compressed_k = past_key_value.compress_k_cache[self.layer_idx] # [seqlen, nheads_k, head_dim]
1537
+ compressed_cu_seqlens = past_key_value.cached_compressed_cu_seqlens[self.layer_idx]
1538
+
1539
+ compressed_v = compressed_k.clone()
1540
+
1541
+ compressed_seqlens = compressed_cu_seqlens[1:] - compressed_cu_seqlens[:-1]
1542
+ torch.cuda.synchronize()
1543
+ # Manually verify that the lengths match
1544
+ assert compressed_k.shape[0] == compressed_seqlens.sum().item(), 'The length of compressed_k does not match the sum of compressed_seqlens'
1545
+ topk_idx = compressed_attention(
1546
+ new_q.squeeze(0).contiguous() if no_rope_param is None else no_rope_param['query_states_no_rope'],
1547
+ compressed_k,
1548
+ compressed_v,
1549
+ self.kernel_size,
1550
+ self.kernel_stride,
1551
+ self.block_size,
1552
+ self.topk,
1553
+ torch.tensor([0, 1], device=compressed_k.device, dtype=torch.int32),
1554
+ compressed_cu_seqlens,
1555
+ 1,
1556
+ compressed_seqlens.max().item(),
1557
+ None,
1558
+ init_blocks=self.init_blocks,
1559
+ local_blocks=self.local_blocks,
1560
+ total_seq_lens=past_k.shape[1] + 1, # !Only batch_size=1 is supported at the moment.
1561
+ )
1562
+
1563
+ repeat_times = 1
1564
+ if repeat_times > 1:
1565
+ new_q = new_q.repeat_interleave(repeat_times, dim=-2)
1566
+ else:
1567
+ new_q = new_q
1568
+
1569
+ cache_batch_idx = torch.arange(batch_size, device=new_q.device, dtype=torch.int32)
1570
+
1571
+ seqlen_k = past_k.shape[1] + new_k.shape[1] # !Only batch_size=1 is supported at the moment.
1572
+ seqlens_k = torch.full((batch_size,), seqlen_k - 1, dtype=torch.int32, device=new_q.device)
1573
+
1574
+ past_k = torch.cat([past_k, torch.zeros_like(new_k, dtype=new_k.dtype)], dim=1).contiguous() # Append one zero vector to avoid potential out-of-bounds access
1575
+ past_v = torch.cat([past_v, torch.zeros_like(new_v, dtype=new_v.dtype)], dim=1).contiguous() # Append one zero vector to avoid potential out-of-bounds access
1576
+ topk_attn_output = infllmv2_attn_with_kvcache(
1577
+ q=new_q,
1578
+ k_cache=past_k,
1579
+ v_cache=past_v,
1580
+ topk_idx=topk_idx,
1581
+ block_window_size=self.window_size // self.block_size,
1582
+ k=new_k, # [batch_size, 1, nheads_k, d]
1583
+ v=new_v, # [batch_size, 1, nheads_k, d]
1584
+ cache_seqlens=seqlens_k, # current_seqlens_k-1
1585
+ rotary_cos=None, # No rotary embeddings
1586
+ rotary_sin=None, # No rotary embeddings
1587
+ cache_batch_idx=cache_batch_idx,
1588
+ causal=False, # Renaming to match function signature
1589
+ )
1590
+ return topk_attn_output
1591
+
1592
+ def _flash_attention_forward_dense(
1593
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
1594
+ ):
1595
+ """
1596
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
1597
+ first unpad the input, then computes the attention scores and pad the final attention scores.
1598
+
1599
+ Args:
1600
+ query_states (`torch.Tensor`):
1601
+ Input query states to be passed to Flash Attention API
1602
+ key_states (`torch.Tensor`):
1603
+ Input key states to be passed to Flash Attention API
1604
+ value_states (`torch.Tensor`):
1605
+ Input value states to be passed to Flash Attention API
1606
+ attention_mask (`torch.Tensor`):
1607
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
1608
+ position of padding tokens and 1 for the position of non-padding tokens.
1609
+ dropout (`int`, *optional*):
1610
+ Attention dropout
1611
+ softmax_scale (`float`, *optional*):
1612
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
1613
+ """
1614
+ if not self._flash_attn_uses_top_left_mask:
1615
+ causal = self.is_causal
1616
+ else:
1617
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in MiniCPMFlashAttention2 __init__.
1618
+ causal = self.is_causal and query_length != 1
1619
+ # Contains at least one padding token in the sequence
1620
+ if attention_mask is not None:
1621
+ batch_size = query_states.shape[0]
1622
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
1623
+ query_states, key_states, value_states, attention_mask, query_length
1624
+ )
1625
+
1626
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
1627
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
1628
+ attn_output_unpad = flash_attn_varlen_func(
1629
+ query_states,
1630
+ key_states,
1631
+ value_states,
1632
+ cu_seqlens_q=cu_seqlens_q,
1633
+ cu_seqlens_k=cu_seqlens_k,
1634
+ max_seqlen_q=max_seqlen_in_batch_q,
1635
+ max_seqlen_k=max_seqlen_in_batch_k,
1636
+ dropout_p=dropout,
1637
+ softmax_scale=softmax_scale,
1638
+ causal=causal,
1639
+ )
1640
+
1641
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
1642
+ else:
1643
+ attn_output = flash_attn_func(
1644
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
1645
+ )
1646
+
1647
+ return attn_output
1648
+
1649
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
1650
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
1651
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
1652
+
1653
+ key_layer = index_first_axis(
1654
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
1655
+ )
1656
+ value_layer = index_first_axis(
1657
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
1658
+ )
1659
+ if query_length == kv_seq_len:
1660
+ query_layer = index_first_axis(
1661
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
1662
+ )
1663
+ cu_seqlens_q = cu_seqlens_k
1664
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
1665
+ indices_q = indices_k
1666
+ elif query_length == 1:
1667
+ max_seqlen_in_batch_q = 1
1668
+ cu_seqlens_q = torch.arange(
1669
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
1670
+ ) # There is a memcpy here, that is very bad.
1671
+ indices_q = cu_seqlens_q[:-1]
1672
+ query_layer = query_layer.squeeze(1)
1673
+ else:
1674
+ # The -q_len: slice assumes left padding.
1675
+ attention_mask = attention_mask[:, -query_length:]
1676
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
1677
+
1678
+ return (
1679
+ query_layer,
1680
+ key_layer,
1681
+ value_layer,
1682
+ indices_q,
1683
+ (cu_seqlens_q, cu_seqlens_k),
1684
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
1685
+ )
1686
+
1687
+
1688
  class MiniCPMSdpaAttention(MiniCPMAttention):
1689
  """
1690
  MiniCPM attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
 
1727
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1728
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1729
 
1730
+ kv_seq_len = key_states.shape[-2]
1731
+ if past_key_value is not None:
1732
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
1733
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
1734
 
1735
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
 
1783
  def __init__(self, config: MiniCPMConfig, layer_idx: int):
1784
  super().__init__()
1785
  self.hidden_size = config.hidden_size
1786
+ if config.sparse_config is not None and torch.cuda.is_available():
1787
+ raise NotImplementedError("MiniCPM4-0.5B does not support sparse attention yet.")
1788
+ else:
1789
+ self.self_attn = MINICPM_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
1790
 
1791
  self.mlp = MiniCPMMLP(config)
1792
  self.input_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
2052
  raise ValueError(
2053
  'You must use the new past_key_values format, such as the Cache class, instead of the old tuple format.'
2054
  )
2055
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
2056
 
2057
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
2058
+ if self.config.sparse_config is not None and torch.cuda.is_available() and past_key_values_length == 0:
2059
+ past_key_values = DynamicCacheQKV()
2060
 
2061
  if position_ids is None:
2062
  device = input_ids.device if input_ids is not None else inputs_embeds.device
 
2282
  ):
2283
  if past_key_values is not None:
2284
  if isinstance(past_key_values, Cache):
 
2285
  cache_length = past_key_values.get_seq_length()
2286
+ past_length = past_key_values.seen_tokens
2287
+ max_cache_length = None # past_key_values.get_max_length()
 
 
2288
  else:
2289
+ cache_length = past_length = past_key_values[0][0].shape[2]
2290
+ max_cache_length = None
 
2291
 
2292
  # Keep only the unprocessed tokens:
2293
  # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where