Update modeling_chatglm.py
Browse files- modeling_chatglm.py +13 -3
modeling_chatglm.py
CHANGED
|
@@ -455,7 +455,7 @@ class SelfAttention(torch.nn.Module):
|
|
| 455 |
|
| 456 |
def _config_to_kwargs(args):
|
| 457 |
common_kwargs = {
|
| 458 |
-
"dtype": args.torch_dtype,
|
| 459 |
}
|
| 460 |
return common_kwargs
|
| 461 |
|
|
@@ -746,7 +746,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 746 |
init_method = default_init
|
| 747 |
init_kwargs = {}
|
| 748 |
if device is not None:
|
| 749 |
-
init_kwargs["device"] = device
|
| 750 |
self.embedding = init_method(Embedding, config, **init_kwargs)
|
| 751 |
self.num_layers = config.num_layers
|
| 752 |
self.multi_query_group_num = config.multi_query_group_num
|
|
@@ -868,6 +868,16 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 868 |
if self.config.quantization_bit:
|
| 869 |
self.quantize(self.config.quantization_bit, empty_init=True)
|
| 870 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 871 |
def _update_model_kwargs_for_generation(
|
| 872 |
self,
|
| 873 |
outputs: ModelOutput,
|
|
@@ -1300,4 +1310,4 @@ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
|
|
| 1300 |
past_key_values=transformer_outputs.past_key_values,
|
| 1301 |
hidden_states=transformer_outputs.hidden_states,
|
| 1302 |
attentions=transformer_outputs.attentions,
|
| 1303 |
-
)
|
|
|
|
| 455 |
|
| 456 |
def _config_to_kwargs(args):
|
| 457 |
common_kwargs = {
|
| 458 |
+
"dtype": args.torch_dtype if not isinstance(args.torch_dtype, str) else getattr(torch, args.torch_dtype)
|
| 459 |
}
|
| 460 |
return common_kwargs
|
| 461 |
|
|
|
|
| 746 |
init_method = default_init
|
| 747 |
init_kwargs = {}
|
| 748 |
if device is not None:
|
| 749 |
+
init_kwargs["device"] = device if not isinstance(device, str) else torch.device(device)
|
| 750 |
self.embedding = init_method(Embedding, config, **init_kwargs)
|
| 751 |
self.num_layers = config.num_layers
|
| 752 |
self.multi_query_group_num = config.multi_query_group_num
|
|
|
|
| 868 |
if self.config.quantization_bit:
|
| 869 |
self.quantize(self.config.quantization_bit, empty_init=True)
|
| 870 |
|
| 871 |
+
|
| 872 |
+
@staticmethod
|
| 873 |
+
def _extract_past_from_model_output(outputs: ModelOutput, *args, **kwargs):
|
| 874 |
+
past_key_values = None
|
| 875 |
+
if "past_key_values" in outputs:
|
| 876 |
+
past_key_values = outputs.past_key_values
|
| 877 |
+
if is_transformers_4_42_or_higher:
|
| 878 |
+
return None, past_key_values
|
| 879 |
+
return past_key_values
|
| 880 |
+
|
| 881 |
def _update_model_kwargs_for_generation(
|
| 882 |
self,
|
| 883 |
outputs: ModelOutput,
|
|
|
|
| 1310 |
past_key_values=transformer_outputs.past_key_values,
|
| 1311 |
hidden_states=transformer_outputs.hidden_states,
|
| 1312 |
attentions=transformer_outputs.attentions,
|
| 1313 |
+
)
|