support transformers 4.57

#3
by xf2022 - opened
Files changed (2) hide show
  1. config.json +1 -1
  2. modeling_qwen2_flash.py +45 -7
config.json CHANGED
@@ -200,7 +200,7 @@
200
  "mm_vision_select_layer": -2,
201
  "mm_vision_tower": "internvideo2",
202
  "mm_vision_tower_lr": 2e-06,
203
- "model_type": "qwen2",
204
  "num_attention_heads": 28,
205
  "num_hidden_layers": 28,
206
  "num_key_value_heads": 4,
 
200
  "mm_vision_select_layer": -2,
201
  "mm_vision_tower": "internvideo2",
202
  "mm_vision_tower_lr": 2e-06,
203
+ "model_type": "videochat_flash_qwen",
204
  "num_attention_heads": 28,
205
  "num_hidden_layers": 28,
206
  "num_key_value_heads": 4,
modeling_qwen2_flash.py CHANGED
@@ -276,7 +276,11 @@ class Qwen2Attention(nn.Module):
276
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
277
  "with a layer index."
278
  )
279
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 
 
 
 
280
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
281
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
282
 
@@ -379,7 +383,11 @@ class Qwen2FlashAttention2(Qwen2Attention):
379
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
380
  "with a layer index."
381
  )
382
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 
 
 
 
383
 
384
  # Because the input can be padded, the absolute sequence length depends on the max position id.
385
  rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
@@ -673,7 +681,11 @@ class Qwen2SdpaAttention(Qwen2Attention):
673
 
674
  kv_seq_len = key_states.shape[-2]
675
  if past_key_value is not None:
676
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 
 
 
 
677
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
678
 
679
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
@@ -993,7 +1005,11 @@ class Qwen2Model_Flash(Qwen2PreTrainedModel):
993
  use_legacy_cache = not isinstance(past_key_values, Cache)
994
  if use_legacy_cache:
995
  past_key_values = DynamicCache.from_legacy_cache(past_key_values)
996
- past_key_values_length = past_key_values.get_usable_length(seq_length)
 
 
 
 
997
 
998
  if position_ids is None:
999
  device = input_ids.device if input_ids is not None else inputs_embeds.device
@@ -1483,8 +1499,16 @@ class Qwen2ForCausalLM_Flash(Qwen2PreTrainedModel):
1483
  if past_key_values is not None:
1484
  if isinstance(past_key_values, Cache):
1485
  cache_length = past_key_values.get_seq_length()
1486
- past_length = past_key_values.seen_tokens
1487
- max_cache_length = past_key_values.get_max_length()
 
 
 
 
 
 
 
 
1488
  else:
1489
  cache_length = past_length = past_key_values[0][0].shape[2]
1490
  max_cache_length = None
@@ -1517,8 +1541,22 @@ class Qwen2ForCausalLM_Flash(Qwen2PreTrainedModel):
1517
  if past_key_values:
1518
  position_ids = position_ids[:, -input_ids.shape[1] :]
1519
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1520
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1521
- if inputs_embeds is not None and past_key_values is None:
 
1522
  model_inputs = {"inputs_embeds": inputs_embeds}
1523
  else:
1524
  model_inputs = {"input_ids": input_ids}
 
276
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
277
  "with a layer index."
278
  )
279
+ # get_usable_length has been removed in transformers 4.54.0
280
+ if hasattr(past_key_value, "get_usable_length"):
281
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
282
+ else:
283
+ kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
284
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
285
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
286
 
 
383
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
384
  "with a layer index."
385
  )
386
+ # get_usable_length has been removed in transformers 4.54.0
387
+ if hasattr(past_key_value, "get_usable_length"):
388
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
389
+ else:
390
+ kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
391
 
392
  # Because the input can be padded, the absolute sequence length depends on the max position id.
393
  rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
 
681
 
682
  kv_seq_len = key_states.shape[-2]
683
  if past_key_value is not None:
684
+ # get_usable_length has been removed in transformers 4.54.0
685
+ if hasattr(past_key_value, "get_usable_length"):
686
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
687
+ else:
688
+ kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
689
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
690
 
691
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
 
1005
  use_legacy_cache = not isinstance(past_key_values, Cache)
1006
  if use_legacy_cache:
1007
  past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1008
+ # get_usable_length has been removed in transformers 4.54.0
1009
+ if hasattr(past_key_values, "get_usable_length"):
1010
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
1011
+ else:
1012
+ past_key_values_length = past_key_values.get_seq_length()
1013
 
1014
  if position_ids is None:
1015
  device = input_ids.device if input_ids is not None else inputs_embeds.device
 
1499
  if past_key_values is not None:
1500
  if isinstance(past_key_values, Cache):
1501
  cache_length = past_key_values.get_seq_length()
1502
+ # seen_tokens property has been removed in transformers 4.54.0
1503
+ past_length = getattr(past_key_values, 'seen_tokens', cache_length)
1504
+ # get_max_length() has been replaced by get_max_cache_shape() in transformers 4.49.0
1505
+ # in transformers 4.54.0, DynamicCache returns -1 instead of None to indicate no limit
1506
+ if hasattr(past_key_values, 'get_max_cache_shape'):
1507
+ max_cache_length = past_key_values.get_max_cache_shape()
1508
+ # Convert -1 to None for consistency with old behavior
1509
+ max_cache_length = None if max_cache_length == -1 else max_cache_length
1510
+ else:
1511
+ max_cache_length = past_key_values.get_max_length()
1512
  else:
1513
  cache_length = past_length = past_key_values[0][0].shape[2]
1514
  max_cache_length = None
 
1541
  if past_key_values:
1542
  position_ids = position_ids[:, -input_ids.shape[1] :]
1543
 
1544
+
1545
+ def is_cache_empty(past_key_values):
1546
+ if past_key_values is None or len(past_key_values) == 0:
1547
+ return True
1548
+ if hasattr(past_key_values, 'is_initialized'):
1549
+ return past_key_values.is_initialized == False
1550
+ if isinstance(past_key_values, Cache):
1551
+ for idx, layer in enumerate(past_key_values.layers):
1552
+ if past_key_values.get_seq_length(idx) > 0:
1553
+ return False
1554
+ return True
1555
+ return False
1556
+
1557
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1558
+ # in newer transformers versions, past_key_values can be an empty cache in the 1st generation step.
1559
+ if inputs_embeds is not None and is_cache_empty(past_key_values):
1560
  model_inputs = {"inputs_embeds": inputs_embeds}
1561
  else:
1562
  model_inputs = {"input_ids": input_ids}