Fix bug for transformer version

#1
by fishmingyu - opened
Files changed (2) hide show
  1. config.json +3 -3
  2. modeling_qwen.py +84 -19
config.json CHANGED
@@ -4,9 +4,9 @@
4
  ],
5
  "attention_dropout": 0.0,
6
  "auto_map": {
7
- "AutoModel": "Alibaba-NLP/gte-Qwen2-7B-instruct--modeling_qwen.Qwen2Model",
8
- "AutoModelForCausalLM": "Alibaba-NLP/gte-Qwen2-7B-instruct--modeling_qwen.Qwen2ForCausalLM",
9
- "AutoModelForSequenceClassification": "Alibaba-NLP/gte-Qwen2-7B-instruct--modeling_qwen.Qwen2ForSequenceClassification"
10
  },
11
  "bos_token_id": 151643,
12
  "eos_token_id": 151643,
 
4
  ],
5
  "attention_dropout": 0.0,
6
  "auto_map": {
7
+ "AutoModel": "modeling_qwen.Qwen2Model",
8
+ "AutoModelForCausalLM": "modeling_qwen.Qwen2ForCausalLM",
9
+ "AutoModelForSequenceClassification": "modeling_qwen.Qwen2ForSequenceClassification"
10
  },
11
  "bos_token_id": 151643,
12
  "eos_token_id": 151643,
modeling_qwen.py CHANGED
@@ -17,6 +17,9 @@
17
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
  # See the License for the specific language governing permissions and
19
  # limitations under the License.
 
 
 
20
  """ PyTorch Qwen2 model."""
21
  from transformers import Qwen2Config
22
  import inspect
@@ -46,11 +49,11 @@ from transformers.utils import (
46
  )
47
 
48
 
49
- if is_flash_attn_2_available():
50
- from flash_attn import flash_attn_func, flash_attn_varlen_func
51
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
52
 
53
- _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
54
 
55
 
56
  logger = logging.get_logger(__name__)
@@ -144,6 +147,7 @@ def rotate_half(x):
144
  # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
145
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
146
  """Applies Rotary Position Embedding to the query and key tensors.
 
147
  Args:
148
  q (`torch.Tensor`): The query tensor.
149
  k (`torch.Tensor`): The key tensor.
@@ -273,7 +277,9 @@ class Qwen2Attention(nn.Module):
273
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
274
  "with a layer index."
275
  )
276
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 
 
277
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
278
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
279
 
@@ -377,7 +383,9 @@ class Qwen2FlashAttention2(Qwen2Attention):
377
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
378
  "with a layer index."
379
  )
380
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 
 
381
 
382
  # Because the input can be padded, the absolute sequence length depends on the max position id.
383
  rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
@@ -494,6 +502,7 @@ class Qwen2FlashAttention2(Qwen2Attention):
494
  """
495
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
496
  first unpad the input, then computes the attention scores and pad the final attention scores.
 
497
  Args:
498
  query_states (`torch.Tensor`):
499
  Input query states to be passed to Flash Attention API
@@ -674,7 +683,9 @@ class Qwen2SdpaAttention(Qwen2Attention):
674
 
675
  kv_seq_len = key_states.shape[-2]
676
  if past_key_value is not None:
677
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 
 
678
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
679
 
680
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
@@ -807,9 +818,11 @@ QWEN2_START_DOCSTRING = r"""
807
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
808
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
809
  etc.)
 
810
  This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
811
  Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
812
  and behavior.
 
813
  Parameters:
814
  config ([`Qwen2Config`]):
815
  Model configuration class with all the parameters of the model. Initializing with a config file does not
@@ -849,38 +862,50 @@ QWEN2_INPUTS_DOCSTRING = r"""
849
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
850
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
851
  it.
 
852
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
853
  [`PreTrainedTokenizer.__call__`] for details.
 
