belztjti commited on
Commit
5f7087e
·
verified ·
1 Parent(s): d68f70c

Update modeling_chatglm.py

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +211 -8
modeling_chatglm.py CHANGED
@@ -1,14 +1,19 @@
1
  """ PyTorch ChatGLM model. """
2
-
3
  import math
 
 
 
4
  import sys
 
5
  import torch
6
  import torch.utils.checkpoint
7
  import torch.nn.functional as F
8
  from torch import nn
9
  from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
10
  from torch.nn.utils import skip_init
11
- from typing import Optional, Tuple, Union, List, Dict, Any
 
12
 
13
  from transformers.modeling_outputs import (
14
  BaseModelOutputWithPast,
@@ -18,19 +23,19 @@ from transformers.modeling_outputs import (
18
  from transformers.modeling_utils import PreTrainedModel
19
  from transformers.utils import logging, is_torch_npu_available
20
  from transformers.generation.logits_process import LogitsProcessor
21
- from transformers.generation.utils import ModelOutput
22
 
23
  from .configuration_chatglm import ChatGLMConfig
24
 
25
  try:
26
  from transformers.utils import is_flash_attn_greater_or_equal_2_10, is_flash_attn_2_available
27
-
28
  if is_flash_attn_2_available():
29
  from flash_attn import flash_attn_func, flash_attn_varlen_func
30
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
31
  except:
32
  pass
33
 
 
34
  # flags required to enable jit fusion kernels
35
 
36
  if sys.platform != 'darwin' and not is_torch_npu_available():
@@ -349,8 +354,7 @@ class FlashAttention2(CoreAttention):
349
  )
350
  if query_length == kv_seq_len:
351
  query_layer = index_first_axis(
352
- query_layer.reshape(batch_size * kv_seq_len, self.num_attention_heads_per_partition, head_dim),
353
- indices_k
354
  )
355
  cu_seqlens_q = cu_seqlens_k
356
  max_seqlen_in_batch_q = max_seqlen_in_batch_k
@@ -793,6 +797,11 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
793
  position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
794
  return position_ids
795
 
 
 
 
 
 
796
  class Embedding(torch.nn.Module):
797
  """Language model embeddings."""
798
 
@@ -927,10 +936,9 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
927
  standardize_cache_format: bool = False,
928
  ) -> Dict[str, Any]:
929
  # update past_key_values
930
- cache_name, cache = self._extract_past_from_model_output(
931
  outputs, standardize_cache_format=standardize_cache_format
932
  )
933
- model_kwargs[cache_name] = cache
934
 
935
  # update attention mask
936
  if "attention_mask" in model_kwargs:
@@ -1055,6 +1063,201 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1055
  for layer_past in past
1056
  )
1057
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1058
 
1059
  class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
1060
  def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
 
1
  """ PyTorch ChatGLM model. """
2
+ import json
3
  import math
4
+ import copy
5
+ import warnings
6
+ import re
7
  import sys
8
+
9
  import torch
10
  import torch.utils.checkpoint
11
  import torch.nn.functional as F
12
  from torch import nn
13
  from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
14
  from torch.nn.utils import skip_init
15
+ from typing import Optional, Tuple, Union, List, Callable, Dict, Any
16
+ from copy import deepcopy
17
 
18
  from transformers.modeling_outputs import (
19
  BaseModelOutputWithPast,
 
23
  from transformers.modeling_utils import PreTrainedModel
24
  from transformers.utils import logging, is_torch_npu_available
25
  from transformers.generation.logits_process import LogitsProcessor
26
+ from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
27
 
28
  from .configuration_chatglm import ChatGLMConfig
29
 
30
  try:
31
  from transformers.utils import is_flash_attn_greater_or_equal_2_10, is_flash_attn_2_available
 
32
  if is_flash_attn_2_available():
33
  from flash_attn import flash_attn_func, flash_attn_varlen_func
34
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
35
  except:
36
  pass
37
 
38
+
39
  # flags required to enable jit fusion kernels
40
 
41
  if sys.platform != 'darwin' and not is_torch_npu_available():
 
354
  )
355
  if query_length == kv_seq_len:
