| import copy |
| import logging |
| from typing import Dict, List, Optional, Union |
|
|
| from lagent.schema import ModelStatusCode |
| from .base_api import APITemplateParser |
| from .base_llm import BaseLLM |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class HFTransformer(BaseLLM): |
| """Model wrapper around HuggingFace general models. |
| |
| Adapted from Internlm (https://github.com/InternLM/InternLM/blob/main/ |
| chat/web_demo.py) |
| |
| Args: |
| path (str): The name or path to HuggingFace's model. |
| tokenizer_path (str): The path to the tokenizer. Defaults to None. |
| tokenizer_kwargs (dict): Keyword arguments for the tokenizer. |
| Defaults to {}. |
| tokenizer_only (bool): If True, only the tokenizer will be initialized. |
| Defaults to False. |
| model_kwargs (dict): Keyword arguments for the model, used in loader. |
| Defaults to dict(device_map='auto'). |
| meta_template (Dict, optional): The model's meta prompt |
| template if needed, in case the requirement of injecting or |
| wrapping of any meta instructions. |
| """ |
|
|
| def __init__(self, |
| path: str, |
| tokenizer_path: Optional[str] = None, |
| tokenizer_kwargs: dict = dict(), |
| tokenizer_only: bool = False, |
| model_kwargs: dict = dict(device_map='auto'), |
| meta_template: Optional[Dict] = None, |
| stop_words_id: Union[List[int], int] = None, |
| **kwargs): |
| super().__init__( |
| path=path, |
| tokenizer_only=tokenizer_only, |
| meta_template=meta_template, |
| **kwargs) |
| if isinstance(stop_words_id, int): |
| stop_words_id = [stop_words_id] |
| self.gen_params.update(stop_words_id=stop_words_id) |
| if self.gen_params['stop_words'] is not None and \ |
| self.gen_params['stop_words_id'] is not None: |
| logger.warning('Both stop_words and stop_words_id are specified,' |
| 'only stop_words_id will be used.') |
|
|
| self._load_tokenizer( |
| path=path, |
| tokenizer_path=tokenizer_path, |
| tokenizer_kwargs=tokenizer_kwargs) |
| if not tokenizer_only: |
| self._load_model(path=path, model_kwargs=model_kwargs) |
|
|
| from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList |
| self.logits_processor = LogitsProcessorList() |
| self.stopping_criteria = StoppingCriteriaList() |
| self.prefix_allowed_tokens_fn = None |
|
|
| stop_words_id = [] |
| if self.gen_params.get('stop_words_id'): |
| stop_words_id = self.gen_params.get('stop_words_id') |
| elif self.gen_params.get('stop_words'): |
| for sw in self.gen_params.get('stop_words'): |
| stop_words_id.append(self.tokenizer(sw)['input_ids'][-1]) |
| self.additional_eos_token_id = stop_words_id |
|
|
| def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], |
| tokenizer_kwargs: dict): |
| from transformers import AutoTokenizer |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| tokenizer_path if tokenizer_path else path, |
| trust_remote_code=True, |
| **tokenizer_kwargs) |
|
|
| if self.tokenizer.pad_token_id is None: |
| if self.tokenizer.eos_token is not None: |
| logger.warning( |
| f'Using eos_token_id {self.tokenizer.eos_token} ' |
| 'as pad_token_id.') |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
| else: |
| from transformers.generation import GenerationConfig |
| self.gcfg = GenerationConfig.from_pretrained(path) |
|
|
| if self.gcfg.pad_token_id is not None: |
| logger.warning( |
| f'Using pad_token_id {self.gcfg.pad_token_id} ' |
| 'as pad_token_id.') |
| self.tokenizer.pad_token_id = self.gcfg.pad_token_id |
| else: |
| raise ValueError( |
| 'pad_token_id is not set for this tokenizer. Try to ' |
| 'set pad_token_id via passing ' |
| '`pad_token_id={PAD_TOKEN_ID}` in model_cfg.') |
|
|
| def _load_model(self, path: str, model_kwargs: dict): |
| import torch |
| from transformers import AutoModel |
| model_kwargs.setdefault('torch_dtype', torch.float16) |
| self.model = AutoModel.from_pretrained( |
| path, trust_remote_code=True, **model_kwargs) |
| self.model.eval() |
|
|
| def tokenize(self, inputs: str): |
| assert isinstance(inputs, str) |
| inputs = self.tokenizer( |
| inputs, return_tensors='pt', return_length=True) |
| return inputs['input_ids'].tolist() |
|
|
| def generate( |
| self, |
| inputs: Union[str, List[str]], |
| do_sample: bool = True, |
| **kwargs, |
| ): |
| """Return the chat completions in non-stream mode. |
| |
| Args: |
| inputs (Union[str, List[str]]): input texts to be completed. |
| do_sample (bool): do sampling if enabled |
| Returns: |
| (a list of/batched) text/chat completion |
| """ |
| for status, chunk, _ in self.stream_generate(inputs, do_sample, |
| **kwargs): |
| response = chunk |
| return response |
|
|
| def stream_generate( |
| self, |
| inputs: List[str], |
| do_sample: bool = True, |
| **kwargs, |
| ): |
| """Return the chat completions in stream mode. |
| |
| Args: |
| inputs (Union[str, List[str]]): input texts to be completed. |
| do_sample (bool): do sampling if enabled |
| Returns: |
| tuple(Status, str, int): status, text/chat completion, |
| generated token number |
| """ |
| import torch |
| from torch import nn |
| with torch.no_grad(): |
| batched = True |
| if isinstance(inputs, str): |
| inputs = [inputs] |
| batched = False |
| inputs = self.tokenizer( |
| inputs, padding=True, return_tensors='pt', return_length=True) |
| input_length = inputs['length'] |
| for k, v in inputs.items(): |
| inputs[k] = v.cuda() |
| input_ids = inputs['input_ids'] |
| attention_mask = inputs['attention_mask'] |
| batch_size = input_ids.shape[0] |
| input_ids_seq_length = input_ids.shape[-1] |
| generation_config = self.model.generation_config |
| generation_config = copy.deepcopy(generation_config) |
| new_gen_params = self.update_gen_params(**kwargs) |
| generation_config.update(**new_gen_params) |
| generation_config.update(**kwargs) |
| model_kwargs = generation_config.to_dict() |
| model_kwargs['attention_mask'] = attention_mask |
| _, eos_token_id = ( |
| generation_config.bos_token_id, |
| generation_config.eos_token_id, |
| ) |
| if eos_token_id is None: |
| if self.gcfg.eos_token_id is not None: |
| eos_token_id = self.gcfg.eos_token_id |
| else: |
| eos_token_id = [] |
| if isinstance(eos_token_id, int): |
| eos_token_id = [eos_token_id] |
| if self.additional_eos_token_id is not None: |
| eos_token_id.extend(self.additional_eos_token_id) |
| eos_token_id_tensor = torch.tensor(eos_token_id).to( |
| input_ids.device) if eos_token_id is not None else None |
| generation_config.max_length = ( |
| generation_config.max_new_tokens + input_ids_seq_length) |
| |
| logits_processor = self.logits_processor |
| stopping_criteria = self.stopping_criteria |
|
|
| logits_processor = self.model._get_logits_processor( |
| generation_config=generation_config, |
| input_ids_seq_length=input_ids_seq_length, |
| encoder_input_ids=input_ids, |
| prefix_allowed_tokens_fn=self.prefix_allowed_tokens_fn, |
| logits_processor=logits_processor, |
| ) |
|
|
| stopping_criteria = self.model._get_stopping_criteria( |
| generation_config=generation_config, |
| stopping_criteria=stopping_criteria) |
| logits_warper = self.model._get_logits_warper(generation_config) |
|
|
| unfinished_sequences = input_ids.new(batch_size).fill_(1) |
| scores = None |
| while True: |
| model_inputs = self.model.prepare_inputs_for_generation( |
| input_ids, **model_kwargs) |
| |
| outputs = self.model( |
| **model_inputs, |
| return_dict=True, |
| output_attentions=False, |
| output_hidden_states=False, |
| ) |
|
|
| next_token_logits = outputs.logits[:, -1, :] |
|
|
| |
| next_token_scores = logits_processor(input_ids, |
| next_token_logits) |
| next_token_scores = logits_warper(input_ids, next_token_scores) |
|
|
| |
| probs = nn.functional.softmax(next_token_scores, dim=-1) |
| if do_sample: |
| next_tokens = torch.multinomial( |
| probs, num_samples=1).squeeze(1) |
| else: |
| next_tokens = torch.argmax(probs, dim=-1) |
|
|
| |
| |
| input_ids = torch.cat([input_ids, next_tokens[:, None]], |
| dim=-1) |
| model_kwargs = self.model._update_model_kwargs_for_generation( |
| outputs, |
| model_kwargs, |
| is_encoder_decoder=False) |
| unfinished_sequences = unfinished_sequences.mul( |
| next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne( |
| eos_token_id_tensor.unsqueeze(1)).prod(dim=0)) |
| output_token_ids = input_ids.cpu().tolist() |
| for i in range(len(output_token_ids)): |
| output_token_ids[i] = output_token_ids[i][:][ |
| input_length[i]:] |
| |
| |
| first_eos_idx = next( |
| (idx |
| for idx, token_id in enumerate(output_token_ids[i]) |
| if token_id in eos_token_id), None) |
| |
| |
| if first_eos_idx is not None: |
| output_token_ids[i] = output_token_ids[ |
| i][:first_eos_idx] |
|
|
| response = self.tokenizer.batch_decode(output_token_ids) |
| |
| if not batched: |
| response = response[0] |
| yield ModelStatusCode.STREAM_ING, response, None |
| |
| |
| if (unfinished_sequences.max() == 0 |
| or stopping_criteria(input_ids, scores)): |
| break |
| yield ModelStatusCode.END, response, None |
|
|
| def stream_chat( |
| self, |
| inputs: List[dict], |
| do_sample: bool = True, |
| **kwargs, |
| ): |
| """Return the chat completions in stream mode. |
| |
| Args: |
| inputs (List[dict]): input messages to be completed. |
| do_sample (bool): do sampling if enabled |
| Returns: |
| the text/chat completion |
| """ |
| prompt = self.template_parser(inputs) |
| yield from self.stream_generate(prompt, do_sample, **kwargs) |
|
|
|
|
| class HFTransformerCasualLM(HFTransformer): |
|
|
| def _load_model(self, path: str, model_kwargs: dict): |
| import torch |
| from transformers import AutoModelForCausalLM |
| model_kwargs.setdefault('torch_dtype', torch.float16) |
| self.model = AutoModelForCausalLM.from_pretrained( |
| path, trust_remote_code=True, **model_kwargs) |
| self.model.eval() |
|
|
|
|
| class HFTransformerChat(HFTransformerCasualLM): |
|
|
| def __init__(self, template_parser=APITemplateParser, **kwargs): |
| super().__init__(template_parser=template_parser, **kwargs) |
|
|
| def chat(self, |
| inputs: Union[List[dict], List[List[dict]]], |
| do_sample: bool = True, |
| **kwargs): |
| """Return the chat completions in stream mode. |
| |
| Args: |
| inputs (Union[List[dict], List[List[dict]]]): input messages to be completed. |
| do_sample (bool): do sampling if enabled |
| Returns: |
| the text/chat completion |
| """ |
| |
| if isinstance(inputs[0], list): |
| resps = [] |
| for input in inputs: |
| resps.append(self.chat(input, do_sample, **kwargs)) |
| return resps |
| prompt = self.template_parser(inputs) |
| query = prompt[-1]['content'] |
| history = prompt[:-1] |
| try: |
| response, history = self.model.chat( |
| self.tokenizer, query, history=history) |
| except Exception as e: |
| |
| logger.warning(str(e)) |
| response = '' |
| return response |
|
|