854
  [What are input IDs?](../glossary#input-ids)
855
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
856
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
 
857
  - 1 for tokens that are **not masked**,
858
  - 0 for tokens that are **masked**.
 
859
  [What are attention masks?](../glossary#attention-mask)
 
860
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
861
  [`PreTrainedTokenizer.__call__`] for details.
 
862
  If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
863
  `past_key_values`).
 
864
  If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
865
  and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
866
  information on the default strategy.
 
867
  - 1 indicates the head is **not masked**,
868
  - 0 indicates the head is **masked**.
869
  position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
870
  Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
871
  config.n_positions - 1]`.
 
872
  [What are position IDs?](../glossary#position-ids)
873
  past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
874
  Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
875
  blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
876
  returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
 
877
  Two formats are allowed:
878
  - a [`~cache_utils.Cache`] instance;
879
  - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
880
  shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
881
  cache format.
 
882
  The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
883
  legacy cache format will be returned.
 
884
  If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
885
  have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
886
  of shape `(batch_size, sequence_length)`.
@@ -909,6 +934,7 @@ QWEN2_INPUTS_DOCSTRING = r"""
909
  class Qwen2Model(Qwen2PreTrainedModel):
910
  """
911
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
 
912
  Args:
913
  config: Qwen2Config
914
  """
@@ -955,7 +981,6 @@ class Qwen2Model(Qwen2PreTrainedModel):
955
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
956
  )
957
  use_cache = use_cache if use_cache is not None else self.config.use_cache
958
-
959
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
960
 
961
  # retrieve input_ids and inputs_embeds
@@ -976,12 +1001,28 @@ class Qwen2Model(Qwen2PreTrainedModel):
976
  use_cache = False
977
 
978
  past_key_values_length = 0
 
979
 
980
  if use_cache:
981
- use_legacy_cache = not isinstance(past_key_values, Cache)
982
- if use_legacy_cache:
 
 
 
 
 
 
 
 
 
983
  past_key_values = DynamicCache.from_legacy_cache(past_key_values)
984
- past_key_values_length = past_key_values.get_usable_length(seq_length)
 
 
 
 
 
 
985
 
986
  if position_ids is None:
987
  device = input_ids.device if input_ids is not None else inputs_embeds.device
@@ -1087,7 +1128,10 @@ class Qwen2Model(Qwen2PreTrainedModel):
1087
 
1088
  next_cache = None
1089
  if use_cache:
1090
- next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
 
 
 
1091
 
1092
  if not return_dict:
1093
  return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
@@ -1099,6 +1143,7 @@ class Qwen2Model(Qwen2PreTrainedModel):
1099
  )
1100
 
1101
 
 
1102
  class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1103
  _tied_weights_keys = ["lm_head.weight"]
1104
 
@@ -1151,14 +1196,20 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1151
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1152
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1153
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
 
1154
  Returns:
 
1155
  Example:
 
1156
  ```python
1157
  >>> from transformers import AutoTokenizer, Qwen2ForCausalLM
 
1158
  >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1159
  >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
 
1160
  >>> prompt = "Hey, are you conscious? Can you talk to me?"
1161
  >>> inputs = tokenizer(prompt, return_tensors="pt")
 
1162
  >>> # Generate
1163
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1164
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
@@ -1220,21 +1271,32 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1220
  # Omit tokens covered by past_key_values
1221
  if past_key_values is not None:
1222
  if isinstance(past_key_values, Cache):
 
1223
  cache_length = past_key_values.get_seq_length()
1224
- past_length = past_key_values.seen_tokens
1225
- max_cache_length = past_key_values.get_max_length()
 
 
 
 
 
 
 
 
 
1226
  else:
 
 
1227
  cache_length = past_length = past_key_values[0][0].shape[2]
1228
  max_cache_length = None
1229
 
1230
  # Keep only the unprocessed tokens:
1231
  # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1232
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1233
- # input)
1234
  if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1235
  input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1236
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1237
- # input_ids based on the past_length.
1238
  elif past_length < input_ids.shape[1]:
1239
  input_ids = input_ids[:, past_length:]
1240
  # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
@@ -1264,13 +1326,14 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1264
  model_inputs.update(
1265
  {
1266
  "position_ids": position_ids,
1267
- "past_key_values": past_key_values,
1268
  "use_cache": kwargs.get("use_cache"),
1269
  "attention_mask": attention_mask,
1270
  }
1271
  )
1272
  return model_inputs
1273
 
 
1274
  @staticmethod
1275
  def _reorder_cache(past_key_values, beam_idx):
1276
  reordered_past = ()
