Mateusz Mróz commited on
Commit
8ba0d9d
·
1 Parent(s): d4318c2
Files changed (1) hide show
  1. modeling_florence2.py +12 -7
modeling_florence2.py CHANGED
@@ -831,8 +831,10 @@ class Florence2Attention(nn.Module):
831
  # reuse k, v, self_attention
832
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
833
  value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
834
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
835
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
 
 
836
  else:
837
  # self_attention
838
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
@@ -1173,6 +1175,7 @@ class Florence2SdpaAttention(Florence2Attention):
1173
  if (
1174
  is_cross_attention
1175
  and past_key_value is not None
 
1176
  and past_key_value[0].shape[2] == key_value_states.shape[1]
1177
  ):
1178
  # reuse k,v, cross_attentions
@@ -1186,8 +1189,10 @@ class Florence2SdpaAttention(Florence2Attention):
1186
  # reuse k, v, self_attention
1187
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
1188
  value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
1189
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
1190
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
 
 
1191
  else:
1192
  # self_attention
1193
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
@@ -1801,7 +1806,7 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
1801
  raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
1802
 
1803
  # past_key_values_length
1804
- past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
1805
 
1806
  if inputs_embeds is None:
1807
  inputs_embeds = self.embed_tokens(input)
@@ -2200,7 +2205,7 @@ class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel
2200
  **kwargs,
2201
  ):
2202
  # cut decoder_input_ids if past_key_values is used
2203
- if past_key_values is not None:
2204
  past_length = past_key_values[0][0].shape[2]
2205
 
2206
  # Some generation methods already pass only the last input ID
@@ -3062,7 +3067,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
3062
  **kwargs,
3063
  ):
3064
  # cut decoder_input_ids if past_key_values is used
3065
- if past_key_values is not None:
3066
  past_length = past_key_values[0][0].shape[2]
3067
 
3068
  # Some generation methods already pass only the last input ID
 
831
  # reuse k, v, self_attention
832
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
833
  value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
834
+ if past_key_value[0] is not None:
835
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
836
+ if past_key_value[1] is not None:
837
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
838
  else:
839
  # self_attention
840
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
 
1175
  if (
1176
  is_cross_attention
1177
  and past_key_value is not None
1178
+ and past_key_value[0] is not None
1179
  and past_key_value[0].shape[2] == key_value_states.shape[1]
1180
  ):
1181
  # reuse k,v, cross_attentions
 
1189
  # reuse k, v, self_attention
1190
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
1191
  value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
1192
+ if past_key_value[0] is not None:
1193
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
1194
+ if past_key_value[1] is not None:
1195
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
1196
  else:
1197
  # self_attention
1198
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
 
1806
  raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
1807
 
1808
  # past_key_values_length
1809
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values and past_key_values[0] and past_key_values[0][0] is not None else 0
1810
 
1811
  if inputs_embeds is None:
1812
  inputs_embeds = self.embed_tokens(input)
 
2205
  **kwargs,
2206
  ):
2207
  # cut decoder_input_ids if past_key_values is used
2208
+ if past_key_values is not None and past_key_values[0] is not None and past_key_values[0][0] is not None:
2209
  past_length = past_key_values[0][0].shape[2]
2210
 
2211
  # Some generation methods already pass only the last input ID
 
3067
  **kwargs,
3068
  ):
3069
  # cut decoder_input_ids if past_key_values is used
3070
+ if past_key_values is not None and past_key_values[0] is not None and past_key_values[0][0] is not None:
3071
  past_length = past_key_values[0][0].shape[2]
3072
 
3073
  # Some generation methods already pass only the last input ID