fixes_transformers_4_55
#2
by
lerignoux
- opened
- modeling_phi3_v.py +53 -9
modeling_phi3_v.py
CHANGED
|
@@ -665,7 +665,13 @@ class Phi3Attention(nn.Module):
|
|
| 665 |
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
| 666 |
"with a layer index."
|
| 667 |
)
|
| 668 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 669 |
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
|
| 670 |
|
| 671 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
|
@@ -783,7 +789,14 @@ class Phi3FlashAttention2(Phi3Attention):
|
|
| 783 |
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
| 784 |
"with a layer index."
|
| 785 |
)
|
| 786 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 787 |
|
| 788 |
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
| 789 |
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
|
|
@@ -1073,7 +1086,13 @@ class Phi3SdpaAttention(Phi3Attention):
|
|
| 1073 |
|
| 1074 |
kv_seq_len = key_states.shape[-2]
|
| 1075 |
if past_key_value is not None:
|
| 1076 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1077 |
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
|
| 1078 |
|
| 1079 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
|
@@ -1414,7 +1433,16 @@ class Phi3VModel(Phi3VPreTrainedModel):
|
|
| 1414 |
use_legacy_cache = not isinstance(past_key_values, Cache)
|
| 1415 |
if use_legacy_cache:
|
| 1416 |
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
| 1417 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1418 |
|
| 1419 |
if position_ids is None:
|
| 1420 |
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
@@ -1650,19 +1678,35 @@ class Phi3VForCausalLM(Phi3VPreTrainedModel):
|
|
| 1650 |
# When the first time input length reached long and short factor switching point, enforce re-compute cache
|
| 1651 |
# It will cause downside of slower at this single token position, however, better than current failure.
|
| 1652 |
if past_key_values and self.config.rope_scaling and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1:
|
| 1653 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1654 |
if past_length <= self.config.original_max_position_embeddings:
|
| 1655 |
past_key_values = None
|
| 1656 |
|
|
|
|
|
|
|
|
|
|
| 1657 |
if past_key_values is not None:
|
| 1658 |
-
if isinstance(past_key_values, Cache):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1659 |
cache_length = past_key_values.get_seq_length()
|
| 1660 |
past_length = past_key_values.seen_tokens
|
| 1661 |
-
# Fixing AttributeError: 'DynamicCache' object has no attribute 'get_max_length'
|
| 1662 |
-
# https://github.com/huggingface/transformers/issues/36071
|
| 1663 |
-
# max_cache_length = past_key_values.get_max_length()
|
| 1664 |
max_cache_length = past_key_values.get_max_cache_shape()
|
| 1665 |
else:
|
|
|
|
| 1666 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
| 1667 |
max_cache_length = None
|
| 1668 |
|
|
|
|
| 665 |
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
| 666 |
"with a layer index."
|
| 667 |
)
|
| 668 |
+
if not hasattr(past_key_value, 'get_usable_length'):
|
| 669 |
+
# Transformers >= 4.55
|
| 670 |
+
past_len = past_key_value.get_seq_length(self.layer_idx) if past_key_value is not None else 0
|
| 671 |
+
kv_seq_len += past_len
|
| 672 |
+
else:
|
| 673 |
+
# Transformers < 4.55
|
| 674 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 675 |
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
|
| 676 |
|
| 677 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
|
|
|
| 789 |
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
| 790 |
"with a layer index."
|
| 791 |
)
|
| 792 |
+
|
| 793 |
+
if not hasattr(past_key_value, 'get_usable_length'):
|
| 794 |
+
# Transformers >= 4.55
|
| 795 |
+
past_len = past_key_value.get_seq_length(self.layer_idx) if past_key_value is not None else 0
|
| 796 |
+
kv_seq_len += past_len
|
| 797 |
+
else:
|
| 798 |
+
# Transformers < 4.55
|
| 799 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 800 |
|
| 801 |
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
| 802 |
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
|
|
|
|
| 1086 |
|
| 1087 |
kv_seq_len = key_states.shape[-2]
|
| 1088 |
if past_key_value is not None:
|
| 1089 |
+
if not hasattr(past_key_value, 'get_usable_length'):
|
| 1090 |
+
# Transformers >= 4.55
|
| 1091 |
+
past_len = past_key_value.get_seq_length(self.layer_idx) if past_key_value is not None else 0
|
| 1092 |
+
kv_seq_len += past_len
|
| 1093 |
+
else:
|
| 1094 |
+
# 4.49 <= Transformers < 4.55
|
| 1095 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 1096 |
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
|
| 1097 |
|
| 1098 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
|
|
|
| 1433 |
use_legacy_cache = not isinstance(past_key_values, Cache)
|
| 1434 |
if use_legacy_cache:
|
| 1435 |
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
| 1436 |
+
|
| 1437 |
+
if isinstance(past_key_values, Cache) and not hasattr(past_key_values, 'get_usable_length'):
|
| 1438 |
+
# Transformers >= 4.55
|
| 1439 |
+
past_key_values_length = past_key_values.get_seq_length()
|
| 1440 |
+
elif isinstance(past_key_values, Cache):
|
| 1441 |
+
# 4.49 <= Transformers < 4.55
|
| 1442 |
+
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
| 1443 |
+
else:
|
| 1444 |
+
# No cache given on first forward, keep length at 0
|
| 1445 |
+
past_key_values_length = 0
|
| 1446 |
|
| 1447 |
if position_ids is None:
|
| 1448 |
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
|
|
| 1678 |
# When the first time input length reached long and short factor switching point, enforce re-compute cache
|
| 1679 |
# It will cause downside of slower at this single token position, however, better than current failure.
|
| 1680 |
if past_key_values and self.config.rope_scaling and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1:
|
| 1681 |
+
if isinstance(past_key_values, Cache) and not hasattr(past_key_values, 'seen_tokens'):
|
| 1682 |
+
# Transformers > 4.55
|
| 1683 |
+
cache_length = past_key_values.get_seq_length()
|
| 1684 |
+
past_length = cache_length
|
| 1685 |
+
elif isinstance(past_key_values, Cache):
|
| 1686 |
+
# 4.49 <= Transformers < 4.55
|
| 1687 |
+
past_length = past_key_values.seen_tokens
|
| 1688 |
+
else:
|
| 1689 |
+
# Transformers < 4.49
|
| 1690 |
+
past_key_values[0][0].shape[2]
|
| 1691 |
if past_length <= self.config.original_max_position_embeddings:
|
| 1692 |
past_key_values = None
|
| 1693 |
|
| 1694 |
+
cache_length = None
|
| 1695 |
+
past_length = None
|
| 1696 |
+
max_cache_length = None
|
| 1697 |
if past_key_values is not None:
|
| 1698 |
+
if isinstance(past_key_values, Cache) and not hasattr(past_key_values, 'seen_tokens'):
|
| 1699 |
+
# Transformers > 4.55
|
| 1700 |
+
cache_length = past_key_values.get_seq_length()
|
| 1701 |
+
past_length = cache_length
|
| 1702 |
+
max_cache_length = past_key_values.get_max_cache_shape()
|
| 1703 |
+
elif isinstance(past_key_values, Cache):
|
| 1704 |
+
# 4.49 <= Transformers < 4.55
|
| 1705 |
cache_length = past_key_values.get_seq_length()
|
| 1706 |
past_length = past_key_values.seen_tokens
|
|
|
|
|
|
|
|
|
|
| 1707 |
max_cache_length = past_key_values.get_max_cache_shape()
|
| 1708 |
else:
|
| 1709 |
+
# Transformers < 4.49
|
| 1710 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
| 1711 |
max_cache_length = None
|
| 1712 |
|