| import copy |
| import warnings |
| import logging |
| from typing import List, Tuple, Optional, Callable |
|
|
| import torch |
| from torch import nn |
| from transformers.utils import logging |
| from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig |
|
|
| from .modeling_chatglm import ChatGLMForConditionalGeneration, InvalidScoreLogitsProcessor |
| from .characterglm_generation_utils import CharacterGLMGenerationUtils, SessionMeta |
|
|
|
|
| logger = logging.get_logger(__name__) |
| default_generation_config = { |
| "do_sample": True, |
| "top_k": 100, |
| "top_p": 0.9, |
| "no_repeat_ngram_size": 0, |
| "temperature": 0.9, |
| "num_beams": 1, |
| "length_penalty": 1.6, |
| "repetition_penalty": 1.3, |
| "eos_token_id": 13 |
| } |
|
|
|
|
| class CharacterGLMForConditionalGeneration(ChatGLMForConditionalGeneration): |
| """ |
| CharacterGLM的prompt格式与chatglm有差异。 |
| CharacterGLMForConditionalGeneration复用了ChatGLMForConditionalGeneration的forward方法, |
| 重新实现了`build_inputs`和`build_stream_inputs`, |
| 调整了`chat`和`stream_chat`方法的函数签名,增加session_meta参数,并修改解码参数的默认值。 |
| """ |
|
|
| def build_inputs(self, tokenizer, session_meta: SessionMeta, query: str, history: Optional[List[Tuple[str, str]]] = None): |
| character_glm_history = CharacterGLMGenerationUtils.convert_chatglm_history_to_characterglm_history(query, history or []) |
| prompt = CharacterGLMGenerationUtils.build_inputs(session_meta, character_glm_history) |
| inputs = tokenizer([prompt], return_tensors="pt") |
| inputs = inputs.to(self.device) |
| return inputs |
| |
| def build_stream_inputs(self, tokenizer, session_meta: SessionMeta, query: str, history: Optional[List[Tuple[str, str]]] = None): |
| prompt = "\n[{}]{}\n[{}]".format( |
| session_meta['user_name'], |
| query.replace('\n', ' '), |
| session_meta['bot_name'] |
| ) |
| input_ids = tokenizer.encode(prompt, add_special_tokens=False) |
| input_ids = input_ids[1:] |
| inputs = tokenizer.batch_encode_plus([(input_ids, None)], return_tensors="pt", add_special_tokens=False) |
| inputs = inputs.to(self.device) |
| return inputs |
|
|
| @torch.inference_mode() |
| def chat(self, tokenizer, session_meta: SessionMeta, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, num_beams=1, |
| do_sample=True, top_p=0.9, temperature=0.9, repetition_penalty=1.6, logits_processor=None, **kwargs): |
| if history is None: |
| history = [] |
| if logits_processor is None: |
| logits_processor = LogitsProcessorList() |
| logits_processor.append(InvalidScoreLogitsProcessor()) |
| gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, |
| "temperature": temperature, "logits_processor": logits_processor, "repetition_penalty": repetition_penalty, **kwargs} |
| gen_kwargs.update({k: v for k, v in default_generation_config.items() if k not in gen_kwargs}) |
| inputs = self.build_inputs(tokenizer, session_meta, query, history=history) |
| outputs = self.generate(**inputs, **gen_kwargs) |
| outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] |
| response = tokenizer.decode(outputs) |
| response = self.process_response(response) |
| history = history + [(query, response)] |
| return response, history |
|
|
| @torch.inference_mode() |
| def stream_chat(self, tokenizer, session_meta: SessionMeta, query: str, history: List[Tuple[str, str]] = None, past_key_values=None, |
| max_length: int = 8192, do_sample=True, top_p=0.9, temperature=0.9, repetition_penalty=1.0, logits_processor=None, |
| return_past_key_values=False, **kwargs): |
| if history is None: |
| history = [] |
| if logits_processor is None: |
| logits_processor = LogitsProcessorList() |
| logits_processor.append(InvalidScoreLogitsProcessor()) |
| gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, |
| "temperature": temperature, "logits_processor": logits_processor, "repetition_penalty": repetition_penalty, **kwargs} |
| gen_kwargs.update({k: v for k, v in default_generation_config.items() if k not in gen_kwargs}) |
| gen_kwargs.pop('repetition_penalty', None) |
| if past_key_values is None: |
| inputs = self.build_inputs(tokenizer, session_meta, query, history=history) |
| else: |
| inputs = self.build_stream_inputs(tokenizer, session_meta, query, history=history) |
| if past_key_values is not None: |
| past_length = past_key_values[0][0].shape[0] |
| if self.transformer.pre_seq_len is not None: |
| past_length -= self.transformer.pre_seq_len |
| inputs.position_ids += past_length |
| attention_mask = inputs.attention_mask |
| attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1) |
| inputs['attention_mask'] = attention_mask |
| for outputs in self.stream_generate(**inputs, past_key_values=past_key_values, |
| return_past_key_values=return_past_key_values, **gen_kwargs): |
| if return_past_key_values: |
| outputs, past_key_values = outputs |
| outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] |
| response = tokenizer.decode(outputs) |
| if response and response[-1] != "�": |
| response = self.process_response(response) |
| new_history = history + [(query, response)] |
| if return_past_key_values: |
| yield response, new_history, past_key_values |
| else: |
| yield response, new_history |
|
|
| @torch.inference_mode() |
| def stream_generate( |
| self, |
| input_ids, |
| generation_config: Optional[GenerationConfig] = None, |
| logits_processor: Optional[LogitsProcessorList] = None, |
| stopping_criteria: Optional[StoppingCriteriaList] = None, |
| prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, |
| return_past_key_values=False, |
| **kwargs, |
| ): |
| batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] |
|
|
| if generation_config is None: |
| generation_config = self.generation_config |
| generation_config = copy.deepcopy(generation_config) |
| model_kwargs = generation_config.update(**kwargs) |
| model_kwargs["use_cache"] = generation_config.use_cache |
| bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id |
|
|
| if isinstance(eos_token_id, int): |
| eos_token_id = [eos_token_id] |
|
|
| has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None |
| if has_default_max_length and generation_config.max_new_tokens is None: |
| warnings.warn( |
| f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " |
| "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" |
| " recommend using `max_new_tokens` to control the maximum length of the generation.", |
| UserWarning, |
| ) |
| elif generation_config.max_new_tokens is not None: |
| generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length |
| if not has_default_max_length: |
| logger.warn( |
| f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" |
| f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " |
| "Please refer to the documentation for more information. " |
| "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", |
| UserWarning, |
| ) |
|
|
| if input_ids_seq_length >= generation_config.max_length: |
| input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" |
| logger.warning( |
| f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" |
| f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" |
| " increasing `max_new_tokens`." |
| ) |
|
|
| |
| logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() |
| stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() |
|
|
| logits_processor = self._get_logits_processor( |
| generation_config=generation_config, |
| input_ids_seq_length=input_ids_seq_length, |
| encoder_input_ids=input_ids, |
| prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, |
| logits_processor=logits_processor, |
| ) |
|
|
| stopping_criteria = self._get_stopping_criteria( |
| generation_config=generation_config, stopping_criteria=stopping_criteria |
| ) |
| logits_warper = self._get_logits_warper(generation_config) |
|
|
| unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) |
| scores = None |
| while True: |
| model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) |
| |
| outputs = self( |
| **model_inputs, |
| return_dict=True, |
| output_attentions=False, |
| output_hidden_states=False, |
| ) |
|
|
| next_token_logits = outputs.logits[:, -1, :] |
|
|
| |
| next_token_scores = logits_processor(input_ids, next_token_logits) |
| next_token_scores = logits_warper(input_ids, next_token_scores) |
|
|
| |
| probs = nn.functional.softmax(next_token_scores, dim=-1) |
| if generation_config.do_sample: |
| next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) |
| else: |
| next_tokens = torch.argmax(probs, dim=-1) |
|
|
| |
| input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) |
| model_kwargs = self._update_model_kwargs_for_generation( |
| outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder |
| ) |
| unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) |
| if return_past_key_values: |
| yield input_ids, outputs.past_key_values |
| else: |
| yield input_ids |
| |
| if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): |
| break |
|
|