20250912
#2
by
ituser1
- opened
- README.md +2 -51
- config.json +1 -1
- modeling_minicpm.py +72 -134
README.md
CHANGED
|
@@ -20,7 +20,6 @@ library_name: transformers
|
|
| 20 |
</p>
|
| 21 |
|
| 22 |
## What's New
|
| 23 |
-
- [2025.09.29] **[InfLLM-V2 paper](https://arxiv.org/abs/2509.24663) is released!** We can train a sparse attention model with only 5B long-text tokens. 🔥🔥🔥
|
| 24 |
- [2025.09.05] **MiniCPM4.1** series are released! This series is a hybrid reasoning model with trainable sparse attention, which can be used in both deep reasoning mode and non-reasoning mode. 🔥🔥🔥
|
| 25 |
- [2025.06.06] **MiniCPM4** series are released! This model achieves ultimate efficiency improvements while maintaining optimal performance at the same scale! It can achieve over 5x generation acceleration on typical end-side chips! You can find technical report [here](https://arxiv.org/abs/2506.07900).🔥🔥🔥
|
| 26 |
|
|
@@ -64,11 +63,6 @@ MiniCPM4.1 launches end-side versions with 8B parameter scale, both achieving be
|
|
| 64 |
|
| 65 |

|
| 66 |
|
| 67 |
-
### Best Practices
|
| 68 |
-
1. It is advisable to use temperature=0.9, topp=0.95. And we suggest setting max_output_token to 65,536 tokens.
|
| 69 |
-
2. For math problems, we recommend using "Please reason step by step, and put your final answer within \boxed{}."
|
| 70 |
-
3. And for English multiple-choice questions, we recommend starting with "Answer the following multiple choice question. The last line of your response should be of the following format: 'ANSWER: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering." And "你回答的最后一行必须是以下格式 '答案:$选项' (不带引号), 其中选项是ABCD之一。请在回答之前一步步思考" for Chinese MCQ.
|
| 71 |
-
|
| 72 |
### Efficiency Evaluation
|
| 73 |
MiniCPM4.1 adopts sparse attention and speculative decoding to improve the inference efficiency. On RTX 4090, MiniCPM4.1 achieves 3x decoding speed improvement in reasoning.
|
| 74 |
|
|
@@ -84,17 +78,8 @@ MiniCPM4.1 adopts sparse attention and speculative decoding to improve the infer
|
|
| 84 |
## Usage
|
| 85 |
MiniCPM 4.1 can be used with following frameworks: Huggingface Transformers, SGLang, vLLM, and CPM.cu. For the ultimate inference speed, we highly recommend CPM.cu.
|
| 86 |
|
| 87 |
-
MiniCPM4/MiniCPM4.1 supports both dense attention inference and sparse attention inference modes, where vLLM and SGLang currently only support dense inference mode. If you want to use sparse inference mode, please use Huggingface Transformers and CPM.cu.
|
| 88 |
-
|
| 89 |
-
- Dense attention inference: vLLM, SGLang, Huggingface Transformers
|
| 90 |
-
- Sparse attention inference: Huggingface Transformers, CPM.cu
|
| 91 |
-
|
| 92 |
-
**To facilitate researches in sparse attention, we provide [InfLLM-V2 Training Kernels](https://github.com/OpenBMB/infllmv2_cuda_impl) and [InfLLM-V2 Inference Kernels](https://github.com/openbmb/cpm.cu).**
|
| 93 |
|
| 94 |
### Inference with Transformers
|
| 95 |
-
MiniCPM4.1-8B requires `transformers>=4.56`.
|
| 96 |
-
|
| 97 |
-
- **Inference with Dense Attention**
|
| 98 |
```python
|
| 99 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 100 |
import torch
|
|
@@ -134,7 +119,6 @@ responses = tokenizer.batch_decode(output_token_ids, skip_special_tokens=True)[0
|
|
| 134 |
print(responses)
|
| 135 |
```
|
| 136 |
|
| 137 |
-
- **Inference with Sparse Attention**
|
| 138 |
MiniCPM4.1-8B supports `InfLLM v2`, a sparse attention mechanism designed for efficient long-sequence inference. It requires the [infllmv2_cuda_impl](https://github.com/OpenBMB/infllmv2_cuda_impl) library.
|
| 139 |
|
| 140 |
You can install it by running the following command:
|
|
@@ -172,7 +156,6 @@ These parameters control the behavior of InfLLM v2:
|
|
| 172 |
* `use_nope` (default: false): Whether to use the NOPE technique in block selection for improved performance.
|
| 173 |
* `dense_len` (default: 8192): Since Sparse Attention offers limited benefits for short sequences, the model can use standard (dense) attention for shorter texts. The model will use dense attention for sequences with a token length below `dense_len` and switch to sparse attention for sequences exceeding this length. Set this to `-1` to always use sparse attention regardless of sequence length.
|
| 174 |
|
| 175 |
-
- **Long Context Extension**
|
| 176 |
MiniCPM4.1 natively supports context lengths of up to 65,536(64k) tokens. For conversations where the total length (including both input and output) significantly exceeds this limit, we recommend using RoPE scaling techniques for effective handling of long texts. We have validated the model's performance on context lengths of up to 131,072 tokens by modifying the LongRoPE factor.
|
| 177 |
|
| 178 |
You can apply the LongRoPE factor modification by modifying the model files. Specifically, in the `config.json` file, adjust the `rope_scaling` fields.
|
|
@@ -484,37 +467,6 @@ python3 -m cpmcu.cli \
|
|
| 484 |
|
| 485 |
For more details about CPM.cu, please refer to [the repo CPM.cu](https://github.com/OpenBMB/cpm.cu).
|
| 486 |
|
| 487 |
-
### Inference with llama.cpp and Ollama
|
| 488 |
-
|
| 489 |
-
We also support inference with [llama.cpp](https://github.com/ggml-org/llama.cpp) and [Ollama](https://ollama.com/).
|
| 490 |
-
|
| 491 |
-
##### llama.cpp
|
| 492 |
-
|
| 493 |
-
You can download the GGUF format of MiniCPM4.1-8B model from [huggingface](https://huggingface.co/openbmb/MiniCPM4.1-8B-GGUF) and run it with llama.cpp for efficient CPU or GPU inference.
|
| 494 |
-
```
|
| 495 |
-
# case 1: main-cli
|
| 496 |
-
./build/bin/llama-cli -m MiniCPM4.1-8B-Q4_K_M.gguf -p "Write an article about Artificial Intelligence." -n 1500
|
| 497 |
-
|
| 498 |
-
# case 2: server
|
| 499 |
-
## launch server
|
| 500 |
-
./build/bin/llama-server -m MiniCPM4.1-8B-Q4_K_M.gguf --host 127.0.0.1 --port 8080 -c 4096 -fa on &
|
| 501 |
-
|
| 502 |
-
## send request
|
| 503 |
-
curl -X POST http://127.0.0.1:8080/v1/chat/completions \
|
| 504 |
-
-H "Content-Type: application/json" \
|
| 505 |
-
-d '{
|
| 506 |
-
"model": "gpt-3.5-turbo",
|
| 507 |
-
"messages": [{"role": "user", "content": "Write an article about Artificial Intelligence."}],
|
| 508 |
-
"max_tokens": 1500
|
| 509 |
-
}'
|
| 510 |
-
```
|
| 511 |
-
|
| 512 |
-
##### Ollama
|
| 513 |
-
Please refer to [model hub](https://ollama.com/openbmb/minicpm4.1) for model download. After installing ollama package, you can use MiniCPM4.1 with following commands:
|
| 514 |
-
```
|
| 515 |
-
ollama run openbmb/minicpm4.1
|
| 516 |
-
```
|
| 517 |
-
|
| 518 |
### Hybird Reasoning Mode
|
| 519 |
|
| 520 |
MiniCPM4.1 supports hybrid reasoning mode, which can be used in both deep reasoning mode and non-reasoning mode. To enable hybrid reasoning mode. User can set `enable_thinking=True` in `tokenizer.apply_chat_template` to enable hybrid reasoning mode, and set `enable_thinking=False` to enable non-reasoning mode. Similarly, user can directly add `/no_think` at the end of the query to enable non-reasoning mode. If not add any special token or add `/think` at the end of the query, the model will enable reasoning mode.
|
|
@@ -550,9 +502,8 @@ prompt_text = tokenizer.apply_chat_template(
|
|
| 550 |
|
| 551 |
```bibtex
|
| 552 |
@article{minicpm4,
|
| 553 |
-
title={
|
| 554 |
-
author={MiniCPM
|
| 555 |
-
journal={arXiv preprint arXiv:2506.07900},
|
| 556 |
year={2025}
|
| 557 |
}
|
| 558 |
```
|
|
|
|
| 20 |
</p>
|
| 21 |
|
| 22 |
## What's New
|
|
|
|
| 23 |
- [2025.09.05] **MiniCPM4.1** series are released! This series is a hybrid reasoning model with trainable sparse attention, which can be used in both deep reasoning mode and non-reasoning mode. 🔥🔥🔥
|
| 24 |
- [2025.06.06] **MiniCPM4** series are released! This model achieves ultimate efficiency improvements while maintaining optimal performance at the same scale! It can achieve over 5x generation acceleration on typical end-side chips! You can find technical report [here](https://arxiv.org/abs/2506.07900).🔥🔥🔥
|
| 25 |
|
|
|
|
| 63 |
|
| 64 |

|
| 65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
### Efficiency Evaluation
|
| 67 |
MiniCPM4.1 adopts sparse attention and speculative decoding to improve the inference efficiency. On RTX 4090, MiniCPM4.1 achieves 3x decoding speed improvement in reasoning.
|
| 68 |
|
|
|
|
| 78 |
## Usage
|
| 79 |
MiniCPM 4.1 can be used with following frameworks: Huggingface Transformers, SGLang, vLLM, and CPM.cu. For the ultimate inference speed, we highly recommend CPM.cu.
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
### Inference with Transformers
|
|
|
|
|
|
|
|
|
|
| 83 |
```python
|
| 84 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 85 |
import torch
|
|
|
|
| 119 |
print(responses)
|
| 120 |
```
|
| 121 |
|
|
|
|
| 122 |
MiniCPM4.1-8B supports `InfLLM v2`, a sparse attention mechanism designed for efficient long-sequence inference. It requires the [infllmv2_cuda_impl](https://github.com/OpenBMB/infllmv2_cuda_impl) library.
|
| 123 |
|
| 124 |
You can install it by running the following command:
|
|
|
|
| 156 |
* `use_nope` (default: false): Whether to use the NOPE technique in block selection for improved performance.
|
| 157 |
* `dense_len` (default: 8192): Since Sparse Attention offers limited benefits for short sequences, the model can use standard (dense) attention for shorter texts. The model will use dense attention for sequences with a token length below `dense_len` and switch to sparse attention for sequences exceeding this length. Set this to `-1` to always use sparse attention regardless of sequence length.
|
| 158 |
|
|
|
|
| 159 |
MiniCPM4.1 natively supports context lengths of up to 65,536(64k) tokens. For conversations where the total length (including both input and output) significantly exceeds this limit, we recommend using RoPE scaling techniques for effective handling of long texts. We have validated the model's performance on context lengths of up to 131,072 tokens by modifying the LongRoPE factor.
|
| 160 |
|
| 161 |
You can apply the LongRoPE factor modification by modifying the model files. Specifically, in the `config.json` file, adjust the `rope_scaling` fields.
|
|
|
|
| 467 |
|
| 468 |
For more details about CPM.cu, please refer to [the repo CPM.cu](https://github.com/OpenBMB/cpm.cu).
|
| 469 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 470 |
### Hybird Reasoning Mode
|
| 471 |
|
| 472 |
MiniCPM4.1 supports hybrid reasoning mode, which can be used in both deep reasoning mode and non-reasoning mode. To enable hybrid reasoning mode. User can set `enable_thinking=True` in `tokenizer.apply_chat_template` to enable hybrid reasoning mode, and set `enable_thinking=False` to enable non-reasoning mode. Similarly, user can directly add `/no_think` at the end of the query to enable non-reasoning mode. If not add any special token or add `/think` at the end of the query, the model will enable reasoning mode.
|
|
|
|
| 502 |
|
| 503 |
```bibtex
|
| 504 |
@article{minicpm4,
|
| 505 |
+
title={{MiniCPM4}: Ultra-Efficient LLMs on End Devices},
|
| 506 |
+
author={MiniCPM Team},
|
|
|
|
| 507 |
year={2025}
|
| 508 |
}
|
| 509 |
```
|
config.json
CHANGED
|
@@ -30,7 +30,7 @@
|
|
| 30 |
"original_max_position_embeddings": 65536
|
| 31 |
},
|
| 32 |
"torch_dtype": "bfloat16",
|
| 33 |
-
"transformers_version": "4.
|
| 34 |
"use_cache": true,
|
| 35 |
"vocab_size": 73448,
|
| 36 |
"rope_theta": 10000.0,
|
|
|
|
| 30 |
"original_max_position_embeddings": 65536
|
| 31 |
},
|
| 32 |
"torch_dtype": "bfloat16",
|
| 33 |
+
"transformers_version": "4.46.3",
|
| 34 |
"use_cache": true,
|
| 35 |
"vocab_size": 73448,
|
| 36 |
"rope_theta": 10000.0,
|
modeling_minicpm.py
CHANGED
|
@@ -21,7 +21,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|
| 21 |
import torch
|
| 22 |
import torch.nn.functional as F
|
| 23 |
import torch.utils.checkpoint
|
| 24 |
-
from torch import
|
| 25 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 26 |
from transformers.activations import ACT2FN
|
| 27 |
from transformers.cache_utils import Cache, DynamicCache, CacheLayerMixin, DynamicLayer
|
|
@@ -47,9 +47,7 @@ from transformers.utils import (
|
|
| 47 |
)
|
| 48 |
from transformers.utils.import_utils import is_torch_fx_available
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
from .configuration_minicpm import MiniCPMConfig #!一定要改
|
| 53 |
|
| 54 |
try:
|
| 55 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
|
@@ -70,28 +68,50 @@ from functools import lru_cache
|
|
| 70 |
def compressed_attention(
|
| 71 |
q: torch.Tensor,
|
| 72 |
k: torch.Tensor,
|
| 73 |
-
|
| 74 |
kernel_size: int,
|
| 75 |
kernel_stride: int,
|
| 76 |
block_size: int,
|
| 77 |
topk: int,
|
| 78 |
cu_seqlens_q: torch.Tensor,
|
| 79 |
cu_seqlens_k: torch.Tensor,
|
| 80 |
-
cu_seqlens_k2: torch.Tensor,
|
| 81 |
max_seqlen_q: int,
|
| 82 |
max_seqlen_k: int,
|
| 83 |
sm_scale: float = None,
|
| 84 |
init_blocks: int = 1,
|
| 85 |
local_blocks: int = 2,
|
| 86 |
-
cache_lens=None,
|
| 87 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
with torch.no_grad():
|
| 89 |
batch_size = cu_seqlens_q.shape[0] - 1
|
| 90 |
|
| 91 |
# Check if it's prefilling stage
|
| 92 |
is_prefilling = cache_lens is None or (cache_lens == 0).all().item()
|
| 93 |
-
|
| 94 |
-
|
|
|
|
| 95 |
# Calculate q_idx for each query position in each batch
|
| 96 |
cache_lens = torch.zeros(batch_size, dtype=torch.int32, device=q.device)
|
| 97 |
q_idx = torch.cat([
|
|
@@ -99,24 +119,25 @@ def compressed_attention(
|
|
| 99 |
max_seqlen_q - (cu_seqlens_q[i + 1] - cu_seqlens_q[i])) // block_size
|
| 100 |
for i in range(batch_size)
|
| 101 |
], dim=0) # shape: [total_q_len]
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
|
|
|
| 105 |
|
| 106 |
-
#
|
| 107 |
score = infllmv2_attn_stage1(
|
| 108 |
q.contiguous(),
|
| 109 |
k.contiguous(),
|
| 110 |
-
|
| 111 |
cu_seqlens_q=cu_seqlens_q,
|
| 112 |
cu_seqlens_k=cu_seqlens_k,
|
| 113 |
-
cu_seqlens_v=cu_seqlens_k2,
|
| 114 |
max_seqlen_q=max_seqlen_q,
|
| 115 |
max_seqlen_k=max_seqlen_k,
|
| 116 |
-
causal=is_prefilling
|
| 117 |
-
|
| 118 |
-
score = score[:, :q_idx.shape[0], :]
|
| 119 |
-
|
|
|
|
| 120 |
block_score = max_pooling_1d_varlen(
|
| 121 |
score.contiguous(),
|
| 122 |
cu_seqlens_q,
|
|
@@ -127,9 +148,7 @@ def compressed_attention(
|
|
| 127 |
local_blocks=local_blocks,
|
| 128 |
init_blocks=init_blocks,
|
| 129 |
block_size=block_size,
|
| 130 |
-
stride=kernel_stride
|
| 131 |
-
) # shape: [num_heads, total_q_len, num_blocks]
|
| 132 |
-
|
| 133 |
|
| 134 |
# get topk
|
| 135 |
topk = min(topk, block_score.shape[-1])
|
|
@@ -243,11 +262,6 @@ class InfLLMv2CacheLayer(DynamicLayer):
|
|
| 243 |
self.no_compress_k_cache = []
|
| 244 |
self.cached_compressed_cu_seqlens = torch.tensor([], dtype=torch.int32)
|
| 245 |
self.compress_k_cache_varlen = torch.tensor([], dtype=torch.float32)
|
| 246 |
-
# Add support for compress_k2
|
| 247 |
-
self.compress_k2_cache = []
|
| 248 |
-
self.cached_compressed_cu_seqlens2 = torch.tensor([], dtype=torch.int32)
|
| 249 |
-
self.compress_k2_cache_varlen = torch.tensor([], dtype=torch.float32)
|
| 250 |
-
self.no_compress_k2_cache = []
|
| 251 |
|
| 252 |
def update_no_rope_key(self, key_states):
|
| 253 |
if self.no_rope_keys.numel() == 0:
|
|
@@ -289,45 +303,12 @@ class InfLLMv2CacheLayer(DynamicLayer):
|
|
| 289 |
k_chunk_list.append(None)
|
| 290 |
return k_chunk_list
|
| 291 |
|
| 292 |
-
def update_compress_k2(self, key_states, cu_seqlens=None):
|
| 293 |
-
if len(self.compress_k2_cache) == 0:
|
| 294 |
-
if cu_seqlens is not None:
|
| 295 |
-
self.cached_compressed_cu_seqlens2 = cu_seqlens.clone()
|
| 296 |
-
self.compress_k2_cache_varlen = key_states
|
| 297 |
-
split_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
| 298 |
-
self.compress_k2_cache = list(torch.split(key_states, split_sizes))
|
| 299 |
-
else:
|
| 300 |
-
for index, k in enumerate(key_states):
|
| 301 |
-
if k is not None:
|
| 302 |
-
self.compress_k2_cache[index] = torch.cat([self.compress_k2_cache[index], k], dim=0)
|
| 303 |
-
new_seq_lens = torch.tensor([tensor.shape[0] for tensor in self.compress_k2_cache], dtype=torch.int32)
|
| 304 |
-
new_cumsum = torch.cumsum(new_seq_lens, dim=0, dtype=torch.int32)
|
| 305 |
-
|
| 306 |
-
self.compress_k2_cache_varlen = torch.cat(self.compress_k2_cache, dim=0)
|
| 307 |
-
self.cached_compressed_cu_seqlens2 = torch.cat([torch.tensor([0], dtype=torch.int32), new_cumsum]).to(self.compress_k2_cache_varlen.device)
|
| 308 |
-
return self.compress_k2_cache_varlen, self.cached_compressed_cu_seqlens2
|
| 309 |
-
|
| 310 |
-
def update_no_compress_k2(self, key_states, kernel_size=128, kernel_stride=64):
|
| 311 |
-
k_chunk_list = []
|
| 312 |
-
for index, k in enumerate(key_states):
|
| 313 |
-
if len(self.no_compress_k2_cache) <= index:
|
| 314 |
-
self.no_compress_k2_cache.append(k)
|
| 315 |
-
else:
|
| 316 |
-
self.no_compress_k2_cache[index] = torch.cat([self.no_compress_k2_cache[index], k], dim=0)
|
| 317 |
-
current_len = self.no_compress_k2_cache[index].shape[0]
|
| 318 |
-
if current_len >= kernel_size:
|
| 319 |
-
k_chunk_list.append(self.no_compress_k2_cache[index][:kernel_size])
|
| 320 |
-
self.no_compress_k2_cache[index] = self.no_compress_k2_cache[index][kernel_stride:]
|
| 321 |
-
else:
|
| 322 |
-
k_chunk_list.append(None)
|
| 323 |
-
return k_chunk_list
|
| 324 |
-
|
| 325 |
class InfLLMv2Cache(DynamicCache):
|
| 326 |
-
def __init__(self,
|
|
|
|
| 327 |
super().__init__(config=config)
|
| 328 |
self.layers = [InfLLMv2CacheLayer() for _ in range(num_hidden_layers)] if num_hidden_layers else []
|
| 329 |
self._seen_tokens = 0
|
| 330 |
-
|
| 331 |
|
| 332 |
def update(self, key_states, value_states, layer_idx, cache_kwargs=None):
|
| 333 |
if layer_idx == 0:
|
|
@@ -343,12 +324,6 @@ class InfLLMv2Cache(DynamicCache):
|
|
| 343 |
def update_no_compress_k(self, key_states, layer_idx, kernel_size=32, kernel_stride=16, cache_kwargs=None):
|
| 344 |
return self.layers[layer_idx].update_no_compress_k(key_states, kernel_size, kernel_stride)
|
| 345 |
|
| 346 |
-
def update_compress_k2(self, key_states, layer_idx, cu_seqlens=None, cache_kwargs=None):
|
| 347 |
-
return self.layers[layer_idx].update_compress_k2(key_states, cu_seqlens)
|
| 348 |
-
|
| 349 |
-
def update_no_compress_k2(self, key_states, layer_idx, kernel_size=128, kernel_stride=64, cache_kwargs=None):
|
| 350 |
-
return self.layers[layer_idx].update_no_compress_k2(key_states, kernel_size, kernel_stride)
|
| 351 |
-
|
| 352 |
def crop(self, max_length):
|
| 353 |
for layer in self.layers:
|
| 354 |
layer.crop(max_length)
|
|
@@ -616,6 +591,7 @@ def _unpad_one_tensor(hidden_states, attention_mask):
|
|
| 616 |
unpadded_states = index_first_axis(reshaped_states, indices)
|
| 617 |
|
| 618 |
return unpadded_states, indices, cu_seqlens, max_seqlen_in_batch
|
|
|
|
| 619 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 620 |
"""
|
| 621 |
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
|
@@ -1022,9 +998,7 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
| 1022 |
self.local_blocks = self.window_size // self.block_size # local_blocks
|
| 1023 |
self.topk = self.config.sparse_config.get('topk', 64) + (self.window_size//self.block_size)
|
| 1024 |
self.use_nope = self.config.sparse_config.get('use_nope', False)
|
| 1025 |
-
|
| 1026 |
self.compress_k = CompressK(self.num_key_value_heads, self.head_dim, kernel_size=self.kernel_size, kernel_stride=self.kernel_stride)
|
| 1027 |
-
self.compress_k2 = CompressK(self.num_key_value_heads, self.head_dim, kernel_size=self.kernel_size*4, kernel_stride=self.kernel_stride*4)
|
| 1028 |
|
| 1029 |
def forward(
|
| 1030 |
self,
|
|
@@ -1049,7 +1023,6 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
| 1049 |
|
| 1050 |
bsz, q_len, _ = hidden_states.size()
|
| 1051 |
|
| 1052 |
-
|
| 1053 |
query_states = self.q_proj(hidden_states)
|
| 1054 |
key_states = self.k_proj(hidden_states)
|
| 1055 |
value_states = self.v_proj(hidden_states)
|
|
@@ -1080,12 +1053,11 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
| 1080 |
key_states = key_states.transpose(1, 2)
|
| 1081 |
value_states = value_states.transpose(1, 2)
|
| 1082 |
if self.use_nope:
|
| 1083 |
-
key_states_no_rope =past_key_value.update_no_rope_key(key_states_no_rope, self.layer_idx)
|
| 1084 |
no_rope_param = {
|
| 1085 |
'key_states_no_rope': key_states_no_rope,
|
| 1086 |
'query_states_no_rope': query_states_no_rope,
|
| 1087 |
}
|
| 1088 |
-
|
| 1089 |
else:
|
| 1090 |
no_rope_param = None
|
| 1091 |
|
|
@@ -1131,8 +1103,16 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
| 1131 |
return attn_output, attn_weights, past_key_value
|
| 1132 |
|
| 1133 |
def _sparse_attention_forward(
|
| 1134 |
-
self,
|
| 1135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1136 |
"""
|
| 1137 |
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
| 1138 |
first unpad the input, then computes the attention scores and pad the final attention scores.
|
|
@@ -1162,17 +1142,15 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
| 1162 |
batch_size = query_states.shape[0]
|
| 1163 |
# assert batch_size == 1, 'Only batch_size=1 is supported at the moment.'
|
| 1164 |
if past_key_value!=None:
|
| 1165 |
-
compressed_k, compressed_cu_seqlens
|
| 1166 |
key_states=key_states if self.use_nope ==False else no_rope_param['key_states_no_rope'], # This can be optimized a bit;
|
| 1167 |
attention_mask=attention_mask,
|
| 1168 |
-
past_key_value=past_key_value
|
| 1169 |
-
|
| 1170 |
-
)
|
| 1171 |
|
| 1172 |
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
| 1173 |
query_states, key_states, value_states, attention_mask, query_length
|
| 1174 |
)
|
| 1175 |
-
|
| 1176 |
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
| 1177 |
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
| 1178 |
if no_rope_param != None:
|
|
@@ -1183,12 +1161,7 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
| 1183 |
if past_key_value==None:
|
| 1184 |
# compress_k use varlen form
|
| 1185 |
compressed_k, compressed_cu_seqlens = self.compress_k(key_states,cu_seqlens_k)
|
| 1186 |
-
compressed_k2, compressed_cu_seqlens2 = self.compress_k2(key_states,cu_seqlens_k)
|
| 1187 |
-
else:
|
| 1188 |
-
# compressed_k and compressed_k2 already retrieved from get_compress_k above
|
| 1189 |
-
pass
|
| 1190 |
|
| 1191 |
-
|
| 1192 |
attn_output_unpad = self.sparse_forward(
|
| 1193 |
query_states,
|
| 1194 |
key_states,
|
|
@@ -1198,16 +1171,15 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
| 1198 |
max_seqlen_in_batch_q,
|
| 1199 |
max_seqlen_in_batch_k,
|
| 1200 |
no_rope_param=no_rope_param,
|
| 1201 |
-
compressed_k=compressed_k,
|
| 1202 |
-
|
| 1203 |
-
)
|
| 1204 |
|
| 1205 |
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
| 1206 |
-
|
| 1207 |
else:
|
| 1208 |
raise ValueError('Need attention mask')
|
| 1209 |
|
| 1210 |
return attn_output
|
|
|
|
| 1211 |
def get_compress_k(self, key_states, attention_mask, past_key_value):
|
| 1212 |
"""
|
| 1213 |
Get compressed key states and corresponding cumulative sequence lengths.
|
|
@@ -1219,51 +1191,34 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
| 1219 |
no_rope_param: Optional parameter containing key states without rope
|
| 1220 |
|
| 1221 |
Returns:
|
| 1222 |
-
Tuple of (compressed_k, compressed_cu_seqlens
|
| 1223 |
"""
|
| 1224 |
-
|
| 1225 |
# Check if this is prefilling or initial compression condition
|
| 1226 |
-
|
| 1227 |
is_prefilling = (
|
| 1228 |
key_states.shape[1] >= self.dense_len and
|
| 1229 |
(
|
| 1230 |
not past_key_value.layers[self.layer_idx].compress_k_cache
|
| 1231 |
)
|
| 1232 |
)
|
| 1233 |
-
|
| 1234 |
if is_prefilling:
|
| 1235 |
unpadded_key_states, indices, cu_seqlens, max_seqlen_in_batch = _unpad_one_tensor(key_states,attention_mask=attention_mask)
|
| 1236 |
# Compress the keys
|
| 1237 |
compressed_k, compressed_cu_seqlens = self.compress_k(unpadded_key_states, cu_seqlens)
|
| 1238 |
-
|
| 1239 |
-
|
| 1240 |
past_key_value.update_compress_k(
|
| 1241 |
compressed_k, self.layer_idx, compressed_cu_seqlens)
|
| 1242 |
-
|
| 1243 |
-
compressed_k2, self.layer_idx, compressed_cu_seqlens2)
|
| 1244 |
-
|
| 1245 |
no_compress_k_list = []
|
| 1246 |
# Compute and update no_compress_k
|
| 1247 |
for i in range(len(compressed_cu_seqlens)-1):
|
| 1248 |
no_compress_k_start = (compressed_cu_seqlens[i+1]- compressed_cu_seqlens[i]) * self.kernel_stride
|
| 1249 |
-
|
| 1250 |
no_compress_k_list.append(unpadded_key_states[cu_seqlens[i]+no_compress_k_start:cu_seqlens[i+1]].clone())
|
| 1251 |
|
| 1252 |
past_key_value.update_no_compress_k(
|
| 1253 |
no_compress_k_list, self.layer_idx,kernel_stride=self.kernel_stride,
|
| 1254 |
kernel_size=self.kernel_size)
|
| 1255 |
-
|
| 1256 |
-
# Also update no_compress_k2
|
| 1257 |
-
no_compress_k2_list = []
|
| 1258 |
-
for i in range(len(compressed_cu_seqlens2)-1):
|
| 1259 |
-
no_compress_k2_start = (compressed_cu_seqlens2[i+1]- compressed_cu_seqlens2[i]) * self.kernel_stride * 4
|
| 1260 |
-
|
| 1261 |
-
no_compress_k2_list.append(unpadded_key_states[cu_seqlens[i]+no_compress_k2_start:cu_seqlens[i+1]].clone())
|
| 1262 |
-
|
| 1263 |
-
past_key_value.update_no_compress_k2(
|
| 1264 |
-
no_compress_k2_list, self.layer_idx,kernel_stride=self.kernel_stride*4,
|
| 1265 |
-
kernel_size=self.kernel_size*4)
|
| 1266 |
-
|
| 1267 |
else:
|
| 1268 |
# Decode case: incremental update
|
| 1269 |
batch_size = key_states.shape[0] # key_states.shape = [batch_size, seq, k_head_num, head_dim]
|
|
@@ -1278,32 +1233,16 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
| 1278 |
kernel_size=self.kernel_size)
|
| 1279 |
new_compressed_k_list = []
|
| 1280 |
for no_compress_k in no_compress_k_list:
|
| 1281 |
-
|
| 1282 |
if no_compress_k is not None:
|
| 1283 |
# We have enough tokens to compress
|
| 1284 |
new_compressed_k = no_compress_k.mean(dim=0, keepdim=True) # [1, n_heads_k, head_dim]
|
| 1285 |
-
|
| 1286 |
new_compressed_k_list.append(new_compressed_k)
|
| 1287 |
else:
|
| 1288 |
new_compressed_k_list.append(None)
|
| 1289 |
compressed_k, compressed_cu_seqlens = past_key_value.update_compress_k(new_compressed_k_list, self.layer_idx,)
|
| 1290 |
-
|
| 1291 |
-
|
| 1292 |
-
|
| 1293 |
-
key_states_split, self.layer_idx,
|
| 1294 |
-
kernel_stride=self.kernel_stride*4,
|
| 1295 |
-
kernel_size=self.kernel_size*4)
|
| 1296 |
-
new_compressed_k2_list = []
|
| 1297 |
-
for no_compress_k2 in no_compress_k2_list:
|
| 1298 |
-
if no_compress_k2 is not None:
|
| 1299 |
-
# We have enough tokens to compress for k2
|
| 1300 |
-
new_compressed_k2 = no_compress_k2.mean(dim=0, keepdim=True) # [1, n_heads_k, head_dim]
|
| 1301 |
-
new_compressed_k2_list.append(new_compressed_k2)
|
| 1302 |
-
else:
|
| 1303 |
-
new_compressed_k2_list.append(None)
|
| 1304 |
-
compressed_k2, compressed_cu_seqlens2 = past_key_value.update_compress_k2(new_compressed_k2_list, self.layer_idx,)
|
| 1305 |
-
|
| 1306 |
-
return compressed_k, compressed_cu_seqlens, compressed_k2, compressed_cu_seqlens2
|
| 1307 |
def sparse_forward(self,
|
| 1308 |
query_layer,
|
| 1309 |
key_layer,
|
|
@@ -1313,8 +1252,8 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
| 1313 |
max_seqlen_in_batch_q,
|
| 1314 |
max_seqlen_in_batch_k,
|
| 1315 |
no_rope_param=None,
|
| 1316 |
-
compressed_k=None,
|
| 1317 |
-
|
| 1318 |
compressed_seqlens = compressed_cu_seqlens[1:] - compressed_cu_seqlens[:-1]
|
| 1319 |
cache_lens = None
|
| 1320 |
if max_seqlen_in_batch_q==1 and max_seqlen_in_batch_k>1: #decoding
|
|
@@ -1324,14 +1263,13 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
| 1324 |
topk_idx = compressed_attention(
|
| 1325 |
query_layer if no_rope_param is None else no_rope_param['query_states_no_rope'],
|
| 1326 |
compressed_k,
|
| 1327 |
-
|
| 1328 |
self.kernel_size,
|
| 1329 |
self.kernel_stride,
|
| 1330 |
self.block_size,
|
| 1331 |
self.topk,
|
| 1332 |
cu_seqlens_q,
|
| 1333 |
compressed_cu_seqlens,
|
| 1334 |
-
compressed_cu_seqlens2,
|
| 1335 |
max_seqlen_in_batch_q,
|
| 1336 |
compressed_seqlens.max().item(),
|
| 1337 |
None,
|
|
|
|
| 21 |
import torch
|
| 22 |
import torch.nn.functional as F
|
| 23 |
import torch.utils.checkpoint
|
| 24 |
+
from torch import nn
|
| 25 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 26 |
from transformers.activations import ACT2FN
|
| 27 |
from transformers.cache_utils import Cache, DynamicCache, CacheLayerMixin, DynamicLayer
|
|
|
|
| 47 |
)
|
| 48 |
from transformers.utils.import_utils import is_torch_fx_available
|
| 49 |
|
| 50 |
+
from .configuration_minicpm import MiniCPMConfig
|
|
|
|
|
|
|
| 51 |
|
| 52 |
try:
|
| 53 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
|
|
|
| 68 |
def compressed_attention(
|
| 69 |
q: torch.Tensor,
|
| 70 |
k: torch.Tensor,
|
| 71 |
+
v: torch.Tensor,
|
| 72 |
kernel_size: int,
|
| 73 |
kernel_stride: int,
|
| 74 |
block_size: int,
|
| 75 |
topk: int,
|
| 76 |
cu_seqlens_q: torch.Tensor,
|
| 77 |
cu_seqlens_k: torch.Tensor,
|
|
|
|
| 78 |
max_seqlen_q: int,
|
| 79 |
max_seqlen_k: int,
|
| 80 |
sm_scale: float = None,
|
| 81 |
init_blocks: int = 1,
|
| 82 |
local_blocks: int = 2,
|
| 83 |
+
cache_lens: torch.Tensor = None,
|
| 84 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 85 |
+
"""Attention between query and compressed key and value. Compute attention output and topk block idx used in topk_sparse_attention.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
q (torch.Tensor): shape [total_q_len, num_q_heads, head_dim]
|
| 89 |
+
k (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]
|
| 90 |
+
v (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]
|
| 91 |
+
kernel_size (int): kernel size in compress_key_value
|
| 92 |
+
kernel_stride (int): stride of compress_key_value
|
| 93 |
+
block_size (int): key value block size for topk sparse attention.
|
| 94 |
+
topk (int): number of blocks for each query.
|
| 95 |
+
cu_seqlens_q (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
|
| 96 |
+
cu_seqlens_k (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_k in flash_attn_func_varlen.
|
| 97 |
+
max_seqlen_q (int): max q len of the batch.
|
| 98 |
+
max_seqlen_k (int): max k len of the batch.
|
| 99 |
+
sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim).
|
| 100 |
+
init_blocks (int, optional): Number of init blocks for each query. Defaults to 1.
|
| 101 |
+
local_blocks (int, optional): Number of local blocks for each query. Defaults to 2.
|
| 102 |
+
cache_lens (torch.Tensor, optional): shape [batch_size], used to record the cache length of each query. Defaults to None.
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention
|
| 106 |
+
"""
|
| 107 |
with torch.no_grad():
|
| 108 |
batch_size = cu_seqlens_q.shape[0] - 1
|
| 109 |
|
| 110 |
# Check if it's prefilling stage
|
| 111 |
is_prefilling = cache_lens is None or (cache_lens == 0).all().item()
|
| 112 |
+
|
| 113 |
+
# prefilling stage
|
| 114 |
+
if is_prefilling:
|
| 115 |
# Calculate q_idx for each query position in each batch
|
| 116 |
cache_lens = torch.zeros(batch_size, dtype=torch.int32, device=q.device)
|
| 117 |
q_idx = torch.cat([
|
|
|
|
| 119 |
max_seqlen_q - (cu_seqlens_q[i + 1] - cu_seqlens_q[i])) // block_size
|
| 120 |
for i in range(batch_size)
|
| 121 |
], dim=0) # shape: [total_q_len]
|
| 122 |
+
# decoding stage
|
| 123 |
+
else:
|
| 124 |
+
# Each batch has only one query (last position). Shape: [batch_size] = [total_q_len] in decoding
|
| 125 |
+
q_idx = cache_lens // block_size
|
| 126 |
|
| 127 |
+
# compute attention score
|
| 128 |
score = infllmv2_attn_stage1(
|
| 129 |
q.contiguous(),
|
| 130 |
k.contiguous(),
|
| 131 |
+
v.contiguous(),
|
| 132 |
cu_seqlens_q=cu_seqlens_q,
|
| 133 |
cu_seqlens_k=cu_seqlens_k,
|
|
|
|
| 134 |
max_seqlen_q=max_seqlen_q,
|
| 135 |
max_seqlen_k=max_seqlen_k,
|
| 136 |
+
causal=is_prefilling)
|
| 137 |
+
# Shape: [num_heads, total_q_len, num_blocks]
|
| 138 |
+
score = score[:, :q_idx.shape[0], :]
|
| 139 |
+
|
| 140 |
+
# Shape: [num_heads, total_q_len, num_blocks]
|
| 141 |
block_score = max_pooling_1d_varlen(
|
| 142 |
score.contiguous(),
|
| 143 |
cu_seqlens_q,
|
|
|
|
| 148 |
local_blocks=local_blocks,
|
| 149 |
init_blocks=init_blocks,
|
| 150 |
block_size=block_size,
|
| 151 |
+
stride=kernel_stride)
|
|
|
|
|
|
|
| 152 |
|
| 153 |
# get topk
|
| 154 |
topk = min(topk, block_score.shape[-1])
|
|
|
|
| 262 |
self.no_compress_k_cache = []
|
| 263 |
self.cached_compressed_cu_seqlens = torch.tensor([], dtype=torch.int32)
|
| 264 |
self.compress_k_cache_varlen = torch.tensor([], dtype=torch.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
|
| 266 |
def update_no_rope_key(self, key_states):
|
| 267 |
if self.no_rope_keys.numel() == 0:
|
|
|
|
| 303 |
k_chunk_list.append(None)
|
| 304 |
return k_chunk_list
|
| 305 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
class InfLLMv2Cache(DynamicCache):
|
| 307 |
+
def __init__(self,
|
| 308 |
+
config,num_hidden_layers: Optional[int] = None) -> None:
|
| 309 |
super().__init__(config=config)
|
| 310 |
self.layers = [InfLLMv2CacheLayer() for _ in range(num_hidden_layers)] if num_hidden_layers else []
|
| 311 |
self._seen_tokens = 0
|
|
|
|
| 312 |
|
| 313 |
def update(self, key_states, value_states, layer_idx, cache_kwargs=None):
|
| 314 |
if layer_idx == 0:
|
|
|
|
| 324 |
def update_no_compress_k(self, key_states, layer_idx, kernel_size=32, kernel_stride=16, cache_kwargs=None):
|
| 325 |
return self.layers[layer_idx].update_no_compress_k(key_states, kernel_size, kernel_stride)
|
| 326 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
def crop(self, max_length):
|
| 328 |
for layer in self.layers:
|
| 329 |
layer.crop(max_length)
|
|
|
|
| 591 |
unpadded_states = index_first_axis(reshaped_states, indices)
|
| 592 |
|
| 593 |
return unpadded_states, indices, cu_seqlens, max_seqlen_in_batch
|
| 594 |
+
|
| 595 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 596 |
"""
|
| 597 |
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
|
|
|
| 998 |
self.local_blocks = self.window_size // self.block_size # local_blocks
|
| 999 |
self.topk = self.config.sparse_config.get('topk', 64) + (self.window_size//self.block_size)
|
| 1000 |
self.use_nope = self.config.sparse_config.get('use_nope', False)
|
|
|
|
| 1001 |
self.compress_k = CompressK(self.num_key_value_heads, self.head_dim, kernel_size=self.kernel_size, kernel_stride=self.kernel_stride)
|
|
|
|
| 1002 |
|
| 1003 |
def forward(
|
| 1004 |
self,
|
|
|
|
| 1023 |
|
| 1024 |
bsz, q_len, _ = hidden_states.size()
|
| 1025 |
|
|
|
|
| 1026 |
query_states = self.q_proj(hidden_states)
|
| 1027 |
key_states = self.k_proj(hidden_states)
|
| 1028 |
value_states = self.v_proj(hidden_states)
|
|
|
|
| 1053 |
key_states = key_states.transpose(1, 2)
|
| 1054 |
value_states = value_states.transpose(1, 2)
|
| 1055 |
if self.use_nope:
|
| 1056 |
+
key_states_no_rope = past_key_value.update_no_rope_key(key_states_no_rope, self.layer_idx)
|
| 1057 |
no_rope_param = {
|
| 1058 |
'key_states_no_rope': key_states_no_rope,
|
| 1059 |
'query_states_no_rope': query_states_no_rope,
|
| 1060 |
}
|
|
|
|
| 1061 |
else:
|
| 1062 |
no_rope_param = None
|
| 1063 |
|
|
|
|
| 1103 |
return attn_output, attn_weights, past_key_value
|
| 1104 |
|
| 1105 |
def _sparse_attention_forward(
|
| 1106 |
+
self,
|
| 1107 |
+
query_states,
|
| 1108 |
+
key_states,
|
| 1109 |
+
value_states,
|
| 1110 |
+
attention_mask,
|
| 1111 |
+
query_length,
|
| 1112 |
+
dropout=0.0,
|
| 1113 |
+
softmax_scale=None,
|
| 1114 |
+
no_rope_param=None,
|
| 1115 |
+
past_key_value=None):
|
| 1116 |
"""
|
| 1117 |
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
| 1118 |
first unpad the input, then computes the attention scores and pad the final attention scores.
|
|
|
|
| 1142 |
batch_size = query_states.shape[0]
|
| 1143 |
# assert batch_size == 1, 'Only batch_size=1 is supported at the moment.'
|
| 1144 |
if past_key_value!=None:
|
| 1145 |
+
compressed_k, compressed_cu_seqlens = self.get_compress_k(
|
| 1146 |
key_states=key_states if self.use_nope ==False else no_rope_param['key_states_no_rope'], # This can be optimized a bit;
|
| 1147 |
attention_mask=attention_mask,
|
| 1148 |
+
past_key_value=past_key_value)
|
|
|
|
|
|
|
| 1149 |
|
| 1150 |
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
| 1151 |
query_states, key_states, value_states, attention_mask, query_length
|
| 1152 |
)
|
| 1153 |
+
|
| 1154 |
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
| 1155 |
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
| 1156 |
if no_rope_param != None:
|
|
|
|
| 1161 |
if past_key_value==None:
|
| 1162 |
# compress_k use varlen form
|
| 1163 |
compressed_k, compressed_cu_seqlens = self.compress_k(key_states,cu_seqlens_k)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1164 |
|
|
|
|
| 1165 |
attn_output_unpad = self.sparse_forward(
|
| 1166 |
query_states,
|
| 1167 |
key_states,
|
|
|
|
| 1171 |
max_seqlen_in_batch_q,
|
| 1172 |
max_seqlen_in_batch_k,
|
| 1173 |
no_rope_param=no_rope_param,
|
| 1174 |
+
compressed_k=compressed_k,
|
| 1175 |
+
compressed_cu_seqlens=compressed_cu_seqlens)
|
|
|
|
| 1176 |
|
| 1177 |
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
|
|
|
| 1178 |
else:
|
| 1179 |
raise ValueError('Need attention mask')
|
| 1180 |
|
| 1181 |
return attn_output
|
| 1182 |
+
|
| 1183 |
def get_compress_k(self, key_states, attention_mask, past_key_value):
|
| 1184 |
"""
|
| 1185 |
Get compressed key states and corresponding cumulative sequence lengths.
|
|
|
|
| 1191 |
no_rope_param: Optional parameter containing key states without rope
|
| 1192 |
|
| 1193 |
Returns:
|
| 1194 |
+
Tuple of (compressed_k, compressed_cu_seqlens)
|
| 1195 |
"""
|
|
|
|
| 1196 |
# Check if this is prefilling or initial compression condition
|
|
|
|
| 1197 |
is_prefilling = (
|
| 1198 |
key_states.shape[1] >= self.dense_len and
|
| 1199 |
(
|
| 1200 |
not past_key_value.layers[self.layer_idx].compress_k_cache
|
| 1201 |
)
|
| 1202 |
)
|
| 1203 |
+
|
| 1204 |
if is_prefilling:
|
| 1205 |
unpadded_key_states, indices, cu_seqlens, max_seqlen_in_batch = _unpad_one_tensor(key_states,attention_mask=attention_mask)
|
| 1206 |
# Compress the keys
|
| 1207 |
compressed_k, compressed_cu_seqlens = self.compress_k(unpadded_key_states, cu_seqlens)
|
| 1208 |
+
|
|
|
|
| 1209 |
past_key_value.update_compress_k(
|
| 1210 |
compressed_k, self.layer_idx, compressed_cu_seqlens)
|
| 1211 |
+
|
|
|
|
|
|
|
| 1212 |
no_compress_k_list = []
|
| 1213 |
# Compute and update no_compress_k
|
| 1214 |
for i in range(len(compressed_cu_seqlens)-1):
|
| 1215 |
no_compress_k_start = (compressed_cu_seqlens[i+1]- compressed_cu_seqlens[i]) * self.kernel_stride
|
| 1216 |
+
|
| 1217 |
no_compress_k_list.append(unpadded_key_states[cu_seqlens[i]+no_compress_k_start:cu_seqlens[i+1]].clone())
|
| 1218 |
|
| 1219 |
past_key_value.update_no_compress_k(
|
| 1220 |
no_compress_k_list, self.layer_idx,kernel_stride=self.kernel_stride,
|
| 1221 |
kernel_size=self.kernel_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1222 |
else:
|
| 1223 |
# Decode case: incremental update
|
| 1224 |
batch_size = key_states.shape[0] # key_states.shape = [batch_size, seq, k_head_num, head_dim]
|
|
|
|
| 1233 |
kernel_size=self.kernel_size)
|
| 1234 |
new_compressed_k_list = []
|
| 1235 |
for no_compress_k in no_compress_k_list:
|
|
|
|
| 1236 |
if no_compress_k is not None:
|
| 1237 |
# We have enough tokens to compress
|
| 1238 |
new_compressed_k = no_compress_k.mean(dim=0, keepdim=True) # [1, n_heads_k, head_dim]
|
|
|
|
| 1239 |
new_compressed_k_list.append(new_compressed_k)
|
| 1240 |
else:
|
| 1241 |
new_compressed_k_list.append(None)
|
| 1242 |
compressed_k, compressed_cu_seqlens = past_key_value.update_compress_k(new_compressed_k_list, self.layer_idx,)
|
| 1243 |
+
|
| 1244 |
+
return compressed_k, compressed_cu_seqlens
|
| 1245 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1246 |
def sparse_forward(self,
|
| 1247 |
query_layer,
|
| 1248 |
key_layer,
|
|
|
|
| 1252 |
max_seqlen_in_batch_q,
|
| 1253 |
max_seqlen_in_batch_k,
|
| 1254 |
no_rope_param=None,
|
| 1255 |
+
compressed_k=None,
|
| 1256 |
+
compressed_cu_seqlens=None):
|
| 1257 |
compressed_seqlens = compressed_cu_seqlens[1:] - compressed_cu_seqlens[:-1]
|
| 1258 |
cache_lens = None
|
| 1259 |
if max_seqlen_in_batch_q==1 and max_seqlen_in_batch_k>1: #decoding
|
|
|
|
| 1263 |
topk_idx = compressed_attention(
|
| 1264 |
query_layer if no_rope_param is None else no_rope_param['query_states_no_rope'],
|
| 1265 |
compressed_k,
|
| 1266 |
+
compressed_k.clone(),
|
| 1267 |
self.kernel_size,
|
| 1268 |
self.kernel_stride,
|
| 1269 |
self.block_size,
|
| 1270 |
self.topk,
|
| 1271 |
cu_seqlens_q,
|
| 1272 |
compressed_cu_seqlens,
|
|
|
|
| 1273 |
max_seqlen_in_batch_q,
|
| 1274 |
compressed_seqlens.max().item(),
|
| 1275 |
None,
|