Fix the issues with latest transformers, add previously removed function to compute usable past KV length for cache compatibility.
Browse filesThe code to fix the issue is copied from another fix pull request: https://huggingface.co/it-just-works/stella_en_1.5B_v5_bf16/commit/03aedd040580357ec688f3467f1109af5e053249
- modeling_qwen2_rm.py +23 -3
modeling_qwen2_rm.py
CHANGED
|
@@ -58,6 +58,23 @@ logger = logging.get_logger(__name__)
|
|
| 58 |
_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta"
|
| 59 |
_CONFIG_FOR_DOC = "Qwen2Config"
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
|
| 63 |
def _prepare_4d_causal_attention_mask_with_cache_position(
|
|
@@ -307,7 +324,8 @@ class Qwen2Attention(nn.Module):
|
|
| 307 |
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
| 308 |
"with a layer index."
|
| 309 |
)
|
| 310 |
-
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
|
|
|
| 311 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 312 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 313 |
|
|
@@ -399,7 +417,8 @@ class Qwen2FlashAttention2(Qwen2Attention):
|
|
| 399 |
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
| 400 |
"with a layer index."
|
| 401 |
)
|
| 402 |
-
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
|
|
|
| 403 |
|
| 404 |
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
| 405 |
rotary_seq_len = (
|
|
@@ -549,7 +568,8 @@ class Qwen2SdpaAttention(Qwen2Attention):
|
|
| 549 |
|
| 550 |
kv_seq_len = key_states.shape[-2]
|
| 551 |
if past_key_value is not None:
|
| 552 |
-
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
|
|
|
| 553 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 554 |
|
| 555 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
|
|
|
| 58 |
_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta"
|
| 59 |
_CONFIG_FOR_DOC = "Qwen2Config"
|
| 60 |
|
| 61 |
+
# Copied from the fix: https://huggingface.co/it-just-works/stella_en_1.5B_v5_bf16/commit/03aedd040580357ec688f3467f1109af5e053249.
|
| 62 |
+
def _get_usable_past_kv_length(cache: Cache, new_seq_length: int, layer_idx: int = 0) -> int:
|
| 63 |
+
"""Compute the usable past length for the given cache and upcoming new sequence length.
|
| 64 |
+
|
| 65 |
+
This mirrors the previous `get_usable_length(new_seq_length, layer_idx)` behavior that existed in
|
| 66 |
+
Transformers < 4.45, while being compatible with the new Cache API.
|
| 67 |
+
"""
|
| 68 |
+
try:
|
| 69 |
+
previous_length = cache.get_seq_length(layer_idx)
|
| 70 |
+
# Dynamic layers return -1, static layers return an int
|
| 71 |
+
max_length = cache.get_max_cache_shape(layer_idx)
|
| 72 |
+
if max_length is not None and max_length != -1 and previous_length + new_seq_length > max_length:
|
| 73 |
+
return max_length - new_seq_length
|
| 74 |
+
return previous_length
|
| 75 |
+
except Exception:
|
| 76 |
+
# Best-effort fallback
|
| 77 |
+
return cache.get_seq_length(layer_idx) if hasattr(cache, "get_seq_length") else 0
|
| 78 |
|
| 79 |
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
|
| 80 |
def _prepare_4d_causal_attention_mask_with_cache_position(
|
|
|
|
| 324 |
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
| 325 |
"with a layer index."
|
| 326 |
)
|
| 327 |
+
# kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 328 |
+
kv_seq_len += _get_usable_past_kv_length(past_key_value, kv_seq_len, self.layer_idx)
|
| 329 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 330 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 331 |
|
|
|
|
| 417 |
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
| 418 |
"with a layer index."
|
| 419 |
)
|
| 420 |
+
# kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 421 |
+
kv_seq_len += _get_usable_past_kv_length(past_key_value, kv_seq_len, self.layer_idx)
|
| 422 |
|
| 423 |
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
| 424 |
rotary_seq_len = (
|
|
|
|
| 568 |
|
| 569 |
kv_seq_len = key_states.shape[-2]
|
| 570 |
if past_key_value is not None:
|
| 571 |
+
# kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 572 |
+
kv_seq_len += _get_usable_past_kv_length(past_key_value, kv_seq_len, self.layer_idx)
|
| 573 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 574 |
|
| 575 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|