Update modeling_chatglm.py
Browse files- modeling_chatglm.py +12 -3
modeling_chatglm.py
CHANGED
|
@@ -422,7 +422,7 @@ class SelfAttention(torch.nn.Module):
|
|
| 422 |
|
| 423 |
def _config_to_kwargs(args):
|
| 424 |
common_kwargs = {
|
| 425 |
-
"dtype": args.torch_dtype,
|
| 426 |
}
|
| 427 |
return common_kwargs
|
| 428 |
|
|
@@ -720,7 +720,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 720 |
init_method = default_init
|
| 721 |
init_kwargs = {}
|
| 722 |
if device is not None:
|
| 723 |
-
init_kwargs["device"] = device
|
| 724 |
self.embedding = init_method(Embedding, config, **init_kwargs)
|
| 725 |
self.num_layers = config.num_layers
|
| 726 |
self.multi_query_group_num = config.multi_query_group_num
|
|
@@ -954,6 +954,15 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 954 |
for layer_past in past
|
| 955 |
)
|
| 956 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 957 |
def process_response(self, output, history):
|
| 958 |
content = ""
|
| 959 |
history = deepcopy(history)
|
|
@@ -1231,4 +1240,4 @@ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
|
|
| 1231 |
past_key_values=transformer_outputs.past_key_values,
|
| 1232 |
hidden_states=transformer_outputs.hidden_states,
|
| 1233 |
attentions=transformer_outputs.attentions,
|
| 1234 |
-
)
|
|
|
|
| 422 |
|
| 423 |
def _config_to_kwargs(args):
|
| 424 |
common_kwargs = {
|
| 425 |
+
"dtype": args.torch_dtype if not isinstance(args.torch_dtype, str) else getattr(torch, args.torch_dtype)
|
| 426 |
}
|
| 427 |
return common_kwargs
|
| 428 |
|
|
|
|
| 720 |
init_method = default_init
|
| 721 |
init_kwargs = {}
|
| 722 |
if device is not None:
|
| 723 |
+
init_kwargs["device"] = device if not isinstance(device, str) else torch.device(device)
|
| 724 |
self.embedding = init_method(Embedding, config, **init_kwargs)
|
| 725 |
self.num_layers = config.num_layers
|
| 726 |
self.multi_query_group_num = config.multi_query_group_num
|
|
|
|
| 954 |
for layer_past in past
|
| 955 |
)
|
| 956 |
|
| 957 |
+
@staticmethod
|
| 958 |
+
def _extract_past_from_model_output(outputs: ModelOutput, *args, **kwargs):
|
| 959 |
+
past_key_values = None
|
| 960 |
+
if "past_key_values" in outputs:
|
| 961 |
+
past_key_values = outputs.past_key_values
|
| 962 |
+
if is_transformers_4_42_or_higher:
|
| 963 |
+
return None, past_key_values
|
| 964 |
+
return past_key_values
|
| 965 |
+
|
| 966 |
def process_response(self, output, history):
|
| 967 |
content = ""
|
| 968 |
history = deepcopy(history)
|
|
|
|
| 1240 |
past_key_values=transformer_outputs.past_key_values,
|
| 1241 |
hidden_states=transformer_outputs.hidden_states,
|
| 1242 |
attentions=transformer_outputs.attentions,
|
| 1243 |
+
)
|