oweller2 commited on
Commit ·
1f61dbc
1
Parent(s): 3cd62e1
update
Browse files- modeling_flexbert.py +19 -22
modeling_flexbert.py
CHANGED
|
@@ -1708,28 +1708,25 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
|
|
| 1708 |
attentions=None,
|
| 1709 |
)
|
| 1710 |
|
| 1711 |
-
def prepare_inputs_for_generation(
|
| 1712 |
-
|
| 1713 |
-
|
| 1714 |
-
|
| 1715 |
-
|
| 1716 |
-
|
| 1717 |
-
|
| 1718 |
-
|
| 1719 |
-
|
| 1720 |
-
|
| 1721 |
-
|
| 1722 |
-
|
| 1723 |
-
|
| 1724 |
-
|
| 1725 |
-
|
| 1726 |
-
|
| 1727 |
-
|
| 1728 |
-
|
| 1729 |
-
|
| 1730 |
-
input_ids = torch.cat([input_ids, dummy_token], dim=1)
|
| 1731 |
-
|
| 1732 |
-
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
| 1733 |
|
| 1734 |
def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
|
| 1735 |
"""Returns the number of parameters in the model.
|
|
|
|
| 1708 |
attentions=None,
|
| 1709 |
)
|
| 1710 |
|
| 1711 |
+
def prepare_inputs_for_generation(
|
| 1712 |
+
self,
|
| 1713 |
+
input_ids: torch.Tensor,
|
| 1714 |
+
past_key_values: Optional[torch.FloatTensor] = None,
|
| 1715 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1716 |
+
**kwargs
|
| 1717 |
+
) -> dict:
|
| 1718 |
+
# only last token for inputs if past is defined
|
| 1719 |
+
if past_key_values is not None:
|
| 1720 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
| 1721 |
+
if attention_mask is not None:
|
| 1722 |
+
attention_mask = attention_mask[:, -1:]
|
| 1723 |
+
|
| 1724 |
+
return {
|
| 1725 |
+
"input_ids": input_ids,
|
| 1726 |
+
"past_key_values": past_key_values,
|
| 1727 |
+
"use_cache": kwargs.get("use_cache", True),
|
| 1728 |
+
"attention_mask": attention_mask,
|
| 1729 |
+
}
|
|
|
|
|
|
|
|
|
|
| 1730 |
|
| 1731 |
def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
|
| 1732 |
"""Returns the number of parameters in the model.
|