| import gc |
| import copy |
| from tenacity import RetryError |
| from tenacity import retry, stop_after_attempt, wait_fixed |
|
|
| import torch |
|
|
| from transformers import ( |
| AutoModelForCausalLM, |
| AutoModelForSeq2SeqLM, |
| AutoTokenizer, |
| LogitsProcessorList, |
| MinNewTokensLengthLogitsProcessor, |
| TemperatureLogitsWarper, |
| TopPLogitsWarper, |
| ) |
|
|
| def get_output_batch( |
| model, tokenizer, prompts, generation_config |
| ): |
| if len(prompts) == 1: |
| encoding = tokenizer(prompts, return_tensors="pt") |
| input_ids = encoding["input_ids"].cuda() |
| generated_id = model.generate( |
| input_ids=input_ids, |
| generation_config=generation_config, |
| max_new_tokens=256 |
| ) |
|
|
| decoded = tokenizer.batch_decode(generated_id) |
| del input_ids, generated_id |
| torch.cuda.empty_cache() |
| return decoded |
| else: |
| encodings = tokenizer(prompts, padding=True, return_tensors="pt").to('cuda') |
| generated_ids = model.generate( |
| **encodings, |
| generation_config=generation_config, |
| max_new_tokens=256 |
| ) |
|
|
| decoded = tokenizer.batch_decode(generated_ids) |
| del encodings, generated_ids |
| torch.cuda.empty_cache() |
| return decoded |
|
|
|
|
| |
| |
| class StreamModel: |
| """StreamModel wraps around a language model to provide stream decoding.""" |
|
|
| def __init__(self, model, tokenizer): |
| super().__init__() |
| self.model = model |
| self.tokenizer = tokenizer |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| def __call__( |
| self, |
| prompt, |
| min_tokens=0, |
| max_tokens=16, |
| temperature=1.0, |
| top_p=1.0, |
| n=1, |
| logprobs=0, |
| ): |
| """Create a completion stream for the provided prompt.""" |
| input_ids = self.tokenize(prompt) |
| logprobs = max(logprobs, 0) |
|
|
| |
| chunk_size = 2 |
| chunk_count = 0 |
| |
| |
| final_tokens = torch.empty(0).to(self.device) |
| |
| try: |
| for tokens in self.generate( |
| input_ids[None, :].repeat(n, 1), |
| logprobs=logprobs, |
| min_new_tokens=min_tokens, |
| max_new_tokens=max_tokens, |
| temperature=temperature, |
| top_p=top_p, |
| ): |
| if chunk_count < chunk_size: |
| chunk_count = chunk_count + 1 |
| |
| final_tokens = torch.cat((final_tokens, tokens)) |
|
|
| if chunk_count == chunk_size-1: |
| chunk_count = 0 |
| yield self.tokenizer.decode(final_tokens, skip_special_tokens=True) |
|
|
| if chunk_count > 0: |
| yield self.tokenizer.decode(final_tokens, skip_special_tokens=True) |
|
|
| except RetryError as e: |
| print(e) |
| del input_ids |
| gc.collect() |
| |
| del final_tokens |
| if self.device == "cuda": |
| torch.cuda.empty_cache() |
|
|
| @retry(stop=stop_after_attempt(5), wait=wait_fixed(1)) |
| def _infer(self, model_fn, **kwargs): |
| """Call a model function in inference mode with auto retrying.""" |
| |
| |
| with torch.inference_mode(): |
| return model_fn(**kwargs) |
|
|
| def _logits_processor(self, config, input_length): |
| """Set up logits processor based on the generation config.""" |
| processor = LogitsProcessorList() |
|
|
| |
| if ( |
| config.min_new_tokens is not None |
| and config.min_new_tokens > 0 |
| and config.eos_token_id is not None |
| ): |
| processor.append( |
| MinNewTokensLengthLogitsProcessor( |
| prompt_length_to_skip=input_length, |
| min_new_tokens=config.min_new_tokens, |
| eos_token_id=config.eos_token_id, |
| ) |
| ) |
|
|
| |
| if ( |
| config.temperature is not None |
| and config.temperature > 0 |
| and config.temperature != 1.0 |
| ): |
| processor.append(TemperatureLogitsWarper(config.temperature)) |
|
|
| |
| if config.top_p is not None and config.top_p > 0 and config.top_p < 1: |
| processor.append(TopPLogitsWarper(config.top_p)) |
|
|
| return processor |
|
|
| def tokenize(self, text): |
| """Tokenize a string into a tensor of token IDs.""" |
| batch = self.tokenizer.encode(text, return_tensors="pt") |
| return batch[0].to(self.device) |
|
|
| def generate(self, input_ids, logprobs=0, **kwargs): |
| """Generate a stream of predicted tokens using the language model.""" |
|
|
| |
| batch_size = input_ids.shape[0] |
| input_length = input_ids.shape[-1] |
|
|
| |
| config = self.model.generation_config |
| config = copy.deepcopy(config) |
| kwargs = config.update(**kwargs) |
| kwargs["output_attentions"] = False |
| kwargs["output_hidden_states"] = False |
| kwargs["use_cache"] = True |
|
|
| |
| pad_token_id = config.pad_token_id |
| bos_token_id = config.bos_token_id |
| eos_token_id = config.eos_token_id |
| if isinstance(eos_token_id, int): |
| eos_token_id = [eos_token_id] |
| if pad_token_id is None and eos_token_id is not None: |
| pad_token_id = eos_token_id[0] |
|
|
| |
| if input_length == 0: |
| input_ids = input_ids.new_ones((batch_size, 1)).long() |
| if eos_token_id is not None: |
| input_ids = input_ids * eos_token_id[0] |
| input_length = 1 |
|
|
| |
| if self.model.config.is_encoder_decoder: |
| |
| encoder = self.model.get_encoder() |
| encoder_kwargs = kwargs.copy() |
| encoder_kwargs.pop("use_cache", None) |
| encoder_kwargs["input_ids"] = input_ids |
| encoder_kwargs["return_dict"] = True |
| encoder_outputs = self._infer(encoder, **encoder_kwargs) |
| kwargs["encoder_outputs"] = encoder_outputs |
|
|
| |
| decoder_start_token_id = config.decoder_start_token_id |
| if decoder_start_token_id is None: |
| decoder_start_token_id = bos_token_id |
| input_ids = input_ids.new_ones((batch_size, 1)) |
| input_ids = input_ids * decoder_start_token_id |
| input_length = 1 |
|
|
| |
| processor = self._logits_processor(config, input_length) |
|
|
| |
| unfinished = input_ids.new_ones(batch_size) |
|
|
| |
| while True: |
| inputs = self.model.prepare_inputs_for_generation( |
| input_ids, **kwargs |
| ) |
| outputs = self._infer( |
| self.model, |
| **inputs, |
| return_dict=True, |
| output_attentions=False, |
| output_hidden_states=False, |
| ) |
|
|
| |
| logits = outputs.logits[:, -1, :] |
| with torch.inference_mode(): |
| logits = processor(input_ids, logits) |
| probs = torch.nn.functional.softmax(logits, dim=-1) |
|
|
| |
| if (config.top_p is not None and config.top_p <= 0) or ( |
| config.temperature is not None and config.temperature <= 0 |
| ): |
| tokens = torch.argmax(probs, dim=-1)[:, None] |
| else: |
| tokens = torch.multinomial(probs, num_samples=1) |
|
|
| tokens = tokens.squeeze(1) |
|
|
| |
| if pad_token_id is not None: |
| tokens = tokens * unfinished + pad_token_id * (1 - unfinished) |
|
|
| |
| input_ids = torch.cat([input_ids, tokens[:, None]], dim=-1) |
|
|
| |
| if eos_token_id is not None: |
| not_eos = sum(tokens != i for i in eos_token_id) |
| unfinished = unfinished.mul(not_eos.long()) |
|
|
| |
| status = unfinished.clone() |
| if input_ids.shape[-1] - input_length >= config.max_new_tokens: |
| status = 0 - status |
|
|
| |
| yield tokens |
|
|
| |
| if status.max() <= 0: |
| break |
| |
|
|