Files changed (2) hide show
  1. README.md +3 -3
  2. modeling_plamo.py +5 -14
README.md CHANGED
@@ -27,9 +27,9 @@ PLaMo 2 1B is released under Apache License version 2.0.
27
  ```
28
  numpy>=1.26.4
29
  numba>=0.60.0
30
- torch<=2.5.1
31
- transformers>=4.44.2,<=4.57.1
32
- mamba_ssm>=2.2.2,<=2.2.4
33
  causal_conv1d>=1.4.0
34
  ```
35
 
 
27
  ```
28
  numpy>=1.26.4
29
  numba>=0.60.0
30
+ torch>=2.4.1
31
+ transformers>=4.44.2
32
+ mamba_ssm>=2.2.2
33
  causal_conv1d>=1.4.0
34
  ```
35
 
modeling_plamo.py CHANGED
@@ -1426,10 +1426,8 @@ class Plamo2Model(Plamo2PreTrainedModel):
1426
  past_key_values_prev = past_key_values
1427
  past_key_values = Plamo2Cache(self.config)
1428
 
1429
- # If `past_key_values` is a `DynamicCache` object, it must be empty or all layer caches have zero sequence length.
1430
- assert len(past_key_values_prev) == 0 or not any(
1431
- layer_cache.get_seq_length() for layer_cache in past_key_values_prev.layers
1432
- )
1433
  assert isinstance(past_key_values, Plamo2Cache)
1434
  past_key_values_length = past_key_values.get_seq_length()
1435
  seq_length_with_past = seq_length_with_past + past_key_values_length
@@ -1635,11 +1633,7 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel):
1635
  image_features: Optional[torch.Tensor] = None,
1636
  **kwargs: Any,
1637
  ) -> Dict[str, Any]:
1638
- # Starting from transformers v4.54, `DynamicCache` is passed to `past_key_values` during the prefill stage,
1639
- # and its length becomes non-zero from v4.56 onward.
1640
- # `Plamo2Model.forward` converts it into a `Plamo2Cache` on the first call,
1641
- # se we use the type of `past_key_values` to distinguish between the prefill and decode stages.
1642
- if isinstance(past_key_values, Plamo2Cache):
1643
  input_ids = input_ids[:, -1:]
1644
  if image_features is not None:
1645
  image_features = image_features[:, -1:, :]
@@ -1649,7 +1643,7 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel):
1649
  # create position_ids on the fly for batch generation
1650
  position_ids = attention_mask.long().cumsum(-1) - 1
1651
  position_ids.masked_fill_(attention_mask == 0, 1)
1652
- if isinstance(past_key_values, Plamo2Cache):
1653
  position_ids = position_ids[:, -1].unsqueeze(-1)
1654
 
1655
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
@@ -1663,9 +1657,6 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel):
1663
  "position_ids": position_ids,
1664
  "past_key_values": past_key_values,
1665
  "use_cache": kwargs.get("use_cache"),
1666
- "output_attentions": kwargs.get("output_attentions"),
1667
- "output_hidden_states": kwargs.get("output_hidden_states"),
1668
- "logits_to_keep": kwargs.get("logits_to_keep"),
1669
  "attention_mask": attention_mask,
1670
  "image_features": image_features,
1671
  }
@@ -1723,4 +1714,4 @@ class Bias(nn.Module):
1723
  self,
1724
  x: torch.Tensor,
1725
  ) -> torch.Tensor:
1726
- return x + self._bias
 
1426
  past_key_values_prev = past_key_values
1427
  past_key_values = Plamo2Cache(self.config)
1428
 
1429
+ # If `past_key_values` is a `DynamicCache` object, it must be empty
1430
+ assert len(past_key_values_prev) == 0
 
 
1431
  assert isinstance(past_key_values, Plamo2Cache)
1432
  past_key_values_length = past_key_values.get_seq_length()
1433
  seq_length_with_past = seq_length_with_past + past_key_values_length
 
1633
  image_features: Optional[torch.Tensor] = None,
1634
  **kwargs: Any,
1635
  ) -> Dict[str, Any]:
1636
+ if past_key_values:
 
 
 
 
1637
  input_ids = input_ids[:, -1:]
1638
  if image_features is not None:
1639
  image_features = image_features[:, -1:, :]
 
1643
  # create position_ids on the fly for batch generation
1644
  position_ids = attention_mask.long().cumsum(-1) - 1
1645
  position_ids.masked_fill_(attention_mask == 0, 1)
1646
+ if past_key_values:
1647
  position_ids = position_ids[:, -1].unsqueeze(-1)
1648
 
1649
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
 
1657
  "position_ids": position_ids,
1658
  "past_key_values": past_key_values,
1659
  "use_cache": kwargs.get("use_cache"),
 
 
 
1660
  "attention_mask": attention_mask,
1661
  "image_features": image_features,
1662
  }
 
1714
  self,
1715
  x: torch.Tensor,
1716
  ) -> torch.Tensor:
1717
+ return x + self._bias