fix-generation
#1
by
echarlaix
HF Staff
- opened
- modeling_chatglm.py +11 -3
modeling_chatglm.py
CHANGED
|
@@ -40,6 +40,9 @@ logger = logging.get_logger(__name__)
|
|
| 40 |
_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
|
| 41 |
_CONFIG_FOR_DOC = "ChatGLMConfig"
|
| 42 |
|
|
|
|
|
|
|
|
|
|
| 43 |
def default_init(cls, *args, **kwargs):
|
| 44 |
return cls(*args, **kwargs)
|
| 45 |
|
|
@@ -809,9 +812,14 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 809 |
standardize_cache_format: bool = False,
|
| 810 |
) -> Dict[str, Any]:
|
| 811 |
# update past_key_values
|
| 812 |
-
|
| 813 |
-
|
| 814 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 815 |
|
| 816 |
# update attention mask
|
| 817 |
if "attention_mask" in model_kwargs:
|
|
|
|
| 40 |
_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
|
| 41 |
_CONFIG_FOR_DOC = "ChatGLMConfig"
|
| 42 |
|
| 43 |
+
is_transformers_4_42_or_higher = int(transformers.__version__.split(".")[1]) >= 42
|
| 44 |
+
|
| 45 |
+
|
| 46 |
def default_init(cls, *args, **kwargs):
|
| 47 |
return cls(*args, **kwargs)
|
| 48 |
|
|
|
|
| 812 |
standardize_cache_format: bool = False,
|
| 813 |
) -> Dict[str, Any]:
|
| 814 |
# update past_key_values
|
| 815 |
+
if is_transformers_4_42_or_higher:
|
| 816 |
+
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
| 817 |
+
outputs, standardize_cache_format=standardize_cache_format
|
| 818 |
+
)[1]
|
| 819 |
+
else:
|
| 820 |
+
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
| 821 |
+
outputs, standardize_cache_format=standardize_cache_format
|
| 822 |
+
)
|
| 823 |
|
| 824 |
# update attention mask
|
| 825 |
if "attention_mask" in model_kwargs:
|