适配新版transformers | adapt transformers update (https://github.com/huggingface/transformers/pull/31116)
#58
by
HibernantBear
- opened
- modeling_chatglm.py +10 -1
modeling_chatglm.py
CHANGED
|
@@ -936,9 +936,18 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 936 |
standardize_cache_format: bool = False,
|
| 937 |
) -> Dict[str, Any]:
|
| 938 |
# update past_key_values
|
| 939 |
-
|
| 940 |
outputs, standardize_cache_format=standardize_cache_format
|
| 941 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 942 |
|
| 943 |
# update attention mask
|
| 944 |
if "attention_mask" in model_kwargs:
|
|
|
|
| 936 |
standardize_cache_format: bool = False,
|
| 937 |
) -> Dict[str, Any]:
|
| 938 |
# update past_key_values
|
| 939 |
+
past_output = self._extract_past_from_model_output(
|
| 940 |
outputs, standardize_cache_format=standardize_cache_format
|
| 941 |
)
|
| 942 |
+
# adapt transformers update (https://github.com/huggingface/transformers/pull/31116)
|
| 943 |
+
if(type(past_output) is tuple and type(past_output[0]) is str):
|
| 944 |
+
if past_output[0]=="past_key_values":
|
| 945 |
+
model_kwargs["past_key_values"] = past_output[1]
|
| 946 |
+
else:
|
| 947 |
+
model_kwargs["past_key_values"] = None
|
| 948 |
+
print(f"WARN: Get \"{past_output[0]}\" during self._extract_past_from_model_output, not \"past_key_values\"")
|
| 949 |
+
else:
|
| 950 |
+
model_kwargs["past_key_values"] = past_output
|
| 951 |
|
| 952 |
# update attention mask
|
| 953 |
if "attention_mask" in model_kwargs:
|