Fix the issues with latest transformers, add previously removed function to compute usable past KV length for cache compatibility.
#13
by
Great-Luso
- opened
- modeling_qwen2_rm.py +23 -3
modeling_qwen2_rm.py
CHANGED
|
@@ -58,6 +58,23 @@ logger = logging.get_logger(__name__)
|
|
| 58 |
_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta"
|
| 59 |
_CONFIG_FOR_DOC = "Qwen2Config"
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
|
| 63 |
def _prepare_4d_causal_attention_mask_with_cache_position(
|
|
@@ -307,7 +324,8 @@ class Qwen2Attention(nn.Module):
|
|
| 307 |
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
| 308 |
"with a layer index."
|
| 309 |
)
|
| 310 |
-
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
|
|
|
| 311 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 312 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 313 |
|
|
@@ -399,7 +417,8 @@ class Qwen2FlashAttention2(Qwen2Attention):
|
|
| 399 |
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
| 400 |
"with a layer index."
|
| 401 |
)
|
| 402 |
-
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
|
|
|
| 403 |
|
| 404 |
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
| 405 |
rotary_seq_len = (
|
|
@@ -549,7 +568,8 @@ class Qwen2SdpaAttention(Qwen2Attention):
|
|
| 549 |
|
| 550 |
kv_seq_len = key_states.shape[-2]
|
| 551 |
if past_key_value is not None:
|
| 552 |
-
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
|
|
|
| 553 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 554 |
|
| 555 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
|
|
|
| 58 |
_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta"
|
| 59 |
_CONFIG_FOR_DOC = "Qwen2Config"
|
| 60 |
|
| 61 |
+
# Copied from the fix: https://huggingface.co/it-just-works/stella_en_1.5B_v5_bf16/commit/03aedd040580357ec688f3467f1109af5e053249.
|
| 62 |
+
def _get_usable_past_kv_length(cache: Cache, new_seq_length: int, layer_idx: int = 0) -> int:
|
| 63 |
+
"""Compute the usable past length for the given cache and upcoming new sequence length.
|
| 64 |
+
|
| 65 |
+
This mirrors the previous `get_usable_length(new_seq_length, layer_idx)` behavior that existed in
|
| 66 |
+
Transformers < 4.45, while being compatible with the new Cache API.
|
| 67 |
+
"""
|
| 68 |
+
try:
|
| 69 |
+
previous_length = cache.get_seq_length(layer_idx)
|
| 70 |
+
# Dynamic layers return -1, static layers return an int
|
| 71 |
+
max_length = cache.get_max_cache_shape(layer_idx)
|
| 72 |
+
if max_length is not None and max_length != -1 and previous_length + new_seq_length > max_length:
|
| 73 |
+
return max_length - new_seq_length
|
| 74 |
+
return previous_length
|
| 75 |
+
except Exception:
|
| 76 |
+
# Best-effort fallback
|
| 77 |
+
return cache.get_seq_length(layer_idx) if hasattr(cache, "get_seq_length") else 0
|
| 78 |
|
| 79 |
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
|
| 80 |
def _prepare_4d_causal_attention_mask_with_cache_position(
|
|
|
|
| 324 |
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
| 325 |
"with a layer index."
|
| 326 |
)
|
| 327 |
+
# kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 328 |
+
kv_seq_len += _get_usable_past_kv_length(past_key_value, kv_seq_len, self.layer_idx)
|
| 329 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 330 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 331 |
|
|
|
|
| 417 |
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
| 418 |
"with a layer index."
|
| 419 |
)
|
| 420 |
+
# kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 421 |
+
kv_seq_len += _get_usable_past_kv_length(past_key_value, kv_seq_len, self.layer_idx)
|
| 422 |
|
| 423 |
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
| 424 |
rotary_seq_len = (
|
|
|
|
| 568 |
|
| 569 |
kv_seq_len = key_states.shape[-2]
|
| 570 |
if past_key_value is not None:
|
| 571 |
+
# kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 572 |
+
kv_seq_len += _get_usable_past_kv_length(past_key_value, kv_seq_len, self.layer_idx)
|
| 573 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 574 |
|
| 575 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|