Files changed (3) hide show
  1. README.md +2 -51
  2. config.json +1 -1
  3. modeling_minicpm.py +72 -134
README.md CHANGED
@@ -20,7 +20,6 @@ library_name: transformers
20
  </p>
21
 
22
  ## What's New
23
- - [2025.09.29] **[InfLLM-V2 paper](https://arxiv.org/abs/2509.24663) is released!** We can train a sparse attention model with only 5B long-text tokens. 🔥🔥🔥
24
  - [2025.09.05] **MiniCPM4.1** series are released! This series is a hybrid reasoning model with trainable sparse attention, which can be used in both deep reasoning mode and non-reasoning mode. 🔥🔥🔥
25
  - [2025.06.06] **MiniCPM4** series are released! This model achieves ultimate efficiency improvements while maintaining optimal performance at the same scale! It can achieve over 5x generation acceleration on typical end-side chips! You can find technical report [here](https://arxiv.org/abs/2506.07900).🔥🔥🔥
26
 
@@ -64,11 +63,6 @@ MiniCPM4.1 launches end-side versions with 8B parameter scale, both achieving be
64
 
65
  ![benchmark](https://github.com/OpenBMB/MiniCPM/blob/main/assets/minicpm4/benchmark4.1.png?raw=true)
66
 
67
- ### Best Practices
68
- 1. It is advisable to use temperature=0.9, topp=0.95. And we suggest setting max_output_token to 65,536 tokens.
69
- 2. For math problems, we recommend using "Please reason step by step, and put your final answer within \boxed{}."
70
- 3. And for English multiple-choice questions, we recommend starting with "Answer the following multiple choice question. The last line of your response should be of the following format: 'ANSWER: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering." And "你回答的最后一行必须是以下格式 '答案:$选项' (不带引号), 其中选项是ABCD之一。请在回答之前一步步思考" for Chinese MCQ.
71
-
72
  ### Efficiency Evaluation
73
  MiniCPM4.1 adopts sparse attention and speculative decoding to improve the inference efficiency. On RTX 4090, MiniCPM4.1 achieves 3x decoding speed improvement in reasoning.
74
 
@@ -84,17 +78,8 @@ MiniCPM4.1 adopts sparse attention and speculative decoding to improve the infer
84
  ## Usage
85
  MiniCPM 4.1 can be used with following frameworks: Huggingface Transformers, SGLang, vLLM, and CPM.cu. For the ultimate inference speed, we highly recommend CPM.cu.
86
 
87
- MiniCPM4/MiniCPM4.1 supports both dense attention inference and sparse attention inference modes, where vLLM and SGLang currently only support dense inference mode. If you want to use sparse inference mode, please use Huggingface Transformers and CPM.cu.
88
-
89
- - Dense attention inference: vLLM, SGLang, Huggingface Transformers
90
- - Sparse attention inference: Huggingface Transformers, CPM.cu
91
-
92
- **To facilitate researches in sparse attention, we provide [InfLLM-V2 Training Kernels](https://github.com/OpenBMB/infllmv2_cuda_impl) and [InfLLM-V2 Inference Kernels](https://github.com/openbmb/cpm.cu).**
93
 
94
  ### Inference with Transformers
95
- MiniCPM4.1-8B requires `transformers>=4.56`.
96
-
97
- - **Inference with Dense Attention**
98
  ```python
99
  from transformers import AutoModelForCausalLM, AutoTokenizer
100
  import torch
@@ -134,7 +119,6 @@ responses = tokenizer.batch_decode(output_token_ids, skip_special_tokens=True)[0
134
  print(responses)
135
  ```
136
 
137
- - **Inference with Sparse Attention**
138
  MiniCPM4.1-8B supports `InfLLM v2`, a sparse attention mechanism designed for efficient long-sequence inference. It requires the [infllmv2_cuda_impl](https://github.com/OpenBMB/infllmv2_cuda_impl) library.
139
 
140
  You can install it by running the following command:
@@ -172,7 +156,6 @@ These parameters control the behavior of InfLLM v2:
172
  * `use_nope` (default: false): Whether to use the NOPE technique in block selection for improved performance.
173
  * `dense_len` (default: 8192): Since Sparse Attention offers limited benefits for short sequences, the model can use standard (dense) attention for shorter texts. The model will use dense attention for sequences with a token length below `dense_len` and switch to sparse attention for sequences exceeding this length. Set this to `-1` to always use sparse attention regardless of sequence length.
174
 
175
- - **Long Context Extension**
176
  MiniCPM4.1 natively supports context lengths of up to 65,536(64k) tokens. For conversations where the total length (including both input and output) significantly exceeds this limit, we recommend using RoPE scaling techniques for effective handling of long texts. We have validated the model's performance on context lengths of up to 131,072 tokens by modifying the LongRoPE factor.
177
 
178
  You can apply the LongRoPE factor modification by modifying the model files. Specifically, in the `config.json` file, adjust the `rope_scaling` fields.
@@ -484,37 +467,6 @@ python3 -m cpmcu.cli \
484
 
485
  For more details about CPM.cu, please refer to [the repo CPM.cu](https://github.com/OpenBMB/cpm.cu).
486
 
487
- ### Inference with llama.cpp and Ollama
488
-
489
- We also support inference with [llama.cpp](https://github.com/ggml-org/llama.cpp) and [Ollama](https://ollama.com/).
490
-
491
- ##### llama.cpp
492
-
493
- You can download the GGUF format of MiniCPM4.1-8B model from [huggingface](https://huggingface.co/openbmb/MiniCPM4.1-8B-GGUF) and run it with llama.cpp for efficient CPU or GPU inference.
494
- ```
495
- # case 1: main-cli
496
- ./build/bin/llama-cli -m MiniCPM4.1-8B-Q4_K_M.gguf -p "Write an article about Artificial Intelligence." -n 1500
497
-
498
- # case 2: server
499
- ## launch server
500
- ./build/bin/llama-server -m MiniCPM4.1-8B-Q4_K_M.gguf --host 127.0.0.1 --port 8080 -c 4096 -fa on &
501
-
502
- ## send request
503
- curl -X POST http://127.0.0.1:8080/v1/chat/completions \
504
- -H "Content-Type: application/json" \
505
- -d '{
506
- "model": "gpt-3.5-turbo",
507
- "messages": [{"role": "user", "content": "Write an article about Artificial Intelligence."}],
508
- "max_tokens": 1500
509
- }'
510
- ```
511
-
512
- ##### Ollama
513
- Please refer to [model hub](https://ollama.com/openbmb/minicpm4.1) for model download. After installing ollama package, you can use MiniCPM4.1 with following commands:
514
- ```
515
- ollama run openbmb/minicpm4.1
516
- ```
517
-
518
  ### Hybird Reasoning Mode
519
 
520
  MiniCPM4.1 supports hybrid reasoning mode, which can be used in both deep reasoning mode and non-reasoning mode. To enable hybrid reasoning mode. User can set `enable_thinking=True` in `tokenizer.apply_chat_template` to enable hybrid reasoning mode, and set `enable_thinking=False` to enable non-reasoning mode. Similarly, user can directly add `/no_think` at the end of the query to enable non-reasoning mode. If not add any special token or add `/think` at the end of the query, the model will enable reasoning mode.
@@ -550,9 +502,8 @@ prompt_text = tokenizer.apply_chat_template(
550
 
551
  ```bibtex
552
  @article{minicpm4,
553
- title={Minicpm4: Ultra-efficient llms on end devices},
554
- author={MiniCPM, Team},
555
- journal={arXiv preprint arXiv:2506.07900},
556
  year={2025}
557
  }
558
  ```
 
20
  </p>
21
 
22
  ## What's New
 
23
  - [2025.09.05] **MiniCPM4.1** series are released! This series is a hybrid reasoning model with trainable sparse attention, which can be used in both deep reasoning mode and non-reasoning mode. 🔥🔥🔥
24
  - [2025.06.06] **MiniCPM4** series are released! This model achieves ultimate efficiency improvements while maintaining optimal performance at the same scale! It can achieve over 5x generation acceleration on typical end-side chips! You can find technical report [here](https://arxiv.org/abs/2506.07900).🔥🔥🔥
25
 
 
63
 
64
  ![benchmark](https://github.com/OpenBMB/MiniCPM/blob/main/assets/minicpm4/benchmark4.1.png?raw=true)
65
 
 
 
 
 
 
66
  ### Efficiency Evaluation
67
  MiniCPM4.1 adopts sparse attention and speculative decoding to improve the inference efficiency. On RTX 4090, MiniCPM4.1 achieves 3x decoding speed improvement in reasoning.
68
 
 
78
  ## Usage
79
  MiniCPM 4.1 can be used with following frameworks: Huggingface Transformers, SGLang, vLLM, and CPM.cu. For the ultimate inference speed, we highly recommend CPM.cu.
80
 
 
 
 
 
 
 
81
 
82
  ### Inference with Transformers
 
 
 
83
  ```python
84
  from transformers import AutoModelForCausalLM, AutoTokenizer
85
  import torch
 
119
  print(responses)
120
  ```
121
 
 
122
  MiniCPM4.1-8B supports `InfLLM v2`, a sparse attention mechanism designed for efficient long-sequence inference. It requires the [infllmv2_cuda_impl](https://github.com/OpenBMB/infllmv2_cuda_impl) library.
123
 
124
  You can install it by running the following command:
 
156
  * `use_nope` (default: false): Whether to use the NOPE technique in block selection for improved performance.
157
  * `dense_len` (default: 8192): Since Sparse Attention offers limited benefits for short sequences, the model can use standard (dense) attention for shorter texts. The model will use dense attention for sequences with a token length below `dense_len` and switch to sparse attention for sequences exceeding this length. Set this to `-1` to always use sparse attention regardless of sequence length.
158
 
 
159
  MiniCPM4.1 natively supports context lengths of up to 65,536(64k) tokens. For conversations where the total length (including both input and output) significantly exceeds this limit, we recommend using RoPE scaling techniques for effective handling of long texts. We have validated the model's performance on context lengths of up to 131,072 tokens by modifying the LongRoPE factor.
160
 
161
  You can apply the LongRoPE factor modification by modifying the model files. Specifically, in the `config.json` file, adjust the `rope_scaling` fields.
 
467
 
468
  For more details about CPM.cu, please refer to [the repo CPM.cu](https://github.com/OpenBMB/cpm.cu).
469
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
  ### Hybird Reasoning Mode
471
 
472
  MiniCPM4.1 supports hybrid reasoning mode, which can be used in both deep reasoning mode and non-reasoning mode. To enable hybrid reasoning mode. User can set `enable_thinking=True` in `tokenizer.apply_chat_template` to enable hybrid reasoning mode, and set `enable_thinking=False` to enable non-reasoning mode. Similarly, user can directly add `/no_think` at the end of the query to enable non-reasoning mode. If not add any special token or add `/think` at the end of the query, the model will enable reasoning mode.
 
502
 
503
  ```bibtex
504
  @article{minicpm4,
505
+ title={{MiniCPM4}: Ultra-Efficient LLMs on End Devices},
506
+ author={MiniCPM Team},
 
507
  year={2025}
508
  }
509
  ```
config.json CHANGED
@@ -30,7 +30,7 @@
30
  "original_max_position_embeddings": 65536
31
  },
32
  "torch_dtype": "bfloat16",
33
- "transformers_version": "4.56.1",
34
  "use_cache": true,
35
  "vocab_size": 73448,
36
  "rope_theta": 10000.0,
 
30
  "original_max_position_embeddings": 65536
31
  },
32
  "torch_dtype": "bfloat16",
33
+ "transformers_version": "4.46.3",
34
  "use_cache": true,
35
  "vocab_size": 73448,
36
  "rope_theta": 10000.0,
modeling_minicpm.py CHANGED
@@ -21,7 +21,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
21
  import torch
22
  import torch.nn.functional as F
23
  import torch.utils.checkpoint
24
- from torch import nn
25
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
26
  from transformers.activations import ACT2FN
27
  from transformers.cache_utils import Cache, DynamicCache, CacheLayerMixin, DynamicLayer
@@ -47,9 +47,7 @@ from transformers.utils import (
47
  )
48
  from transformers.utils.import_utils import is_torch_fx_available
49
 
50
-
51
-
52
- from .configuration_minicpm import MiniCPMConfig #!一定要改
53
 
54
  try:
55
  from flash_attn import flash_attn_func, flash_attn_varlen_func
@@ -70,28 +68,50 @@ from functools import lru_cache
70
  def compressed_attention(
71
  q: torch.Tensor,
72
  k: torch.Tensor,
73
- k2: torch.Tensor,
74
  kernel_size: int,
75
  kernel_stride: int,
76
  block_size: int,
77
  topk: int,
78
  cu_seqlens_q: torch.Tensor,
79
  cu_seqlens_k: torch.Tensor,
80
- cu_seqlens_k2: torch.Tensor,
81
  max_seqlen_q: int,
82
  max_seqlen_k: int,
83
  sm_scale: float = None,
84
  init_blocks: int = 1,
85
  local_blocks: int = 2,
86
- cache_lens=None,
87
  ) -> Tuple[torch.Tensor, torch.Tensor]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  with torch.no_grad():
89
  batch_size = cu_seqlens_q.shape[0] - 1
90
 
91
  # Check if it's prefilling stage
92
  is_prefilling = cache_lens is None or (cache_lens == 0).all().item()
93
-
94
- if is_prefilling: # prefilling stage
 
95
  # Calculate q_idx for each query position in each batch
96
  cache_lens = torch.zeros(batch_size, dtype=torch.int32, device=q.device)
97
  q_idx = torch.cat([
@@ -99,24 +119,25 @@ def compressed_attention(
99
  max_seqlen_q - (cu_seqlens_q[i + 1] - cu_seqlens_q[i])) // block_size
100
  for i in range(batch_size)
101
  ], dim=0) # shape: [total_q_len]
102
- else: # decoding stage
103
- # Each batch has only one query (last position)
104
- q_idx = cache_lens // block_size # shape: [batch_size] = [total_q_len] in decoding
 
105
 
106
- # 计算attention score
107
  score = infllmv2_attn_stage1(
108
  q.contiguous(),
109
  k.contiguous(),
110
- k2.contiguous(),
111
  cu_seqlens_q=cu_seqlens_q,
112
  cu_seqlens_k=cu_seqlens_k,
113
- cu_seqlens_v=cu_seqlens_k2,
114
  max_seqlen_q=max_seqlen_q,
115
  max_seqlen_k=max_seqlen_k,
116
- causal=is_prefilling
117
- )
118
- score = score[:, :q_idx.shape[0], :] # [num_heads, total_q_len, num_blocks]
119
-
 
120
  block_score = max_pooling_1d_varlen(
121
  score.contiguous(),
122
  cu_seqlens_q,
@@ -127,9 +148,7 @@ def compressed_attention(
127
  local_blocks=local_blocks,
128
  init_blocks=init_blocks,
129
  block_size=block_size,
130
- stride=kernel_stride
131
- ) # shape: [num_heads, total_q_len, num_blocks]
132
-
133
 
134
  # get topk
135
  topk = min(topk, block_score.shape[-1])
@@ -243,11 +262,6 @@ class InfLLMv2CacheLayer(DynamicLayer):
243
  self.no_compress_k_cache = []
244
  self.cached_compressed_cu_seqlens = torch.tensor([], dtype=torch.int32)
245
  self.compress_k_cache_varlen = torch.tensor([], dtype=torch.float32)
246
- # Add support for compress_k2
247
- self.compress_k2_cache = []
248
- self.cached_compressed_cu_seqlens2 = torch.tensor([], dtype=torch.int32)
249
- self.compress_k2_cache_varlen = torch.tensor([], dtype=torch.float32)
250
- self.no_compress_k2_cache = []
251
 
252
  def update_no_rope_key(self, key_states):
253
  if self.no_rope_keys.numel() == 0:
@@ -289,45 +303,12 @@ class InfLLMv2CacheLayer(DynamicLayer):
289
  k_chunk_list.append(None)
290
  return k_chunk_list
291
 
292
- def update_compress_k2(self, key_states, cu_seqlens=None):
293
- if len(self.compress_k2_cache) == 0:
294
- if cu_seqlens is not None:
295
- self.cached_compressed_cu_seqlens2 = cu_seqlens.clone()
296
- self.compress_k2_cache_varlen = key_states
297
- split_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
298
- self.compress_k2_cache = list(torch.split(key_states, split_sizes))
299
- else:
300
- for index, k in enumerate(key_states):
301
- if k is not None:
302
- self.compress_k2_cache[index] = torch.cat([self.compress_k2_cache[index], k], dim=0)
303
- new_seq_lens = torch.tensor([tensor.shape[0] for tensor in self.compress_k2_cache], dtype=torch.int32)
304
- new_cumsum = torch.cumsum(new_seq_lens, dim=0, dtype=torch.int32)
305
-
306
- self.compress_k2_cache_varlen = torch.cat(self.compress_k2_cache, dim=0)
307
- self.cached_compressed_cu_seqlens2 = torch.cat([torch.tensor([0], dtype=torch.int32), new_cumsum]).to(self.compress_k2_cache_varlen.device)
308
- return self.compress_k2_cache_varlen, self.cached_compressed_cu_seqlens2
309
-
310
- def update_no_compress_k2(self, key_states, kernel_size=128, kernel_stride=64):
311
- k_chunk_list = []
312
- for index, k in enumerate(key_states):
313
- if len(self.no_compress_k2_cache) <= index:
314
- self.no_compress_k2_cache.append(k)
315
- else:
316
- self.no_compress_k2_cache[index] = torch.cat([self.no_compress_k2_cache[index], k], dim=0)
317
- current_len = self.no_compress_k2_cache[index].shape[0]
318
- if current_len >= kernel_size:
319
- k_chunk_list.append(self.no_compress_k2_cache[index][:kernel_size])
320
- self.no_compress_k2_cache[index] = self.no_compress_k2_cache[index][kernel_stride:]
321
- else:
322
- k_chunk_list.append(None)
323
- return k_chunk_list
324
-
325
  class InfLLMv2Cache(DynamicCache):
326
- def __init__(self, config,num_hidden_layers: Optional[int] = None) -> None:
 
327
  super().__init__(config=config)
328
  self.layers = [InfLLMv2CacheLayer() for _ in range(num_hidden_layers)] if num_hidden_layers else []
329
  self._seen_tokens = 0
330
-
331
 
332
  def update(self, key_states, value_states, layer_idx, cache_kwargs=None):
333
  if layer_idx == 0:
@@ -343,12 +324,6 @@ class InfLLMv2Cache(DynamicCache):
343
  def update_no_compress_k(self, key_states, layer_idx, kernel_size=32, kernel_stride=16, cache_kwargs=None):
344
  return self.layers[layer_idx].update_no_compress_k(key_states, kernel_size, kernel_stride)
345
 
346
- def update_compress_k2(self, key_states, layer_idx, cu_seqlens=None, cache_kwargs=None):
347
- return self.layers[layer_idx].update_compress_k2(key_states, cu_seqlens)
348
-
349
- def update_no_compress_k2(self, key_states, layer_idx, kernel_size=128, kernel_stride=64, cache_kwargs=None):
350
- return self.layers[layer_idx].update_no_compress_k2(key_states, kernel_size, kernel_stride)
351
-
352
  def crop(self, max_length):
353
  for layer in self.layers:
354
  layer.crop(max_length)
@@ -616,6 +591,7 @@ def _unpad_one_tensor(hidden_states, attention_mask):
616
  unpadded_states = index_first_axis(reshaped_states, indices)
617
 
618
  return unpadded_states, indices, cu_seqlens, max_seqlen_in_batch
 
619
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
620
  """
621
  This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
@@ -1022,9 +998,7 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1022
  self.local_blocks = self.window_size // self.block_size # local_blocks
1023
  self.topk = self.config.sparse_config.get('topk', 64) + (self.window_size//self.block_size)
1024
  self.use_nope = self.config.sparse_config.get('use_nope', False)
1025
-
1026
  self.compress_k = CompressK(self.num_key_value_heads, self.head_dim, kernel_size=self.kernel_size, kernel_stride=self.kernel_stride)
1027
- self.compress_k2 = CompressK(self.num_key_value_heads, self.head_dim, kernel_size=self.kernel_size*4, kernel_stride=self.kernel_stride*4)
1028
 
1029
  def forward(
1030
  self,
@@ -1049,7 +1023,6 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1049
 
1050
  bsz, q_len, _ = hidden_states.size()
1051
 
1052
-
1053
  query_states = self.q_proj(hidden_states)
1054
  key_states = self.k_proj(hidden_states)
1055
  value_states = self.v_proj(hidden_states)
@@ -1080,12 +1053,11 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1080
  key_states = key_states.transpose(1, 2)
1081
  value_states = value_states.transpose(1, 2)
1082
  if self.use_nope:
1083
- key_states_no_rope =past_key_value.update_no_rope_key(key_states_no_rope, self.layer_idx)
1084
  no_rope_param = {
1085
  'key_states_no_rope': key_states_no_rope,
1086
  'query_states_no_rope': query_states_no_rope,
1087
  }
1088
-
1089
  else:
1090
  no_rope_param = None
1091
 
@@ -1131,8 +1103,16 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1131
  return attn_output, attn_weights, past_key_value
1132
 
1133
  def _sparse_attention_forward(
1134
- self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None, no_rope_param=None, past_key_value=None
1135
- ):
 
 
 
 
 
 
 
 
1136
  """
1137
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
1138
  first unpad the input, then computes the attention scores and pad the final attention scores.
@@ -1162,17 +1142,15 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1162
  batch_size = query_states.shape[0]
1163
  # assert batch_size == 1, 'Only batch_size=1 is supported at the moment.'
1164
  if past_key_value!=None:
1165
- compressed_k, compressed_cu_seqlens, compressed_k2, compressed_cu_seqlens2 = self.get_compress_k(
1166
  key_states=key_states if self.use_nope ==False else no_rope_param['key_states_no_rope'], # This can be optimized a bit;
1167
  attention_mask=attention_mask,
1168
- past_key_value=past_key_value,
1169
-
1170
- )
1171
 
1172
  query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
1173
  query_states, key_states, value_states, attention_mask, query_length
1174
  )
1175
-
1176
  cu_seqlens_q, cu_seqlens_k = cu_seq_lens
1177
  max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
1178
  if no_rope_param != None:
@@ -1183,12 +1161,7 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1183
  if past_key_value==None:
1184
  # compress_k use varlen form
1185
  compressed_k, compressed_cu_seqlens = self.compress_k(key_states,cu_seqlens_k)
1186
- compressed_k2, compressed_cu_seqlens2 = self.compress_k2(key_states,cu_seqlens_k)
1187
- else:
1188
- # compressed_k and compressed_k2 already retrieved from get_compress_k above
1189
- pass
1190
 
1191
-
1192
  attn_output_unpad = self.sparse_forward(
1193
  query_states,
1194
  key_states,
@@ -1198,16 +1171,15 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1198
  max_seqlen_in_batch_q,
1199
  max_seqlen_in_batch_k,
1200
  no_rope_param=no_rope_param,
1201
- compressed_k=compressed_k, compressed_cu_seqlens=compressed_cu_seqlens,
1202
- compressed_k2=compressed_k2, compressed_cu_seqlens2=compressed_cu_seqlens2
1203
- )
1204
 
1205
  attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
1206
-
1207
  else:
1208
  raise ValueError('Need attention mask')
1209
 
1210
  return attn_output
 
1211
  def get_compress_k(self, key_states, attention_mask, past_key_value):
1212
  """
1213
  Get compressed key states and corresponding cumulative sequence lengths.
@@ -1219,51 +1191,34 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1219
  no_rope_param: Optional parameter containing key states without rope
1220
 
1221
  Returns:
1222
- Tuple of (compressed_k, compressed_cu_seqlens, compressed_k2, compressed_cu_seqlens2)
1223
  """
1224
-
1225
  # Check if this is prefilling or initial compression condition
1226
-
1227
  is_prefilling = (
1228
  key_states.shape[1] >= self.dense_len and
1229
  (
1230
  not past_key_value.layers[self.layer_idx].compress_k_cache
1231
  )
1232
  )
1233
-
1234
  if is_prefilling:
1235
  unpadded_key_states, indices, cu_seqlens, max_seqlen_in_batch = _unpad_one_tensor(key_states,attention_mask=attention_mask)
1236
  # Compress the keys
1237
  compressed_k, compressed_cu_seqlens = self.compress_k(unpadded_key_states, cu_seqlens)
1238
- compressed_k2, compressed_cu_seqlens2 = self.compress_k2(unpadded_key_states, cu_seqlens)
1239
-
1240
  past_key_value.update_compress_k(
1241
  compressed_k, self.layer_idx, compressed_cu_seqlens)
1242
- past_key_value.update_compress_k2(
1243
- compressed_k2, self.layer_idx, compressed_cu_seqlens2)
1244
-
1245
  no_compress_k_list = []
1246
  # Compute and update no_compress_k
1247
  for i in range(len(compressed_cu_seqlens)-1):
1248
  no_compress_k_start = (compressed_cu_seqlens[i+1]- compressed_cu_seqlens[i]) * self.kernel_stride
1249
-
1250
  no_compress_k_list.append(unpadded_key_states[cu_seqlens[i]+no_compress_k_start:cu_seqlens[i+1]].clone())
1251
 
1252
  past_key_value.update_no_compress_k(
1253
  no_compress_k_list, self.layer_idx,kernel_stride=self.kernel_stride,
1254
  kernel_size=self.kernel_size)
1255
-
1256
- # Also update no_compress_k2
1257
- no_compress_k2_list = []
1258
- for i in range(len(compressed_cu_seqlens2)-1):
1259
- no_compress_k2_start = (compressed_cu_seqlens2[i+1]- compressed_cu_seqlens2[i]) * self.kernel_stride * 4
1260
-
1261
- no_compress_k2_list.append(unpadded_key_states[cu_seqlens[i]+no_compress_k2_start:cu_seqlens[i+1]].clone())
1262
-
1263
- past_key_value.update_no_compress_k2(
1264
- no_compress_k2_list, self.layer_idx,kernel_stride=self.kernel_stride*4,
1265
- kernel_size=self.kernel_size*4)
1266
-
1267
  else:
1268
  # Decode case: incremental update
1269
  batch_size = key_states.shape[0] # key_states.shape = [batch_size, seq, k_head_num, head_dim]
@@ -1278,32 +1233,16 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1278
  kernel_size=self.kernel_size)
1279
  new_compressed_k_list = []
1280
  for no_compress_k in no_compress_k_list:
1281
-
1282
  if no_compress_k is not None:
1283
  # We have enough tokens to compress
1284
  new_compressed_k = no_compress_k.mean(dim=0, keepdim=True) # [1, n_heads_k, head_dim]
1285
-
1286
  new_compressed_k_list.append(new_compressed_k)
1287
  else:
1288
  new_compressed_k_list.append(None)
1289
  compressed_k, compressed_cu_seqlens = past_key_value.update_compress_k(new_compressed_k_list, self.layer_idx,)
1290
-
1291
- # For compress_k2, update no_compress_k2 buffer and compress when ready
1292
- no_compress_k2_list = past_key_value.update_no_compress_k2(
1293
- key_states_split, self.layer_idx,
1294
- kernel_stride=self.kernel_stride*4,
1295
- kernel_size=self.kernel_size*4)
1296
- new_compressed_k2_list = []
1297
- for no_compress_k2 in no_compress_k2_list:
1298
- if no_compress_k2 is not None:
1299
- # We have enough tokens to compress for k2
1300
- new_compressed_k2 = no_compress_k2.mean(dim=0, keepdim=True) # [1, n_heads_k, head_dim]
1301
- new_compressed_k2_list.append(new_compressed_k2)
1302
- else:
1303
- new_compressed_k2_list.append(None)
1304
- compressed_k2, compressed_cu_seqlens2 = past_key_value.update_compress_k2(new_compressed_k2_list, self.layer_idx,)
1305
-
1306
- return compressed_k, compressed_cu_seqlens, compressed_k2, compressed_cu_seqlens2
1307
  def sparse_forward(self,
1308
  query_layer,
1309
  key_layer,
@@ -1313,8 +1252,8 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1313
  max_seqlen_in_batch_q,
1314
  max_seqlen_in_batch_k,
1315
  no_rope_param=None,
1316
- compressed_k=None, compressed_cu_seqlens=None,
1317
- compressed_k2=None, compressed_cu_seqlens2=None):
1318
  compressed_seqlens = compressed_cu_seqlens[1:] - compressed_cu_seqlens[:-1]
1319
  cache_lens = None
1320
  if max_seqlen_in_batch_q==1 and max_seqlen_in_batch_k>1: #decoding
@@ -1324,14 +1263,13 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1324
  topk_idx = compressed_attention(
1325
  query_layer if no_rope_param is None else no_rope_param['query_states_no_rope'],
1326
  compressed_k,
1327
- compressed_k2,
1328
  self.kernel_size,
1329
  self.kernel_stride,
1330
  self.block_size,
1331
  self.topk,
1332
  cu_seqlens_q,
1333
  compressed_cu_seqlens,
1334
- compressed_cu_seqlens2,
1335
  max_seqlen_in_batch_q,
1336
  compressed_seqlens.max().item(),
1337
  None,
 
21
  import torch
22
  import torch.nn.functional as F
23
  import torch.utils.checkpoint
24
+ from torch import nn
25
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
26
  from transformers.activations import ACT2FN
27
  from transformers.cache_utils import Cache, DynamicCache, CacheLayerMixin, DynamicLayer
 
47
  )
48
  from transformers.utils.import_utils import is_torch_fx_available
49
 
50
+ from .configuration_minicpm import MiniCPMConfig
 
 
51
 
52
  try:
53
  from flash_attn import flash_attn_func, flash_attn_varlen_func
 
68
  def compressed_attention(
69
  q: torch.Tensor,
70
  k: torch.Tensor,
71
+ v: torch.Tensor,
72
  kernel_size: int,
73
  kernel_stride: int,
74
  block_size: int,
75
  topk: int,
76
  cu_seqlens_q: torch.Tensor,
77
  cu_seqlens_k: torch.Tensor,
 
78
  max_seqlen_q: int,
79
  max_seqlen_k: int,
80
  sm_scale: float = None,
81
  init_blocks: int = 1,
82
  local_blocks: int = 2,
83
+ cache_lens: torch.Tensor = None,
84
  ) -> Tuple[torch.Tensor, torch.Tensor]:
85
+ """Attention between query and compressed key and value. Compute attention output and topk block idx used in topk_sparse_attention.
86
+
87
+ Args:
88
+ q (torch.Tensor): shape [total_q_len, num_q_heads, head_dim]
89
+ k (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]
90
+ v (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]
91
+ kernel_size (int): kernel size in compress_key_value
92
+ kernel_stride (int): stride of compress_key_value
93
+ block_size (int): key value block size for topk sparse attention.
94
+ topk (int): number of blocks for each query.
95
+ cu_seqlens_q (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
96
+ cu_seqlens_k (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_k in flash_attn_func_varlen.
97
+ max_seqlen_q (int): max q len of the batch.
98
+ max_seqlen_k (int): max k len of the batch.
99
+ sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim).
100
+ init_blocks (int, optional): Number of init blocks for each query. Defaults to 1.
101
+ local_blocks (int, optional): Number of local blocks for each query. Defaults to 2.
102
+ cache_lens (torch.Tensor, optional): shape [batch_size], used to record the cache length of each query. Defaults to None.
103
+
104
+ Returns:
105
+ Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention
106
+ """
107
  with torch.no_grad():
108
  batch_size = cu_seqlens_q.shape[0] - 1
109
 
110
  # Check if it's prefilling stage
111
  is_prefilling = cache_lens is None or (cache_lens == 0).all().item()
112
+
113
+ # prefilling stage
114
+ if is_prefilling:
115
  # Calculate q_idx for each query position in each batch
116
  cache_lens = torch.zeros(batch_size, dtype=torch.int32, device=q.device)
117
  q_idx = torch.cat([
 
119
  max_seqlen_q - (cu_seqlens_q[i + 1] - cu_seqlens_q[i])) // block_size
120
  for i in range(batch_size)
121
  ], dim=0) # shape: [total_q_len]
122
+ # decoding stage
123
+ else:
124
+ # Each batch has only one query (last position). Shape: [batch_size] = [total_q_len] in decoding
125
+ q_idx = cache_lens // block_size
126
 
127
+ # compute attention score
128
  score = infllmv2_attn_stage1(
129
  q.contiguous(),
130
  k.contiguous(),
131
+ v.contiguous(),
132
  cu_seqlens_q=cu_seqlens_q,
133
  cu_seqlens_k=cu_seqlens_k,
 
134
  max_seqlen_q=max_seqlen_q,
135
  max_seqlen_k=max_seqlen_k,
136
+ causal=is_prefilling)
137
+ # Shape: [num_heads, total_q_len, num_blocks]
138
+ score = score[:, :q_idx.shape[0], :]
139
+
140
+ # Shape: [num_heads, total_q_len, num_blocks]
141
  block_score = max_pooling_1d_varlen(
142
  score.contiguous(),
143
  cu_seqlens_q,
 
148
  local_blocks=local_blocks,
149
  init_blocks=init_blocks,
150
  block_size=block_size,
151
+ stride=kernel_stride)
 
 
152
 
153
  # get topk
154
  topk = min(topk, block_score.shape[-1])
 
262
  self.no_compress_k_cache = []
263
  self.cached_compressed_cu_seqlens = torch.tensor([], dtype=torch.int32)
264
  self.compress_k_cache_varlen = torch.tensor([], dtype=torch.float32)
 
 
 
 
 
265
 
266
  def update_no_rope_key(self, key_states):
267
  if self.no_rope_keys.numel() == 0:
 
303
  k_chunk_list.append(None)
304
  return k_chunk_list
305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  class InfLLMv2Cache(DynamicCache):
307
+ def __init__(self,
308
+ config,num_hidden_layers: Optional[int] = None) -> None:
309
  super().__init__(config=config)
310
  self.layers = [InfLLMv2CacheLayer() for _ in range(num_hidden_layers)] if num_hidden_layers else []
311
  self._seen_tokens = 0
 
312
 
313
  def update(self, key_states, value_states, layer_idx, cache_kwargs=None):
314
  if layer_idx == 0:
 
324
  def update_no_compress_k(self, key_states, layer_idx, kernel_size=32, kernel_stride=16, cache_kwargs=None):
325
  return self.layers[layer_idx].update_no_compress_k(key_states, kernel_size, kernel_stride)
326
 
 
 
 
 
 
 
327
  def crop(self, max_length):
328
  for layer in self.layers:
329
  layer.crop(max_length)
 
591
  unpadded_states = index_first_axis(reshaped_states, indices)
592
 
593
  return unpadded_states, indices, cu_seqlens, max_seqlen_in_batch
594
+
595
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
596
  """
597
  This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
 
998
  self.local_blocks = self.window_size // self.block_size # local_blocks
999
  self.topk = self.config.sparse_config.get('topk', 64) + (self.window_size//self.block_size)
1000
  self.use_nope = self.config.sparse_config.get('use_nope', False)
 
1001
  self.compress_k = CompressK(self.num_key_value_heads, self.head_dim, kernel_size=self.kernel_size, kernel_stride=self.kernel_stride)
 
1002
 
1003
  def forward(
1004
  self,
 
1023
 
1024
  bsz, q_len, _ = hidden_states.size()
1025
 
 
1026
  query_states = self.q_proj(hidden_states)
1027
  key_states = self.k_proj(hidden_states)
1028
  value_states = self.v_proj(hidden_states)
 
1053
  key_states = key_states.transpose(1, 2)
1054
  value_states = value_states.transpose(1, 2)
1055
  if self.use_nope:
1056
+ key_states_no_rope = past_key_value.update_no_rope_key(key_states_no_rope, self.layer_idx)
1057
  no_rope_param = {
1058
  'key_states_no_rope': key_states_no_rope,
1059
  'query_states_no_rope': query_states_no_rope,
1060
  }
 
1061
  else:
1062
  no_rope_param = None
1063
 
 
1103
  return attn_output, attn_weights, past_key_value
1104
 
1105
  def _sparse_attention_forward(
1106
+ self,
1107
+ query_states,
1108
+ key_states,
1109
+ value_states,
1110
+ attention_mask,
1111
+ query_length,
1112
+ dropout=0.0,
1113
+ softmax_scale=None,
1114
+ no_rope_param=None,
1115
+ past_key_value=None):
1116
  """
1117
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
1118
  first unpad the input, then computes the attention scores and pad the final attention scores.
 
1142
  batch_size = query_states.shape[0]
1143
  # assert batch_size == 1, 'Only batch_size=1 is supported at the moment.'
1144
  if past_key_value!=None:
1145
+ compressed_k, compressed_cu_seqlens = self.get_compress_k(
1146
  key_states=key_states if self.use_nope ==False else no_rope_param['key_states_no_rope'], # This can be optimized a bit;
1147
  attention_mask=attention_mask,
1148
+ past_key_value=past_key_value)
 
 
1149
 
1150
  query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
1151
  query_states, key_states, value_states, attention_mask, query_length
1152
  )
1153
+
1154
  cu_seqlens_q, cu_seqlens_k = cu_seq_lens
1155
  max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
1156
  if no_rope_param != None:
 
1161
  if past_key_value==None:
1162
  # compress_k use varlen form
1163
  compressed_k, compressed_cu_seqlens = self.compress_k(key_states,cu_seqlens_k)
 
 
 
 
1164
 
 
1165
  attn_output_unpad = self.sparse_forward(
1166
  query_states,
1167
  key_states,
 
1171
  max_seqlen_in_batch_q,
1172
  max_seqlen_in_batch_k,
1173
  no_rope_param=no_rope_param,
1174
+ compressed_k=compressed_k,
1175
+ compressed_cu_seqlens=compressed_cu_seqlens)
 
1176
 
1177
  attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
 
1178
  else:
1179
  raise ValueError('Need attention mask')
1180
 
1181
  return attn_output
1182
+
1183
  def get_compress_k(self, key_states, attention_mask, past_key_value):
1184
  """
1185
  Get compressed key states and corresponding cumulative sequence lengths.
 
1191
  no_rope_param: Optional parameter containing key states without rope
1192
 
1193
  Returns:
1194
+ Tuple of (compressed_k, compressed_cu_seqlens)
1195
  """
 
1196
  # Check if this is prefilling or initial compression condition
 
1197
  is_prefilling = (
1198
  key_states.shape[1] >= self.dense_len and
1199
  (
1200
  not past_key_value.layers[self.layer_idx].compress_k_cache
1201
  )
1202
  )
1203
+
1204
  if is_prefilling:
1205
  unpadded_key_states, indices, cu_seqlens, max_seqlen_in_batch = _unpad_one_tensor(key_states,attention_mask=attention_mask)
1206
  # Compress the keys
1207
  compressed_k, compressed_cu_seqlens = self.compress_k(unpadded_key_states, cu_seqlens)
1208
+
 
1209
  past_key_value.update_compress_k(
1210
  compressed_k, self.layer_idx, compressed_cu_seqlens)
1211
+
 
 
1212
  no_compress_k_list = []
1213
  # Compute and update no_compress_k
1214
  for i in range(len(compressed_cu_seqlens)-1):
1215
  no_compress_k_start = (compressed_cu_seqlens[i+1]- compressed_cu_seqlens[i]) * self.kernel_stride
1216
+
1217
  no_compress_k_list.append(unpadded_key_states[cu_seqlens[i]+no_compress_k_start:cu_seqlens[i+1]].clone())
1218
 
1219
  past_key_value.update_no_compress_k(
1220
  no_compress_k_list, self.layer_idx,kernel_stride=self.kernel_stride,
1221
  kernel_size=self.kernel_size)
 
 
 
 
 
 
 
 
 
 
 
 
1222
  else:
1223
  # Decode case: incremental update
1224
  batch_size = key_states.shape[0] # key_states.shape = [batch_size, seq, k_head_num, head_dim]
 
1233
  kernel_size=self.kernel_size)
1234
  new_compressed_k_list = []
1235
  for no_compress_k in no_compress_k_list:
 
1236
  if no_compress_k is not None:
1237
  # We have enough tokens to compress
1238
  new_compressed_k = no_compress_k.mean(dim=0, keepdim=True) # [1, n_heads_k, head_dim]
 
1239
  new_compressed_k_list.append(new_compressed_k)
1240
  else:
1241
  new_compressed_k_list.append(None)
1242
  compressed_k, compressed_cu_seqlens = past_key_value.update_compress_k(new_compressed_k_list, self.layer_idx,)
1243
+
1244
+ return compressed_k, compressed_cu_seqlens
1245
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1246
  def sparse_forward(self,
1247
  query_layer,
1248
  key_layer,
 
1252
  max_seqlen_in_batch_q,
1253
  max_seqlen_in_batch_k,
1254
  no_rope_param=None,
1255
+ compressed_k=None,
1256
+ compressed_cu_seqlens=None):
1257
  compressed_seqlens = compressed_cu_seqlens[1:] - compressed_cu_seqlens[:-1]
1258
  cache_lens = None
1259
  if max_seqlen_in_batch_q==1 and max_seqlen_in_batch_k>1: #decoding
 
1263
  topk_idx = compressed_attention(
1264
  query_layer if no_rope_param is None else no_rope_param['query_states_no_rope'],
1265
  compressed_k,
1266
+ compressed_k.clone(),
1267
  self.kernel_size,
1268
  self.kernel_stride,
1269
  self.block_size,
1270
  self.topk,
1271
  cu_seqlens_q,
1272
  compressed_cu_seqlens,
 
1273
  max_seqlen_in_batch_q,
1274
  compressed_seqlens.max().item(),
1275
  None,