Spaces:
Paused
Paused
| import json | |
| import pathlib | |
| from dataclasses import dataclass | |
| from http import HTTPStatus | |
| from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union | |
| from pydantic import Field | |
| from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast | |
| from typing_extensions import Annotated | |
| from vllm.config import ModelConfig | |
| from vllm.engine.async_llm_engine import AsyncLLMEngine | |
| from vllm.entrypoints.logger import RequestLogger | |
| # yapf conflicts with isort for this block | |
| # yapf: disable | |
| from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, | |
| CompletionRequest, | |
| DetokenizeRequest, | |
| EmbeddingRequest, ErrorResponse, | |
| ModelCard, ModelList, | |
| ModelPermission, | |
| TokenizeChatRequest, | |
| TokenizeCompletionRequest, | |
| TokenizeRequest) | |
| # yapf: enable | |
| from vllm.inputs import parse_and_batch_prompt | |
| from vllm.logger import init_logger | |
| from vllm.lora.request import LoRARequest | |
| from vllm.pooling_params import PoolingParams | |
| from vllm.prompt_adapter.request import PromptAdapterRequest | |
| from vllm.sampling_params import SamplingParams | |
| from vllm.sequence import Logprob | |
| logger = init_logger(__name__) | |
| class PromptAdapterPath: | |
| name: str | |
| local_path: str | |
| class LoRAModulePath: | |
| name: str | |
| path: str | |
| AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest, | |
| EmbeddingRequest, TokenizeRequest] | |
| AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] | |
| class TextTokensPrompt(TypedDict): | |
| prompt: str | |
| prompt_token_ids: List[int] | |
| class OpenAIServing: | |
| def __init__( | |
| self, | |
| engine: AsyncLLMEngine, | |
| model_config: ModelConfig, | |
| served_model_names: List[str], | |
| *, | |
| lora_modules: Optional[List[LoRAModulePath]], | |
| prompt_adapters: Optional[List[PromptAdapterPath]], | |
| request_logger: Optional[RequestLogger], | |
| ): | |
| super().__init__() | |
| self.engine = engine | |
| self.model_config = model_config | |
| self.max_model_len = model_config.max_model_len | |
| self.served_model_names = served_model_names | |
| self.lora_requests = [] | |
| if lora_modules is not None: | |
| self.lora_requests = [ | |
| LoRARequest( | |
| lora_name=lora.name, | |
| lora_int_id=i, | |
| lora_path=lora.path, | |
| ) for i, lora in enumerate(lora_modules, start=1) | |
| ] | |
| self.prompt_adapter_requests = [] | |
| if prompt_adapters is not None: | |
| for i, prompt_adapter in enumerate(prompt_adapters, start=1): | |
| with pathlib.Path(prompt_adapter.local_path, | |
| "adapter_config.json").open() as f: | |
| adapter_config = json.load(f) | |
| num_virtual_tokens = adapter_config["num_virtual_tokens"] | |
| self.prompt_adapter_requests.append( | |
| PromptAdapterRequest( | |
| prompt_adapter_name=prompt_adapter.name, | |
| prompt_adapter_id=i, | |
| prompt_adapter_local_path=prompt_adapter.local_path, | |
| prompt_adapter_num_virtual_tokens=num_virtual_tokens)) | |
| self.request_logger = request_logger | |
| async def show_available_models(self) -> ModelList: | |
| """Show available models. Right now we only have one model.""" | |
| model_cards = [ | |
| ModelCard(id=served_model_name, | |
| max_model_len=self.max_model_len, | |
| root=self.served_model_names[0], | |
| permission=[ModelPermission()]) | |
| for served_model_name in self.served_model_names | |
| ] | |
| lora_cards = [ | |
| ModelCard(id=lora.lora_name, | |
| root=self.served_model_names[0], | |
| permission=[ModelPermission()]) | |
| for lora in self.lora_requests | |
| ] | |
| prompt_adapter_cards = [ | |
| ModelCard(id=prompt_adapter.prompt_adapter_name, | |
| root=self.served_model_names[0], | |
| permission=[ModelPermission()]) | |
| for prompt_adapter in self.prompt_adapter_requests | |
| ] | |
| model_cards.extend(lora_cards) | |
| model_cards.extend(prompt_adapter_cards) | |
| return ModelList(data=model_cards) | |
| def create_error_response( | |
| self, | |
| message: str, | |
| err_type: str = "BadRequestError", | |
| status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse: | |
| return ErrorResponse(message=message, | |
| type=err_type, | |
| code=status_code.value) | |
| def create_streaming_error_response( | |
| self, | |
| message: str, | |
| err_type: str = "BadRequestError", | |
| status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str: | |
| json_str = json.dumps({ | |
| "error": | |
| self.create_error_response(message=message, | |
| err_type=err_type, | |
| status_code=status_code).model_dump() | |
| }) | |
| return json_str | |
| async def _check_model( | |
| self, | |
| request: AnyRequest, | |
| ) -> Optional[ErrorResponse]: | |
| if request.model in self.served_model_names: | |
| return None | |
| if request.model in [lora.lora_name for lora in self.lora_requests]: | |
| return None | |
| if request.model in [ | |
| prompt_adapter.prompt_adapter_name | |
| for prompt_adapter in self.prompt_adapter_requests | |
| ]: | |
| return None | |
| return self.create_error_response( | |
| message=f"The model `{request.model}` does not exist.", | |
| err_type="NotFoundError", | |
| status_code=HTTPStatus.NOT_FOUND) | |
| def _maybe_get_adapters( | |
| self, request: AnyRequest | |
| ) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[ | |
| None, PromptAdapterRequest]]: | |
| if request.model in self.served_model_names: | |
| return None, None | |
| for lora in self.lora_requests: | |
| if request.model == lora.lora_name: | |
| return lora, None | |
| for prompt_adapter in self.prompt_adapter_requests: | |
| if request.model == prompt_adapter.prompt_adapter_name: | |
| return None, prompt_adapter | |
| # if _check_model has been called earlier, this will be unreachable | |
| raise ValueError(f"The model `{request.model}` does not exist.") | |
| def _normalize_prompt_text_to_input( | |
| self, | |
| request: AnyRequest, | |
| tokenizer: AnyTokenizer, | |
| prompt: str, | |
| truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]], | |
| add_special_tokens: bool, | |
| ) -> TextTokensPrompt: | |
| if truncate_prompt_tokens is None: | |
| encoded = tokenizer(prompt, add_special_tokens=add_special_tokens) | |
| else: | |
| encoded = tokenizer(prompt, | |
| add_special_tokens=add_special_tokens, | |
| truncation=True, | |
| max_length=truncate_prompt_tokens) | |
| input_ids = encoded.input_ids | |
| input_text = prompt | |
| return self._validate_input(request, input_ids, input_text) | |
| def _normalize_prompt_tokens_to_input( | |
| self, | |
| request: AnyRequest, | |
| tokenizer: AnyTokenizer, | |
| prompt_ids: List[int], | |
| truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]], | |
| ) -> TextTokensPrompt: | |
| if truncate_prompt_tokens is None: | |
| input_ids = prompt_ids | |
| else: | |
| input_ids = prompt_ids[-truncate_prompt_tokens:] | |
| input_text = tokenizer.decode(input_ids) | |
| return self._validate_input(request, input_ids, input_text) | |
| def _validate_input( | |
| self, | |
| request: AnyRequest, | |
| input_ids: List[int], | |
| input_text: str, | |
| ) -> TextTokensPrompt: | |
| token_num = len(input_ids) | |
| # Note: EmbeddingRequest doesn't have max_tokens | |
| if isinstance(request, EmbeddingRequest): | |
| if token_num > self.max_model_len: | |
| raise ValueError( | |
| f"This model's maximum context length is " | |
| f"{self.max_model_len} tokens. However, you requested " | |
| f"{token_num} tokens in the input for embedding " | |
| f"generation. Please reduce the length of the input.") | |
| return TextTokensPrompt(prompt=input_text, | |
| prompt_token_ids=input_ids) | |
| # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens | |
| # and does not require model context length validation | |
| if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest, | |
| DetokenizeRequest)): | |
| return TextTokensPrompt(prompt=input_text, | |
| prompt_token_ids=input_ids) | |
| if request.max_tokens is None: | |
| if token_num >= self.max_model_len: | |
| raise ValueError( | |
| f"This model's maximum context length is " | |
| f"{self.max_model_len} tokens. However, you requested " | |
| f"{token_num} tokens in the messages, " | |
| f"Please reduce the length of the messages.") | |
| request.max_tokens = self.max_model_len - token_num | |
| if token_num + request.max_tokens > self.max_model_len: | |
| raise ValueError( | |
| f"This model's maximum context length is " | |
| f"{self.max_model_len} tokens. However, you requested " | |
| f"{request.max_tokens + token_num} tokens " | |
| f"({token_num} in the messages, " | |
| f"{request.max_tokens} in the completion). " | |
| f"Please reduce the length of the messages or completion.") | |
| return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) | |
| def _tokenize_prompt_input( | |
| self, | |
| request: AnyRequest, | |
| tokenizer: AnyTokenizer, | |
| prompt_input: Union[str, List[int]], | |
| truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, | |
| add_special_tokens: bool = True, | |
| ) -> TextTokensPrompt: | |
| """ | |
| A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs` | |
| that assumes single input. | |
| """ | |
| return next( | |
| self._tokenize_prompt_inputs( | |
| request, | |
| tokenizer, | |
| [prompt_input], | |
| truncate_prompt_tokens=truncate_prompt_tokens, | |
| add_special_tokens=add_special_tokens, | |
| )) | |
| def _tokenize_prompt_inputs( | |
| self, | |
| request: AnyRequest, | |
| tokenizer: AnyTokenizer, | |
| prompt_inputs: Iterable[Union[str, List[int]]], | |
| truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, | |
| add_special_tokens: bool = True, | |
| ) -> Iterator[TextTokensPrompt]: | |
| """ | |
| A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs` | |
| that assumes multiple inputs. | |
| """ | |
| for text in prompt_inputs: | |
| if isinstance(text, str): | |
| yield self._normalize_prompt_text_to_input( | |
| request, | |
| tokenizer, | |
| prompt=text, | |
| truncate_prompt_tokens=truncate_prompt_tokens, | |
| add_special_tokens=add_special_tokens, | |
| ) | |
| else: | |
| yield self._normalize_prompt_tokens_to_input( | |
| request, | |
| tokenizer, | |
| prompt_ids=text, | |
| truncate_prompt_tokens=truncate_prompt_tokens, | |
| ) | |
| def _tokenize_prompt_input_or_inputs( | |
| self, | |
| request: AnyRequest, | |
| tokenizer: AnyTokenizer, | |
| input_or_inputs: Union[str, List[str], List[int], List[List[int]]], | |
| truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, | |
| add_special_tokens: bool = True, | |
| ) -> Iterator[TextTokensPrompt]: | |
| """ | |
| Tokenize/detokenize depending on the input format. | |
| According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_ | |
| , each input can be a string or array of tokens. Note that each request | |
| can pass one or more inputs. | |
| """ | |
| for prompt_input in parse_and_batch_prompt(input_or_inputs): | |
| # Although our type checking is based on mypy, | |
| # VSCode Pyright extension should still work properly | |
| # "is True" is required for Pyright to perform type narrowing | |
| # See: https://github.com/microsoft/pyright/issues/7672 | |
| if prompt_input["is_tokens"] is False: | |
| yield self._normalize_prompt_text_to_input( | |
| request, | |
| tokenizer, | |
| prompt=prompt_input["content"], | |
| truncate_prompt_tokens=truncate_prompt_tokens, | |
| add_special_tokens=add_special_tokens, | |
| ) | |
| else: | |
| yield self._normalize_prompt_tokens_to_input( | |
| request, | |
| tokenizer, | |
| prompt_ids=prompt_input["content"], | |
| truncate_prompt_tokens=truncate_prompt_tokens, | |
| ) | |
| def _log_inputs( | |
| self, | |
| request_id: str, | |
| inputs: Union[str, List[int], TextTokensPrompt], | |
| params: Optional[Union[SamplingParams, PoolingParams]], | |
| lora_request: Optional[LoRARequest], | |
| prompt_adapter_request: Optional[PromptAdapterRequest], | |
| ) -> None: | |
| if self.request_logger is None: | |
| return | |
| if isinstance(inputs, str): | |
| prompt = inputs | |
| prompt_token_ids = None | |
| elif isinstance(inputs, list): | |
| prompt = None | |
| prompt_token_ids = inputs | |
| else: | |
| prompt = inputs["prompt"] | |
| prompt_token_ids = inputs["prompt_token_ids"] | |
| self.request_logger.log_inputs( | |
| request_id, | |
| prompt, | |
| prompt_token_ids, | |
| params=params, | |
| lora_request=lora_request, | |
| prompt_adapter_request=prompt_adapter_request, | |
| ) | |
| def _get_decoded_token( | |
| logprob: Logprob, | |
| token_id: int, | |
| tokenizer: AnyTokenizer, | |
| ) -> str: | |
| if logprob.decoded_token is not None: | |
| return logprob.decoded_token | |
| return tokenizer.decode(token_id) |