356
  query_layer = index_first_axis(
357
+ query_layer.reshape(batch_size * kv_seq_len, self.num_attention_heads_per_partition, head_dim), indices_k
 
358
  )
359
  cu_seqlens_q = cu_seqlens_k
360
  max_seqlen_in_batch_q = max_seqlen_in_batch_k
 
797
  position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
798
  return position_ids
799
 
800
+ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
801
+ if not self.supports_gradient_checkpointing:
802
+ raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
803
+
804
+
805
  class Embedding(torch.nn.Module):
806
  """Language model embeddings."""
807
 
 
936
  standardize_cache_format: bool = False,
937
  ) -> Dict[str, Any]:
938
  # update past_key_values
939
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
940
  outputs, standardize_cache_format=standardize_cache_format
941
  )
 
942
 
943
  # update attention mask
944
  if "attention_mask" in model_kwargs:
 
1063
  for layer_past in past
1064
  )
1065
 
1066
+ def process_response(self, output, history):
1067
+ content = ""
1068
+ history = deepcopy(history)
1069
+ for response in output.split("<|assistant|>"):
1070
+ if "\n" in response:
1071
+ metadata, content = response.split("\n", maxsplit=1)
1072
+ else:
1073
+ metadata, content = "", response
1074
+ if not metadata.strip():
1075
+ content = content.strip()
1076
+ history.append({"role": "assistant", "metadata": metadata, "content": content})
1077
+ content = content.replace("[[训练时间]]", "2023年")
1078
+ else:
1079
+ history.append({"role": "assistant", "metadata": metadata, "content": content})
1080
+ if history[0]["role"] == "system" and "tools" in history[0]:
1081
+ parameters = json.loads(content)
1082
+ content = {"name": metadata.strip(), "parameters": parameters}
1083
+ else:
1084
+ content = {"name": metadata.strip(), "content": content}
1085
+ return content, history
1086
+
1087
+ @torch.inference_mode()
1088
+ def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
1089
+ max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
1090
+ **kwargs):
1091
+ if history is None:
1092
+ history = []
1093
+ if logits_processor is None:
1094
+ logits_processor = LogitsProcessorList()
1095
+ logits_processor.append(InvalidScoreLogitsProcessor())
1096
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1097
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1098
+ history.append({"role": role, "content": query})
1099
+ inputs = tokenizer.apply_chat_template(history, add_generation_prompt=True, tokenize=True,
1100
+ return_tensors="pt", return_dict=True)
1101
+ inputs = inputs.to(self.device)
1102
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|user|>"),
1103
+ tokenizer.convert_tokens_to_ids("<|observation|>")]
1104
+ outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
1105
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
1106
+ response = tokenizer.decode(outputs)
1107
+ response, history = self.process_response(response, history)
1108
+ return response, history
1109
+
1110
+ @torch.inference_mode()
1111
+ def stream_chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
1112
+ past_key_values=None, max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
1113
+ logits_processor=None, return_past_key_values=False, **kwargs):
1114
+ if history is None:
1115
+ history = []
1116
+ if logits_processor is None:
1117
+ logits_processor = LogitsProcessorList()
1118
+ logits_processor.append(InvalidScoreLogitsProcessor())
1119
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|user|>"),
1120
+ tokenizer.convert_tokens_to_ids("<|observation|>")]
1121
+ gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
1122
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1123
+ if past_key_values is None:
1124
+ inputs = tokenizer.apply_chat_template(history + [{"role": role, "content": query}],
1125
+ add_generation_prompt=True, tokenize=True, return_tensors="pt",
1126
+ return_dict=True)
1127
+ else:
1128
+ inputs = tokenizer.apply_chat_template([{"role": role, "content": query}], add_special_tokens=False,
1129
+ add_generation_prompt=True, tokenize=True, return_tensors="pt",
1130
+ return_dict=True)
1131
+ inputs = inputs.to(self.device)
1132
+ if past_key_values is not None:
1133
+ past_length = past_key_values[0][0].shape[2]
1134
+ inputs.position_ids += past_length
1135
+ attention_mask = inputs.attention_mask
1136
+ attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
1137
+ inputs['attention_mask'] = attention_mask
1138
+ history.append({"role": role, "content": query})
1139
+ for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
1140
+ eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,
1141
+ **gen_kwargs):
1142
+ if return_past_key_values:
1143
+ outputs, past_key_values = outputs
1144
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
1145
+ response = tokenizer.decode(outputs)
1146
+ if response and response[-1] != "�":
1147
+ response, new_history = self.process_response(response, history)
1148
+ if return_past_key_values:
1149
+ yield response, new_history, past_key_values
1150
+ else:
1151
+ yield response, new_history
1152
+
1153
+ @torch.inference_mode()
1154
+ def stream_generate(
1155
+ self,
1156
+ input_ids,
1157
+ generation_config: Optional[GenerationConfig] = None,
1158
+ logits_processor: Optional[LogitsProcessorList] = None,
1159
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
1160
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
1161
+ return_past_key_values=False,
1162
+ **kwargs,
1163
+ ):
1164
+ batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
1165
+
1166
+ if generation_config is None:
1167
+ generation_config = self.generation_config
1168
+ generation_config = copy.deepcopy(generation_config)
1169
+ model_kwargs = generation_config.update(**kwargs)
1170
+ model_kwargs["use_cache"] = generation_config.use_cache
1171
+ bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
1172
+
1173
+ if isinstance(eos_token_id, int):
1174
+ eos_token_id = [eos_token_id]
1175
+ eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
1176
+
1177
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
1178
+ if has_default_max_length and generation_config.max_new_tokens is None:
1179
+ warnings.warn(
1180
+ f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
1181
+ "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
1182
+ " recommend using `max_new_tokens` to control the maximum length of the generation.",
1183
+ UserWarning,
1184
+ )
1185
+ elif generation_config.max_new_tokens is not None:
1186
+ generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
1187
+ if not has_default_max_length:
1188
+ logger.warn(
1189
+ f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
1190
+ f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
1191
+ "Please refer to the documentation for more information. "
1192
+ "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
1193
+ UserWarning,
1194
+ )
1195
+
1196
+ if input_ids_seq_length >= generation_config.max_length:
1197
+ input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
1198
+ logger.warning(
1199
+ f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
1200
+ f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
1201
+ " increasing `max_new_tokens`."
1202
+ )
1203
+
1204
+ # 2. Set generation parameters if not already defined
1205
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
1206
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
1207
+
1208
+ logits_processor = self._get_logits_processor(
1209
+ generation_config=generation_config,
1210
+ input_ids_seq_length=input_ids_seq_length,
1211
+ encoder_input_ids=input_ids,
1212
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
1213
+ logits_processor=logits_processor,
1214
+ )
1215
+
1216
+ stopping_criteria = self._get_stopping_criteria(
1217
+ generation_config=generation_config, stopping_criteria=stopping_criteria
1218
+ )
1219
+ logits_warper = self._get_logits_warper(generation_config)
1220
+
1221
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
1222
+ scores = None
1223
+ while True:
1224
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1225
+ # forward pass to get next token
1226
+ outputs = self(
1227
+ **model_inputs,
1228
+ return_dict=True,
1229
+ output_attentions=False,
1230
+ output_hidden_states=False,
1231
+ )
1232
+
1233
+ next_token_logits = outputs.logits[:, -1, :]
1234
+
1235
+ # pre-process distribution
1236
+ next_token_scores = logits_processor(input_ids, next_token_logits)
1237
+ next_token_scores = logits_warper(input_ids, next_token_scores)
1238
+
1239
+ # sample
1240
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
1241
+ if generation_config.do_sample:
1242
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
1243
+ else:
1244
+ next_tokens = torch.argmax(probs, dim=-1)
1245
+ # update generated ids, model inputs, and length for next step
1246
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
1247
+ model_kwargs = self._update_model_kwargs_for_generation(
1248
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
1249
+ )
1250
+ unfinished_sequences = unfinished_sequences.mul(
1251
+ next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
1252
+ )
1253
+ if return_past_key_values:
1254
+ yield input_ids, outputs.past_key_values
1255
+ else:
1256
+ yield input_ids
1257
+ # stop when each sentence is finished, or if we exceed the maximum length
1258
+ if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
1259
+ break
1260
+
1261
 
1262
  class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
1263
  def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):