Update modeling_chatglm.py
Browse files- modeling_chatglm.py +211 -8
modeling_chatglm.py
CHANGED
|
@@ -1,14 +1,19 @@
|
|
| 1 |
""" PyTorch ChatGLM model. """
|
| 2 |
-
|
| 3 |
import math
|
|
|
|
|
|
|
|
|
|
| 4 |
import sys
|
|
|
|
| 5 |
import torch
|
| 6 |
import torch.utils.checkpoint
|
| 7 |
import torch.nn.functional as F
|
| 8 |
from torch import nn
|
| 9 |
from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
|
| 10 |
from torch.nn.utils import skip_init
|
| 11 |
-
from typing import Optional, Tuple, Union, List, Dict, Any
|
|
|
|
| 12 |
|
| 13 |
from transformers.modeling_outputs import (
|
| 14 |
BaseModelOutputWithPast,
|
|
@@ -18,19 +23,19 @@ from transformers.modeling_outputs import (
|
|
| 18 |
from transformers.modeling_utils import PreTrainedModel
|
| 19 |
from transformers.utils import logging, is_torch_npu_available
|
| 20 |
from transformers.generation.logits_process import LogitsProcessor
|
| 21 |
-
from transformers.generation.utils import ModelOutput
|
| 22 |
|
| 23 |
from .configuration_chatglm import ChatGLMConfig
|
| 24 |
|
| 25 |
try:
|
| 26 |
from transformers.utils import is_flash_attn_greater_or_equal_2_10, is_flash_attn_2_available
|
| 27 |
-
|
| 28 |
if is_flash_attn_2_available():
|
| 29 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 30 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
| 31 |
except:
|
| 32 |
pass
|
| 33 |
|
|
|
|
| 34 |
# flags required to enable jit fusion kernels
|
| 35 |
|
| 36 |
if sys.platform != 'darwin' and not is_torch_npu_available():
|
|
@@ -349,8 +354,7 @@ class FlashAttention2(CoreAttention):
|
|
| 349 |
)
|
| 350 |
if query_length == kv_seq_len:
|
| 351 |
query_layer = index_first_axis(
|
| 352 |
-
query_layer.reshape(batch_size * kv_seq_len, self.num_attention_heads_per_partition, head_dim),
|
| 353 |
-
indices_k
|
| 354 |
)
|
| 355 |
cu_seqlens_q = cu_seqlens_k
|
| 356 |
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
|
@@ -793,6 +797,11 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
|
| 793 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
| 794 |
return position_ids
|
| 795 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 796 |
class Embedding(torch.nn.Module):
|
| 797 |
"""Language model embeddings."""
|
| 798 |
|
|
@@ -927,10 +936,9 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 927 |
standardize_cache_format: bool = False,
|
| 928 |
) -> Dict[str, Any]:
|
| 929 |
# update past_key_values
|
| 930 |
-
|
| 931 |
outputs, standardize_cache_format=standardize_cache_format
|
| 932 |
)
|
| 933 |
-
model_kwargs[cache_name] = cache
|
| 934 |
|
| 935 |
# update attention mask
|
| 936 |
if "attention_mask" in model_kwargs:
|
|
@@ -1055,6 +1063,201 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1055 |
for layer_past in past
|
| 1056 |
)
|
| 1057 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1058 |
|
| 1059 |
class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
|
| 1060 |
def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
|
|
|
|
| 1 |
""" PyTorch ChatGLM model. """
|
| 2 |
+
import json
|
| 3 |
import math
|
| 4 |
+
import copy
|
| 5 |
+
import warnings
|
| 6 |
+
import re
|
| 7 |
import sys
|
| 8 |
+
|
| 9 |
import torch
|
| 10 |
import torch.utils.checkpoint
|
| 11 |
import torch.nn.functional as F
|
| 12 |
from torch import nn
|
| 13 |
from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
|
| 14 |
from torch.nn.utils import skip_init
|
| 15 |
+
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
|
| 16 |
+
from copy import deepcopy
|
| 17 |
|
| 18 |
from transformers.modeling_outputs import (
|
| 19 |
BaseModelOutputWithPast,
|
|
|
|
| 23 |
from transformers.modeling_utils import PreTrainedModel
|
| 24 |
from transformers.utils import logging, is_torch_npu_available
|
| 25 |
from transformers.generation.logits_process import LogitsProcessor
|
| 26 |
+
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
|
| 27 |
|
| 28 |
from .configuration_chatglm import ChatGLMConfig
|
| 29 |
|
| 30 |
try:
|
| 31 |
from transformers.utils import is_flash_attn_greater_or_equal_2_10, is_flash_attn_2_available
|
|
|
|
| 32 |
if is_flash_attn_2_available():
|
| 33 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 34 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
| 35 |
except:
|
| 36 |
pass
|
| 37 |
|
| 38 |
+
|
| 39 |
# flags required to enable jit fusion kernels
|
| 40 |
|
| 41 |
if sys.platform != 'darwin' and not is_torch_npu_available():
|
|
|
|
| 354 |
)
|
| 355 |
if query_length == kv_seq_len:
|
| 356 |
query_layer = index_first_axis(
|
| 357 |
+
query_layer.reshape(batch_size * kv_seq_len, self.num_attention_heads_per_partition, head_dim), indices_k
|
|
|
|
| 358 |
)
|
| 359 |
cu_seqlens_q = cu_seqlens_k
|
| 360 |
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
|
|
|
| 797 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
| 798 |
return position_ids
|
| 799 |
|
| 800 |
+
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
|
| 801 |
+
if not self.supports_gradient_checkpointing:
|
| 802 |
+
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
| 803 |
+
|
| 804 |
+
|
| 805 |
class Embedding(torch.nn.Module):
|
| 806 |
"""Language model embeddings."""
|
| 807 |
|
|
|
|
| 936 |
standardize_cache_format: bool = False,
|
| 937 |
) -> Dict[str, Any]:
|
| 938 |
# update past_key_values
|
| 939 |
+
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
| 940 |
outputs, standardize_cache_format=standardize_cache_format
|
| 941 |
)
|
|
|
|
| 942 |
|
| 943 |
# update attention mask
|
| 944 |
if "attention_mask" in model_kwargs:
|
|
|
|
| 1063 |
for layer_past in past
|
| 1064 |
)
|
| 1065 |
|
| 1066 |
+
def process_response(self, output, history):
|
| 1067 |
+
content = ""
|
| 1068 |
+
history = deepcopy(history)
|
| 1069 |
+
for response in output.split("<|assistant|>"):
|
| 1070 |
+
if "\n" in response:
|
| 1071 |
+
metadata, content = response.split("\n", maxsplit=1)
|
| 1072 |
+
else:
|
| 1073 |
+
metadata, content = "", response
|
| 1074 |
+
if not metadata.strip():
|
| 1075 |
+
content = content.strip()
|
| 1076 |
+
history.append({"role": "assistant", "metadata": metadata, "content": content})
|
| 1077 |
+
content = content.replace("[[训练时间]]", "2023年")
|
| 1078 |
+
else:
|
| 1079 |
+
history.append({"role": "assistant", "metadata": metadata, "content": content})
|
| 1080 |
+
if history[0]["role"] == "system" and "tools" in history[0]:
|
| 1081 |
+
parameters = json.loads(content)
|
| 1082 |
+
content = {"name": metadata.strip(), "parameters": parameters}
|
| 1083 |
+
else:
|
| 1084 |
+
content = {"name": metadata.strip(), "content": content}
|
| 1085 |
+
return content, history
|
| 1086 |
+
|
| 1087 |
+
@torch.inference_mode()
|
| 1088 |
+
def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
|
| 1089 |
+
max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
|
| 1090 |
+
**kwargs):
|
| 1091 |
+
if history is None:
|
| 1092 |
+
history = []
|
| 1093 |
+
if logits_processor is None:
|
| 1094 |
+
logits_processor = LogitsProcessorList()
|
| 1095 |
+
logits_processor.append(InvalidScoreLogitsProcessor())
|
| 1096 |
+
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
|
| 1097 |
+
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
| 1098 |
+
history.append({"role": role, "content": query})
|
| 1099 |
+
inputs = tokenizer.apply_chat_template(history, add_generation_prompt=True, tokenize=True,
|
| 1100 |
+
return_tensors="pt", return_dict=True)
|
| 1101 |
+
inputs = inputs.to(self.device)
|
| 1102 |
+
eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|user|>"),
|
| 1103 |
+
tokenizer.convert_tokens_to_ids("<|observation|>")]
|
| 1104 |
+
outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
|
| 1105 |
+
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
|
| 1106 |
+
response = tokenizer.decode(outputs)
|
| 1107 |
+
response, history = self.process_response(response, history)
|
| 1108 |
+
return response, history
|
| 1109 |
+
|
| 1110 |
+
@torch.inference_mode()
|
| 1111 |
+
def stream_chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
|
| 1112 |
+
past_key_values=None, max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
|
| 1113 |
+
logits_processor=None, return_past_key_values=False, **kwargs):
|
| 1114 |
+
if history is None:
|
| 1115 |
+
history = []
|
| 1116 |
+
if logits_processor is None:
|
| 1117 |
+
logits_processor = LogitsProcessorList()
|
| 1118 |
+
logits_processor.append(InvalidScoreLogitsProcessor())
|
| 1119 |
+
eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|user|>"),
|
| 1120 |
+
tokenizer.convert_tokens_to_ids("<|observation|>")]
|
| 1121 |
+
gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
|
| 1122 |
+
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
| 1123 |
+
if past_key_values is None:
|
| 1124 |
+
inputs = tokenizer.apply_chat_template(history + [{"role": role, "content": query}],
|
| 1125 |
+
add_generation_prompt=True, tokenize=True, return_tensors="pt",
|
| 1126 |
+
return_dict=True)
|
| 1127 |
+
else:
|
| 1128 |
+
inputs = tokenizer.apply_chat_template([{"role": role, "content": query}], add_special_tokens=False,
|
| 1129 |
+
add_generation_prompt=True, tokenize=True, return_tensors="pt",
|
| 1130 |
+
return_dict=True)
|
| 1131 |
+
inputs = inputs.to(self.device)
|
| 1132 |
+
if past_key_values is not None:
|
| 1133 |
+
past_length = past_key_values[0][0].shape[2]
|
| 1134 |
+
inputs.position_ids += past_length
|
| 1135 |
+
attention_mask = inputs.attention_mask
|
| 1136 |
+
attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
|
| 1137 |
+
inputs['attention_mask'] = attention_mask
|
| 1138 |
+
history.append({"role": role, "content": query})
|
| 1139 |
+
for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
|
| 1140 |
+
eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,
|
| 1141 |
+
**gen_kwargs):
|
| 1142 |
+
if return_past_key_values:
|
| 1143 |
+
outputs, past_key_values = outputs
|
| 1144 |
+
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
|
| 1145 |
+
response = tokenizer.decode(outputs)
|
| 1146 |
+
if response and response[-1] != "�":
|
| 1147 |
+
response, new_history = self.process_response(response, history)
|
| 1148 |
+
if return_past_key_values:
|
| 1149 |
+
yield response, new_history, past_key_values
|
| 1150 |
+
else:
|
| 1151 |
+
yield response, new_history
|
| 1152 |
+
|
| 1153 |
+
@torch.inference_mode()
|
| 1154 |
+
def stream_generate(
|
| 1155 |
+
self,
|
| 1156 |
+
input_ids,
|
| 1157 |
+
generation_config: Optional[GenerationConfig] = None,
|
| 1158 |
+
logits_processor: Optional[LogitsProcessorList] = None,
|
| 1159 |
+
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
| 1160 |
+
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
| 1161 |
+
return_past_key_values=False,
|
| 1162 |
+
**kwargs,
|
| 1163 |
+
):
|
| 1164 |
+
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
|
| 1165 |
+
|
| 1166 |
+
if generation_config is None:
|
| 1167 |
+
generation_config = self.generation_config
|
| 1168 |
+
generation_config = copy.deepcopy(generation_config)
|
| 1169 |
+
model_kwargs = generation_config.update(**kwargs)
|
| 1170 |
+
model_kwargs["use_cache"] = generation_config.use_cache
|
| 1171 |
+
bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
|
| 1172 |
+
|
| 1173 |
+
if isinstance(eos_token_id, int):
|
| 1174 |
+
eos_token_id = [eos_token_id]
|
| 1175 |
+
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
| 1176 |
+
|
| 1177 |
+
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
| 1178 |
+
if has_default_max_length and generation_config.max_new_tokens is None:
|
| 1179 |
+
warnings.warn(
|
| 1180 |
+
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
|
| 1181 |
+
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
|
| 1182 |
+
" recommend using `max_new_tokens` to control the maximum length of the generation.",
|
| 1183 |
+
UserWarning,
|
| 1184 |
+
)
|
| 1185 |
+
elif generation_config.max_new_tokens is not None:
|
| 1186 |
+
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
| 1187 |
+
if not has_default_max_length:
|
| 1188 |
+
logger.warn(
|
| 1189 |
+
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
| 1190 |
+
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
| 1191 |
+
"Please refer to the documentation for more information. "
|
| 1192 |
+
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
|
| 1193 |
+
UserWarning,
|
| 1194 |
+
)
|
| 1195 |
+
|
| 1196 |
+
if input_ids_seq_length >= generation_config.max_length:
|
| 1197 |
+
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
| 1198 |
+
logger.warning(
|
| 1199 |
+
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
|
| 1200 |
+
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
| 1201 |
+
" increasing `max_new_tokens`."
|
| 1202 |
+
)
|
| 1203 |
+
|
| 1204 |
+
# 2. Set generation parameters if not already defined
|
| 1205 |
+
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
| 1206 |
+
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
| 1207 |
+
|
| 1208 |
+
logits_processor = self._get_logits_processor(
|
| 1209 |
+
generation_config=generation_config,
|
| 1210 |
+
input_ids_seq_length=input_ids_seq_length,
|
| 1211 |
+
encoder_input_ids=input_ids,
|
| 1212 |
+
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
| 1213 |
+
logits_processor=logits_processor,
|
| 1214 |
+
)
|
| 1215 |
+
|
| 1216 |
+
stopping_criteria = self._get_stopping_criteria(
|
| 1217 |
+
generation_config=generation_config, stopping_criteria=stopping_criteria
|
| 1218 |
+
)
|
| 1219 |
+
logits_warper = self._get_logits_warper(generation_config)
|
| 1220 |
+
|
| 1221 |
+
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
| 1222 |
+
scores = None
|
| 1223 |
+
while True:
|
| 1224 |
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
| 1225 |
+
# forward pass to get next token
|
| 1226 |
+
outputs = self(
|
| 1227 |
+
**model_inputs,
|
| 1228 |
+
return_dict=True,
|
| 1229 |
+
output_attentions=False,
|
| 1230 |
+
output_hidden_states=False,
|
| 1231 |
+
)
|
| 1232 |
+
|
| 1233 |
+
next_token_logits = outputs.logits[:, -1, :]
|
| 1234 |
+
|
| 1235 |
+
# pre-process distribution
|
| 1236 |
+
next_token_scores = logits_processor(input_ids, next_token_logits)
|
| 1237 |
+
next_token_scores = logits_warper(input_ids, next_token_scores)
|
| 1238 |
+
|
| 1239 |
+
# sample
|
| 1240 |
+
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
| 1241 |
+
if generation_config.do_sample:
|
| 1242 |
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
| 1243 |
+
else:
|
| 1244 |
+
next_tokens = torch.argmax(probs, dim=-1)
|
| 1245 |
+
# update generated ids, model inputs, and length for next step
|
| 1246 |
+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
| 1247 |
+
model_kwargs = self._update_model_kwargs_for_generation(
|
| 1248 |
+
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
| 1249 |
+
)
|
| 1250 |
+
unfinished_sequences = unfinished_sequences.mul(
|
| 1251 |
+
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
| 1252 |
+
)
|
| 1253 |
+
if return_past_key_values:
|
| 1254 |
+
yield input_ids, outputs.past_key_values
|
| 1255 |
+
else:
|
| 1256 |
+
yield input_ids
|
| 1257 |
+
# stop when each sentence is finished, or if we exceed the maximum length
|
| 1258 |
+
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
| 1259 |
+
break
|
| 1260 |
+
|
| 1261 |
|
| 1262 |
class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
|
| 1263 |
def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
|