Update modeling_chatglm.py
Browse files- modeling_chatglm.py +90 -29
modeling_chatglm.py
CHANGED
|
@@ -157,7 +157,7 @@ class RotaryEmbedding(nn.Module):
|
|
| 157 |
)
|
| 158 |
|
| 159 |
|
| 160 |
-
@torch.jit.script
|
| 161 |
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
|
| 162 |
# x: [sq, b, np, hn]
|
| 163 |
sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
|
|
@@ -223,8 +223,7 @@ class CoreAttention(torch.nn.Module):
|
|
| 223 |
if pytorch_major_version >= 2:
|
| 224 |
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
|
| 225 |
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
|
| 226 |
-
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
|
| 227 |
-
is_causal=True)
|
| 228 |
else:
|
| 229 |
if attention_mask is not None:
|
| 230 |
attention_mask = ~attention_mask
|
|
@@ -237,7 +236,7 @@ class CoreAttention(torch.nn.Module):
|
|
| 237 |
# Raw attention scores
|
| 238 |
|
| 239 |
# [b, np, sq, sk]
|
| 240 |
-
output_size = (query_layer.size(
|
| 241 |
|
| 242 |
# [sq, b, np, hn] -> [sq, b * np, hn]
|
| 243 |
query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
|
|
@@ -312,7 +311,6 @@ class CoreAttention(torch.nn.Module):
|
|
| 312 |
|
| 313 |
class SelfAttention(torch.nn.Module):
|
| 314 |
"""Parallel self-attention layer abstract class.
|
| 315 |
-
|
| 316 |
Self-attention layer takes input with size [s, b, h]
|
| 317 |
and returns output of the same size.
|
| 318 |
"""
|
|
@@ -448,7 +446,6 @@ class SelfAttention(torch.nn.Module):
|
|
| 448 |
|
| 449 |
return output, kv_cache
|
| 450 |
|
| 451 |
-
|
| 452 |
def _config_to_kwargs(args):
|
| 453 |
common_kwargs = {
|
| 454 |
"dtype": args.torch_dtype,
|
|
@@ -504,7 +501,6 @@ class MLP(torch.nn.Module):
|
|
| 504 |
|
| 505 |
class GLMBlock(torch.nn.Module):
|
| 506 |
"""A single transformer layer.
|
| 507 |
-
|
| 508 |
Transformer layer takes input with size [s, b, h] and returns an
|
| 509 |
output of the same size.
|
| 510 |
"""
|
|
@@ -597,7 +593,7 @@ class GLMTransformer(torch.nn.Module):
|
|
| 597 |
if self.post_layer_norm:
|
| 598 |
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
| 599 |
# Final layer norm before output.
|
| 600 |
-
self.
|
| 601 |
dtype=config.torch_dtype)
|
| 602 |
|
| 603 |
self.gradient_checkpointing = False
|
|
@@ -653,7 +649,7 @@ class GLMTransformer(torch.nn.Module):
|
|
| 653 |
|
| 654 |
# Final layer norm.
|
| 655 |
if self.post_layer_norm:
|
| 656 |
-
hidden_states = self.
|
| 657 |
|
| 658 |
return hidden_states, presents, all_hidden_states, all_self_attentions
|
| 659 |
|
|
@@ -740,7 +736,14 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 740 |
init_kwargs = {}
|
| 741 |
if device is not None:
|
| 742 |
init_kwargs["device"] = device
|
| 743 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 744 |
self.num_layers = config.num_layers
|
| 745 |
self.multi_query_group_num = config.multi_query_group_num
|
| 746 |
self.kv_channels = config.kv_channels
|
|
@@ -753,9 +756,21 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 753 |
|
| 754 |
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
|
| 755 |
dtype=config.torch_dtype)
|
| 756 |
-
|
| 757 |
-
|
| 758 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 759 |
self.pre_seq_len = config.pre_seq_len
|
| 760 |
self.prefix_projection = config.prefix_projection
|
| 761 |
if self.pre_seq_len is not None:
|
|
@@ -765,6 +780,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 765 |
self.prefix_encoder = PrefixEncoder(config)
|
| 766 |
self.dropout = torch.nn.Dropout(0.1)
|
| 767 |
|
|
|
|
|
|
|
| 768 |
def get_input_embeddings(self):
|
| 769 |
return self.embedding.word_embeddings
|
| 770 |
|
|
@@ -804,7 +821,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 804 |
batch_size, seq_length = input_ids.shape
|
| 805 |
|
| 806 |
if inputs_embeds is None:
|
| 807 |
-
inputs_embeds = self.
|
| 808 |
|
| 809 |
if self.pre_seq_len is not None:
|
| 810 |
if past_key_values is None:
|
|
@@ -827,10 +844,54 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 827 |
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
| 828 |
|
| 829 |
# Run encoder.
|
| 830 |
-
|
| 831 |
-
|
| 832 |
-
|
| 833 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 834 |
|
| 835 |
if not return_dict:
|
| 836 |
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
|
@@ -844,7 +905,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 844 |
|
| 845 |
def quantize(self, weight_bit_width: int):
|
| 846 |
from .quantization import quantize
|
| 847 |
-
quantize(self
|
| 848 |
return self
|
| 849 |
|
| 850 |
|
|
@@ -853,7 +914,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 853 |
super().__init__(config)
|
| 854 |
|
| 855 |
self.max_sequence_length = config.max_length
|
| 856 |
-
self.
|
|
|
|
| 857 |
self.config = config
|
| 858 |
self.quantized = False
|
| 859 |
|
|
@@ -934,7 +996,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 934 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 935 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 936 |
|
| 937 |
-
transformer_outputs = self.
|
| 938 |
input_ids=input_ids,
|
| 939 |
position_ids=position_ids,
|
| 940 |
attention_mask=attention_mask,
|
|
@@ -948,8 +1010,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 948 |
hidden_states = transformer_outputs[0]
|
| 949 |
if return_last_logit:
|
| 950 |
hidden_states = hidden_states[-1:]
|
| 951 |
-
lm_logits = self.
|
| 952 |
-
lm_logits = lm_logits.transpose(0, 1).contiguous()
|
| 953 |
|
| 954 |
loss = None
|
| 955 |
if labels is not None:
|
|
@@ -1062,8 +1123,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1062 |
inputs = inputs.to(self.device)
|
| 1063 |
if past_key_values is not None:
|
| 1064 |
past_length = past_key_values[0][0].shape[0]
|
| 1065 |
-
if self.
|
| 1066 |
-
past_length -= self.
|
| 1067 |
inputs.position_ids += past_length
|
| 1068 |
attention_mask = inputs.attention_mask
|
| 1069 |
attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
|
|
@@ -1205,7 +1266,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1205 |
|
| 1206 |
self.config.quantization_bit = bits
|
| 1207 |
|
| 1208 |
-
self.
|
| 1209 |
**kwargs)
|
| 1210 |
return self
|
| 1211 |
|
|
@@ -1215,7 +1276,7 @@ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
|
|
| 1215 |
super().__init__(config)
|
| 1216 |
|
| 1217 |
self.num_labels = config.num_labels
|
| 1218 |
-
self.
|
| 1219 |
|
| 1220 |
self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half)
|
| 1221 |
if config.classifier_dropout is not None:
|
|
@@ -1242,7 +1303,7 @@ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
|
|
| 1242 |
) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
|
| 1243 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1244 |
|
| 1245 |
-
transformer_outputs = self.
|
| 1246 |
input_ids=input_ids,
|
| 1247 |
position_ids=position_ids,
|
| 1248 |
attention_mask=attention_mask,
|
|
@@ -1293,4 +1354,4 @@ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
|
|
| 1293 |
past_key_values=transformer_outputs.past_key_values,
|
| 1294 |
hidden_states=transformer_outputs.hidden_states,
|
| 1295 |
attentions=transformer_outputs.attentions,
|
| 1296 |
-
)
|
|
|
|
| 157 |
)
|
| 158 |
|
| 159 |
|
| 160 |
+
# @torch.jit.script
|
| 161 |
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
|
| 162 |
# x: [sq, b, np, hn]
|
| 163 |
sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
|
|
|
|
| 223 |
if pytorch_major_version >= 2:
|
| 224 |
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
|
| 225 |
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
|
| 226 |
+
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,is_causal=True)
|
|
|
|
| 227 |
else:
|
| 228 |
if attention_mask is not None:
|
| 229 |
attention_mask = ~attention_mask
|
|
|
|
| 236 |
# Raw attention scores
|
| 237 |
|
| 238 |
# [b, np, sq, sk]
|
| 239 |
+
output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(0))
|
| 240 |
|
| 241 |
# [sq, b, np, hn] -> [sq, b * np, hn]
|
| 242 |
query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
|
|
|
|
| 311 |
|
| 312 |
class SelfAttention(torch.nn.Module):
|
| 313 |
"""Parallel self-attention layer abstract class.
|
|
|
|
| 314 |
Self-attention layer takes input with size [s, b, h]
|
| 315 |
and returns output of the same size.
|
| 316 |
"""
|
|
|
|
| 446 |
|
| 447 |
return output, kv_cache
|
| 448 |
|
|
|
|
| 449 |
def _config_to_kwargs(args):
|
| 450 |
common_kwargs = {
|
| 451 |
"dtype": args.torch_dtype,
|
|
|
|
| 501 |
|
| 502 |
class GLMBlock(torch.nn.Module):
|
| 503 |
"""A single transformer layer.
|
|
|
|
| 504 |
Transformer layer takes input with size [s, b, h] and returns an
|
| 505 |
output of the same size.
|
| 506 |
"""
|
|
|
|
| 593 |
if self.post_layer_norm:
|
| 594 |
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
| 595 |
# Final layer norm before output.
|
| 596 |
+
self.norm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
| 597 |
dtype=config.torch_dtype)
|
| 598 |
|
| 599 |
self.gradient_checkpointing = False
|
|
|
|
| 649 |
|
| 650 |
# Final layer norm.
|
| 651 |
if self.post_layer_norm:
|
| 652 |
+
hidden_states = self.norm(hidden_states)
|
| 653 |
|
| 654 |
return hidden_states, presents, all_hidden_states, all_self_attentions
|
| 655 |
|
|
|
|
| 736 |
init_kwargs = {}
|
| 737 |
if device is not None:
|
| 738 |
init_kwargs["device"] = device
|
| 739 |
+
|
| 740 |
+
self.embed_tokens = nn.Embedding(
|
| 741 |
+
config.padded_vocab_size,
|
| 742 |
+
config.hidden_size,
|
| 743 |
+
dtype=config.torch_dtype,
|
| 744 |
+
device=device
|
| 745 |
+
)
|
| 746 |
+
|
| 747 |
self.num_layers = config.num_layers
|
| 748 |
self.multi_query_group_num = config.multi_query_group_num
|
| 749 |
self.kv_channels = config.kv_channels
|
|
|
|
| 756 |
|
| 757 |
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
|
| 758 |
dtype=config.torch_dtype)
|
| 759 |
+
|
| 760 |
+
# Transformer layers.
|
| 761 |
+
def build_layer(layer_number):
|
| 762 |
+
return GLMBlock(config, layer_number, device=device)
|
| 763 |
+
|
| 764 |
+
self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
|
| 765 |
+
self.num_layers = config.num_layers
|
| 766 |
+
self.post_layer_norm = config.post_layer_norm
|
| 767 |
+
|
| 768 |
+
if self.post_layer_norm:
|
| 769 |
+
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
| 770 |
+
# Final layer norm before output.
|
| 771 |
+
self.norm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
| 772 |
+
dtype=config.torch_dtype)
|
| 773 |
+
|
| 774 |
self.pre_seq_len = config.pre_seq_len
|
| 775 |
self.prefix_projection = config.prefix_projection
|
| 776 |
if self.pre_seq_len is not None:
|
|
|
|
| 780 |
self.prefix_encoder = PrefixEncoder(config)
|
| 781 |
self.dropout = torch.nn.Dropout(0.1)
|
| 782 |
|
| 783 |
+
self.gradient_checkpointing = False
|
| 784 |
+
|
| 785 |
def get_input_embeddings(self):
|
| 786 |
return self.embedding.word_embeddings
|
| 787 |
|
|
|
|
| 821 |
batch_size, seq_length = input_ids.shape
|
| 822 |
|
| 823 |
if inputs_embeds is None:
|
| 824 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 825 |
|
| 826 |
if self.pre_seq_len is not None:
|
| 827 |
if past_key_values is None:
|
|
|
|
| 844 |
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
| 845 |
|
| 846 |
# Run encoder.
|
| 847 |
+
if not past_key_values:
|
| 848 |
+
past_key_values = [None for _ in range(self.num_layers)]
|
| 849 |
+
presents = () if use_cache else None
|
| 850 |
+
if self.gradient_checkpointing and self.training:
|
| 851 |
+
if use_cache:
|
| 852 |
+
logger.warning_once(
|
| 853 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 854 |
+
)
|
| 855 |
+
use_cache = False
|
| 856 |
+
|
| 857 |
+
all_self_attentions = None
|
| 858 |
+
all_hidden_states = () if output_hidden_states else None
|
| 859 |
+
|
| 860 |
+
hidden_states = inputs_embeds
|
| 861 |
+
# To comply with former chat-glm format that expects (seqlen, bs, hd)
|
| 862 |
+
hidden_states = hidden_states.permute(1, 0, 2)
|
| 863 |
+
|
| 864 |
+
for index, layer in enumerate(self.layers):
|
| 865 |
+
if output_hidden_states:
|
| 866 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 867 |
+
|
| 868 |
+
if self.gradient_checkpointing and self.training:
|
| 869 |
+
layer_ret = torch.utils.checkpoint.checkpoint(
|
| 870 |
+
layer,
|
| 871 |
+
hidden_states,
|
| 872 |
+
full_attention_mask,
|
| 873 |
+
rotary_pos_emb,
|
| 874 |
+
past_key_values[index],
|
| 875 |
+
use_cache
|
| 876 |
+
)
|
| 877 |
+
else:
|
| 878 |
+
layer_ret = layer(
|
| 879 |
+
hidden_states,
|
| 880 |
+
full_attention_mask,
|
| 881 |
+
rotary_pos_emb,
|
| 882 |
+
kv_cache=past_key_values[index],
|
| 883 |
+
use_cache=use_cache
|
| 884 |
+
)
|
| 885 |
+
hidden_states, kv_cache = layer_ret
|
| 886 |
+
if use_cache:
|
| 887 |
+
presents = presents + (kv_cache,)
|
| 888 |
+
|
| 889 |
+
if output_hidden_states:
|
| 890 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 891 |
+
|
| 892 |
+
# Final layer norm.
|
| 893 |
+
if self.post_layer_norm:
|
| 894 |
+
hidden_states = self.norm(hidden_states)
|
| 895 |
|
| 896 |
if not return_dict:
|
| 897 |
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
|
|
|
| 905 |
|
| 906 |
def quantize(self, weight_bit_width: int):
|
| 907 |
from .quantization import quantize
|
| 908 |
+
quantize(self, weight_bit_width)
|
| 909 |
return self
|
| 910 |
|
| 911 |
|
|
|
|
| 914 |
super().__init__(config)
|
| 915 |
|
| 916 |
self.max_sequence_length = config.max_length
|
| 917 |
+
self.model = ChatGLMModel(config, empty_init=empty_init, device=device)
|
| 918 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 919 |
self.config = config
|
| 920 |
self.quantized = False
|
| 921 |
|
|
|
|
| 996 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 997 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 998 |
|
| 999 |
+
transformer_outputs = self.model(
|
| 1000 |
input_ids=input_ids,
|
| 1001 |
position_ids=position_ids,
|
| 1002 |
attention_mask=attention_mask,
|
|
|
|
| 1010 |
hidden_states = transformer_outputs[0]
|
| 1011 |
if return_last_logit:
|
| 1012 |
hidden_states = hidden_states[-1:]
|
| 1013 |
+
lm_logits = self.lm_head(hidden_states)
|
|
|
|
| 1014 |
|
| 1015 |
loss = None
|
| 1016 |
if labels is not None:
|
|
|
|
| 1123 |
inputs = inputs.to(self.device)
|
| 1124 |
if past_key_values is not None:
|
| 1125 |
past_length = past_key_values[0][0].shape[0]
|
| 1126 |
+
if self.model.pre_seq_len is not None:
|
| 1127 |
+
past_length -= self.model.pre_seq_len
|
| 1128 |
inputs.position_ids += past_length
|
| 1129 |
attention_mask = inputs.attention_mask
|
| 1130 |
attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
|
|
|
|
| 1266 |
|
| 1267 |
self.config.quantization_bit = bits
|
| 1268 |
|
| 1269 |
+
self.model = quantize(self.model, bits, empty_init=empty_init, device=device,
|
| 1270 |
**kwargs)
|
| 1271 |
return self
|
| 1272 |
|
|
|
|
| 1276 |
super().__init__(config)
|
| 1277 |
|
| 1278 |
self.num_labels = config.num_labels
|
| 1279 |
+
self.model = ChatGLMModel(config, empty_init=empty_init, device=device)
|
| 1280 |
|
| 1281 |
self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half)
|
| 1282 |
if config.classifier_dropout is not None:
|
|
|
|
| 1303 |
) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
|
| 1304 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1305 |
|
| 1306 |
+
transformer_outputs = self.model(
|
| 1307 |
input_ids=input_ids,
|
| 1308 |
position_ids=position_ids,
|
| 1309 |
attention_mask=attention_mask,
|
|
|
|
| 1354 |
past_key_values=transformer_outputs.past_key_values,
|
| 1355 |
hidden_states=transformer_outputs.hidden_states,
|
| 1356 |
attentions=transformer_outputs.attentions,
|
| 1357 |
+
)
|