Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import sys | |
| import traceback | |
| import time | |
| from re import compile, Match, Pattern | |
| from typing import Callable, Coroutine, Optional, Tuple, Union, Dict | |
| from typing_extensions import TypedDict | |
| from fastapi import ( | |
| Request, | |
| Response, | |
| HTTPException, | |
| ) | |
| from fastapi.responses import JSONResponse | |
| from fastapi.routing import APIRoute | |
| from llama_cpp.server.types import ( | |
| CreateCompletionRequest, | |
| CreateEmbeddingRequest, | |
| CreateChatCompletionRequest, | |
| ) | |
| class ErrorResponse(TypedDict): | |
| """OpenAI style error response""" | |
| message: str | |
| type: str | |
| param: Optional[str] | |
| code: Optional[str] | |
| class ErrorResponseFormatters: | |
| """Collection of formatters for error responses. | |
| Args: | |
| request (Union[CreateCompletionRequest, CreateChatCompletionRequest]): | |
| Request body | |
| match (Match[str]): Match object from regex pattern | |
| Returns: | |
| Tuple[int, ErrorResponse]: Status code and error response | |
| """ | |
| def context_length_exceeded( | |
| request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"], | |
| match, # type: Match[str] # type: ignore | |
| ) -> Tuple[int, ErrorResponse]: | |
| """Formatter for context length exceeded error""" | |
| context_window = int(match.group(2)) | |
| prompt_tokens = int(match.group(1)) | |
| completion_tokens = request.max_tokens | |
| if hasattr(request, "messages"): | |
| # Chat completion | |
| message = ( | |
| "This model's maximum context length is {} tokens. " | |
| "However, you requested {} tokens " | |
| "({} in the messages, {} in the completion). " | |
| "Please reduce the length of the messages or completion." | |
| ) | |
| else: | |
| # Text completion | |
| message = ( | |
| "This model's maximum context length is {} tokens, " | |
| "however you requested {} tokens " | |
| "({} in your prompt; {} for the completion). " | |
| "Please reduce your prompt; or completion length." | |
| ) | |
| return 400, ErrorResponse( | |
| message=message.format( | |
| context_window, | |
| (completion_tokens or 0) + prompt_tokens, | |
| prompt_tokens, | |
| completion_tokens, | |
| ), # type: ignore | |
| type="invalid_request_error", | |
| param="messages", | |
| code="context_length_exceeded", | |
| ) | |
| def model_not_found( | |
| request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"], | |
| match, # type: Match[str] # type: ignore | |
| ) -> Tuple[int, ErrorResponse]: | |
| """Formatter for model_not_found error""" | |
| model_path = str(match.group(1)) | |
| message = f"The model `{model_path}` does not exist" | |
| return 400, ErrorResponse( | |
| message=message, | |
| type="invalid_request_error", | |
| param=None, | |
| code="model_not_found", | |
| ) | |
| class RouteErrorHandler(APIRoute): | |
| """Custom APIRoute that handles application errors and exceptions""" | |
| # key: regex pattern for original error message from llama_cpp | |
| # value: formatter function | |
| pattern_and_formatters: Dict[ | |
| "Pattern[str]", | |
| Callable[ | |
| [ | |
| Union["CreateCompletionRequest", "CreateChatCompletionRequest"], | |
| "Match[str]", | |
| ], | |
| Tuple[int, ErrorResponse], | |
| ], | |
| ] = { | |
| compile( | |
| r"Requested tokens \((\d+)\) exceed context window of (\d+)" | |
| ): ErrorResponseFormatters.context_length_exceeded, | |
| compile( | |
| r"Model path does not exist: (.+)" | |
| ): ErrorResponseFormatters.model_not_found, | |
| } | |
| def error_message_wrapper( | |
| self, | |
| error: Exception, | |
| body: Optional[ | |
| Union[ | |
| "CreateChatCompletionRequest", | |
| "CreateCompletionRequest", | |
| "CreateEmbeddingRequest", | |
| ] | |
| ] = None, | |
| ) -> Tuple[int, ErrorResponse]: | |
| """Wraps error message in OpenAI style error response""" | |
| print(f"Exception: {str(error)}", file=sys.stderr) | |
| traceback.print_exc(file=sys.stderr) | |
| if body is not None and isinstance( | |
| body, | |
| ( | |
| CreateCompletionRequest, | |
| CreateChatCompletionRequest, | |
| ), | |
| ): | |
| # When text completion or chat completion | |
| for pattern, callback in self.pattern_and_formatters.items(): | |
| match = pattern.search(str(error)) | |
| if match is not None: | |
| return callback(body, match) | |
| # Wrap other errors as internal server error | |
| return 500, ErrorResponse( | |
| message=str(error), | |
| type="internal_server_error", | |
| param=None, | |
| code=None, | |
| ) | |
| def get_route_handler( | |
| self, | |
| ) -> Callable[[Request], Coroutine[None, None, Response]]: | |
| """Defines custom route handler that catches exceptions and formats | |
| in OpenAI style error response""" | |
| original_route_handler = super().get_route_handler() | |
| async def custom_route_handler(request: Request) -> Response: | |
| try: | |
| start_sec = time.perf_counter() | |
| response = await original_route_handler(request) | |
| elapsed_time_ms = int((time.perf_counter() - start_sec) * 1000) | |
| response.headers["openai-processing-ms"] = f"{elapsed_time_ms}" | |
| return response | |
| except HTTPException as unauthorized: | |
| # api key check failed | |
| raise unauthorized | |
| except Exception as exc: | |
| json_body = await request.json() | |
| try: | |
| if "messages" in json_body: | |
| # Chat completion | |
| body: Optional[ | |
| Union[ | |
| CreateChatCompletionRequest, | |
| CreateCompletionRequest, | |
| CreateEmbeddingRequest, | |
| ] | |
| ] = CreateChatCompletionRequest(**json_body) | |
| elif "prompt" in json_body: | |
| # Text completion | |
| body = CreateCompletionRequest(**json_body) | |
| else: | |
| # Embedding | |
| body = CreateEmbeddingRequest(**json_body) | |
| except Exception: | |
| # Invalid request body | |
| body = None | |
| # Get proper error message from the exception | |
| ( | |
| status_code, | |
| error_message, | |
| ) = self.error_message_wrapper(error=exc, body=body) | |
| return JSONResponse( | |
| {"error": error_message}, | |
| status_code=status_code, | |
| ) | |
| return custom_route_handler | |