| import asyncio |
| from typing import List, Union |
|
|
| from lagent.llms.base_llm import AsyncBaseLLM, BaseLLM |
| from lagent.utils.util import filter_suffix |
|
|
|
|
| def asdict_completion(output): |
| return { |
| key: getattr(output, key) |
| for key in [ |
| 'text', 'token_ids', 'cumulative_logprob', 'logprobs', |
| 'finish_reason', 'stop_reason' |
| ] |
| } |
|
|
|
|
| class VllmModel(BaseLLM): |
| """ |
| A wrapper of vLLM model. |
| |
| Args: |
| path (str): The path to the model. |
| It could be one of the following options: |
| - i) A local directory path of a huggingface model. |
| - ii) The model_id of a model hosted inside a model repo |
| on huggingface.co, such as "internlm/internlm-chat-7b", |
| "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" |
| and so on. |
| tp (int): tensor parallel |
| vllm_cfg (dict): Other kwargs for vllm model initialization. |
| """ |
|
|
| def __init__(self, path: str, tp: int = 1, vllm_cfg=dict(), **kwargs): |
|
|
| super().__init__(path=path, **kwargs) |
| from vllm import LLM |
| self.model = LLM( |
| model=self.path, |
| trust_remote_code=True, |
| tensor_parallel_size=tp, |
| **vllm_cfg) |
|
|
| def generate(self, |
| inputs: Union[str, List[str]], |
| do_preprocess: bool = None, |
| skip_special_tokens: bool = False, |
| return_dict: bool = False, |
| **kwargs): |
| """Return the chat completions in non-stream mode. |
| |
| Args: |
| inputs (Union[str, List[str]]): input texts to be completed. |
| do_preprocess (bool): whether pre-process the messages. Default to |
| True, which means chat_template will be applied. |
| skip_special_tokens (bool): Whether or not to remove special tokens |
| in the decoding. Default to be False. |
| Returns: |
| (a list of/batched) text/chat completion |
| """ |
| from vllm import SamplingParams |
|
|
| batched = True |
| if isinstance(inputs, str): |
| inputs = [inputs] |
| batched = False |
| prompt = inputs |
| gen_params = self.update_gen_params(**kwargs) |
| max_new_tokens = gen_params.pop('max_new_tokens') |
| stop_words = gen_params.pop('stop_words') |
|
|
| sampling_config = SamplingParams( |
| skip_special_tokens=skip_special_tokens, |
| max_tokens=max_new_tokens, |
| stop=stop_words, |
| **gen_params) |
| response = self.model.generate(prompt, sampling_params=sampling_config) |
| texts = [resp.outputs[0].text for resp in response] |
| |
| texts = filter_suffix(texts, self.gen_params.get('stop_words')) |
| for resp, text in zip(response, texts): |
| resp.outputs[0].text = text |
| if batched: |
| return [asdict_completion(resp.outputs[0]) |
| for resp in response] if return_dict else texts |
| return asdict_completion( |
| response[0].outputs[0]) if return_dict else texts[0] |
|
|
|
|
| class AsyncVllmModel(AsyncBaseLLM): |
| """ |
| A asynchronous wrapper of vLLM model. |
| |
| Args: |
| path (str): The path to the model. |
| It could be one of the following options: |
| - i) A local directory path of a huggingface model. |
| - ii) The model_id of a model hosted inside a model repo |
| on huggingface.co, such as "internlm/internlm-chat-7b", |
| "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" |
| and so on. |
| tp (int): tensor parallel |
| vllm_cfg (dict): Other kwargs for vllm model initialization. |
| """ |
|
|
| def __init__(self, path: str, tp: int = 1, vllm_cfg=dict(), **kwargs): |
| super().__init__(path=path, **kwargs) |
| from vllm import AsyncEngineArgs, AsyncLLMEngine |
|
|
| engine_args = AsyncEngineArgs( |
| model=self.path, |
| trust_remote_code=True, |
| tensor_parallel_size=tp, |
| **vllm_cfg) |
| self.model = AsyncLLMEngine.from_engine_args(engine_args) |
|
|
| async def generate(self, |
| inputs: Union[str, List[str]], |
| session_ids: Union[int, List[int]] = None, |
| do_preprocess: bool = None, |
| skip_special_tokens: bool = False, |
| return_dict: bool = False, |
| **kwargs): |
| """Return the chat completions in non-stream mode. |
| |
| Args: |
| inputs (Union[str, List[str]]): input texts to be completed. |
| do_preprocess (bool): whether pre-process the messages. Default to |
| True, which means chat_template will be applied. |
| skip_special_tokens (bool): Whether or not to remove special tokens |
| in the decoding. Default to be False. |
| Returns: |
| (a list of/batched) text/chat completion |
| """ |
| from vllm import SamplingParams |
|
|
| batched = True |
| if isinstance(inputs, str): |
| inputs = [inputs] |
| batched = False |
| if session_ids is None: |
| session_ids = list(range(len(inputs))) |
| elif isinstance(session_ids, (int, str)): |
| session_ids = [session_ids] |
| assert len(inputs) == len(session_ids) |
|
|
| prompt = inputs |
| gen_params = self.update_gen_params(**kwargs) |
| max_new_tokens = gen_params.pop('max_new_tokens') |
| stop_words = gen_params.pop('stop_words') |
|
|
| sampling_config = SamplingParams( |
| skip_special_tokens=skip_special_tokens, |
| max_tokens=max_new_tokens, |
| stop=stop_words, |
| **gen_params) |
|
|
| async def _inner_generate(uid, text): |
| resp, generator = '', self.model.generate( |
| text, sampling_params=sampling_config, request_id=uid) |
| async for out in generator: |
| resp = out.outputs[0] |
| return resp |
|
|
| response = await asyncio.gather(*[ |
| _inner_generate(sid, inp) for sid, inp in zip(session_ids, prompt) |
| ]) |
| texts = [resp.text for resp in response] |
| |
| texts = filter_suffix(texts, self.gen_params.get('stop_words')) |
| for resp, text in zip(response, texts): |
| resp.text = text |
| if batched: |
| return [asdict_completion(resp) |
| for resp in response] if return_dict else texts |
| return asdict_completion(response[0]) if return_dict else texts[0] |
|
|