File size: 9,823 Bytes
cb2428f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
# Copyright (c) Alibaba, Inc. and its affiliates.
import asyncio
import inspect
import multiprocessing
import time
from contextlib import contextmanager
from dataclasses import asdict
from http import HTTPStatus
from threading import Thread
from typing import List, Optional, Union

import json
import uvicorn
from aiohttp import ClientConnectorError
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse

from swift.llm import AdapterRequest, DeployArguments
from swift.llm.infer.protocol import MultiModalRequestMixin
from swift.plugin import InferStats
from swift.utils import JsonlWriter, get_logger
from .infer import SwiftInfer
from .infer_engine import InferClient
from .protocol import ChatCompletionRequest, CompletionRequest, Model, ModelList

logger = get_logger()


class SwiftDeploy(SwiftInfer):
    args_class = DeployArguments
    args: args_class

    def _register_app(self):
        self.app.get('/v1/models')(self.get_available_models)
        self.app.post('/v1/chat/completions')(self.create_chat_completion)
        self.app.post('/v1/completions')(self.create_completion)

    def __init__(self, args: Union[List[str], DeployArguments, None] = None) -> None:
        super().__init__(args)

        self.infer_engine.strict = True
        self.infer_stats = InferStats()
        self.app = FastAPI(lifespan=self.lifespan)
        self._register_app()

    async def _log_stats_hook(self):
        while True:
            await asyncio.sleep(self.args.log_interval)
            self._compute_infer_stats()
            self.infer_stats.reset()

    def _compute_infer_stats(self):
        global_stats = self.infer_stats.compute()
        for k, v in global_stats.items():
            global_stats[k] = round(v, 8)
        logger.info(global_stats)

    def lifespan(self, app: FastAPI):
        args = self.args
        if args.log_interval > 0:
            thread = Thread(target=lambda: asyncio.run(self._log_stats_hook()), daemon=True)
            thread.start()
        try:
            yield
        finally:
            if args.log_interval > 0:
                self._compute_infer_stats()

    def _get_model_list(self):
        args = self.args
        model_list = [args.served_model_name or args.model_suffix]
        if args.adapter_mapping:
            model_list += [name for name in args.adapter_mapping.keys()]
        return model_list

    async def get_available_models(self):
        model_list = self._get_model_list()
        data = [Model(id=model_id, owned_by=self.args.owned_by) for model_id in model_list]
        return ModelList(data=data)

    async def _check_model(self, request: ChatCompletionRequest) -> Optional[str]:
        available_models = await self.get_available_models()
        model_list = [model.id for model in available_models.data]
        if request.model not in model_list:
            return f'`{request.model}` is not in the model_list: `{model_list}`.'

    def _check_api_key(self, raw_request: Request) -> Optional[str]:
        api_key = self.args.api_key
        if api_key is None:
            return
        authorization = dict(raw_request.headers).get('authorization')
        error_msg = 'API key error'
        if authorization is None or not authorization.startswith('Bearer '):
            return error_msg
        request_api_key = authorization[7:]
        if request_api_key != api_key:
            return error_msg

    def _check_max_logprobs(self, request):
        args = self.args
        if isinstance(request.top_logprobs, int) and request.top_logprobs > args.max_logprobs:
            return (f'The value of top_logprobs({request.top_logprobs}) is greater than '
                    f'the server\'s max_logprobs({args.max_logprobs}).')

    @staticmethod
    def create_error_response(status_code: Union[int, str, HTTPStatus], message: str) -> JSONResponse:
        status_code = int(status_code)
        return JSONResponse({'message': message, 'object': 'error'}, status_code)

    def _post_process(self, request_info, response, return_cmpl_response: bool = False):
        args = self.args

        for i in range(len(response.choices)):
            if not hasattr(response.choices[i], 'message') or not isinstance(response.choices[i].message.content,
                                                                             (tuple, list)):
                continue
            for j, content in enumerate(response.choices[i].message.content):
                if content['type'] == 'image':
                    b64_image = MultiModalRequestMixin.to_base64(content['image'])
                    response.choices[i].message.content[j]['image'] = f'data:image/jpg;base64,{b64_image}'

        is_finished = all(response.choices[i].finish_reason for i in range(len(response.choices)))
        if 'stream' in response.__class__.__name__.lower():
            request_info['response'] += response.choices[0].delta.content
        else:
            request_info['response'] = response.choices[0].message.content
        if return_cmpl_response:
            response = response.to_cmpl_response()
        if is_finished:
            if args.log_interval > 0:
                self.infer_stats.update(response)
            if self.jsonl_writer:
                self.jsonl_writer.append(request_info)
            if self.args.verbose:
                logger.info(request_info)
        return response

    def _set_request_config(self, request_config) -> None:
        default_request_config = self.args.get_request_config()
        if default_request_config is None:
            return
        for key, val in asdict(request_config).items():
            default_val = getattr(default_request_config, key)
            if default_val is not None and (val is None or isinstance(val, (list, tuple)) and len(val) == 0):
                setattr(request_config, key, default_val)

    async def create_chat_completion(self,
                                     request: ChatCompletionRequest,
                                     raw_request: Request,
                                     *,
                                     return_cmpl_response: bool = False):
        args = self.args
        error_msg = (await self._check_model(request) or self._check_api_key(raw_request)
                     or self._check_max_logprobs(request))
        if error_msg:
            return self.create_error_response(HTTPStatus.BAD_REQUEST, error_msg)
        infer_kwargs = self.infer_kwargs.copy()
        adapter_path = args.adapter_mapping.get(request.model)
        if adapter_path:
            infer_kwargs['adapter_request'] = AdapterRequest(request.model, adapter_path)

        infer_request, request_config = request.parse()
        self._set_request_config(request_config)
        request_info = {'response': '', 'infer_request': infer_request.to_printable()}

        def pre_infer_hook(kwargs):
            request_info['generation_config'] = kwargs['generation_config']
            return kwargs

        infer_kwargs['pre_infer_hook'] = pre_infer_hook
        try:
            res_or_gen = await self.infer_async(infer_request, request_config, template=self.template, **infer_kwargs)
        except Exception as e:
            import traceback
            logger.info(traceback.format_exc())
            return self.create_error_response(HTTPStatus.BAD_REQUEST, str(e))
        if request_config.stream:

            async def _gen_wrapper():
                async for res in res_or_gen:
                    res = self._post_process(request_info, res, return_cmpl_response)
                    yield f'data: {json.dumps(asdict(res), ensure_ascii=False)}\n\n'
                yield 'data: [DONE]\n\n'

            return StreamingResponse(_gen_wrapper(), media_type='text/event-stream')
        else:
            return self._post_process(request_info, res_or_gen, return_cmpl_response)

    async def create_completion(self, request: CompletionRequest, raw_request: Request):
        chat_request = ChatCompletionRequest.from_cmpl_request(request)
        return await self.create_chat_completion(chat_request, raw_request, return_cmpl_response=True)

    def run(self):
        args = self.args
        self.jsonl_writer = JsonlWriter(args.result_path) if args.result_path else None
        logger.info(f'model_list: {self._get_model_list()}')
        uvicorn.run(
            self.app,
            host=args.host,
            port=args.port,
            ssl_keyfile=args.ssl_keyfile,
            ssl_certfile=args.ssl_certfile,
            log_level=args.log_level)


def deploy_main(args: Union[List[str], DeployArguments, None] = None) -> None:
    SwiftDeploy(args).main()


def is_accessible(port: int):
    infer_client = InferClient(port=port)
    try:
        infer_client.get_model_list()
    except ClientConnectorError:
        return False
    return True


@contextmanager
def run_deploy(args: DeployArguments, return_url: bool = False):
    if isinstance(args, DeployArguments) and args.__class__.__name__ == 'DeployArguments':
        deploy_args = args
    else:
        args_dict = asdict(args)
        parameters = inspect.signature(DeployArguments).parameters
        for k in list(args_dict.keys()):
            if k not in parameters or args_dict[k] is None:
                args_dict.pop(k)
        deploy_args = DeployArguments(**args_dict)

    mp = multiprocessing.get_context('spawn')
    process = mp.Process(target=deploy_main, args=(deploy_args, ))
    process.start()
    try:
        while not is_accessible(deploy_args.port):
            time.sleep(1)
        yield f'http://127.0.0.1:{deploy_args.port}/v1' if return_url else deploy_args.port
    finally:
        process.terminate()
        logger.info('The deployment process has been terminated.')