ChuuniZ commited on
Commit
5efdb02
·
verified ·
1 Parent(s): 912454f

Upload LLM/Florence-2-base-PromptGen-v2.0/modeling_florence2.py

Browse files
LLM/Florence-2-base-PromptGen-v2.0/modeling_florence2.py CHANGED
@@ -29,6 +29,12 @@ from einops import rearrange
29
  from timm.models.layers import DropPath, trunc_normal_
30
 
31
  from transformers.modeling_utils import PreTrainedModel
 
 
 
 
 
 
32
  from transformers.utils import (
33
  ModelOutput,
34
  add_start_docstrings,
@@ -812,6 +818,8 @@ class Florence2Attention(nn.Module):
812
  if (
813
  is_cross_attention
814
  and past_key_value is not None
 
 
815
  and past_key_value[0].shape[2] == key_value_states.shape[1]
816
  ):
817
  # reuse k,v, cross_attentions
@@ -821,7 +829,7 @@ class Florence2Attention(nn.Module):
821
  # cross_attentions
822
  key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
823
  value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
824
- elif past_key_value is not None:
825
  # reuse k, v, self_attention
826
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
827
  value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
@@ -954,6 +962,8 @@ class Florence2FlashAttention2(Florence2Attention):
954
  if (
955
  is_cross_attention
956
  and past_key_value is not None
 
 
957
  and past_key_value[0].shape[2] == key_value_states.shape[1]
958
  ):
959
  # reuse k,v, cross_attentions
@@ -963,7 +973,7 @@ class Florence2FlashAttention2(Florence2Attention):
963
  # cross_attentions
964
  key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
965
  value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
966
- elif past_key_value is not None:
967
  # reuse k, v, self_attention
968
  key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
969
  value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
@@ -985,7 +995,7 @@ class Florence2FlashAttention2(Florence2Attention):
985
  past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
986
 
987
  kv_seq_len = key_states.shape[-2]
988
- if past_key_value is not None:
989
  kv_seq_len += past_key_value[0].shape[-2]
990
 
991
  # In PEFT, usually we cast the layer norms in float32 for training stability reasons
@@ -1167,6 +1177,8 @@ class Florence2SdpaAttention(Florence2Attention):
1167
  if (
1168
  is_cross_attention
1169
  and past_key_value is not None
 
 
1170
  and past_key_value[0].shape[2] == key_value_states.shape[1]
1171
  ):
1172
  # reuse k,v, cross_attentions
@@ -1176,7 +1188,7 @@ class Florence2SdpaAttention(Florence2Attention):
1176
  # cross_attentions
1177
  key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
1178
  value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
1179
- elif past_key_value is not None:
1180
  # reuse k, v, self_attention
1181
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
1182
  value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
@@ -1795,7 +1807,7 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
1795
  raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
1796
 
1797
  # past_key_values_length
1798
- past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
1799
 
1800
  if inputs_embeds is None:
1801
  inputs_embeds = self.embed_tokens(input)
@@ -2059,10 +2071,14 @@ class Florence2LanguageModel(Florence2LanguagePreTrainedModel):
2059
  )
2060
 
2061
 
2062
- class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel):
2063
  base_model_prefix = "model"
2064
  _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
2065
  _keys_to_ignore_on_load_missing = ["final_logits_bias"]
 
 
 
 
2066
 
2067
  def __init__(self, config: Florence2LanguageConfig):
2068
  super().__init__(config)
@@ -2194,7 +2210,7 @@ class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel
2194
  **kwargs,
2195
  ):
2196
  # cut decoder_input_ids if past_key_values is used
2197
- if past_key_values is not None:
2198
  past_length = past_key_values[0][0].shape[2]
2199
 
2200
  # Some generation methods already pass only the last input ID
@@ -2530,6 +2546,11 @@ class Florence2VisionModelWithProjection(Florence2PreTrainedModel):
2530
  FLORENCE2_START_DOCSTRING,
2531
  )
2532
  class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
 
 
 
 
 
2533
  def __init__(self, config: Florence2Config):
2534
  super().__init__(config)
2535
  assert config.vision_config.model_type == 'davit', 'only DaViT is supported for now'
@@ -2814,7 +2835,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2814
  **kwargs,
2815
  ):
2816
  # cut decoder_input_ids if past_key_values is used
2817
- if past_key_values is not None:
2818
  past_length = past_key_values[0][0].shape[2]
2819
 
2820
  # Some generation methods already pass only the last input ID
 
