| |
| import asyncio |
| import inspect |
| import multiprocessing |
| import time |
| from contextlib import contextmanager |
| from dataclasses import asdict |
| from http import HTTPStatus |
| from threading import Thread |
| from typing import List, Optional, Union |
|
|
| import json |
| import uvicorn |
| from aiohttp import ClientConnectorError |
| from fastapi import FastAPI, Request |
| from fastapi.responses import JSONResponse, StreamingResponse |
|
|
| from swift.llm import AdapterRequest, DeployArguments |
| from swift.llm.infer.protocol import MultiModalRequestMixin |
| from swift.plugin import InferStats |
| from swift.utils import JsonlWriter, get_logger |
| from .infer import SwiftInfer |
| from .infer_engine import InferClient |
| from .protocol import ChatCompletionRequest, CompletionRequest, Model, ModelList |
|
|
| logger = get_logger() |
|
|
|
|
| class SwiftDeploy(SwiftInfer): |
| args_class = DeployArguments |
| args: args_class |
|
|
| def _register_app(self): |
| self.app.get('/v1/models')(self.get_available_models) |
| self.app.post('/v1/chat/completions')(self.create_chat_completion) |
| self.app.post('/v1/completions')(self.create_completion) |
|
|
| def __init__(self, args: Union[List[str], DeployArguments, None] = None) -> None: |
| super().__init__(args) |
|
|
| self.infer_engine.strict = True |
| self.infer_stats = InferStats() |
| self.app = FastAPI(lifespan=self.lifespan) |
| self._register_app() |
|
|
| async def _log_stats_hook(self): |
| while True: |
| await asyncio.sleep(self.args.log_interval) |
| self._compute_infer_stats() |
| self.infer_stats.reset() |
|
|
| def _compute_infer_stats(self): |
| global_stats = self.infer_stats.compute() |
| for k, v in global_stats.items(): |
| global_stats[k] = round(v, 8) |
| logger.info(global_stats) |
|
|
| def lifespan(self, app: FastAPI): |
| args = self.args |
| if args.log_interval > 0: |
| thread = Thread(target=lambda: asyncio.run(self._log_stats_hook()), daemon=True) |
| thread.start() |
| try: |
| yield |
| finally: |
| if args.log_interval > 0: |
| self._compute_infer_stats() |
|
|
| def _get_model_list(self): |
| args = self.args |
| model_list = [args.served_model_name or args.model_suffix] |
| if args.adapter_mapping: |
| model_list += [name for name in args.adapter_mapping.keys()] |
| return model_list |
|
|
| async def get_available_models(self): |
| model_list = self._get_model_list() |
| data = [Model(id=model_id, owned_by=self.args.owned_by) for model_id in model_list] |
| return ModelList(data=data) |
|
|
| async def _check_model(self, request: ChatCompletionRequest) -> Optional[str]: |
| available_models = await self.get_available_models() |
| model_list = [model.id for model in available_models.data] |
| if request.model not in model_list: |
| return f'`{request.model}` is not in the model_list: `{model_list}`.' |
|
|
| def _check_api_key(self, raw_request: Request) -> Optional[str]: |
| api_key = self.args.api_key |
| if api_key is None: |
| return |
| authorization = dict(raw_request.headers).get('authorization') |
| error_msg = 'API key error' |
| if authorization is None or not authorization.startswith('Bearer '): |
| return error_msg |
| request_api_key = authorization[7:] |
| if request_api_key != api_key: |
| return error_msg |
|
|
| def _check_max_logprobs(self, request): |
| args = self.args |
| if isinstance(request.top_logprobs, int) and request.top_logprobs > args.max_logprobs: |
| return (f'The value of top_logprobs({request.top_logprobs}) is greater than ' |
| f'the server\'s max_logprobs({args.max_logprobs}).') |
|
|
| @staticmethod |
| def create_error_response(status_code: Union[int, str, HTTPStatus], message: str) -> JSONResponse: |
| status_code = int(status_code) |
| return JSONResponse({'message': message, 'object': 'error'}, status_code) |
|
|
| def _post_process(self, request_info, response, return_cmpl_response: bool = False): |
| args = self.args |
|
|
| for i in range(len(response.choices)): |
| if not hasattr(response.choices[i], 'message') or not isinstance(response.choices[i].message.content, |
| (tuple, list)): |
| continue |
| for j, content in enumerate(response.choices[i].message.content): |
| if content['type'] == 'image': |
| b64_image = MultiModalRequestMixin.to_base64(content['image']) |
| response.choices[i].message.content[j]['image'] = f'data:image/jpg;base64,{b64_image}' |
|
|
| is_finished = all(response.choices[i].finish_reason for i in range(len(response.choices))) |
| if 'stream' in response.__class__.__name__.lower(): |
| request_info['response'] += response.choices[0].delta.content |
| else: |
| request_info['response'] = response.choices[0].message.content |
| if return_cmpl_response: |
| response = response.to_cmpl_response() |
| if is_finished: |
| if args.log_interval > 0: |
| self.infer_stats.update(response) |
| if self.jsonl_writer: |
| self.jsonl_writer.append(request_info) |
| if self.args.verbose: |
| logger.info(request_info) |
| return response |
|
|
| def _set_request_config(self, request_config) -> None: |
| default_request_config = self.args.get_request_config() |
| if default_request_config is None: |
| return |
| for key, val in asdict(request_config).items(): |
| default_val = getattr(default_request_config, key) |
| if default_val is not None and (val is None or isinstance(val, (list, tuple)) and len(val) == 0): |
| setattr(request_config, key, default_val) |
|
|
| async def create_chat_completion(self, |
| request: ChatCompletionRequest, |
| raw_request: Request, |
| *, |
| return_cmpl_response: bool = False): |
| args = self.args |
| error_msg = (await self._check_model(request) or self._check_api_key(raw_request) |
| or self._check_max_logprobs(request)) |
| if error_msg: |
| return self.create_error_response(HTTPStatus.BAD_REQUEST, error_msg) |
| infer_kwargs = self.infer_kwargs.copy() |
| adapter_path = args.adapter_mapping.get(request.model) |
| if adapter_path: |
| infer_kwargs['adapter_request'] = AdapterRequest(request.model, adapter_path) |
|
|
| infer_request, request_config = request.parse() |
| self._set_request_config(request_config) |
| request_info = {'response': '', 'infer_request': infer_request.to_printable()} |
|
|
| def pre_infer_hook(kwargs): |
| request_info['generation_config'] = kwargs['generation_config'] |
| return kwargs |
|
|
| infer_kwargs['pre_infer_hook'] = pre_infer_hook |
| try: |
| res_or_gen = await self.infer_async(infer_request, request_config, template=self.template, **infer_kwargs) |
| except Exception as e: |
| import traceback |
| logger.info(traceback.format_exc()) |
| return self.create_error_response(HTTPStatus.BAD_REQUEST, str(e)) |
| if request_config.stream: |
|
|
| async def _gen_wrapper(): |
| async for res in res_or_gen: |
| res = self._post_process(request_info, res, return_cmpl_response) |
| yield f'data: {json.dumps(asdict(res), ensure_ascii=False)}\n\n' |
| yield 'data: [DONE]\n\n' |
|
|
| return StreamingResponse(_gen_wrapper(), media_type='text/event-stream') |
| else: |
| return self._post_process(request_info, res_or_gen, return_cmpl_response) |
|
|
| async def create_completion(self, request: CompletionRequest, raw_request: Request): |
| chat_request = ChatCompletionRequest.from_cmpl_request(request) |
| return await self.create_chat_completion(chat_request, raw_request, return_cmpl_response=True) |
|
|
| def run(self): |
| args = self.args |
| self.jsonl_writer = JsonlWriter(args.result_path) if args.result_path else None |
| logger.info(f'model_list: {self._get_model_list()}') |
| uvicorn.run( |
| self.app, |
| host=args.host, |
| port=args.port, |
| ssl_keyfile=args.ssl_keyfile, |
| ssl_certfile=args.ssl_certfile, |
| log_level=args.log_level) |
|
|
|
|
| def deploy_main(args: Union[List[str], DeployArguments, None] = None) -> None: |
| SwiftDeploy(args).main() |
|
|
|
|
| def is_accessible(port: int): |
| infer_client = InferClient(port=port) |
| try: |
| infer_client.get_model_list() |
| except ClientConnectorError: |
| return False |
| return True |
|
|
|
|
| @contextmanager |
| def run_deploy(args: DeployArguments, return_url: bool = False): |
| if isinstance(args, DeployArguments) and args.__class__.__name__ == 'DeployArguments': |
| deploy_args = args |
| else: |
| args_dict = asdict(args) |
| parameters = inspect.signature(DeployArguments).parameters |
| for k in list(args_dict.keys()): |
| if k not in parameters or args_dict[k] is None: |
| args_dict.pop(k) |
| deploy_args = DeployArguments(**args_dict) |
|
|
| mp = multiprocessing.get_context('spawn') |
| process = mp.Process(target=deploy_main, args=(deploy_args, )) |
| process.start() |
| try: |
| while not is_accessible(deploy_args.port): |
| time.sleep(1) |
| yield f'http://127.0.0.1:{deploy_args.port}/v1' if return_url else deploy_args.port |
| finally: |
| process.terminate() |
| logger.info('The deployment process has been terminated.') |
|
|