fixes_transformers_4_55

#2
by lerignoux - opened
Files changed (1) hide show
  1. modeling_phi3_v.py +53 -9
modeling_phi3_v.py CHANGED
@@ -665,7 +665,13 @@ class Phi3Attention(nn.Module):
665
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
666
  "with a layer index."
667
  )
668
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 
 
 
 
 
 
669
  cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
670
 
671
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
@@ -783,7 +789,14 @@ class Phi3FlashAttention2(Phi3Attention):
783
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
784
  "with a layer index."
785
  )
786
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 
 
 
 
 
 
 
787
 
788
  # Because the input can be padded, the absolute sequence length depends on the max position id.
789
  rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
@@ -1073,7 +1086,13 @@ class Phi3SdpaAttention(Phi3Attention):
1073
 
1074
  kv_seq_len = key_states.shape[-2]
1075
  if past_key_value is not None:
1076
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 
 
 
 
 
 
1077
  cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
1078
 
1079
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
@@ -1414,7 +1433,16 @@ class Phi3VModel(Phi3VPreTrainedModel):
1414
  use_legacy_cache = not isinstance(past_key_values, Cache)
1415
  if use_legacy_cache:
1416
  past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1417
- past_key_values_length = past_key_values.get_usable_length(seq_length)
 
 
 
 
 
 
 
 
 
1418
 
1419
  if position_ids is None:
1420
  device = input_ids.device if input_ids is not None else inputs_embeds.device
@@ -1650,19 +1678,35 @@ class Phi3VForCausalLM(Phi3VPreTrainedModel):
1650
  # When the first time input length reached long and short factor switching point, enforce re-compute cache
1651
  # It will cause downside of slower at this single token position, however, better than current failure.
1652
  if past_key_values and self.config.rope_scaling and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1:
1653
- past_length = past_key_values.seen_tokens if isinstance(past_key_values, Cache) else past_key_values[0][0].shape[2]
 
 
 
 
 
 
 
 
 
1654
  if past_length <= self.config.original_max_position_embeddings:
1655
  past_key_values = None
1656
 
 
 
 
1657
  if past_key_values is not None:
1658
- if isinstance(past_key_values, Cache):
 
 
 
 
 
 
1659
  cache_length = past_key_values.get_seq_length()
1660
  past_length = past_key_values.seen_tokens
1661
- # Fixing AttributeError: 'DynamicCache' object has no attribute 'get_max_length'
1662
- # https://github.com/huggingface/transformers/issues/36071
1663
- # max_cache_length = past_key_values.get_max_length()
1664
  max_cache_length = past_key_values.get_max_cache_shape()
1665
  else:
 
1666
  cache_length = past_length = past_key_values[0][0].shape[2]
1667
  max_cache_length = None
1668
 
 
665
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
666
  "with a layer index."
667
  )
668
+ if not hasattr(past_key_value, 'get_usable_length'):
669
+ # Transformers >= 4.55
670
+ past_len = past_key_value.get_seq_length(self.layer_idx) if past_key_value is not None else 0
671
+ kv_seq_len += past_len
672
+ else:
673
+ # Transformers < 4.55
674
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
675
  cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
676
 
677
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
 
789
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
790
  "with a layer index."
791
  )
792
+
793
+ if not hasattr(past_key_value, 'get_usable_length'):
794
+ # Transformers >= 4.55
795
+ past_len = past_key_value.get_seq_length(self.layer_idx) if past_key_value is not None else 0
796
+ kv_seq_len += past_len
797
+ else:
798
+ # Transformers < 4.55
799
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
800
 
801
  # Because the input can be padded, the absolute sequence length depends on the max position id.
802
  rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
 
1086
 
1087
  kv_seq_len = key_states.shape[-2]
1088
  if past_key_value is not None:
1089
+ if not hasattr(past_key_value, 'get_usable_length'):
1090
+ # Transformers >= 4.55
1091
+ past_len = past_key_value.get_seq_length(self.layer_idx) if past_key_value is not None else 0
1092
+ kv_seq_len += past_len
1093
+ else:
1094
+ # 4.49 <= Transformers < 4.55
1095
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
1096
  cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
1097
 
1098
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
 
1433
  use_legacy_cache = not isinstance(past_key_values, Cache)
1434
  if use_legacy_cache:
1435
  past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1436
+
1437
+ if isinstance(past_key_values, Cache) and not hasattr(past_key_values, 'get_usable_length'):
1438
+ # Transformers >= 4.55
1439
+ past_key_values_length = past_key_values.get_seq_length()
1440
+ elif isinstance(past_key_values, Cache):
1441
+ # 4.49 <= Transformers < 4.55
1442
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
1443
+ else:
1444
+ # No cache given on first forward, keep length at 0
1445
+ past_key_values_length = 0
1446
 
1447
  if position_ids is None:
1448
  device = input_ids.device if input_ids is not None else inputs_embeds.device
 
1678
  # When the first time input length reached long and short factor switching point, enforce re-compute cache
1679
  # It will cause downside of slower at this single token position, however, better than current failure.
1680
  if past_key_values and self.config.rope_scaling and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1:
1681
+ if isinstance(past_key_values, Cache) and not hasattr(past_key_values, 'seen_tokens'):
1682
+ # Transformers > 4.55
1683
+ cache_length = past_key_values.get_seq_length()
1684
+ past_length = cache_length
1685
+ elif isinstance(past_key_values, Cache):
1686
+ # 4.49 <= Transformers < 4.55
1687
+ past_length = past_key_values.seen_tokens
1688
+ else:
1689
+ # Transformers < 4.49
1690
+ past_key_values[0][0].shape[2]
1691
  if past_length <= self.config.original_max_position_embeddings:
1692
  past_key_values = None
1693
 
1694
+ cache_length = None
1695
+ past_length = None
1696
+ max_cache_length = None
1697
  if past_key_values is not None:
1698
+ if isinstance(past_key_values, Cache) and not hasattr(past_key_values, 'seen_tokens'):
1699
+ # Transformers > 4.55
1700
+ cache_length = past_key_values.get_seq_length()
1701
+ past_length = cache_length
1702
+ max_cache_length = past_key_values.get_max_cache_shape()
1703
+ elif isinstance(past_key_values, Cache):
1704
+ # 4.49 <= Transformers < 4.55
1705
  cache_length = past_key_values.get_seq_length()
1706
  past_length = past_key_values.seen_tokens
 
 
 
1707
  max_cache_length = past_key_values.get_max_cache_shape()
1708
  else:
1709
+ # Transformers < 4.49
1710
  cache_length = past_length = past_key_values[0][0].shape[2]
1711
  max_cache_length = None
1712