Fix the issues with latest transformers, add previously removed function to compute usable past KV length for cache compatibility.

#13
by Great-Luso - opened
Files changed (1) hide show
  1. modeling_qwen2_rm.py +23 -3
modeling_qwen2_rm.py CHANGED
@@ -58,6 +58,23 @@ logger = logging.get_logger(__name__)
58
  _CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta"
59
  _CONFIG_FOR_DOC = "Qwen2Config"
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  # Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
63
  def _prepare_4d_causal_attention_mask_with_cache_position(
@@ -307,7 +324,8 @@ class Qwen2Attention(nn.Module):
307
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
308
  "with a layer index."
309
  )
310
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 
311
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
312
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
313
 
@@ -399,7 +417,8 @@ class Qwen2FlashAttention2(Qwen2Attention):
399
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
400
  "with a layer index."
401
  )
402
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 
403
 
404
  # Because the input can be padded, the absolute sequence length depends on the max position id.
405
  rotary_seq_len = (
@@ -549,7 +568,8 @@ class Qwen2SdpaAttention(Qwen2Attention):
549
 
550
  kv_seq_len = key_states.shape[-2]
551
  if past_key_value is not None:
552
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 
553
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
554
 
555
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
 
58
  _CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta"
59
  _CONFIG_FOR_DOC = "Qwen2Config"
60
 
61
+ # Copied from the fix: https://huggingface.co/it-just-works/stella_en_1.5B_v5_bf16/commit/03aedd040580357ec688f3467f1109af5e053249.
62
+ def _get_usable_past_kv_length(cache: Cache, new_seq_length: int, layer_idx: int = 0) -> int:
63
+ """Compute the usable past length for the given cache and upcoming new sequence length.
64
+
65
+ This mirrors the previous `get_usable_length(new_seq_length, layer_idx)` behavior that existed in
66
+ Transformers < 4.45, while being compatible with the new Cache API.
67
+ """
68
+ try:
69
+ previous_length = cache.get_seq_length(layer_idx)
70
+ # Dynamic layers return -1, static layers return an int
71
+ max_length = cache.get_max_cache_shape(layer_idx)
72
+ if max_length is not None and max_length != -1 and previous_length + new_seq_length > max_length:
73
+ return max_length - new_seq_length
74
+ return previous_length
75
+ except Exception:
76
+ # Best-effort fallback
77
+ return cache.get_seq_length(layer_idx) if hasattr(cache, "get_seq_length") else 0
78
 
79
  # Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
80
  def _prepare_4d_causal_attention_mask_with_cache_position(
 
324
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
325
  "with a layer index."
326
  )
327
+ # kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
328
+ kv_seq_len += _get_usable_past_kv_length(past_key_value, kv_seq_len, self.layer_idx)
329
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
330
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
331
 
 
417
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
418
  "with a layer index."
419
  )
420
+ # kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
421
+ kv_seq_len += _get_usable_past_kv_length(past_key_value, kv_seq_len, self.layer_idx)
422
 
423
  # Because the input can be padded, the absolute sequence length depends on the max position id.
424
  rotary_seq_len = (
 
568
 
569
  kv_seq_len = key_states.shape[-2]
570
  if past_key_value is not None:
571
+ # kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
572
+ kv_seq_len += _get_usable_past_kv_length(past_key_value, kv_seq_len, self.layer_idx)
573
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
574
 
575
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)