yhirokawa commited on
Commit
07d01cb
·
verified ·
1 Parent(s): 922a883

Support transformers>=4.56 (#4)

Browse files

- Support transformers>=4.56 (a2dc46a3514d46d18a36702357e78d9d1f6f431d)

Files changed (2) hide show
  1. README.md +3 -3
  2. modeling_plamo.py +21 -5
README.md CHANGED
@@ -177,9 +177,9 @@ Please check the PLaMo community license and contact us via the following form t
177
  ```
178
  numpy>=1.26.4
179
  numba>=0.60.0
180
- torch>=2.4.1
181
- transformers>=4.44.2
182
- mamba_ssm>=2.2.2
183
  causal_conv1d>=1.4.0
184
  ```
185
 
 
177
  ```
178
  numpy>=1.26.4
179
  numba>=0.60.0
180
+ torch<=2.5.1
181
+ transformers>=4.44.2,<=4.57.1
182
+ mamba_ssm>=2.2.2,<=2.2.4
183
  causal_conv1d>=1.4.0
184
  ```
185
 
modeling_plamo.py CHANGED
@@ -19,6 +19,7 @@ import torch
19
  from torch import nn
20
  from torch.nn import functional as F
21
  from transformers import PretrainedConfig, PreTrainedModel
 
22
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
23
 
24
 
@@ -327,7 +328,8 @@ class Plamo2Cache(torch.nn.Module):
327
  if sequence_length is not None
328
  else layer_cache.key.shape[2]
329
  )
330
- assert sequence_length is not None
 
331
  return sequence_length
332
 
333
  def get_max_length(self) -> int | None:
@@ -1387,7 +1389,7 @@ class Plamo2Model(Plamo2PreTrainedModel):
1387
  input_ids: Optional[torch.LongTensor] = None,
1388
  attention_mask: Optional[torch.Tensor] = None,
1389
  position_ids: Optional[torch.Tensor] = None,
1390
- past_key_values: Optional[Plamo2Cache] = None,
1391
  inputs_embeds: Optional[torch.Tensor] = None,
1392
  image_features: Optional[torch.Tensor] = None,
1393
  use_cache: Optional[bool] = None,
@@ -1419,6 +1421,16 @@ class Plamo2Model(Plamo2PreTrainedModel):
1419
  seq_length_with_past = seq_length
1420
  past_key_values_length = 0
1421
  if past_key_values is not None:
 
 
 
 
 
 
 
 
 
 
1422
  past_key_values_length = past_key_values.get_seq_length()
1423
  seq_length_with_past = seq_length_with_past + past_key_values_length
1424
  assert cache_position is None, "cache_position is not supported yet"
@@ -1623,7 +1635,11 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel):
1623
  image_features: Optional[torch.Tensor] = None,
1624
  **kwargs: Any,
1625
  ) -> Dict[str, Any]:
1626
- if past_key_values:
 
 
 
 
1627
  input_ids = input_ids[:, -1:]
1628
  if image_features is not None:
1629
  image_features = image_features[:, -1:, :]
@@ -1633,7 +1649,7 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel):
1633
  # create position_ids on the fly for batch generation
1634
  position_ids = attention_mask.long().cumsum(-1) - 1
1635
  position_ids.masked_fill_(attention_mask == 0, 1)
1636
- if past_key_values:
1637
  position_ids = position_ids[:, -1].unsqueeze(-1)
1638
 
1639
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
@@ -1704,4 +1720,4 @@ class Bias(nn.Module):
1704
  self,
1705
  x: torch.Tensor,
1706
  ) -> torch.Tensor:
1707
- return x + self._bias
 
19
  from torch import nn
20
  from torch.nn import functional as F
21
  from transformers import PretrainedConfig, PreTrainedModel
22
+ from transformers.cache_utils import DynamicCache
23
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
24
 
25
 
 
328
  if sequence_length is not None
329
  else layer_cache.key.shape[2]
330
  )
331
+ if sequence_length is None:
332
+ return 0
333
  return sequence_length
334
 
335
  def get_max_length(self) -> int | None:
 
1389
  input_ids: Optional[torch.LongTensor] = None,
1390
  attention_mask: Optional[torch.Tensor] = None,
1391
  position_ids: Optional[torch.Tensor] = None,
1392
+ past_key_values: Optional[Plamo2Cache | DynamicCache] = None,
1393
  inputs_embeds: Optional[torch.Tensor] = None,
1394
  image_features: Optional[torch.Tensor] = None,
1395
  use_cache: Optional[bool] = None,
 
1421
  seq_length_with_past = seq_length
1422
  past_key_values_length = 0
1423
  if past_key_values is not None:
1424
+ # In some `transformers` versions, `past_key_values` may be a `DynamicCache` object.
1425
+ if not isinstance(past_key_values, Plamo2Cache):
1426
+ past_key_values_prev = past_key_values
1427
+ past_key_values = Plamo2Cache(self.config)
1428
+
1429
+ # If `past_key_values` is a `DynamicCache` object, it must be empty or all layer caches have zero sequence length.
1430
+ assert len(past_key_values_prev) == 0 or not any(
1431
+ layer_cache.get_seq_length() for layer_cache in past_key_values_prev.layers
1432
+ )
1433
+ assert isinstance(past_key_values, Plamo2Cache)
1434
  past_key_values_length = past_key_values.get_seq_length()
1435
  seq_length_with_past = seq_length_with_past + past_key_values_length
1436
  assert cache_position is None, "cache_position is not supported yet"
 
1635
  image_features: Optional[torch.Tensor] = None,
1636
  **kwargs: Any,
1637
  ) -> Dict[str, Any]:
1638
+ # Starting from transformers v4.54, `DynamicCache` is passed to `past_key_values` during the prefill stage,
1639
+ # and its length becomes non-zero from v4.56 onward.
1640
+ # `Plamo2Model.forward` converts it into a `Plamo2Cache` on the first call,
1641
+ # se we use the type of `past_key_values` to distinguish between the prefill and decode stages.
1642
+ if isinstance(past_key_values, Plamo2Cache):
1643
  input_ids = input_ids[:, -1:]
1644
  if image_features is not None:
1645
  image_features = image_features[:, -1:, :]
 
1649
  # create position_ids on the fly for batch generation
1650
  position_ids = attention_mask.long().cumsum(-1) - 1
1651
  position_ids.masked_fill_(attention_mask == 0, 1)
1652
+ if isinstance(past_key_values, Plamo2Cache):
1653
  position_ids = position_ids[:, -1].unsqueeze(-1)
1654
 
1655
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
 
1720
  self,
1721
  x: torch.Tensor,
1722
  ) -> torch.Tensor:
1723
+ return x + self._bias