Fix KV cache compatibility with transformers 4.50+

#42
Files changed (1) hide show
  1. modeling_florence2.py +10 -10
modeling_florence2.py CHANGED
@@ -794,7 +794,7 @@ class Florence2Attention(nn.Module):
794
  if (
795
  is_cross_attention
796
  and past_key_value is not None
797
- and past_key_value[0].shape[2] == key_value_states.shape[1]
798
  ):
799
  # reuse k,v, cross_attentions
800
  key_states = past_key_value[0]
@@ -803,7 +803,7 @@ class Florence2Attention(nn.Module):
803
  # cross_attentions
804
  key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
805
  value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
806
- elif past_key_value is not None:
807
  # reuse k, v, self_attention
808
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
809
  value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
@@ -936,7 +936,7 @@ class Florence2FlashAttention2(Florence2Attention):
936
  if (
937
  is_cross_attention
938
  and past_key_value is not None
939
- and past_key_value[0].shape[2] == key_value_states.shape[1]
940
  ):
941
  # reuse k,v, cross_attentions
942
  key_states = past_key_value[0].transpose(1, 2)
@@ -945,7 +945,7 @@ class Florence2FlashAttention2(Florence2Attention):
945
  # cross_attentions
946
  key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
947
  value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
948
- elif past_key_value is not None:
949
  # reuse k, v, self_attention
950
  key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
951
  value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
@@ -968,7 +968,7 @@ class Florence2FlashAttention2(Florence2Attention):
968
 
969
  kv_seq_len = key_states.shape[-2]
970
  if past_key_value is not None:
971
- kv_seq_len += past_key_value[0].shape[-2]
972
 
973
  # In PEFT, usually we cast the layer norms in float32 for training stability reasons
974
  # therefore the input hidden states gets silently casted in float32. Hence, we need
@@ -1149,7 +1149,7 @@ class Florence2SdpaAttention(Florence2Attention):
1149
  if (
1150
  is_cross_attention
1151
  and past_key_value is not None
1152
- and past_key_value[0].shape[2] == key_value_states.shape[1]
1153
  ):
1154
  # reuse k,v, cross_attentions
1155
  key_states = past_key_value[0]
@@ -1158,7 +1158,7 @@ class Florence2SdpaAttention(Florence2Attention):
1158
  # cross_attentions
1159
  key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
1160
  value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
1161
- elif past_key_value is not None:
1162
  # reuse k, v, self_attention
1163
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
1164
  value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
@@ -1788,7 +1788,7 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
1788
  raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
1789
 
1790
  # past_key_values_length
1791
- past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
1792
 
1793
  if inputs_embeds is None:
1794
  inputs_embeds = self.embed_tokens(input)
@@ -2193,7 +2193,7 @@ class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel
2193
  **kwargs,
2194
  ):
2195
  # cut decoder_input_ids if past_key_values is used
2196
- if past_key_values is not None:
2197
  past_length = past_key_values[0][0].shape[2]
2198
 
2199
  # Some generation methods already pass only the last input ID
@@ -2813,7 +2813,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2813
  **kwargs,
2814
  ):
2815
  # cut decoder_input_ids if past_key_values is used
2816
- if past_key_values is not None:
2817
  past_length = past_key_values[0][0].shape[2]
2818
 
2819
  # Some generation methods already pass only the last input ID
 
794
  if (
795
  is_cross_attention
796
  and past_key_value is not None
797
+ and past_key_value[0] is not None and past_key_value[0].shape[2] == key_value_states.shape[1]
798
  ):
799
  # reuse k,v, cross_attentions
800
  key_states = past_key_value[0]
 
803
  # cross_attentions
804
  key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
805
  value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
806
+ elif past_key_value is not None and past_key_value[0] is not None:
807
  # reuse k, v, self_attention
808
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
809
  value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
 
936
  if (
937
  is_cross_attention
938
  and past_key_value is not None
939
+ and past_key_value[0] is not None and past_key_value[0].shape[2] == key_value_states.shape[1]
940
  ):
941
  # reuse k,v, cross_attentions
942
  key_states = past_key_value[0].transpose(1, 2)
 
945
  # cross_attentions
946
  key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
947
  value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
948
+ elif past_key_value is not None and past_key_value[0] is not None:
949
  # reuse k, v, self_attention
950
  key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
951
  value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
 
968
 
969
  kv_seq_len = key_states.shape[-2]
970
  if past_key_value is not None:
971
+ if past_key_value[0] is not None: kv_seq_len += past_key_value[0].shape[-2]
972
 
973
  # In PEFT, usually we cast the layer norms in float32 for training stability reasons
974
  # therefore the input hidden states gets silently casted in float32. Hence, we need
 
1149
  if (
1150
  is_cross_attention
1151
  and past_key_value is not None
1152
+ and past_key_value[0] is not None and past_key_value[0].shape[2] == key_value_states.shape[1]
1153
  ):
1154
  # reuse k,v, cross_attentions
1155
  key_states = past_key_value[0]
 
1158
  # cross_attentions
1159
  key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
1160
  value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
1161
+ elif past_key_value is not None and past_key_value[0] is not None:
1162
  # reuse k, v, self_attention
1163
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
1164
  value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
 
1788
  raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
1789
 
1790
  # past_key_values_length
1791
+ past_key_values_length = past_key_values[0][0].shape[2] if (past_key_values is not None and len(past_key_values) > 0 and past_key_values[0][0] is not None) else 0
1792
 
1793
  if inputs_embeds is None:
1794
  inputs_embeds = self.embed_tokens(input)
 
2193
  **kwargs,
2194
  ):
2195
  # cut decoder_input_ids if past_key_values is used
2196
+ if past_key_values is not None and len(past_key_values) > 0 and past_key_values[0][0] is not None:
2197
  past_length = past_key_values[0][0].shape[2]
2198
 
2199
  # Some generation methods already pass only the last input ID
 
2813
  **kwargs,
2814
  ):
2815
  # cut decoder_input_ids if past_key_values is used
2816
+ if past_key_values is not None and len(past_key_values) > 0 and past_key_values[0][0] is not None:
2817
  past_length = past_key_values[0][0].shape[2]
2818
 
2819
  # Some generation methods already pass only the last input ID