Spaces:
Paused
Paused
| import argparse | |
| import asyncio | |
| import json | |
| from contextlib import asynccontextmanager | |
| from aioprometheus import MetricsMiddleware | |
| from aioprometheus.asgi.starlette import metrics | |
| import fastapi | |
| import uvicorn | |
| from http import HTTPStatus | |
| from fastapi import Request | |
| from fastapi.exceptions import RequestValidationError | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse, StreamingResponse, Response | |
| from vllm.engine.arg_utils import AsyncEngineArgs | |
| from vllm.engine.async_llm_engine import AsyncLLMEngine | |
| from vllm.engine.metrics import add_global_metrics_labels | |
| from protocol import CompletionRequest, ChatCompletionRequest, ErrorResponse | |
| from vllm.logger import init_logger | |
| from serving_chat import OpenAIServingChat | |
| from serving_completion import OpenAIServingCompletion | |
| TIMEOUT_KEEP_ALIVE = 5 # seconds | |
| openai_serving_chat: OpenAIServingChat = None | |
| openai_serving_completion: OpenAIServingCompletion = None | |
| logger = init_logger(__name__) | |
| async def lifespan(app: fastapi.FastAPI): | |
| async def _force_log(): | |
| while True: | |
| await asyncio.sleep(10) | |
| await engine.do_log_stats() | |
| if not engine_args.disable_log_stats: | |
| asyncio.create_task(_force_log()) | |
| yield | |
| app = fastapi.FastAPI(lifespan=lifespan) | |
| def parse_args(): | |
| parser = argparse.ArgumentParser( | |
| description="vLLM OpenAI-Compatible RESTful API server.") | |
| parser.add_argument("--host", type=str, default=None, help="host name") | |
| parser.add_argument("--port", type=int, default=8000, help="port number") | |
| parser.add_argument("--allow-credentials", | |
| action="store_true", | |
| help="allow credentials") | |
| parser.add_argument("--allowed-origins", | |
| type=json.loads, | |
| default=["*"], | |
| help="allowed origins") | |
| parser.add_argument("--allowed-methods", | |
| type=json.loads, | |
| default=["*"], | |
| help="allowed methods") | |
| parser.add_argument("--allowed-headers", | |
| type=json.loads, | |
| default=["*"], | |
| help="allowed headers") | |
| parser.add_argument("--served-model-name", | |
| type=str, | |
| default=None, | |
| help="The model name used in the API. If not " | |
| "specified, the model name will be the same as " | |
| "the huggingface name.") | |
| parser.add_argument("--chat-template", | |
| type=str, | |
| default=None, | |
| help="The file path to the chat template, " | |
| "or the template in single-line form " | |
| "for the specified model") | |
| parser.add_argument("--response-role", | |
| type=str, | |
| default="assistant", | |
| help="The role name to return if " | |
| "`request.add_generation_prompt=true`.") | |
| parser.add_argument("--ssl-keyfile", | |
| type=str, | |
| default=None, | |
| help="The file path to the SSL key file") | |
| parser.add_argument("--ssl-certfile", | |
| type=str, | |
| default=None, | |
| help="The file path to the SSL cert file") | |
| parser.add_argument( | |
| "--root-path", | |
| type=str, | |
| default=None, | |
| help="FastAPI root_path when app is behind a path based routing proxy") | |
| parser = AsyncEngineArgs.add_cli_args(parser) | |
| return parser.parse_args() | |
| app.add_middleware(MetricsMiddleware) # Trace HTTP server metrics | |
| app.add_route("/metrics", metrics) # Exposes HTTP metrics | |
| async def validation_exception_handler(_, exc): | |
| err = openai_serving_chat.create_error_response(message=str(exc)) | |
| return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST) | |
| async def health() -> Response: | |
| """Health check.""" | |
| return Response(status_code=200) | |
| async def show_available_models(): | |
| models = await openai_serving_chat.show_available_models() | |
| return JSONResponse(content=models.model_dump()) | |
| async def create_chat_completion(request: ChatCompletionRequest, | |
| raw_request: Request): | |
| generator = await openai_serving_chat.create_chat_completion( | |
| request, raw_request) | |
| if isinstance(generator, ErrorResponse): | |
| return JSONResponse(content=generator.model_dump(), | |
| status_code=generator.code) | |
| if request.stream: | |
| return StreamingResponse(content=generator, | |
| media_type="text/event-stream") | |
| else: | |
| return JSONResponse(content=generator.model_dump()) | |
| async def create_completion(request: CompletionRequest, raw_request: Request): | |
| generator = await openai_serving_completion.create_completion( | |
| request, raw_request) | |
| if isinstance(generator, ErrorResponse): | |
| return JSONResponse(content=generator.model_dump(), | |
| status_code=generator.code) | |
| if request.stream: | |
| return StreamingResponse(content=generator, | |
| media_type="text/event-stream") | |
| else: | |
| return JSONResponse(content=generator.model_dump()) | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=args.allowed_origins, | |
| allow_credentials=args.allow_credentials, | |
| allow_methods=args.allowed_methods, | |
| allow_headers=args.allowed_headers, | |
| ) | |
| logger.info(f"args: {args}") | |
| if args.served_model_name is not None: | |
| served_model = args.served_model_name | |
| else: | |
| served_model = args.model | |
| engine_args = AsyncEngineArgs.from_cli_args(args) | |
| engine = AsyncLLMEngine.from_engine_args(engine_args) | |
| openai_serving_chat = OpenAIServingChat(engine, served_model, | |
| args.response_role, | |
| args.chat_template) | |
| openai_serving_completion = OpenAIServingCompletion(engine, served_model) | |
| # Register labels for metrics | |
| add_global_metrics_labels(model_name=engine_args.model) | |
| app.root_path = args.root_path | |
| uvicorn.run(app, | |
| host=args.host, | |
| port=args.port, | |
| log_level="info", | |
| timeout_keep_alive=TIMEOUT_KEEP_ALIVE, | |
| ssl_keyfile=args.ssl_keyfile, | |
| ssl_certfile=args.ssl_certfile) |