Spaces:
Runtime error
Runtime error
| import copy | |
| import logging | |
| from typing import List, Optional, Union | |
| from lagent.llms.base_llm import BaseModel | |
| from lagent.schema import ModelStatusCode | |
| from lagent.utils.util import filter_suffix | |
| class LMDeployServer(BaseModel): | |
| """ | |
| Args: | |
| path (str): The path to the model. | |
| It could be one of the following options: | |
| - i) A local directory path of a turbomind model which is | |
| converted by `lmdeploy convert` command or download from | |
| ii) and iii). | |
| - ii) The model_id of a lmdeploy-quantized model hosted | |
| inside a model repo on huggingface.co, such as | |
| "InternLM/internlm-chat-20b-4bit", | |
| "lmdeploy/llama2-chat-70b-4bit", etc. | |
| - iii) The model_id of a model hosted inside a model repo | |
| on huggingface.co, such as "internlm/internlm-chat-7b", | |
| "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" | |
| and so on. | |
| model_name (str): needed when model_path is a pytorch model on | |
| huggingface.co, such as "internlm-chat-7b", | |
| "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on. | |
| server_name (str): host ip for serving | |
| server_port (int): server port | |
| tp (int): tensor parallel | |
| log_level (str): set log level whose value among | |
| [CRITICAL, ERROR, WARNING, INFO, DEBUG] | |
| """ | |
| def __init__(self, | |
| path: str, | |
| model_name: Optional[str] = None, | |
| server_name: str = '0.0.0.0', | |
| server_port: int = 23333, | |
| tp: int = 1, | |
| log_level: str = 'WARNING', | |
| serve_cfg=dict(), | |
| **kwargs): | |
| super().__init__(path=path, **kwargs) | |
| self.model_name = model_name | |
| # TODO get_logger issue in multi processing | |
| import lmdeploy | |
| self.client = lmdeploy.serve( | |
| model_path=self.path, | |
| model_name=model_name, | |
| server_name=server_name, | |
| server_port=server_port, | |
| tp=tp, | |
| log_level=log_level, | |
| **serve_cfg) | |
| def generate(self, | |
| inputs: Union[str, List[str]], | |
| session_id: int = 2967, | |
| sequence_start: bool = True, | |
| sequence_end: bool = True, | |
| ignore_eos: bool = False, | |
| skip_special_tokens: Optional[bool] = False, | |
| timeout: int = 30, | |
| **kwargs) -> List[str]: | |
| """Start a new round conversation of a session. Return the chat | |
| completions in non-stream mode. | |
| Args: | |
| inputs (str, List[str]): user's prompt(s) in this round | |
| session_id (int): the identical id of a session | |
| sequence_start (bool): start flag of a session | |
| sequence_end (bool): end flag of a session | |
| ignore_eos (bool): indicator for ignoring eos | |
| skip_special_tokens (bool): Whether or not to remove special tokens | |
| in the decoding. Default to be False. | |
| timeout (int): max time to wait for response | |
| Returns: | |
| (a list of/batched) text/chat completion | |
| """ | |
| batched = True | |
| if isinstance(inputs, str): | |
| inputs = [inputs] | |
| batched = False | |
| gen_params = self.update_gen_params(**kwargs) | |
| max_new_tokens = gen_params.pop('max_new_tokens') | |
| gen_params.update(max_tokens=max_new_tokens) | |
| resp = [''] * len(inputs) | |
| for text in self.client.completions_v1( | |
| self.model_name, | |
| inputs, | |
| session_id=session_id, | |
| sequence_start=sequence_start, | |
| sequence_end=sequence_end, | |
| stream=False, | |
| ignore_eos=ignore_eos, | |
| skip_special_tokens=skip_special_tokens, | |
| timeout=timeout, | |
| **gen_params): | |
| resp = [ | |
| resp[i] + item['text'] | |
| for i, item in enumerate(text['choices']) | |
| ] | |
| # remove stop_words | |
| resp = filter_suffix(resp, self.gen_params.get('stop_words')) | |
| if not batched: | |
| return resp[0] | |
| return resp | |
| def stream_chat(self, | |
| inputs: List[dict], | |
| session_id=0, | |
| sequence_start: bool = True, | |
| sequence_end: bool = True, | |
| stream: bool = True, | |
| ignore_eos: bool = False, | |
| skip_special_tokens: Optional[bool] = False, | |
| timeout: int = 30, | |
| **kwargs): | |
| """Start a new round conversation of a session. Return the chat | |
| completions in stream mode. | |
| Args: | |
| session_id (int): the identical id of a session | |
| inputs (List[dict]): user's inputs in this round conversation | |
| sequence_start (bool): start flag of a session | |
| sequence_end (bool): end flag of a session | |
| stream (bool): return in a streaming format if enabled | |
| ignore_eos (bool): indicator for ignoring eos | |
| skip_special_tokens (bool): Whether or not to remove special tokens | |
| in the decoding. Default to be False. | |
| timeout (int): max time to wait for response | |
| Returns: | |
| tuple(Status, str, int): status, text/chat completion, | |
| generated token number | |
| """ | |
| gen_params = self.update_gen_params(**kwargs) | |
| max_new_tokens = gen_params.pop('max_new_tokens') | |
| gen_params.update(max_tokens=max_new_tokens) | |
| prompt = self.template_parser(inputs) | |
| resp = '' | |
| finished = False | |
| stop_words = self.gen_params.get('stop_words') | |
| for text in self.client.completions_v1( | |
| self.model_name, | |
| prompt, | |
| session_id=session_id, | |
| sequence_start=sequence_start, | |
| sequence_end=sequence_end, | |
| stream=stream, | |
| ignore_eos=ignore_eos, | |
| skip_special_tokens=skip_special_tokens, | |
| timeout=timeout, | |
| **gen_params): | |
| resp += text['choices'][0]['text'] | |
| if not resp: | |
| continue | |
| # remove stop_words | |
| for sw in stop_words: | |
| if sw in resp: | |
| resp = filter_suffix(resp, stop_words) | |
| finished = True | |
| break | |
| yield ModelStatusCode.STREAM_ING, resp, None | |
| if finished: | |
| break | |
| yield ModelStatusCode.END, resp, None |