@@ -1284,8 +1347,10 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1284
  @add_start_docstrings(
1285
  """
1286
  The Qwen2 Model transformer with a sequence classification head on top (linear layer).
 
1287
  [`Qwen2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1288
  (e.g. GPT-2) do.
 
1289
  Since it does classification on the last token, it requires to know the position of the last token. If a
1290
  `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1291
  no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
 
17
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
  # See the License for the specific language governing permissions and
19
  # limitations under the License.
20
+
21
+ # includes edits by https://github.com/BBC-Esq to fix cache errors following transformers version post 4.53.3 major cache refactor
22
+
23
  """ PyTorch Qwen2 model."""
24
  from transformers import Qwen2Config
25
  import inspect
 
49
  )
50
 
51
 
52
+ # if is_flash_attn_2_available():
53
+ # from flash_attn import flash_attn_func, flash_attn_varlen_func
54
+ # from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
55
 
56
+ # _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
57
 
58
 
59
  logger = logging.get_logger(__name__)
 
147
  # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
148
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
149
  """Applies Rotary Position Embedding to the query and key tensors.
150
+
151
  Args:
152
  q (`torch.Tensor`): The query tensor.
153
  k (`torch.Tensor`): The key tensor.
 
277
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
278
  "with a layer index."
279
  )
280
+ # kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
281
+ past_len = past_key_value.get_seq_length(self.layer_idx) if past_key_value is not None else 0
282
+ kv_seq_len += past_len
283
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
284
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
285
 
 
383
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
384
  "with a layer index."
385
  )
386
+ # kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
387
+ past_len = past_key_value.get_seq_length(self.layer_idx) if past_key_value is not None else 0
388
+ kv_seq_len += past_len
389
 
390
  # Because the input can be padded, the absolute sequence length depends on the max position id.
391
  rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
 
502
  """
503
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
504
  first unpad the input, then computes the attention scores and pad the final attention scores.
505
+
506
  Args:
507
  query_states (`torch.Tensor`):
508
  Input query states to be passed to Flash Attention API
 
683
 
684
  kv_seq_len = key_states.shape[-2]
685
  if past_key_value is not None:
686
+ # kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
687
+ past_len = past_key_value.get_seq_length(self.layer_idx) if past_key_value is not None else 0
688
+ kv_seq_len += past_len
689
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
690
 
691
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
 
818
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
819
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
820
  etc.)
821
+
822
  This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
823
  Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
824
  and behavior.
825
+
826
  Parameters:
827
  config ([`Qwen2Config`]):
828
  Model configuration class with all the parameters of the model. Initializing with a config file does not
 
862
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
863
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
864
  it.
865
+
866
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
867
  [`PreTrainedTokenizer.__call__`] for details.
868
+
869
  [What are input IDs?](../glossary#input-ids)
870
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
871
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
872
+
873
  - 1 for tokens that are **not masked**,
874
  - 0 for tokens that are **masked**.
875
+
876
  [What are attention masks?](../glossary#attention-mask)
877
+
878
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
879
  [`PreTrainedTokenizer.__call__`] for details.
880
+
881
  If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
882
  `past_key_values`).
883
+
884
  If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
885
  and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
886
  information on the default strategy.
887
+
888
  - 1 indicates the head is **not masked**,
889
  - 0 indicates the head is **masked**.
890
  position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
891
  Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
892
  config.n_positions - 1]`.
893
+
894
  [What are position IDs?](../glossary#position-ids)
895
  past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
896
  Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
897
  blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
898
  returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
899
+
900
  Two formats are allowed:
901
  - a [`~cache_utils.Cache`] instance;
902
  - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
903
  shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
904
  cache format.
905
+
906
  The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
907
  legacy cache format will be returned.
908
+
909
  If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
910
  have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
911
  of shape `(batch_size, sequence_length)`.
 
934
  class Qwen2Model(Qwen2PreTrainedModel):
935
  """
936
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
937
+
938
  Args:
939
  config: Qwen2Config
