| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """Contains utilities used by both the sync and async inference clients.""" |
| |
|
| | import base64 |
| | import io |
| | import json |
| | import logging |
| | from contextlib import contextmanager |
| | from dataclasses import dataclass |
| | from pathlib import Path |
| | from typing import ( |
| | TYPE_CHECKING, |
| | Any, |
| | AsyncIterable, |
| | BinaryIO, |
| | ContextManager, |
| | Dict, |
| | Generator, |
| | Iterable, |
| | List, |
| | Literal, |
| | NoReturn, |
| | Optional, |
| | Union, |
| | overload, |
| | ) |
| |
|
| | from requests import HTTPError |
| |
|
| | from huggingface_hub.errors import ( |
| | GenerationError, |
| | IncompleteGenerationError, |
| | OverloadedError, |
| | TextGenerationError, |
| | UnknownError, |
| | ValidationError, |
| | ) |
| |
|
| | from ..utils import get_session, is_aiohttp_available, is_numpy_available, is_pillow_available |
| | from ._generated.types import ChatCompletionStreamOutput, TextGenerationStreamOutput |
| |
|
| |
|
| | if TYPE_CHECKING: |
| | from aiohttp import ClientResponse, ClientSession |
| | from PIL.Image import Image |
| |
|
| | |
| | UrlT = str |
| | PathT = Union[str, Path] |
| | BinaryT = Union[bytes, BinaryIO] |
| | ContentT = Union[BinaryT, PathT, UrlT] |
| |
|
| | |
| | TASKS_EXPECTING_IMAGES = {"text-to-image", "image-to-image"} |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | @dataclass |
| | class RequestParameters: |
| | url: str |
| | task: str |
| | model: Optional[str] |
| | json: Optional[Union[str, Dict, List]] |
| | data: Optional[ContentT] |
| | headers: Dict[str, Any] |
| |
|
| |
|
| | |
| | @dataclass |
| | class ModelStatus: |
| | """ |
| | This Dataclass represents the model status in the HF Inference API. |
| | |
| | Args: |
| | loaded (`bool`): |
| | If the model is currently loaded into HF's Inference API. Models |
| | are loaded on-demand, leading to the user's first request taking longer. |
| | If a model is loaded, you can be assured that it is in a healthy state. |
| | state (`str`): |
| | The current state of the model. This can be 'Loaded', 'Loadable', 'TooBig'. |
| | If a model's state is 'Loadable', it's not too big and has a supported |
| | backend. Loadable models are automatically loaded when the user first |
| | requests inference on the endpoint. This means it is transparent for the |
| | user to load a model, except that the first call takes longer to complete. |
| | compute_type (`Dict`): |
| | Information about the compute resource the model is using or will use, such as 'gpu' type and number of |
| | replicas. |
| | framework (`str`): |
| | The name of the framework that the model was built with, such as 'transformers' |
| | or 'text-generation-inference'. |
| | """ |
| |
|
| | loaded: bool |
| | state: str |
| | compute_type: Dict |
| | framework: str |
| |
|
| |
|
| | |
| |
|
| |
|
| | def _import_aiohttp(): |
| | |
| | if not is_aiohttp_available(): |
| | raise ImportError("Please install aiohttp to use `AsyncInferenceClient` (`pip install aiohttp`).") |
| | import aiohttp |
| |
|
| | return aiohttp |
| |
|
| |
|
| | def _import_numpy(): |
| | """Make sure `numpy` is installed on the machine.""" |
| | if not is_numpy_available(): |
| | raise ImportError("Please install numpy to use deal with embeddings (`pip install numpy`).") |
| | import numpy |
| |
|
| | return numpy |
| |
|
| |
|
| | def _import_pil_image(): |
| | """Make sure `PIL` is installed on the machine.""" |
| | if not is_pillow_available(): |
| | raise ImportError( |
| | "Please install Pillow to use deal with images (`pip install Pillow`). If you don't want the image to be" |
| | " post-processed, use `client.post(...)` and get the raw response from the server." |
| | ) |
| | from PIL import Image |
| |
|
| | return Image |
| |
|
| |
|
| | |
| |
|
| |
|
| | @overload |
| | def _open_as_binary( |
| | content: ContentT, |
| | ) -> ContextManager[BinaryT]: ... |
| |
|
| |
|
| | @overload |
| | def _open_as_binary( |
| | content: Literal[None], |
| | ) -> ContextManager[Literal[None]]: ... |
| |
|
| |
|
| | @contextmanager |
| | def _open_as_binary(content: Optional[ContentT]) -> Generator[Optional[BinaryT], None, None]: |
| | """Open `content` as a binary file, either from a URL, a local path, or raw bytes. |
| | |
| | Do nothing if `content` is None, |
| | |
| | TODO: handle a PIL.Image as input |
| | TODO: handle base64 as input |
| | """ |
| | |
| | if isinstance(content, str): |
| | if content.startswith("https://") or content.startswith("http://"): |
| | logger.debug(f"Downloading content from {content}") |
| | yield get_session().get(content).content |
| | return |
| | content = Path(content) |
| | if not content.exists(): |
| | raise FileNotFoundError( |
| | f"File not found at {content}. If `data` is a string, it must either be a URL or a path to a local" |
| | " file. To pass raw content, please encode it as bytes first." |
| | ) |
| |
|
| | |
| | if isinstance(content, Path): |
| | logger.debug(f"Opening content from {content}") |
| | with content.open("rb") as f: |
| | yield f |
| | else: |
| | |
| | yield content |
| |
|
| |
|
| | def _b64_encode(content: ContentT) -> str: |
| | """Encode a raw file (image, audio) into base64. Can be bytes, an opened file, a path or a URL.""" |
| | with _open_as_binary(content) as data: |
| | data_as_bytes = data if isinstance(data, bytes) else data.read() |
| | return base64.b64encode(data_as_bytes).decode() |
| |
|
| |
|
| | def _b64_to_image(encoded_image: str) -> "Image": |
| | """Parse a base64-encoded string into a PIL Image.""" |
| | Image = _import_pil_image() |
| | return Image.open(io.BytesIO(base64.b64decode(encoded_image))) |
| |
|
| |
|
| | def _bytes_to_list(content: bytes) -> List: |
| | """Parse bytes from a Response object into a Python list. |
| | |
| | Expects the response body to be JSON-encoded data. |
| | |
| | NOTE: This is exactly the same implementation as `_bytes_to_dict` and will not complain if the returned data is a |
| | dictionary. The only advantage of having both is to help the user (and mypy) understand what kind of data to expect. |
| | """ |
| | return json.loads(content.decode()) |
| |
|
| |
|
| | def _bytes_to_dict(content: bytes) -> Dict: |
| | """Parse bytes from a Response object into a Python dictionary. |
| | |
| | Expects the response body to be JSON-encoded data. |
| | |
| | NOTE: This is exactly the same implementation as `_bytes_to_list` and will not complain if the returned data is a |
| | list. The only advantage of having both is to help the user (and mypy) understand what kind of data to expect. |
| | """ |
| | return json.loads(content.decode()) |
| |
|
| |
|
| | def _bytes_to_image(content: bytes) -> "Image": |
| | """Parse bytes from a Response object into a PIL Image. |
| | |
| | Expects the response body to be raw bytes. To deal with b64 encoded images, use `_b64_to_image` instead. |
| | """ |
| | Image = _import_pil_image() |
| | return Image.open(io.BytesIO(content)) |
| |
|
| |
|
| | def _as_dict(response: Union[bytes, Dict]) -> Dict: |
| | return json.loads(response) if isinstance(response, bytes) else response |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | def _stream_text_generation_response( |
| | bytes_output_as_lines: Iterable[bytes], details: bool |
| | ) -> Union[Iterable[str], Iterable[TextGenerationStreamOutput]]: |
| | """Used in `InferenceClient.text_generation`.""" |
| | |
| | for byte_payload in bytes_output_as_lines: |
| | try: |
| | output = _format_text_generation_stream_output(byte_payload, details) |
| | except StopIteration: |
| | break |
| | if output is not None: |
| | yield output |
| |
|
| |
|
| | async def _async_stream_text_generation_response( |
| | bytes_output_as_lines: AsyncIterable[bytes], details: bool |
| | ) -> Union[AsyncIterable[str], AsyncIterable[TextGenerationStreamOutput]]: |
| | """Used in `AsyncInferenceClient.text_generation`.""" |
| | |
| | async for byte_payload in bytes_output_as_lines: |
| | try: |
| | output = _format_text_generation_stream_output(byte_payload, details) |
| | except StopIteration: |
| | break |
| | if output is not None: |
| | yield output |
| |
|
| |
|
| | def _format_text_generation_stream_output( |
| | byte_payload: bytes, details: bool |
| | ) -> Optional[Union[str, TextGenerationStreamOutput]]: |
| | if not byte_payload.startswith(b"data:"): |
| | return None |
| |
|
| | if byte_payload.strip() == b"data: [DONE]": |
| | raise StopIteration("[DONE] signal received.") |
| |
|
| | |
| | payload = byte_payload.decode("utf-8") |
| | json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) |
| |
|
| | |
| | if json_payload.get("error") is not None: |
| | raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type")) |
| |
|
| | |
| | output = TextGenerationStreamOutput.parse_obj_as_instance(json_payload) |
| | return output.token.text if not details else output |
| |
|
| |
|
| | def _stream_chat_completion_response( |
| | bytes_lines: Iterable[bytes], |
| | ) -> Iterable[ChatCompletionStreamOutput]: |
| | """Used in `InferenceClient.chat_completion` if model is served with TGI.""" |
| | for item in bytes_lines: |
| | try: |
| | output = _format_chat_completion_stream_output(item) |
| | except StopIteration: |
| | break |
| | if output is not None: |
| | yield output |
| |
|
| |
|
| | async def _async_stream_chat_completion_response( |
| | bytes_lines: AsyncIterable[bytes], |
| | ) -> AsyncIterable[ChatCompletionStreamOutput]: |
| | """Used in `AsyncInferenceClient.chat_completion`.""" |
| | async for item in bytes_lines: |
| | try: |
| | output = _format_chat_completion_stream_output(item) |
| | except StopIteration: |
| | break |
| | if output is not None: |
| | yield output |
| |
|
| |
|
| | def _format_chat_completion_stream_output( |
| | byte_payload: bytes, |
| | ) -> Optional[ChatCompletionStreamOutput]: |
| | if not byte_payload.startswith(b"data:"): |
| | return None |
| |
|
| | if byte_payload.strip() == b"data: [DONE]": |
| | raise StopIteration("[DONE] signal received.") |
| |
|
| | |
| | payload = byte_payload.decode("utf-8") |
| | json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) |
| |
|
| | |
| | if json_payload.get("error") is not None: |
| | raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type")) |
| |
|
| | |
| | return ChatCompletionStreamOutput.parse_obj_as_instance(json_payload) |
| |
|
| |
|
| | async def _async_yield_from(client: "ClientSession", response: "ClientResponse") -> AsyncIterable[bytes]: |
| | async for byte_payload in response.content: |
| | yield byte_payload.strip() |
| | await client.close() |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | _UNSUPPORTED_TEXT_GENERATION_KWARGS: Dict[Optional[str], List[str]] = {} |
| |
|
| |
|
| | def _set_unsupported_text_generation_kwargs(model: Optional[str], unsupported_kwargs: List[str]) -> None: |
| | _UNSUPPORTED_TEXT_GENERATION_KWARGS.setdefault(model, []).extend(unsupported_kwargs) |
| |
|
| |
|
| | def _get_unsupported_text_generation_kwargs(model: Optional[str]) -> List[str]: |
| | return _UNSUPPORTED_TEXT_GENERATION_KWARGS.get(model, []) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | def raise_text_generation_error(http_error: HTTPError) -> NoReturn: |
| | """ |
| | Try to parse text-generation-inference error message and raise HTTPError in any case. |
| | |
| | Args: |
| | error (`HTTPError`): |
| | The HTTPError that have been raised. |
| | """ |
| | |
| |
|
| | try: |
| | |
| | payload = getattr(http_error, "response_error_payload", None) or http_error.response.json() |
| | error = payload.get("error") |
| | error_type = payload.get("error_type") |
| | except Exception: |
| | raise http_error |
| |
|
| | |
| | if error_type is not None: |
| | exception = _parse_text_generation_error(error, error_type) |
| | raise exception from http_error |
| |
|
| | |
| | raise http_error |
| |
|
| |
|
| | def _parse_text_generation_error(error: Optional[str], error_type: Optional[str]) -> TextGenerationError: |
| | if error_type == "generation": |
| | return GenerationError(error) |
| | if error_type == "incomplete_generation": |
| | return IncompleteGenerationError(error) |
| | if error_type == "overloaded": |
| | return OverloadedError(error) |
| | if error_type == "validation": |
| | return ValidationError(error) |
| | return UnknownError(error) |
| |
|