Great-Luso commited on
Commit
dca93c9
·
verified ·
1 Parent(s): 0610740

Fix the issues with latest transformers, add previously removed function to compute usable past KV length for cache compatibility.

Browse files

The code to fix the issue is copied from another fix pull request: https://huggingface.co/it-just-works/stella_en_1.5B_v5_bf16/commit/03aedd040580357ec688f3467f1109af5e053249

Files changed (1) hide show
  1. modeling_qwen2_rm.py +23 -3
modeling_qwen2_rm.py CHANGED
@@ -58,6 +58,23 @@ logger = logging.get_logger(__name__)
58
  _CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta"
59
  _CONFIG_FOR_DOC = "Qwen2Config"
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  # Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
63
  def _prepare_4d_causal_attention_mask_with_cache_position(
@@ -307,7 +324,8 @@ class Qwen2Attention(nn.Module):
307
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
308
  "with a layer index."
309
  )
310
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 
311
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
312
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
313
 
@@ -399,7 +417,8 @@ class Qwen2FlashAttention2(Qwen2Attention):
399
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
400
  "with a layer index."
401
  )
402
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 
403
 
404
  # Because the input can be padded, the absolute sequence length depends on the max position id.
405
  rotary_seq_len = (
@@ -549,7 +568,8 @@ class Qwen2SdpaAttention(Qwen2Attention):
549
 
550
  kv_seq_len = key_states.shape[-2]
551
  if past_key_value is not None:
552
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 
553
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
554
 
555
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
 
58
  _CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta"
59
  _CONFIG_FOR_DOC = "Qwen2Config"
60
 
61
+ # Copied from the fix: https://huggingface.co/it-just-works/stella_en_1.5B_v5_bf16/commit/03aedd040580357ec688f3467f1109af5e053249.
62
+ def _get_usable_past_kv_length(cache: Cache, new_seq_length: int, layer_idx: int = 0) -> int:
63
+ """Compute the usable past length for the given cache and upcoming new sequence length.
64
+
65
+ This mirrors the previous `get_usable_length(new_seq_length, layer_idx)` behavior that existed in
66
+ Transformers < 4.45, while being compatible with the new Cache API.
67
+ """
68
+ try:
69
+ previous_length = cache.get_seq_length(layer_idx)
70
+ # Dynamic layers return -1, static layers return an int
71
+ max_length = cache.get_max_cache_shape(layer_idx)
72
+ if max_length is not None and max_length != -1 and previous_length + new_seq_length > max_length:
73
+ return max_length - new_seq_length
74
+ return previous_length
75
+ except Exception:
76
+ # Best-effort fallback
77
+ return cache.get_seq_length(layer_idx) if hasattr(cache, "get_seq_length") else 0
78
 
79
  # Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
80
  def _prepare_4d_causal_attention_mask_with_cache_position(
 
324
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
325
  "with a layer index."
326
  )
327
+ # kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
328
+ kv_seq_len += _get_usable_past_kv_length(past_key_value, kv_seq_len, self.layer_idx)
329
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
330
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
331
 
 
417
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
418
  "with a layer index."
419
  )
420
+ # kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
421
+ kv_seq_len += _get_usable_past_kv_length(past_key_value, kv_seq_len, self.layer_idx)
422
 
423
  # Because the input can be padded, the absolute sequence length depends on the max position id.
424
  rotary_seq_len = (
 
568
 
569
  kv_seq_len = key_states.shape[-2]
570
  if past_key_value is not None:
571
+ # kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
572
+ kv_seq_len += _get_usable_past_kv_length(past_key_value, kv_seq_len, self.layer_idx)
573
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
574
 
575
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)