# Copyright (c) Alibaba, Inc. and its affiliates. import asyncio import inspect import os import time from contextlib import contextmanager from copy import deepcopy from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union import lmdeploy import torch from lmdeploy import PytorchEngineConfig, TurbomindEngineConfig, VisionConfig, pipeline from lmdeploy.api import autoget_backend_config from lmdeploy.serve import async_engine from packaging import version from transformers import GenerationConfig from swift.llm import InferRequest, Template, TemplateMeta, get_model_tokenizer from swift.plugin import Metric from swift.utils import get_logger, get_seed from ..protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, DeltaMessage, RequestConfig) from .infer_engine import InferEngine from .patch import patch_auto_config, patch_auto_tokenizer from .utils import InferStreamer, patch_lmdeploy try: from lmdeploy import EngineGenerationConfig as LmdeployGenerationConfig except ImportError: # compat lmdeploy >= 0.6.* from lmdeploy import GenerationConfig as LmdeployGenerationConfig logger = get_logger() class LmdeployEngine(InferEngine): def __init__( self, model_id_or_path: str, torch_dtype: Optional[torch.dtype] = None, *, model_type: Optional[str] = None, use_hf: Optional[bool] = None, hub_token: Optional[str] = None, revision: Optional[str] = None, # engine_kwargs tp: int = 1, session_len: Optional[int] = None, cache_max_entry_count: float = 0.8, quant_policy: int = 0, # e.g. 4, 8 vision_batch_size: int = 1, # max_batch_size in VisionConfig devices: Optional[List[int]] = None, reload_weights: bool = False, engine_kwargs: Optional[Dict[str, Any]] = None, ) -> None: version_7 = version.parse(lmdeploy.__version__) >= version.parse('0.7.0') if reload_weights: assert version_7, 'grpo or reload_weights need lmdeploy>=0.7.0' if version_7 and tp == 1: patch_lmdeploy(reload_weights) self.processor = get_model_tokenizer( model_id_or_path, torch_dtype, load_model=False, download_model=True, model_type=model_type, use_hf=use_hf, hub_token=hub_token, revision=revision)[1] self._post_init() if self.max_model_len is not None: self.max_model_len -= 1 self._prepare_engine_kwargs( tp=tp, session_len=session_len, cache_max_entry_count=cache_max_entry_count, quant_policy=quant_policy, vision_batch_size=vision_batch_size, devices=devices, engine_kwargs=engine_kwargs) self.config.torch_dtype = torch_dtype or self.model_info.torch_dtype @contextmanager def disable_deepspeed(): from transformers import modeling_utils modeling_utils.is_deepspeed_zero3_enabled_origin = modeling_utils.is_deepspeed_zero3_enabled modeling_utils.is_deepspeed_zero3_enabled = lambda: False yield modeling_utils.is_deepspeed_zero3_enabled = modeling_utils.is_deepspeed_zero3_enabled_origin del modeling_utils.is_deepspeed_zero3_enabled_origin with disable_deepspeed(): self._prepare_engine() self._load_generation_config() def _prepare_engine_kwargs(self, tp: int = 1, session_len: Optional[int] = None, cache_max_entry_count: float = 0.8, quant_policy: int = 0, vision_batch_size: int = 1, devices: Optional[List[int]] = None, engine_kwargs: Optional[Dict[str, Any]] = None): if engine_kwargs is None: engine_kwargs = {} engine_kwargs['tp'] = tp engine_kwargs['session_len'] = session_len engine_kwargs['cache_max_entry_count'] = cache_max_entry_count engine_kwargs['quant_policy'] = quant_policy backend_config = TurbomindEngineConfig(**engine_kwargs) backend_config = autoget_backend_config(self.model_dir, backend_config) if hasattr(backend_config, 'devices'): if devices is None: devices = [0] backend_config.devices = devices self.backend_config = backend_config logger.info(f'backend_config: {backend_config}') pipeline_kwargs = {} is_multimodal = self.model_meta.is_multimodal if is_multimodal: vision_config = VisionConfig(max_batch_size=vision_batch_size) pipeline_kwargs['vision_config'] = vision_config logger.info(f'vision_config: {vision_config}') self.pipeline_kwargs = pipeline_kwargs @contextmanager def _patch_pipeline(self): _old_best_match_model = async_engine.best_match_model def _best_match_model(*args, **kwargs) -> Optional[str]: return self.model_info.model_type async_engine.best_match_model = _best_match_model try: yield finally: async_engine.best_match_model = _old_best_match_model def _prepare_engine(self): with patch_auto_tokenizer(self.tokenizer), patch_auto_config(self.config), self._patch_pipeline(): engine = pipeline(self.model_dir, backend_config=self.backend_config, **self.pipeline_kwargs) self.engine = engine def _load_generation_config(self): generation_config_path = os.path.join(self.model_dir, 'generation_config.json') if os.path.isfile(generation_config_path): generation_config = GenerationConfig.from_pretrained(self.model_dir) kwargs = generation_config.to_dict() max_new_tokens = kwargs.get('max_new_tokens') if max_new_tokens is None: kwargs.pop('max_new_tokens', None) parameters = inspect.signature(LmdeployGenerationConfig).parameters for k, v in kwargs.copy().items(): if k not in parameters or v is None: kwargs.pop(k) self.generation_config = LmdeployGenerationConfig(**kwargs) else: self.generation_config = LmdeployGenerationConfig() def _get_stop_token_ids(self, stop_words: List[Union[str, List[int], None]]) -> List[int]: stop_token_ids: List[int] = [] for stop_word in stop_words: if stop_word is None: continue if isinstance(stop_word, str): stop_word = self.tokenizer.encode(stop_word, add_special_tokens=False) if isinstance(stop_word, list): if len(stop_word) != 1: continue else: stop_token = stop_word[0] elif isinstance(stop_word, int): stop_token = stop_word assert isinstance(stop_token, int) if stop_token not in stop_token_ids: stop_token_ids.append(stop_token) return stop_token_ids def _add_stop_words(self, generation_config: LmdeployGenerationConfig, request_config: RequestConfig, template_meta: TemplateMeta) -> None: stop_words = (request_config.stop or []) + (self.generation_config.stop_words or []) + template_meta.stop_words generation_config.stop_words = self._get_stop_token_ids(stop_words) # compat lmdeploy >= 0.6.* generation_config.stop_token_ids = generation_config.stop_words def _prepare_generation_config(self, request_config: RequestConfig) -> LmdeployGenerationConfig: kwargs = {'max_new_tokens': request_config.max_tokens} for key in ['temperature', 'top_k', 'top_p', 'repetition_penalty']: new_value = getattr(request_config, key) if new_value is None: kwargs[key] = getattr(self.generation_config, key) else: kwargs[key] = new_value if request_config.seed is None: request_config.seed = get_seed() kwargs['random_seed'] = request_config.seed if request_config.temperature == 0: kwargs['temperature'] = 1 # avoid unnecessary process kwargs['top_k'] = 1 if request_config.logprobs: kwargs['logprobs'] = 1 if request_config.top_logprobs is not None: kwargs['logprobs'] = max(1, request_config.top_logprobs) res = LmdeployGenerationConfig(**kwargs) res.top_logprobs = request_config.top_logprobs return res async def _infer_stream_async( self, template: Template, inputs: Dict[str, Any], generation_config: LmdeployGenerationConfig) -> AsyncIterator[ChatCompletionStreamResponse]: session_id = time.time_ns() kwargs = {'stream_output': True, 'gen_config': generation_config, 'sequence_start': True, 'sequence_end': True} if version.parse(lmdeploy.__version__) >= version.parse('0.6.5'): async with self.engine.model_inst(session_id) as inst: context = self.engine.safe_run(inst, session_id, **inputs, **kwargs) else: context = self.engine.safe_run(session_id) infer_streamer = InferStreamer(template) token_idx = 0 async with context as gen: if version.parse(lmdeploy.__version__) < version.parse('0.6.5'): generator = await self.engine.get_generator(False, session_id) gen = generator.async_stream_infer(session_id=session_id, **inputs, **kwargs) is_finished = False while not is_finished: try: output = await gen.__anext__() except StopAsyncIteration: is_finished = True delta_text = infer_streamer.get_printable_text(output.token_ids, is_finished) if not delta_text and not is_finished: continue logprobs = self._get_logprobs(output.logprobs, output.token_ids[token_idx:], generation_config.top_logprobs) token_idx = len(output.token_ids) usage_info = self._get_usage_info(len(inputs['input_ids']), output.num_token) toolcall = None if is_finished: toolcall = self._get_toolcall(template.decode(output.token_ids), template) finish_reason = self._get_finish_reason(generation_config.max_new_tokens, output.num_token, output.status.name == 'FINISH') choices = [ ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(role='assistant', content=delta_text, tool_calls=toolcall), finish_reason=finish_reason, logprobs=logprobs) ] yield ChatCompletionStreamResponse(model=self.model_name, choices=choices, usage=usage_info) async def _infer_full_async(self, template: Template, inputs: Dict[str, Any], generation_config: LmdeployGenerationConfig) -> ChatCompletionResponse: session_id = time.time_ns() kwargs = {'stream_output': False, 'gen_config': generation_config, 'sequence_start': True, 'sequence_end': True} if version.parse(lmdeploy.__version__) >= version.parse('0.6.5'): async with self.engine.model_inst(session_id) as inst: async with self.engine.safe_run(inst, session_id, **inputs, **kwargs) as gen: async for output in gen: pass if self.engine.backend == 'pytorch': # manually end pytorch session await inst.async_end(session_id) else: async with self.engine.safe_run(session_id): generator = await self.engine.get_generator(False, session_id) async for output in generator.async_stream_infer(session_id=session_id, **inputs, **kwargs): pass response = template.decode(output.token_ids) logprobs = self._get_logprobs(output.logprobs, output.token_ids, generation_config.top_logprobs) usage_info = self._get_usage_info(len(inputs['input_ids']), output.num_token) toolcall = self._get_toolcall(response, template) finish_reason = self._get_finish_reason(generation_config.max_new_tokens, output.num_token, output.status.name == 'FINISH') choices = [ ChatCompletionResponseChoice( index=0, message=ChatMessage(role='assistant', content=response, tool_calls=toolcall), finish_reason=finish_reason, logprobs=logprobs) ] return ChatCompletionResponse(model=self.model_name, choices=choices, usage=usage_info) async def infer_async(self, infer_request: InferRequest, request_config: Optional[RequestConfig] = None, *, template: Optional[Template] = None, pre_infer_hook=None, **kwargs) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionStreamResponse]]: request_config = deepcopy(request_config or RequestConfig()) if template is None: template = self.default_template template.set_mode('lmdeploy') loop = asyncio.get_running_loop() with torch.inference_mode(): inputs = await loop.run_in_executor(None, template.encode, infer_request) images = inputs.pop('images', None) if images: if version.parse(lmdeploy.__version__) >= version.parse('0.6.5'): messages = self.engine._convert_prompts(('', images)) messages = await self.engine.async_convert_to_pil_images(messages) results = await self.engine.vl_encoder.preprocess(messages) if self.engine.backend == 'turbomind': results = await self.engine.vl_encoder.async_infer(results) inputs['images'] = [result['content'] for result in results if result['role'] == 'forward'][0] await template.prepare_lmdeploy_turbomind_inputs(inputs) else: inputs['images'] = results[1]['content'] await template.prepare_lmdeploy_pytorch_inputs(inputs) else: inputs['images'] = await self.engine.vl_encoder.async_infer(images) await template.prepare_lmdeploy_turbomind_inputs(inputs) self.set_default_max_tokens(request_config, inputs) generation_config = self._prepare_generation_config(request_config) self._add_stop_words(generation_config, request_config, template.template_meta) kwargs.update({'template': template, 'inputs': inputs, 'generation_config': generation_config}) if pre_infer_hook: kwargs = pre_infer_hook(kwargs) if request_config.stream: return self._infer_stream_async(**kwargs) else: return await self._infer_full_async(**kwargs) def _batch_infer_stream(self, *args, **kwargs): if hasattr(self.engine, 'vl_encoder'): self.engine.vl_encoder._loop_task = None if hasattr(self.engine, 'free_insts'): self.engine.free_insts = None return super()._batch_infer_stream(*args, **kwargs) def infer( self, infer_requests: List[InferRequest], request_config: Optional[RequestConfig] = None, metrics: Optional[List[Metric]] = None, *, template: Optional[Template] = None, use_tqdm: Optional[bool] = None, ) -> List[Union[ChatCompletionResponse, Iterator[ChatCompletionStreamResponse]]]: return super().infer(infer_requests, request_config, metrics, template=template, use_tqdm=use_tqdm)