940
  """
 
981
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
982
  )
983
  use_cache = use_cache if use_cache is not None else self.config.use_cache
 
984
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
985
 
986
  # retrieve input_ids and inputs_embeds
 
1001
  use_cache = False
1002
 
1003
  past_key_values_length = 0
1004
+ use_legacy_cache = False
1005
 
1006
  if use_cache:
1007
+ # OLD behavior (removed in HF >= 4.55): treat anything not Cache as "legacy" but then
1008
+ # directly used legacy methods on it (would crash if None or new API).
1009
+ # use_legacy_cache = not isinstance(past_key_values, Cache)
1010
+ # if use_legacy_cache:
1011
+ # # past_key_values_length = past_key_values.get_seq_length()
1012
+ # past_key_values_length = past_key_values.get_usable_length(seq_length)
1013
+
1014
+ # NEW behavior: if a legacy tuple is passed, convert it to the new Cache API,
1015
+ # compute length via .get_seq_length(), and remember to return legacy if that’s what came in.
1016
+ if past_key_values is not None and not isinstance(past_key_values, Cache):
1017
+ use_legacy_cache = True # remember input format for return
1018
  past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1019
+
1020
+ if isinstance(past_key_values, Cache):
1021
+ # Layer-agnostic total length; cache_position is handled deeper if needed
1022
+ past_key_values_length = past_key_values.get_seq_length()
1023
+ else:
1024
+ # No cache given on first forward, keep length at 0
1025
+ past_key_values_length = 0
1026
 
1027
  if position_ids is None:
1028
  device = input_ids.device if input_ids is not None else inputs_embeds.device
 
1128
 
1129
  next_cache = None
1130
  if use_cache:
1131
+ # If the caller passed legacy, return legacy. Otherwise return the Cache object.
1132
+ next_cache = (
1133
+ next_decoder_cache.to_legacy_cache() if (use_legacy_cache and next_decoder_cache is not None) else next_decoder_cache
1134
+ )
1135
 
1136
  if not return_dict:
1137
  return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
 
1143
  )
1144
 
1145
 
1146
+
1147
  class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1148
  _tied_weights_keys = ["lm_head.weight"]
1149
 
 
1196
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1197
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1198
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1199
+
1200
  Returns:
1201
+
1202
  Example:
1203
+
1204
  ```python
1205
  >>> from transformers import AutoTokenizer, Qwen2ForCausalLM
1206
+
1207
  >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1208
  >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1209
+
1210
  >>> prompt = "Hey, are you conscious? Can you talk to me?"
1211
  >>> inputs = tokenizer(prompt, return_tensors="pt")
1212
+
1213
  >>> # Generate
1214
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1215
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
 
1271
  # Omit tokens covered by past_key_values
1272
  if past_key_values is not None:
1273
  if isinstance(past_key_values, Cache):
1274
+ # NEW API (HF >= 4.55): use Cache methods
1275
  cache_length = past_key_values.get_seq_length()
1276
+ past_length = cache_length # `seen_tokens` removed; use total seq length instead
1277
+ try:
1278
+ max_cache_length = past_key_values.get_max_cache_shape()
1279
+ except Exception:
1280
+ max_cache_length = None
1281
+
1282
+ # OLD API (deprecated/removed):
1283
+ # cache_length = past_key_values.get_seq_length()
1284
+ # past_length = past_key_values.seen_tokens
1285
+ # max_cache_length = past_key_values.get_max_length()
1286
+
1287
  else:
1288
+ # Legacy tuple format: keep computing lengths directly from tensors
1289
+ # (We keep it compatible without forcing a conversion here)
1290
  cache_length = past_length = past_key_values[0][0].shape[2]
1291
  max_cache_length = None
1292
 
1293
  # Keep only the unprocessed tokens:
1294
  # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1295
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
 
1296
  if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1297
  input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1298
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens.
1299
+ # We can discard input_ids based on the past_length.
1300
  elif past_length < input_ids.shape[1]:
1301
  input_ids = input_ids[:, past_length:]
1302
  # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
 
1326
  model_inputs.update(
1327
  {
1328
  "position_ids": position_ids,
1329
+ "past_key_values": past_key_values, # pass through unchanged (legacy or new Cache object)
1330
  "use_cache": kwargs.get("use_cache"),
1331
  "attention_mask": attention_mask,
1332
  }
1333
  )
1334
  return model_inputs
1335
 
1336
+
1337
  @staticmethod
1338
  def _reorder_cache(past_key_values, beam_idx):
1339
  reordered_past = ()
 
1347
  @add_start_docstrings(
1348
  """
1349
  The Qwen2 Model transformer with a sequence classification head on top (linear layer).
1350
+
1351
  [`Qwen2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1352
  (e.g. GPT-2) do.
1353
+
1354
  Since it does classification on the last token, it requires to know the position of the last token. If a
1355
  `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1356
  no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the