Fix KV cache compatibility with transformers 4.50+
#42
by
kebabman
- opened
- modeling_florence2.py +10 -10
modeling_florence2.py
CHANGED
|
@@ -794,7 +794,7 @@ class Florence2Attention(nn.Module):
|
|
| 794 |
if (
|
| 795 |
is_cross_attention
|
| 796 |
and past_key_value is not None
|
| 797 |
-
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
| 798 |
):
|
| 799 |
# reuse k,v, cross_attentions
|
| 800 |
key_states = past_key_value[0]
|
|
@@ -803,7 +803,7 @@ class Florence2Attention(nn.Module):
|
|
| 803 |
# cross_attentions
|
| 804 |
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
| 805 |
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
| 806 |
-
elif past_key_value is not None:
|
| 807 |
# reuse k, v, self_attention
|
| 808 |
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
| 809 |
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
|
@@ -936,7 +936,7 @@ class Florence2FlashAttention2(Florence2Attention):
|
|
| 936 |
if (
|
| 937 |
is_cross_attention
|
| 938 |
and past_key_value is not None
|
| 939 |
-
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
| 940 |
):
|
| 941 |
# reuse k,v, cross_attentions
|
| 942 |
key_states = past_key_value[0].transpose(1, 2)
|
|
@@ -945,7 +945,7 @@ class Florence2FlashAttention2(Florence2Attention):
|
|
| 945 |
# cross_attentions
|
| 946 |
key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
|
| 947 |
value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
|
| 948 |
-
elif past_key_value is not None:
|
| 949 |
# reuse k, v, self_attention
|
| 950 |
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
|
| 951 |
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
|
|
@@ -968,7 +968,7 @@ class Florence2FlashAttention2(Florence2Attention):
|
|
| 968 |
|
| 969 |
kv_seq_len = key_states.shape[-2]
|
| 970 |
if past_key_value is not None:
|
| 971 |
-
kv_seq_len += past_key_value[0].shape[-2]
|
| 972 |
|
| 973 |
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
| 974 |
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
|
@@ -1149,7 +1149,7 @@ class Florence2SdpaAttention(Florence2Attention):
|
|
| 1149 |
if (
|
| 1150 |
is_cross_attention
|
| 1151 |
and past_key_value is not None
|
| 1152 |
-
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
| 1153 |
):
|
| 1154 |
# reuse k,v, cross_attentions
|
| 1155 |
key_states = past_key_value[0]
|
|
@@ -1158,7 +1158,7 @@ class Florence2SdpaAttention(Florence2Attention):
|
|
| 1158 |
# cross_attentions
|
| 1159 |
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
| 1160 |
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
| 1161 |
-
elif past_key_value is not None:
|
| 1162 |
# reuse k, v, self_attention
|
| 1163 |
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
| 1164 |
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
|
@@ -1788,7 +1788,7 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
|
|
| 1788 |
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
| 1789 |
|
| 1790 |
# past_key_values_length
|
| 1791 |
-
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
| 1792 |
|
| 1793 |
if inputs_embeds is None:
|
| 1794 |
inputs_embeds = self.embed_tokens(input)
|
|
@@ -2193,7 +2193,7 @@ class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel
|
|
| 2193 |
**kwargs,
|
| 2194 |
):
|
| 2195 |
# cut decoder_input_ids if past_key_values is used
|
| 2196 |
-
if past_key_values is not None:
|
| 2197 |
past_length = past_key_values[0][0].shape[2]
|
| 2198 |
|
| 2199 |
# Some generation methods already pass only the last input ID
|
|
@@ -2813,7 +2813,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2813 |
**kwargs,
|
| 2814 |
):
|
| 2815 |
# cut decoder_input_ids if past_key_values is used
|
| 2816 |
-
if past_key_values is not None:
|
| 2817 |
past_length = past_key_values[0][0].shape[2]
|
| 2818 |
|
| 2819 |
# Some generation methods already pass only the last input ID
|
|
|
|
| 794 |
if (
|
| 795 |
is_cross_attention
|
| 796 |
and past_key_value is not None
|
| 797 |
+
and past_key_value[0] is not None and past_key_value[0].shape[2] == key_value_states.shape[1]
|
| 798 |
):
|
| 799 |
# reuse k,v, cross_attentions
|
| 800 |
key_states = past_key_value[0]
|
|
|
|
| 803 |
# cross_attentions
|
| 804 |
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
| 805 |
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
| 806 |
+
elif past_key_value is not None and past_key_value[0] is not None:
|
| 807 |
# reuse k, v, self_attention
|
| 808 |
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
| 809 |
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
|
|
|
| 936 |
if (
|
| 937 |
is_cross_attention
|
| 938 |
and past_key_value is not None
|
| 939 |
+
and past_key_value[0] is not None and past_key_value[0].shape[2] == key_value_states.shape[1]
|
| 940 |
):
|
| 941 |
# reuse k,v, cross_attentions
|
| 942 |
key_states = past_key_value[0].transpose(1, 2)
|
|
|
|
| 945 |
# cross_attentions
|
| 946 |
key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
|
| 947 |
value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
|
| 948 |
+
elif past_key_value is not None and past_key_value[0] is not None:
|
| 949 |
# reuse k, v, self_attention
|
| 950 |
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
|
| 951 |
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
|
|
|
|
| 968 |
|
| 969 |
kv_seq_len = key_states.shape[-2]
|
| 970 |
if past_key_value is not None:
|
| 971 |
+
if past_key_value[0] is not None: kv_seq_len += past_key_value[0].shape[-2]
|
| 972 |
|
| 973 |
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
| 974 |
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
|
|
|
| 1149 |
if (
|
| 1150 |
is_cross_attention
|
| 1151 |
and past_key_value is not None
|
| 1152 |
+
and past_key_value[0] is not None and past_key_value[0].shape[2] == key_value_states.shape[1]
|
| 1153 |
):
|
| 1154 |
# reuse k,v, cross_attentions
|
| 1155 |
key_states = past_key_value[0]
|
|
|
|
| 1158 |
# cross_attentions
|
| 1159 |
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
| 1160 |
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
| 1161 |
+
elif past_key_value is not None and past_key_value[0] is not None:
|
| 1162 |
# reuse k, v, self_attention
|
| 1163 |
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
| 1164 |
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
|
|
|
| 1788 |
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
| 1789 |
|
| 1790 |
# past_key_values_length
|
| 1791 |
+
past_key_values_length = past_key_values[0][0].shape[2] if (past_key_values is not None and len(past_key_values) > 0 and past_key_values[0][0] is not None) else 0
|
| 1792 |
|
| 1793 |
if inputs_embeds is None:
|
| 1794 |
inputs_embeds = self.embed_tokens(input)
|
|
|
|
| 2193 |
**kwargs,
|
| 2194 |
):
|
| 2195 |
# cut decoder_input_ids if past_key_values is used
|
| 2196 |
+
if past_key_values is not None and len(past_key_values) > 0 and past_key_values[0][0] is not None:
|
| 2197 |
past_length = past_key_values[0][0].shape[2]
|
| 2198 |
|
| 2199 |
# Some generation methods already pass only the last input ID
|
|
|
|
| 2813 |
**kwargs,
|
| 2814 |
):
|
| 2815 |
# cut decoder_input_ids if past_key_values is used
|
| 2816 |
+
if past_key_values is not None and len(past_key_values) > 0 and past_key_values[0][0] is not None:
|
| 2817 |
past_length = past_key_values[0][0].shape[2]
|
| 2818 |
|
| 2819 |
# Some generation methods already pass only the last input ID
|