test
#12
by
Alnusjaponica
- opened
- README.md +3 -3
- modeling_plamo.py +5 -14
README.md
CHANGED
|
@@ -27,9 +27,9 @@ PLaMo 2 1B is released under Apache License version 2.0.
|
|
| 27 |
```
|
| 28 |
numpy>=1.26.4
|
| 29 |
numba>=0.60.0
|
| 30 |
-
torch
|
| 31 |
-
transformers>=4.44.2
|
| 32 |
-
mamba_ssm>=2.2.2
|
| 33 |
causal_conv1d>=1.4.0
|
| 34 |
```
|
| 35 |
|
|
|
|
| 27 |
```
|
| 28 |
numpy>=1.26.4
|
| 29 |
numba>=0.60.0
|
| 30 |
+
torch>=2.4.1
|
| 31 |
+
transformers>=4.44.2
|
| 32 |
+
mamba_ssm>=2.2.2
|
| 33 |
causal_conv1d>=1.4.0
|
| 34 |
```
|
| 35 |
|
modeling_plamo.py
CHANGED
|
@@ -1426,10 +1426,8 @@ class Plamo2Model(Plamo2PreTrainedModel):
|
|
| 1426 |
past_key_values_prev = past_key_values
|
| 1427 |
past_key_values = Plamo2Cache(self.config)
|
| 1428 |
|
| 1429 |
-
# If `past_key_values` is a `DynamicCache` object, it must be empty
|
| 1430 |
-
assert len(past_key_values_prev) == 0
|
| 1431 |
-
layer_cache.get_seq_length() for layer_cache in past_key_values_prev.layers
|
| 1432 |
-
)
|
| 1433 |
assert isinstance(past_key_values, Plamo2Cache)
|
| 1434 |
past_key_values_length = past_key_values.get_seq_length()
|
| 1435 |
seq_length_with_past = seq_length_with_past + past_key_values_length
|
|
@@ -1635,11 +1633,7 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel):
|
|
| 1635 |
image_features: Optional[torch.Tensor] = None,
|
| 1636 |
**kwargs: Any,
|
| 1637 |
) -> Dict[str, Any]:
|
| 1638 |
-
|
| 1639 |
-
# and its length becomes non-zero from v4.56 onward.
|
| 1640 |
-
# `Plamo2Model.forward` converts it into a `Plamo2Cache` on the first call,
|
| 1641 |
-
# se we use the type of `past_key_values` to distinguish between the prefill and decode stages.
|
| 1642 |
-
if isinstance(past_key_values, Plamo2Cache):
|
| 1643 |
input_ids = input_ids[:, -1:]
|
| 1644 |
if image_features is not None:
|
| 1645 |
image_features = image_features[:, -1:, :]
|
|
@@ -1649,7 +1643,7 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel):
|
|
| 1649 |
# create position_ids on the fly for batch generation
|
| 1650 |
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 1651 |
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 1652 |
-
if
|
| 1653 |
position_ids = position_ids[:, -1].unsqueeze(-1)
|
| 1654 |
|
| 1655 |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
|
@@ -1663,9 +1657,6 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel):
|
|
| 1663 |
"position_ids": position_ids,
|
| 1664 |
"past_key_values": past_key_values,
|
| 1665 |
"use_cache": kwargs.get("use_cache"),
|
| 1666 |
-
"output_attentions": kwargs.get("output_attentions"),
|
| 1667 |
-
"output_hidden_states": kwargs.get("output_hidden_states"),
|
| 1668 |
-
"logits_to_keep": kwargs.get("logits_to_keep"),
|
| 1669 |
"attention_mask": attention_mask,
|
| 1670 |
"image_features": image_features,
|
| 1671 |
}
|
|
@@ -1723,4 +1714,4 @@ class Bias(nn.Module):
|
|
| 1723 |
self,
|
| 1724 |
x: torch.Tensor,
|
| 1725 |
) -> torch.Tensor:
|
| 1726 |
-
return x + self._bias
|
|
|
|
| 1426 |
past_key_values_prev = past_key_values
|
| 1427 |
past_key_values = Plamo2Cache(self.config)
|
| 1428 |
|
| 1429 |
+
# If `past_key_values` is a `DynamicCache` object, it must be empty
|
| 1430 |
+
assert len(past_key_values_prev) == 0
|
|
|
|
|
|
|
| 1431 |
assert isinstance(past_key_values, Plamo2Cache)
|
| 1432 |
past_key_values_length = past_key_values.get_seq_length()
|
| 1433 |
seq_length_with_past = seq_length_with_past + past_key_values_length
|
|
|
|
| 1633 |
image_features: Optional[torch.Tensor] = None,
|
| 1634 |
**kwargs: Any,
|
| 1635 |
) -> Dict[str, Any]:
|
| 1636 |
+
if past_key_values:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1637 |
input_ids = input_ids[:, -1:]
|
| 1638 |
if image_features is not None:
|
| 1639 |
image_features = image_features[:, -1:, :]
|
|
|
|
| 1643 |
# create position_ids on the fly for batch generation
|
| 1644 |
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 1645 |
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 1646 |
+
if past_key_values:
|
| 1647 |
position_ids = position_ids[:, -1].unsqueeze(-1)
|
| 1648 |
|
| 1649 |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
|
|
|
| 1657 |
"position_ids": position_ids,
|
| 1658 |
"past_key_values": past_key_values,
|
| 1659 |
"use_cache": kwargs.get("use_cache"),
|
|
|
|
|
|
|
|
|
|
| 1660 |
"attention_mask": attention_mask,
|
| 1661 |
"image_features": image_features,
|
| 1662 |
}
|
|
|
|
| 1714 |
self,
|
| 1715 |
x: torch.Tensor,
|
| 1716 |
) -> torch.Tensor:
|
| 1717 |
+
return x + self._bias
|