Support `transformers>=4.56`

#11
Files changed (1) hide show
  1. modeling_plamo.py +11 -5
modeling_plamo.py CHANGED
@@ -1426,8 +1426,10 @@ class Plamo2Model(Plamo2PreTrainedModel):
1426
  past_key_values_prev = past_key_values
1427
  past_key_values = Plamo2Cache(self.config)
1428
 
1429
- # If `past_key_values` is a `DynamicCache` object, it must be empty
1430
- assert len(past_key_values_prev) == 0
 
 
1431
  assert isinstance(past_key_values, Plamo2Cache)
1432
  past_key_values_length = past_key_values.get_seq_length()
1433
  seq_length_with_past = seq_length_with_past + past_key_values_length
@@ -1633,7 +1635,11 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel):
1633
  image_features: Optional[torch.Tensor] = None,
1634
  **kwargs: Any,
1635
  ) -> Dict[str, Any]:
1636
- if past_key_values:
 
 
 
 
1637
  input_ids = input_ids[:, -1:]
1638
  if image_features is not None:
1639
  image_features = image_features[:, -1:, :]
@@ -1643,7 +1649,7 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel):
1643
  # create position_ids on the fly for batch generation
1644
  position_ids = attention_mask.long().cumsum(-1) - 1
1645
  position_ids.masked_fill_(attention_mask == 0, 1)
1646
- if past_key_values:
1647
  position_ids = position_ids[:, -1].unsqueeze(-1)
1648
 
1649
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
@@ -1714,4 +1720,4 @@ class Bias(nn.Module):
1714
  self,
1715
  x: torch.Tensor,
1716
  ) -> torch.Tensor:
1717
- return x + self._bias
 
1426
  past_key_values_prev = past_key_values
1427
  past_key_values = Plamo2Cache(self.config)
1428
 
1429
+ # If `past_key_values` is a `DynamicCache` object, it must be empty or all layer caches have zero sequence length.
1430
+ assert len(past_key_values_prev) == 0 or not any(
1431
+ layer_cache.get_seq_length() for layer_cache in past_key_values_prev.layers
1432
+ )
1433
  assert isinstance(past_key_values, Plamo2Cache)
1434
  past_key_values_length = past_key_values.get_seq_length()
1435
  seq_length_with_past = seq_length_with_past + past_key_values_length
 
1635
  image_features: Optional[torch.Tensor] = None,
1636
  **kwargs: Any,
1637
  ) -> Dict[str, Any]:
1638
+ # Starting from transformers v4.54, `DynamicCache` is passed to `past_key_values` during the prefill stage,
1639
+ # and its length becomes non-zero from v4.56 onward.
1640
+ # `Plamo2Model.forward` converts it into a `Plamo2Cache` on the first call,
1641
+ # se we use the type of `past_key_values` to distinguish between the prefill and decode stages.
1642
+ if isinstance(past_key_values, Plamo2Cache):
1643
  input_ids = input_ids[:, -1:]
1644
  if image_features is not None:
1645
  image_features = image_features[:, -1:, :]
 
1649
  # create position_ids on the fly for batch generation
1650
  position_ids = attention_mask.long().cumsum(-1) - 1
1651
  position_ids.masked_fill_(attention_mask == 0, 1)
1652
+ if isinstance(past_key_values, Plamo2Cache):
1653
  position_ids = position_ids[:, -1].unsqueeze(-1)
1654
 
1655
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
 
1720
  self,
1721
  x: torch.Tensor,
1722
  ) -> torch.Tensor:
1723
+ return x + self._bias