Spaces:
Runtime error
Runtime error
| import json | |
| from typing import Optional, List, AsyncIterator | |
| from aiohttp import ClientSession | |
| from openai.types.chat import ChatCompletionMessageParam | |
| from pydantic import ValidationError | |
| from text_generation import AsyncClient | |
| from text_generation.errors import parse_error | |
| from text_generation.types import Request, Parameters | |
| from text_generation.types import Response, StreamResponse | |
| from api.adapter import get_prompt_adapter | |
| from api.utils.compat import model_dump | |
| class TGIEngine: | |
| def __init__( | |
| self, | |
| model: AsyncClient, | |
| model_name: str, | |
| prompt_name: Optional[str] = None, | |
| ): | |
| """ | |
| Initializes the TGIEngine object. | |
| Args: | |
| model: The AsyncLLMEngine object. | |
| model_name: The name of the model. | |
| prompt_name: The name of the prompt (optional). | |
| """ | |
| self.model = model | |
| self.model_name = model_name.lower() | |
| self.prompt_name = prompt_name.lower() if prompt_name is not None else None | |
| self.prompt_adapter = get_prompt_adapter(self.model_name, prompt_name=self.prompt_name) | |
| def apply_chat_template( | |
| self, messages: List[ChatCompletionMessageParam], | |
| ) -> str: | |
| """ | |
| Applies a chat template to the given messages and returns the processed output. | |
| Args: | |
| messages: A list of ChatCompletionMessageParam objects representing the chat messages. | |
| Returns: | |
| str: The processed output as a string. | |
| """ | |
| return self.prompt_adapter.apply_chat_template(messages) | |
| async def generate( | |
| self, | |
| prompt: str, | |
| do_sample: bool = True, | |
| max_new_tokens: int = 20, | |
| best_of: Optional[int] = None, | |
| repetition_penalty: Optional[float] = None, | |
| return_full_text: bool = False, | |
| seed: Optional[int] = None, | |
| stop_sequences: Optional[List[str]] = None, | |
| temperature: Optional[float] = None, | |
| top_k: Optional[int] = None, | |
| top_p: Optional[float] = None, | |
| truncate: Optional[int] = None, | |
| typical_p: Optional[float] = None, | |
| watermark: bool = False, | |
| decoder_input_details: bool = True, | |
| top_n_tokens: Optional[int] = None, | |
| ) -> Response: | |
| """ | |
| Given a prompt, generate the following text asynchronously | |
| Args: | |
| prompt (`str`): | |
| Input text | |
| do_sample (`bool`): | |
| Activate logits sampling | |
| max_new_tokens (`int`): | |
| Maximum number of generated tokens | |
| best_of (`int`): | |
| Generate best_of sequences and return the one if the highest token logprobs | |
| repetition_penalty (`float`): | |
| The parameter for repetition penalty. 1.0 means no penalty. See [this | |
| paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. | |
| return_full_text (`bool`): | |
| Whether to prepend the prompt to the generated text | |
| seed (`int`): | |
| Random sampling seed | |
| stop_sequences (`List[str]`): | |
| Stop generating tokens if a member of `stop_sequences` is generated | |
| temperature (`float`): | |
| The value used to module the logits distribution. | |
| top_k (`int`): | |
| The number of the highest probability vocabulary tokens to keep for top-k-filtering. | |
| top_p (`float`): | |
| If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or | |
| higher are kept for generation. | |
| truncate (`int`): | |
| Truncate inputs tokens to the given size | |
| typical_p (`float`): | |
| Typical Decoding mass | |
| See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information | |
| watermark (`bool`): | |
| Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) | |
| decoder_input_details (`bool`): | |
| Return the decoder input token logprobs and ids | |
| top_n_tokens (`int`): | |
| Return the `n` most likely tokens at each step | |
| Returns: | |
| Response: generated response | |
| """ | |
| # Validate parameters | |
| parameters = Parameters( | |
| best_of=best_of, | |
| details=True, | |
| decoder_input_details=decoder_input_details, | |
| do_sample=do_sample, | |
| max_new_tokens=max_new_tokens, | |
| repetition_penalty=repetition_penalty, | |
| return_full_text=return_full_text, | |
| seed=seed, | |
| stop=stop_sequences if stop_sequences is not None else [], | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| truncate=truncate, | |
| typical_p=typical_p, | |
| watermark=watermark, | |
| top_n_tokens=top_n_tokens, | |
| ) | |
| request = Request(inputs=prompt, stream=False, parameters=parameters) | |
| async with ClientSession( | |
| headers=self.model.headers, cookies=self.model.cookies, timeout=self.model.timeout | |
| ) as session: | |
| async with session.post(f"{self.model.base_url}/generate", json=model_dump(request)) as resp: | |
| payload = await resp.json() | |
| if resp.status != 200: | |
| raise parse_error(resp.status, payload) | |
| return Response(**payload) | |
| async def generate_stream( | |
| self, | |
| prompt: str, | |
| do_sample: bool = False, | |
| max_new_tokens: int = 20, | |
| best_of: Optional[int] = 1, | |
| repetition_penalty: Optional[float] = None, | |
| return_full_text: bool = False, | |
| seed: Optional[int] = None, | |
| stop_sequences: Optional[List[str]] = None, | |
| temperature: Optional[float] = None, | |
| top_k: Optional[int] = None, | |
| top_p: Optional[float] = None, | |
| truncate: Optional[int] = None, | |
| typical_p: Optional[float] = None, | |
| watermark: bool = False, | |
| top_n_tokens: Optional[int] = None, | |
| ) -> AsyncIterator[StreamResponse]: | |
| """ | |
| Given a prompt, generate the following stream of tokens asynchronously | |
| Args: | |
| prompt (`str`): | |
| Input text | |
| do_sample (`bool`): | |
| Activate logits sampling | |
| max_new_tokens (`int`): | |
| Maximum number of generated tokens | |
| best_of (`int`): | |
| Generate best_of sequences and return the one if the highest token logprobs | |
| repetition_penalty (`float`): | |
| The parameter for repetition penalty. 1.0 means no penalty. See [this | |
| paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. | |
| return_full_text (`bool`): | |
| Whether to prepend the prompt to the generated text | |
| seed (`int`): | |
| Random sampling seed | |
| stop_sequences (`List[str]`): | |
| Stop generating tokens if a member of `stop_sequences` is generated | |
| temperature (`float`): | |
| The value used to module the logits distribution. | |
| top_k (`int`): | |
| The number of the highest probability vocabulary tokens to keep for top-k-filtering. | |
| top_p (`float`): | |
| If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or | |
| higher are kept for generation. | |
| truncate (`int`): | |
| Truncate inputs tokens to the given size | |
| typical_p (`float`): | |
| Typical Decoding mass | |
| See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information | |
| watermark (`bool`): | |
| Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) | |
| top_n_tokens (`int`): | |
| Return the `n` most likely tokens at each step | |
| Returns: | |
| AsyncIterator: stream of generated tokens | |
| """ | |
| # Validate parameters | |
| parameters = Parameters( | |
| best_of=best_of, | |
| details=True, | |
| do_sample=do_sample, | |
| max_new_tokens=max_new_tokens, | |
| repetition_penalty=repetition_penalty, | |
| return_full_text=return_full_text, | |
| seed=seed, | |
| stop=stop_sequences if stop_sequences is not None else [], | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| truncate=truncate, | |
| typical_p=typical_p, | |
| watermark=watermark, | |
| top_n_tokens=top_n_tokens, | |
| ) | |
| request = Request(inputs=prompt, parameters=parameters) | |
| async with ClientSession( | |
| headers=self.model.headers, cookies=self.model.cookies, timeout=self.model.timeout | |
| ) as session: | |
| async with session.post(f"{self.model.base_url}/generate_stream", json=model_dump(request)) as resp: | |
| if resp.status != 200: | |
| raise parse_error(resp.status, await resp.json()) | |
| # Parse ServerSentEvents | |
| async for byte_payload in resp.content: | |
| # Skip line | |
| if byte_payload == b"\n": | |
| continue | |
| payload = byte_payload.decode("utf-8") | |
| # Event data | |
| if payload.startswith("data:"): | |
| # Decode payload | |
| json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) | |
| # Parse payload | |
| try: | |
| response = StreamResponse(**json_payload) | |
| except ValidationError: | |
| # If we failed to parse the payload, then it is an error payload | |
| raise parse_error(resp.status, json_payload) | |
| yield response | |
| def stop(self): | |
| """ | |
| Gets the stop property of the prompt adapter. | |
| Returns: | |
| The stop property of the prompt adapter, or None if it does not exist. | |
| """ | |
| return self.prompt_adapter.stop if hasattr(self.prompt_adapter, "stop") else None | |