29
  from timm.models.layers import DropPath, trunc_normal_
30
 
31
  from transformers.modeling_utils import PreTrainedModel
32
+ try:
33
+ # Try new import path first (transformers >= 4.40.0)
34
+ from transformers.generation import GenerationMixin
35
+ except ImportError:
36
+ # Fallback to old import path (transformers < 4.40.0)
37
+ from transformers.generation.utils import GenerationMixin
38
  from transformers.utils import (
39
  ModelOutput,
40
  add_start_docstrings,
 
818
  if (
819
  is_cross_attention
820
  and past_key_value is not None
821
+ and past_key_value[0] is not None
822
+ and past_key_value[1] is not None
823
  and past_key_value[0].shape[2] == key_value_states.shape[1]
824
  ):
825
  # reuse k,v, cross_attentions
 
829
  # cross_attentions
830
  key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
831
  value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
832
+ elif past_key_value is not None and past_key_value[0] is not None and past_key_value[1] is not None:
833
  # reuse k, v, self_attention
834
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
835
  value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
 
962
  if (
963
  is_cross_attention
964
  and past_key_value is not None
965
+ and past_key_value[0] is not None
966
+ and past_key_value[1] is not None
967
  and past_key_value[0].shape[2] == key_value_states.shape[1]
968
  ):
969
  # reuse k,v, cross_attentions
 
973
  # cross_attentions
974
  key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
975
  value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
976
+ elif past_key_value is not None and past_key_value[0] is not None and past_key_value[1] is not None:
977
  # reuse k, v, self_attention
978
  key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
979
  value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
 
995
  past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
996
 
997
  kv_seq_len = key_states.shape[-2]
998
+ if past_key_value is not None and past_key_value[0] is not None:
999
  kv_seq_len += past_key_value[0].shape[-2]
1000
 
1001
  # In PEFT, usually we cast the layer norms in float32 for training stability reasons
 
1177
  if (
1178
  is_cross_attention
1179
  and past_key_value is not None
1180
+ and past_key_value[0] is not None
1181
+ and past_key_value[1] is not None
1182
  and past_key_value[0].shape[2] == key_value_states.shape[1]
1183
  ):
1184
  # reuse k,v, cross_attentions
 
1188
  # cross_attentions
1189
  key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
1190
  value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
1191
+ elif past_key_value is not None and past_key_value[0] is not None and past_key_value[1] is not None:
1192
  # reuse k, v, self_attention
1193
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
1194
  value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
 
1807
  raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
1808
 
1809
  # past_key_values_length
1810
+ past_key_values_length = past_key_values[0][0].shape[2] if (past_key_values is not None and past_key_values[0] is not None and past_key_values[0][0] is not None) else 0
1811
 
1812
  if inputs_embeds is None:
1813
  inputs_embeds = self.embed_tokens(input)
 
2071
  )
2072
 
2073
 
2074
+ class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel, GenerationMixin):
2075
  base_model_prefix = "model"
2076
  _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
2077
  _keys_to_ignore_on_load_missing = ["final_logits_bias"]
2078
+ # Add support for new transformers versions
2079
+ _supports_sdpa = True
2080
+ _supports_flash_attn_2 = False
2081
+ _supports_sdpa_4d_causal_mask = True
2082
 
2083
  def __init__(self, config: Florence2LanguageConfig):
2084
  super().__init__(config)
 
2210
  **kwargs,
2211
  ):
2212
  # cut decoder_input_ids if past_key_values is used
2213
+ if past_key_values is not None and past_key_values[0] is not None and past_key_values[0][0] is not None:
2214
  past_length = past_key_values[0][0].shape[2]
2215
 
2216
  # Some generation methods already pass only the last input ID
 
2546
  FLORENCE2_START_DOCSTRING,
2547
  )
2548
  class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2549
+ # Add support for new transformers versions
2550
+ _supports_sdpa = True
2551
+ _supports_flash_attn_2 = False
2552
+ _supports_sdpa_4d_causal_mask = True
2553
+
2554
  def __init__(self, config: Florence2Config):
2555
  super().__init__(config)
2556
  assert config.vision_config.model_type == 'davit', 'only DaViT is supported for now'
 
2835
  **kwargs,
2836
  ):
2837
  # cut decoder_input_ids if past_key_values is used
2838
+ if past_key_values is not None and past_key_values[0] is not None and past_key_values[0][0] is not None:
2839
  past_length = past_key_values[0][0].shape[2]
2840
 
2841
  # Some generation methods already pass only the last input ID