Support `transformers>=4.56`
#11
by
Alnusjaponica
- opened
- modeling_plamo.py +11 -5
modeling_plamo.py
CHANGED
|
@@ -1426,8 +1426,10 @@ class Plamo2Model(Plamo2PreTrainedModel):
|
|
| 1426 |
past_key_values_prev = past_key_values
|
| 1427 |
past_key_values = Plamo2Cache(self.config)
|
| 1428 |
|
| 1429 |
-
# If `past_key_values` is a `DynamicCache` object, it must be empty
|
| 1430 |
-
assert len(past_key_values_prev) == 0
|
|
|
|
|
|
|
| 1431 |
assert isinstance(past_key_values, Plamo2Cache)
|
| 1432 |
past_key_values_length = past_key_values.get_seq_length()
|
| 1433 |
seq_length_with_past = seq_length_with_past + past_key_values_length
|
|
@@ -1633,7 +1635,11 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel):
|
|
| 1633 |
image_features: Optional[torch.Tensor] = None,
|
| 1634 |
**kwargs: Any,
|
| 1635 |
) -> Dict[str, Any]:
|
| 1636 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1637 |
input_ids = input_ids[:, -1:]
|
| 1638 |
if image_features is not None:
|
| 1639 |
image_features = image_features[:, -1:, :]
|
|
@@ -1643,7 +1649,7 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel):
|
|
| 1643 |
# create position_ids on the fly for batch generation
|
| 1644 |
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 1645 |
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 1646 |
-
if past_key_values:
|
| 1647 |
position_ids = position_ids[:, -1].unsqueeze(-1)
|
| 1648 |
|
| 1649 |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
|
@@ -1714,4 +1720,4 @@ class Bias(nn.Module):
|
|
| 1714 |
self,
|
| 1715 |
x: torch.Tensor,
|
| 1716 |
) -> torch.Tensor:
|
| 1717 |
-
return x + self._bias
|
|
|
|
| 1426 |
past_key_values_prev = past_key_values
|
| 1427 |
past_key_values = Plamo2Cache(self.config)
|
| 1428 |
|
| 1429 |
+
# If `past_key_values` is a `DynamicCache` object, it must be empty or all layer caches have zero sequence length.
|
| 1430 |
+
assert len(past_key_values_prev) == 0 or not any(
|
| 1431 |
+
layer_cache.get_seq_length() for layer_cache in past_key_values_prev.layers
|
| 1432 |
+
)
|
| 1433 |
assert isinstance(past_key_values, Plamo2Cache)
|
| 1434 |
past_key_values_length = past_key_values.get_seq_length()
|
| 1435 |
seq_length_with_past = seq_length_with_past + past_key_values_length
|
|
|
|
| 1635 |
image_features: Optional[torch.Tensor] = None,
|
| 1636 |
**kwargs: Any,
|
| 1637 |
) -> Dict[str, Any]:
|
| 1638 |
+
# Starting from transformers v4.54, `DynamicCache` is passed to `past_key_values` during the prefill stage,
|
| 1639 |
+
# and its length becomes non-zero from v4.56 onward.
|
| 1640 |
+
# `Plamo2Model.forward` converts it into a `Plamo2Cache` on the first call,
|
| 1641 |
+
# se we use the type of `past_key_values` to distinguish between the prefill and decode stages.
|
| 1642 |
+
if isinstance(past_key_values, Plamo2Cache):
|
| 1643 |
input_ids = input_ids[:, -1:]
|
| 1644 |
if image_features is not None:
|
| 1645 |
image_features = image_features[:, -1:, :]
|
|
|
|
| 1649 |
# create position_ids on the fly for batch generation
|
| 1650 |
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 1651 |
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 1652 |
+
if isinstance(past_key_values, Plamo2Cache):
|
| 1653 |
position_ids = position_ids[:, -1].unsqueeze(-1)
|
| 1654 |
|
| 1655 |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
|
|
|
| 1720 |
self,
|
| 1721 |
x: torch.Tensor,
|
| 1722 |
) -> torch.Tensor:
|
| 1723 |
+
return x + self._bias
|