| | from __future__ import annotations |
| |
|
| | import sys |
| | import traceback |
| | import time |
| | from re import compile, Match, Pattern |
| | from typing import Callable, Coroutine, Optional, Tuple, Union, Dict |
| | from typing_extensions import TypedDict |
| |
|
| |
|
| | from fastapi import ( |
| | Request, |
| | Response, |
| | HTTPException, |
| | ) |
| | from fastapi.responses import JSONResponse |
| | from fastapi.routing import APIRoute |
| |
|
| | from llama_cpp.server.types import ( |
| | CreateCompletionRequest, |
| | CreateEmbeddingRequest, |
| | CreateChatCompletionRequest, |
| | ) |
| |
|
| |
|
| | class ErrorResponse(TypedDict): |
| | """OpenAI style error response""" |
| |
|
| | message: str |
| | type: str |
| | param: Optional[str] |
| | code: Optional[str] |
| |
|
| |
|
| | class ErrorResponseFormatters: |
| | """Collection of formatters for error responses. |
| | |
| | Args: |
| | request (Union[CreateCompletionRequest, CreateChatCompletionRequest]): |
| | Request body |
| | match (Match[str]): Match object from regex pattern |
| | |
| | Returns: |
| | Tuple[int, ErrorResponse]: Status code and error response |
| | """ |
| |
|
| | @staticmethod |
| | def context_length_exceeded( |
| | request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"], |
| | match, |
| | ) -> Tuple[int, ErrorResponse]: |
| | """Formatter for context length exceeded error""" |
| |
|
| | context_window = int(match.group(2)) |
| | prompt_tokens = int(match.group(1)) |
| | completion_tokens = request.max_tokens |
| | if hasattr(request, "messages"): |
| | |
| | message = ( |
| | "This model's maximum context length is {} tokens. " |
| | "However, you requested {} tokens " |
| | "({} in the messages, {} in the completion). " |
| | "Please reduce the length of the messages or completion." |
| | ) |
| | else: |
| | |
| | message = ( |
| | "This model's maximum context length is {} tokens, " |
| | "however you requested {} tokens " |
| | "({} in your prompt; {} for the completion). " |
| | "Please reduce your prompt; or completion length." |
| | ) |
| | return 400, ErrorResponse( |
| | message=message.format( |
| | context_window, |
| | (completion_tokens or 0) + prompt_tokens, |
| | prompt_tokens, |
| | completion_tokens, |
| | ), |
| | type="invalid_request_error", |
| | param="messages", |
| | code="context_length_exceeded", |
| | ) |
| |
|
| | @staticmethod |
| | def model_not_found( |
| | request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"], |
| | match, |
| | ) -> Tuple[int, ErrorResponse]: |
| | """Formatter for model_not_found error""" |
| |
|
| | model_path = str(match.group(1)) |
| | message = f"The model `{model_path}` does not exist" |
| | return 400, ErrorResponse( |
| | message=message, |
| | type="invalid_request_error", |
| | param=None, |
| | code="model_not_found", |
| | ) |
| |
|
| |
|
| | class RouteErrorHandler(APIRoute): |
| | """Custom APIRoute that handles application errors and exceptions""" |
| |
|
| | |
| | |
| | pattern_and_formatters: Dict[ |
| | "Pattern[str]", |
| | Callable[ |
| | [ |
| | Union["CreateCompletionRequest", "CreateChatCompletionRequest"], |
| | "Match[str]", |
| | ], |
| | Tuple[int, ErrorResponse], |
| | ], |
| | ] = { |
| | compile( |
| | r"Requested tokens \((\d+)\) exceed context window of (\d+)" |
| | ): ErrorResponseFormatters.context_length_exceeded, |
| | compile( |
| | r"Model path does not exist: (.+)" |
| | ): ErrorResponseFormatters.model_not_found, |
| | } |
| |
|
| | def error_message_wrapper( |
| | self, |
| | error: Exception, |
| | body: Optional[ |
| | Union[ |
| | "CreateChatCompletionRequest", |
| | "CreateCompletionRequest", |
| | "CreateEmbeddingRequest", |
| | ] |
| | ] = None, |
| | ) -> Tuple[int, ErrorResponse]: |
| | """Wraps error message in OpenAI style error response""" |
| | if body is not None and isinstance( |
| | body, |
| | ( |
| | CreateCompletionRequest, |
| | CreateChatCompletionRequest, |
| | ), |
| | ): |
| | |
| | for pattern, callback in self.pattern_and_formatters.items(): |
| | match = pattern.search(str(error)) |
| | if match is not None: |
| | return callback(body, match) |
| |
|
| | |
| | print(f"Exception: {str(error)}", file=sys.stderr) |
| | traceback.print_exc(file=sys.stderr) |
| |
|
| | |
| | return 500, ErrorResponse( |
| | message=str(error), |
| | type="internal_server_error", |
| | param=None, |
| | code=None, |
| | ) |
| |
|
| | def get_route_handler( |
| | self, |
| | ) -> Callable[[Request], Coroutine[None, None, Response]]: |
| | """Defines custom route handler that catches exceptions and formats |
| | in OpenAI style error response""" |
| |
|
| | original_route_handler = super().get_route_handler() |
| |
|
| | async def custom_route_handler(request: Request) -> Response: |
| | try: |
| | start_sec = time.perf_counter() |
| | response = await original_route_handler(request) |
| | elapsed_time_ms = int((time.perf_counter() - start_sec) * 1000) |
| | response.headers["openai-processing-ms"] = f"{elapsed_time_ms}" |
| | return response |
| | except HTTPException as unauthorized: |
| | |
| | raise unauthorized |
| | except Exception as exc: |
| | json_body = await request.json() |
| | try: |
| | if "messages" in json_body: |
| | |
| | body: Optional[ |
| | Union[ |
| | CreateChatCompletionRequest, |
| | CreateCompletionRequest, |
| | CreateEmbeddingRequest, |
| | ] |
| | ] = CreateChatCompletionRequest(**json_body) |
| | elif "prompt" in json_body: |
| | |
| | body = CreateCompletionRequest(**json_body) |
| | else: |
| | |
| | body = CreateEmbeddingRequest(**json_body) |
| | except Exception: |
| | |
| | body = None |
| |
|
| | |
| | ( |
| | status_code, |
| | error_message, |
| | ) = self.error_message_wrapper(error=exc, body=body) |
| | return JSONResponse( |
| | {"error": error_message}, |
| | status_code=status_code, |
| | ) |
| |
|
| | return custom_route_handler |
| |
|